Skip to main content

rig/
transcription.rs

1//! This module provides functionality for working with audio transcription models.
2//! It provides traits, structs, and enums for generating audio transcription requests,
3//! handling transcription responses, and defining transcription models.
4use crate::markers::{Missing, Provided};
5use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
6use crate::{http_client, json_utils};
7use std::io;
8use std::{fs, path::Path};
9use thiserror::Error;
10
11// Errors
12#[derive(Debug, Error)]
13#[non_exhaustive]
14pub enum TranscriptionError {
15    /// Http error (e.g.: connection error, timeout, etc.)
16    #[error("HttpError: {0}")]
17    HttpError(#[from] http_client::Error),
18
19    /// Json error (e.g.: serialization, deserialization)
20    #[error("JsonError: {0}")]
21    JsonError(#[from] serde_json::Error),
22
23    #[cfg(not(target_family = "wasm"))]
24    /// Error building the transcription request
25    #[error("RequestError: {0}")]
26    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
27
28    #[cfg(target_family = "wasm")]
29    /// Error building the transcription request
30    #[error("RequestError: {0}")]
31    RequestError(#[from] Box<dyn std::error::Error + 'static>),
32
33    /// Error parsing the transcription response
34    #[error("ResponseError: {0}")]
35    ResponseError(String),
36
37    /// Error returned by the transcription model provider
38    #[error("ProviderError: {0}")]
39    ProviderError(String),
40}
41
42/// Trait defining a low-level LLM transcription interface
43pub trait Transcription<M>
44where
45    M: TranscriptionModel,
46{
47    /// Generates a transcription request builder for the given `file`.
48    /// This function is meant to be called by the user to further customize the
49    /// request at transcription time before sending it.
50    ///
51    /// ❗IMPORTANT: The type that implements this trait might have already
52    /// populated fields in the builder (the exact fields depend on the type).
53    /// For fields that have already been set by the model, calling the corresponding
54    /// method on the builder will overwrite the value set by the model.
55    fn transcription(
56        &self,
57        filename: &str,
58        data: &[u8],
59    ) -> impl std::future::Future<
60        Output = Result<TranscriptionRequestBuilder<M, Provided<Vec<u8>>>, TranscriptionError>,
61    > + WasmCompatSend;
62}
63
64/// General transcription response struct that contains the transcription text
65/// and the raw response.
66pub struct TranscriptionResponse<T> {
67    pub text: String,
68    pub response: T,
69}
70
71/// Trait defining a transcription model that can be used to generate transcription requests.
72/// This trait is meant to be implemented by the user to define a custom transcription model,
73/// either from a third-party provider (e.g: OpenAI) or a local model.
74pub trait TranscriptionModel: Clone + WasmCompatSend + WasmCompatSync {
75    /// The raw response type returned by the underlying model.
76    type Response: WasmCompatSend + WasmCompatSync;
77    type Client;
78
79    fn make(client: &Self::Client, model: impl Into<String>) -> Self;
80
81    /// Generates a completion response for the given transcription model
82    fn transcription(
83        &self,
84        request: TranscriptionRequest,
85    ) -> impl std::future::Future<
86        Output = Result<TranscriptionResponse<Self::Response>, TranscriptionError>,
87    > + WasmCompatSend;
88
89    /// Generates a transcription request builder for the given `file`
90    fn transcription_request(&self) -> TranscriptionRequestBuilder<Self, Missing> {
91        TranscriptionRequestBuilder::new(self.clone())
92    }
93}
94/// Struct representing a general transcription request that can be sent to a transcription model provider.
95pub struct TranscriptionRequest {
96    /// The file data to be sent to the transcription model provider
97    pub data: Vec<u8>,
98    /// The file name to be used in the request
99    pub filename: String,
100    /// The language used in the response from the transcription model provider
101    pub language: Option<String>,
102    /// The prompt to be sent to the transcription model provider
103    pub prompt: Option<String>,
104    /// The temperature sent to the transcription model provider
105    pub temperature: Option<f64>,
106    /// Additional parameters to be sent to the transcription model provider
107    pub additional_params: Option<serde_json::Value>,
108}
109
110/// Builder struct for a transcription request
111///
112/// Example usage:
113/// ```rust
114/// use rig::{
115///     providers::openai::{Client, self},
116///     transcription::TranscriptionRequestBuilder,
117/// };
118///
119/// let openai = Client::new("your-openai-api-key");
120/// let model = openai.transcription_model(openai::WHISPER_1).build();
121///
122/// // Create the completion request and execute it separately
123/// let request = TranscriptionRequestBuilder::new(model, "~/audio.mp3".to_string())
124///     .temperature(0.5)
125///     .build();
126///
127/// let response = model.transcription(request)
128///     .await
129///     .expect("Failed to get transcription response");
130/// ```
131///
132/// Alternatively, you can execute the transcription request directly from the builder:
133/// ```rust
134/// use rig::{
135///     providers::openai::{Client, self},
136///     transcription::TranscriptionRequestBuilder,
137/// };
138///
139/// let openai = Client::new("your-openai-api-key");
140/// let model = openai.transcription_model(openai::WHISPER_1).build();
141///
142/// // Create the completion request and execute it directly
143/// let response = TranscriptionRequestBuilder::new(model, "~/audio.mp3".to_string())
144///     .temperature(0.5)
145///     .send()
146///     .await
147///     .expect("Failed to get transcription response");
148/// ```
149///
150/// Note: It is usually unnecessary to create a completion request builder directly.
151/// Instead, use the [TranscriptionModel::transcription_request] method.
152pub struct TranscriptionRequestBuilder<M, D>
153where
154    M: TranscriptionModel,
155{
156    model: M,
157    data: D, // starts Missing, becomes Provided<Vec<u8>> after data is set or load_file is called
158    filename: Option<String>,
159    language: Option<String>,
160    prompt: Option<String>,
161    temperature: Option<f64>,
162    additional_params: Option<serde_json::Value>,
163}
164
165impl<M> TranscriptionRequestBuilder<M, Missing>
166where
167    M: TranscriptionModel,
168{
169    pub fn new(model: M) -> Self {
170        TranscriptionRequestBuilder {
171            model,
172            data: Missing,
173            filename: None,
174            language: None,
175            prompt: None,
176            temperature: None,
177            additional_params: None,
178        }
179    }
180}
181
182impl<M, D> TranscriptionRequestBuilder<M, D>
183where
184    M: TranscriptionModel,
185{
186    pub fn filename(mut self, filename: Option<String>) -> Self {
187        self.filename = filename;
188        self
189    }
190
191    /// Sets the data for the request and transitions the builder to the next state where data is provided.
192    pub fn data(self, data: Vec<u8>) -> TranscriptionRequestBuilder<M, Provided<Vec<u8>>> {
193        TranscriptionRequestBuilder {
194            model: self.model,
195            data: Provided(data),
196            filename: self.filename,
197            language: self.language,
198            prompt: self.prompt,
199            temperature: self.temperature,
200            additional_params: self.additional_params,
201        }
202    }
203
204    /// Load the specified file into data and transitions the builder to the next state where data is provided.
205    pub fn load_file<P>(
206        self,
207        path: P,
208    ) -> io::Result<TranscriptionRequestBuilder<M, Provided<Vec<u8>>>>
209    where
210        P: AsRef<Path>,
211    {
212        let path = path.as_ref();
213        let data = fs::read(path)?;
214
215        let filename = path.file_name().map(|n| n.to_string_lossy().into_owned());
216
217        Ok(TranscriptionRequestBuilder {
218            model: self.model,
219            data: Provided(data),
220            filename: filename.or(self.filename),
221            language: self.language,
222            prompt: self.prompt,
223            temperature: self.temperature,
224            additional_params: self.additional_params,
225        })
226    }
227
228    /// Sets the output language for the transcription request
229    pub fn language(mut self, language: String) -> Self {
230        self.language = Some(language);
231        self
232    }
233
234    /// Sets the prompt to be sent in the transcription request
235    pub fn prompt(mut self, prompt: String) -> Self {
236        self.prompt = Some(prompt);
237        self
238    }
239
240    /// Set the temperature to be sent in the transcription request
241    pub fn temperature(mut self, temperature: f64) -> Self {
242        self.temperature = Some(temperature);
243        self
244    }
245
246    /// Adds additional parameters to the transcription request.
247    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
248        match self.additional_params {
249            Some(params) => {
250                self.additional_params = Some(json_utils::merge(params, additional_params));
251            }
252            None => {
253                self.additional_params = Some(additional_params);
254            }
255        }
256        self
257    }
258
259    /// Sets the additional parameters for the transcription request.
260    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
261        self.additional_params = additional_params;
262        self
263    }
264}
265
266/// The build and send methods are only available when data is provided, ensuring that the request cannot be sent without the required data.
267impl<M> TranscriptionRequestBuilder<M, Provided<Vec<u8>>>
268where
269    M: TranscriptionModel,
270{
271    /// Builds the transcription request
272    /// Panics if data is empty.
273    pub fn build(self) -> TranscriptionRequest {
274        TranscriptionRequest {
275            data: self.data.0,
276            filename: self.filename.unwrap_or("file".to_string()),
277            language: self.language,
278            prompt: self.prompt,
279            temperature: self.temperature,
280            additional_params: self.additional_params,
281        }
282    }
283
284    /// Sends the transcription request to the transcription model provider and returns the transcription response
285    pub async fn send(self) -> Result<TranscriptionResponse<M::Response>, TranscriptionError> {
286        let model = self.model.clone();
287        model.transcription(self.build()).await
288    }
289}