rig/
transcription.rs

1//! This module provides functionality for working with audio transcription models.
2//! It provides traits, structs, and enums for generating audio transcription requests,
3//! handling transcription responses, and defining transcription models.
4
5use crate::client::transcription::TranscriptionModelHandle;
6use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
7use crate::{http_client, json_utils};
8use std::sync::Arc;
9use std::{fs, path::Path};
10use thiserror::Error;
11
12// Errors
13#[derive(Debug, Error)]
14#[non_exhaustive]
15pub enum TranscriptionError {
16    /// Http error (e.g.: connection error, timeout, etc.)
17    #[error("HttpError: {0}")]
18    HttpError(#[from] http_client::Error),
19
20    /// Json error (e.g.: serialization, deserialization)
21    #[error("JsonError: {0}")]
22    JsonError(#[from] serde_json::Error),
23
24    #[cfg(not(target_family = "wasm"))]
25    /// Error building the transcription request
26    #[error("RequestError: {0}")]
27    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
28
29    #[cfg(target_family = "wasm")]
30    /// Error building the transcription request
31    #[error("RequestError: {0}")]
32    RequestError(#[from] Box<dyn std::error::Error + 'static>),
33
34    /// Error parsing the transcription response
35    #[error("ResponseError: {0}")]
36    ResponseError(String),
37
38    /// Error returned by the transcription model provider
39    #[error("ProviderError: {0}")]
40    ProviderError(String),
41}
42
43/// Trait defining a low-level LLM transcription interface
44pub trait Transcription<M>
45where
46    M: TranscriptionModel,
47{
48    /// Generates a transcription request builder for the given `file`.
49    /// This function is meant to be called by the user to further customize the
50    /// request at transcription time before sending it.
51    ///
52    /// ❗IMPORTANT: The type that implements this trait might have already
53    /// populated fields in the builder (the exact fields depend on the type).
54    /// For fields that have already been set by the model, calling the corresponding
55    /// method on the builder will overwrite the value set by the model.
56    fn transcription(
57        &self,
58        filename: &str,
59        data: &[u8],
60    ) -> impl std::future::Future<
61        Output = Result<TranscriptionRequestBuilder<M>, TranscriptionError>,
62    > + WasmCompatSend;
63}
64
65/// General transcription response struct that contains the transcription text
66/// and the raw response.
67pub struct TranscriptionResponse<T> {
68    pub text: String,
69    pub response: T,
70}
71
72/// Trait defining a transcription model that can be used to generate transcription requests.
73/// This trait is meant to be implemented by the user to define a custom transcription model,
74/// either from a third-party provider (e.g: OpenAI) or a local model.
75pub trait TranscriptionModel: Clone + WasmCompatSend + WasmCompatSync {
76    /// The raw response type returned by the underlying model.
77    type Response: WasmCompatSend + WasmCompatSync;
78
79    /// Generates a completion response for the given transcription model
80    fn transcription(
81        &self,
82        request: TranscriptionRequest,
83    ) -> impl std::future::Future<
84        Output = Result<TranscriptionResponse<Self::Response>, TranscriptionError>,
85    > + WasmCompatSend;
86
87    /// Generates a transcription request builder for the given `file`
88    fn transcription_request(&self) -> TranscriptionRequestBuilder<Self> {
89        TranscriptionRequestBuilder::new(self.clone())
90    }
91}
92
93pub trait TranscriptionModelDyn: WasmCompatSend + WasmCompatSync {
94    fn transcription(
95        &self,
96        request: TranscriptionRequest,
97    ) -> WasmBoxedFuture<'_, Result<TranscriptionResponse<()>, TranscriptionError>>;
98
99    fn transcription_request(&self) -> TranscriptionRequestBuilder<TranscriptionModelHandle<'_>>;
100}
101
102impl<T> TranscriptionModelDyn for T
103where
104    T: TranscriptionModel,
105{
106    fn transcription(
107        &self,
108        request: TranscriptionRequest,
109    ) -> WasmBoxedFuture<'_, Result<TranscriptionResponse<()>, TranscriptionError>> {
110        Box::pin(async move {
111            let resp = self.transcription(request).await?;
112
113            Ok(TranscriptionResponse {
114                text: resp.text,
115                response: (),
116            })
117        })
118    }
119
120    fn transcription_request(&self) -> TranscriptionRequestBuilder<TranscriptionModelHandle<'_>> {
121        TranscriptionRequestBuilder::new(TranscriptionModelHandle {
122            inner: Arc::new(self.clone()),
123        })
124    }
125}
126
127/// Struct representing a general transcription request that can be sent to a transcription model provider.
128pub struct TranscriptionRequest {
129    /// The file data to be sent to the transcription model provider
130    pub data: Vec<u8>,
131    /// The file name to be used in the request
132    pub filename: String,
133    /// The language used in the response from the transcription model provider
134    pub language: String,
135    /// The prompt to be sent to the transcription model provider
136    pub prompt: Option<String>,
137    /// The temperature sent to the transcription model provider
138    pub temperature: Option<f64>,
139    /// Additional parameters to be sent to the transcription model provider
140    pub additional_params: Option<serde_json::Value>,
141}
142
143/// Builder struct for a transcription request
144///
145/// Example usage:
146/// ```rust
147/// use rig::{
148///     providers::openai::{Client, self},
149///     transcription::TranscriptionRequestBuilder,
150/// };
151///
152/// let openai = Client::new("your-openai-api-key");
153/// let model = openai.transcription_model(openai::WHISPER_1).build();
154///
155/// // Create the completion request and execute it separately
156/// let request = TranscriptionRequestBuilder::new(model, "~/audio.mp3".to_string())
157///     .temperature(0.5)
158///     .build();
159///
160/// let response = model.transcription(request)
161///     .await
162///     .expect("Failed to get transcription response");
163/// ```
164///
165/// Alternatively, you can execute the transcription request directly from the builder:
166/// ```rust
167/// use rig::{
168///     providers::openai::{Client, self},
169///     transcription::TranscriptionRequestBuilder,
170/// };
171///
172/// let openai = Client::new("your-openai-api-key");
173/// let model = openai.transcription_model(openai::WHISPER_1).build();
174///
175/// // Create the completion request and execute it directly
176/// let response = TranscriptionRequestBuilder::new(model, "~/audio.mp3".to_string())
177///     .temperature(0.5)
178///     .send()
179///     .await
180///     .expect("Failed to get transcription response");
181/// ```
182///
183/// Note: It is usually unnecessary to create a completion request builder directly.
184/// Instead, use the [TranscriptionModel::transcription_request] method.
185pub struct TranscriptionRequestBuilder<M>
186where
187    M: TranscriptionModel,
188{
189    model: M,
190    data: Vec<u8>,
191    filename: Option<String>,
192    language: String,
193    prompt: Option<String>,
194    temperature: Option<f64>,
195    additional_params: Option<serde_json::Value>,
196}
197
198impl<M> TranscriptionRequestBuilder<M>
199where
200    M: TranscriptionModel,
201{
202    pub fn new(model: M) -> Self {
203        TranscriptionRequestBuilder {
204            model,
205            data: vec![],
206            filename: None,
207            language: "en".to_string(),
208            prompt: None,
209            temperature: None,
210            additional_params: None,
211        }
212    }
213
214    pub fn filename(mut self, filename: Option<String>) -> Self {
215        self.filename = filename;
216        self
217    }
218
219    /// Sets the data for the request
220    pub fn data(mut self, data: Vec<u8>) -> Self {
221        self.data = data;
222        self
223    }
224
225    /// Load the specified file into data
226    pub fn load_file<P>(self, path: P) -> Self
227    where
228        P: AsRef<Path>,
229    {
230        let path = path.as_ref();
231        let data = fs::read(path).expect("Failed to load audio file, file did not exist");
232
233        self.filename(Some(
234            path.file_name()
235                .expect("Path was not a file")
236                .to_str()
237                .expect("Failed to convert filename to ascii")
238                .to_string(),
239        ))
240        .data(data)
241    }
242
243    /// Sets the output language for the transcription request
244    pub fn language(mut self, language: String) -> Self {
245        self.language = language;
246        self
247    }
248
249    /// Sets the prompt to be sent in the transcription request
250    pub fn prompt(mut self, prompt: String) -> Self {
251        self.prompt = Some(prompt);
252        self
253    }
254
255    /// Set the temperature to be sent in the transcription request
256    pub fn temperature(mut self, temperature: f64) -> Self {
257        self.temperature = Some(temperature);
258        self
259    }
260
261    /// Adds additional parameters to the transcription request.
262    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
263        match self.additional_params {
264            Some(params) => {
265                self.additional_params = Some(json_utils::merge(params, additional_params));
266            }
267            None => {
268                self.additional_params = Some(additional_params);
269            }
270        }
271        self
272    }
273
274    /// Sets the additional parameters for the transcription request.
275    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
276        self.additional_params = additional_params;
277        self
278    }
279
280    /// Builds the transcription request
281    /// Panics if data is empty.
282    pub fn build(self) -> TranscriptionRequest {
283        if self.data.is_empty() {
284            panic!("Data cannot be empty!")
285        }
286
287        TranscriptionRequest {
288            data: self.data,
289            filename: self.filename.unwrap_or("file".to_string()),
290            language: self.language,
291            prompt: self.prompt,
292            temperature: self.temperature,
293            additional_params: self.additional_params,
294        }
295    }
296
297    /// Sends the transcription request to the transcription model provider and returns the transcription response
298    pub async fn send(self) -> Result<TranscriptionResponse<M::Response>, TranscriptionError> {
299        let model = self.model.clone();
300
301        model.transcription(self.build()).await
302    }
303}