rig/
transcription.rs

1//! This module provides functionality for working with audio transcription models.
2//! It provides traits, structs, and enums for generating audio transcription requests,
3//! handling transcription responses, and defining transcription models.
4#[allow(deprecated)]
5use crate::client::transcription::TranscriptionModelHandle;
6use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
7use crate::{http_client, json_utils};
8use std::sync::Arc;
9use std::{fs, path::Path};
10use thiserror::Error;
11
12// Errors
13#[derive(Debug, Error)]
14#[non_exhaustive]
15pub enum TranscriptionError {
16    /// Http error (e.g.: connection error, timeout, etc.)
17    #[error("HttpError: {0}")]
18    HttpError(#[from] http_client::Error),
19
20    /// Json error (e.g.: serialization, deserialization)
21    #[error("JsonError: {0}")]
22    JsonError(#[from] serde_json::Error),
23
24    #[cfg(not(target_family = "wasm"))]
25    /// Error building the transcription request
26    #[error("RequestError: {0}")]
27    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
28
29    #[cfg(target_family = "wasm")]
30    /// Error building the transcription request
31    #[error("RequestError: {0}")]
32    RequestError(#[from] Box<dyn std::error::Error + 'static>),
33
34    /// Error parsing the transcription response
35    #[error("ResponseError: {0}")]
36    ResponseError(String),
37
38    /// Error returned by the transcription model provider
39    #[error("ProviderError: {0}")]
40    ProviderError(String),
41}
42
43/// Trait defining a low-level LLM transcription interface
44pub trait Transcription<M>
45where
46    M: TranscriptionModel,
47{
48    /// Generates a transcription request builder for the given `file`.
49    /// This function is meant to be called by the user to further customize the
50    /// request at transcription time before sending it.
51    ///
52    /// ❗IMPORTANT: The type that implements this trait might have already
53    /// populated fields in the builder (the exact fields depend on the type).
54    /// For fields that have already been set by the model, calling the corresponding
55    /// method on the builder will overwrite the value set by the model.
56    fn transcription(
57        &self,
58        filename: &str,
59        data: &[u8],
60    ) -> impl std::future::Future<
61        Output = Result<TranscriptionRequestBuilder<M>, TranscriptionError>,
62    > + WasmCompatSend;
63}
64
65/// General transcription response struct that contains the transcription text
66/// and the raw response.
67pub struct TranscriptionResponse<T> {
68    pub text: String,
69    pub response: T,
70}
71
72/// Trait defining a transcription model that can be used to generate transcription requests.
73/// This trait is meant to be implemented by the user to define a custom transcription model,
74/// either from a third-party provider (e.g: OpenAI) or a local model.
75pub trait TranscriptionModel: Clone + WasmCompatSend + WasmCompatSync {
76    /// The raw response type returned by the underlying model.
77    type Response: WasmCompatSend + WasmCompatSync;
78    type Client;
79
80    fn make(client: &Self::Client, model: impl Into<String>) -> Self;
81
82    /// Generates a completion response for the given transcription model
83    fn transcription(
84        &self,
85        request: TranscriptionRequest,
86    ) -> impl std::future::Future<
87        Output = Result<TranscriptionResponse<Self::Response>, TranscriptionError>,
88    > + WasmCompatSend;
89
90    /// Generates a transcription request builder for the given `file`
91    fn transcription_request(&self) -> TranscriptionRequestBuilder<Self> {
92        TranscriptionRequestBuilder::new(self.clone())
93    }
94}
95
96#[allow(deprecated)]
97#[deprecated(
98    since = "0.25.0",
99    note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `TranscriptionModel` instead."
100)]
101pub trait TranscriptionModelDyn: WasmCompatSend + WasmCompatSync {
102    fn transcription(
103        &self,
104        request: TranscriptionRequest,
105    ) -> WasmBoxedFuture<'_, Result<TranscriptionResponse<()>, TranscriptionError>>;
106
107    fn transcription_request(&self) -> TranscriptionRequestBuilder<TranscriptionModelHandle<'_>>;
108}
109
110#[allow(deprecated)]
111impl<T> TranscriptionModelDyn for T
112where
113    T: TranscriptionModel,
114{
115    fn transcription(
116        &self,
117        request: TranscriptionRequest,
118    ) -> WasmBoxedFuture<'_, Result<TranscriptionResponse<()>, TranscriptionError>> {
119        Box::pin(async move {
120            let resp = self.transcription(request).await?;
121
122            Ok(TranscriptionResponse {
123                text: resp.text,
124                response: (),
125            })
126        })
127    }
128
129    fn transcription_request(&self) -> TranscriptionRequestBuilder<TranscriptionModelHandle<'_>> {
130        TranscriptionRequestBuilder::new(TranscriptionModelHandle {
131            inner: Arc::new(self.clone()),
132        })
133    }
134}
135
136/// Struct representing a general transcription request that can be sent to a transcription model provider.
137pub struct TranscriptionRequest {
138    /// The file data to be sent to the transcription model provider
139    pub data: Vec<u8>,
140    /// The file name to be used in the request
141    pub filename: String,
142    /// The language used in the response from the transcription model provider
143    pub language: Option<String>,
144    /// The prompt to be sent to the transcription model provider
145    pub prompt: Option<String>,
146    /// The temperature sent to the transcription model provider
147    pub temperature: Option<f64>,
148    /// Additional parameters to be sent to the transcription model provider
149    pub additional_params: Option<serde_json::Value>,
150}
151
152/// Builder struct for a transcription request
153///
154/// Example usage:
155/// ```rust
156/// use rig::{
157///     providers::openai::{Client, self},
158///     transcription::TranscriptionRequestBuilder,
159/// };
160///
161/// let openai = Client::new("your-openai-api-key");
162/// let model = openai.transcription_model(openai::WHISPER_1).build();
163///
164/// // Create the completion request and execute it separately
165/// let request = TranscriptionRequestBuilder::new(model, "~/audio.mp3".to_string())
166///     .temperature(0.5)
167///     .build();
168///
169/// let response = model.transcription(request)
170///     .await
171///     .expect("Failed to get transcription response");
172/// ```
173///
174/// Alternatively, you can execute the transcription request directly from the builder:
175/// ```rust
176/// use rig::{
177///     providers::openai::{Client, self},
178///     transcription::TranscriptionRequestBuilder,
179/// };
180///
181/// let openai = Client::new("your-openai-api-key");
182/// let model = openai.transcription_model(openai::WHISPER_1).build();
183///
184/// // Create the completion request and execute it directly
185/// let response = TranscriptionRequestBuilder::new(model, "~/audio.mp3".to_string())
186///     .temperature(0.5)
187///     .send()
188///     .await
189///     .expect("Failed to get transcription response");
190/// ```
191///
192/// Note: It is usually unnecessary to create a completion request builder directly.
193/// Instead, use the [TranscriptionModel::transcription_request] method.
194pub struct TranscriptionRequestBuilder<M>
195where
196    M: TranscriptionModel,
197{
198    model: M,
199    data: Vec<u8>,
200    filename: Option<String>,
201    language: Option<String>,
202    prompt: Option<String>,
203    temperature: Option<f64>,
204    additional_params: Option<serde_json::Value>,
205}
206
207impl<M> TranscriptionRequestBuilder<M>
208where
209    M: TranscriptionModel,
210{
211    pub fn new(model: M) -> Self {
212        TranscriptionRequestBuilder {
213            model,
214            data: vec![],
215            filename: None,
216            language: None,
217            prompt: None,
218            temperature: None,
219            additional_params: None,
220        }
221    }
222
223    pub fn filename(mut self, filename: Option<String>) -> Self {
224        self.filename = filename;
225        self
226    }
227
228    /// Sets the data for the request
229    pub fn data(mut self, data: Vec<u8>) -> Self {
230        self.data = data;
231        self
232    }
233
234    /// Load the specified file into data
235    pub fn load_file<P>(self, path: P) -> Self
236    where
237        P: AsRef<Path>,
238    {
239        let path = path.as_ref();
240        let data = fs::read(path).expect("Failed to load audio file, file did not exist");
241
242        self.filename(Some(
243            path.file_name()
244                .expect("Path was not a file")
245                .to_str()
246                .expect("Failed to convert filename to ascii")
247                .to_string(),
248        ))
249        .data(data)
250    }
251
252    /// Sets the output language for the transcription request
253    pub fn language(mut self, language: String) -> Self {
254        self.language = Some(language);
255        self
256    }
257
258    /// Sets the prompt to be sent in the transcription request
259    pub fn prompt(mut self, prompt: String) -> Self {
260        self.prompt = Some(prompt);
261        self
262    }
263
264    /// Set the temperature to be sent in the transcription request
265    pub fn temperature(mut self, temperature: f64) -> Self {
266        self.temperature = Some(temperature);
267        self
268    }
269
270    /// Adds additional parameters to the transcription request.
271    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
272        match self.additional_params {
273            Some(params) => {
274                self.additional_params = Some(json_utils::merge(params, additional_params));
275            }
276            None => {
277                self.additional_params = Some(additional_params);
278            }
279        }
280        self
281    }
282
283    /// Sets the additional parameters for the transcription request.
284    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
285        self.additional_params = additional_params;
286        self
287    }
288
289    /// Builds the transcription request
290    /// Panics if data is empty.
291    pub fn build(self) -> TranscriptionRequest {
292        if self.data.is_empty() {
293            panic!("Data cannot be empty!")
294        }
295
296        TranscriptionRequest {
297            data: self.data,
298            filename: self.filename.unwrap_or("file".to_string()),
299            language: self.language,
300            prompt: self.prompt,
301            temperature: self.temperature,
302            additional_params: self.additional_params,
303        }
304    }
305
306    /// Sends the transcription request to the transcription model provider and returns the transcription response
307    pub async fn send(self) -> Result<TranscriptionResponse<M::Response>, TranscriptionError> {
308        let model = self.model.clone();
309
310        model.transcription(self.build()).await
311    }
312}