Skip to main content

rig_core/
transcription.rs

1//! This module provides functionality for working with audio transcription models.
2//! It provides traits, structs, and enums for generating audio transcription requests,
3//! handling transcription responses, and defining transcription models.
4use crate::markers::{Missing, Provided};
5use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
6use crate::{http_client, json_utils};
7use std::io;
8use std::{fs, path::Path};
9use thiserror::Error;
10
11// Errors
12#[derive(Debug, Error)]
13#[non_exhaustive]
14pub enum TranscriptionError {
15    /// Http error (e.g.: connection error, timeout, etc.)
16    #[error("HttpError: {0}")]
17    HttpError(#[from] http_client::Error),
18
19    /// Json error (e.g.: serialization, deserialization)
20    #[error("JsonError: {0}")]
21    JsonError(#[from] serde_json::Error),
22
23    #[cfg(not(target_family = "wasm"))]
24    /// Error building the transcription request
25    #[error("RequestError: {0}")]
26    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
27
28    #[cfg(target_family = "wasm")]
29    /// Error building the transcription request
30    #[error("RequestError: {0}")]
31    RequestError(#[from] Box<dyn std::error::Error + 'static>),
32
33    /// Error parsing the transcription response
34    #[error("ResponseError: {0}")]
35    ResponseError(String),
36
37    /// Error returned by the transcription model provider
38    #[error("ProviderError: {0}")]
39    ProviderError(String),
40}
41
42/// Trait defining a low-level LLM transcription interface
43pub trait Transcription<M>
44where
45    M: TranscriptionModel,
46{
47    /// Generates a transcription request builder for the given `file`.
48    /// This function is meant to be called by the user to further customize the
49    /// request at transcription time before sending it.
50    ///
51    /// ❗IMPORTANT: The type that implements this trait might have already
52    /// populated fields in the builder (the exact fields depend on the type).
53    /// For fields that have already been set by the model, calling the corresponding
54    /// method on the builder will overwrite the value set by the model.
55    fn transcription(
56        &self,
57        filename: &str,
58        data: &[u8],
59    ) -> impl std::future::Future<
60        Output = Result<TranscriptionRequestBuilder<M, Provided<Vec<u8>>>, TranscriptionError>,
61    > + WasmCompatSend;
62}
63
64/// General transcription response struct that contains the transcription text
65/// and the raw response.
66pub struct TranscriptionResponse<T> {
67    pub text: String,
68    pub response: T,
69}
70
71/// Trait defining a transcription model that can be used to generate transcription requests.
72/// This trait is meant to be implemented by the user to define a custom transcription model,
73/// either from a third-party provider (e.g: OpenAI) or a local model.
74pub trait TranscriptionModel: Clone + WasmCompatSend + WasmCompatSync {
75    /// The raw response type returned by the underlying model.
76    type Response: WasmCompatSend + WasmCompatSync;
77    type Client;
78
79    fn make(client: &Self::Client, model: impl Into<String>) -> Self;
80
81    /// Generates a completion response for the given transcription model
82    fn transcription(
83        &self,
84        request: TranscriptionRequest,
85    ) -> impl std::future::Future<
86        Output = Result<TranscriptionResponse<Self::Response>, TranscriptionError>,
87    > + WasmCompatSend;
88
89    /// Generates a transcription request builder for the given `file`
90    fn transcription_request(&self) -> TranscriptionRequestBuilder<Self, Missing> {
91        TranscriptionRequestBuilder::new(self.clone())
92    }
93}
94/// Struct representing a general transcription request that can be sent to a transcription model provider.
95pub struct TranscriptionRequest {
96    /// The file data to be sent to the transcription model provider
97    pub data: Vec<u8>,
98    /// The file name to be used in the request
99    pub filename: String,
100    /// The language used in the response from the transcription model provider
101    pub language: Option<String>,
102    /// The prompt to be sent to the transcription model provider
103    pub prompt: Option<String>,
104    /// The temperature sent to the transcription model provider
105    pub temperature: Option<f64>,
106    /// Additional parameters to be sent to the transcription model provider
107    pub additional_params: Option<serde_json::Value>,
108}
109
110/// Builder struct for a transcription request
111///
112/// Example usage:
113/// ```no_run
114/// use rig_core::{
115///     prelude::TranscriptionClient,
116///     providers::openai::{Client, self},
117///     transcription::{TranscriptionModel, TranscriptionRequestBuilder},
118/// };
119///
120/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
121/// let openai = Client::new("your-openai-api-key")?;
122/// let model = openai.transcription_model(openai::WHISPER_1);
123///
124/// // Create the transcription request and execute it separately.
125/// let request = TranscriptionRequestBuilder::new(model.clone())
126///     .data(vec![0; 16])
127///     .filename(Some("audio.mp3".to_string()))
128///     .temperature(0.5)
129///     .build();
130///
131/// let response = model.transcription(request).await?;
132/// # Ok(())
133/// # }
134/// ```
135///
136/// Alternatively, you can execute the transcription request directly from the builder:
137/// ```no_run
138/// use rig_core::{
139///     prelude::TranscriptionClient,
140///     providers::openai::{Client, self},
141///     transcription::TranscriptionRequestBuilder,
142/// };
143///
144/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
145/// let openai = Client::new("your-openai-api-key")?;
146/// let model = openai.transcription_model(openai::WHISPER_1);
147///
148/// // Create the transcription request and execute it directly.
149/// let response = TranscriptionRequestBuilder::new(model)
150///     .data(vec![0; 16])
151///     .filename(Some("audio.mp3".to_string()))
152///     .temperature(0.5)
153///     .send()
154///     .await?;
155/// # Ok(())
156/// # }
157/// ```
158///
159/// Note: It is usually unnecessary to create a completion request builder directly.
160/// Instead, use the [TranscriptionModel::transcription_request] method.
161pub struct TranscriptionRequestBuilder<M, D>
162where
163    M: TranscriptionModel,
164{
165    model: M,
166    data: D, // starts Missing, becomes Provided<Vec<u8>> after data is set or load_file is called
167    filename: Option<String>,
168    language: Option<String>,
169    prompt: Option<String>,
170    temperature: Option<f64>,
171    additional_params: Option<serde_json::Value>,
172}
173
174impl<M> TranscriptionRequestBuilder<M, Missing>
175where
176    M: TranscriptionModel,
177{
178    pub fn new(model: M) -> Self {
179        TranscriptionRequestBuilder {
180            model,
181            data: Missing,
182            filename: None,
183            language: None,
184            prompt: None,
185            temperature: None,
186            additional_params: None,
187        }
188    }
189}
190
191impl<M, D> TranscriptionRequestBuilder<M, D>
192where
193    M: TranscriptionModel,
194{
195    pub fn filename(mut self, filename: Option<String>) -> Self {
196        self.filename = filename;
197        self
198    }
199
200    /// Sets the data for the request and transitions the builder to the next state where data is provided.
201    pub fn data(self, data: Vec<u8>) -> TranscriptionRequestBuilder<M, Provided<Vec<u8>>> {
202        TranscriptionRequestBuilder {
203            model: self.model,
204            data: Provided(data),
205            filename: self.filename,
206            language: self.language,
207            prompt: self.prompt,
208            temperature: self.temperature,
209            additional_params: self.additional_params,
210        }
211    }
212
213    /// Load the specified file into data and transitions the builder to the next state where data is provided.
214    pub fn load_file<P>(
215        self,
216        path: P,
217    ) -> io::Result<TranscriptionRequestBuilder<M, Provided<Vec<u8>>>>
218    where
219        P: AsRef<Path>,
220    {
221        let path = path.as_ref();
222        let data = fs::read(path)?;
223
224        let filename = path.file_name().map(|n| n.to_string_lossy().into_owned());
225
226        Ok(TranscriptionRequestBuilder {
227            model: self.model,
228            data: Provided(data),
229            filename: filename.or(self.filename),
230            language: self.language,
231            prompt: self.prompt,
232            temperature: self.temperature,
233            additional_params: self.additional_params,
234        })
235    }
236
237    /// Sets the output language for the transcription request
238    pub fn language(mut self, language: String) -> Self {
239        self.language = Some(language);
240        self
241    }
242
243    /// Sets the prompt to be sent in the transcription request
244    pub fn prompt(mut self, prompt: String) -> Self {
245        self.prompt = Some(prompt);
246        self
247    }
248
249    /// Set the temperature to be sent in the transcription request
250    pub fn temperature(mut self, temperature: f64) -> Self {
251        self.temperature = Some(temperature);
252        self
253    }
254
255    /// Adds additional parameters to the transcription request.
256    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
257        match self.additional_params {
258            Some(params) => {
259                self.additional_params = Some(json_utils::merge(params, additional_params));
260            }
261            None => {
262                self.additional_params = Some(additional_params);
263            }
264        }
265        self
266    }
267
268    /// Sets the additional parameters for the transcription request.
269    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
270        self.additional_params = additional_params;
271        self
272    }
273}
274
275/// The build and send methods are only available when data is provided, ensuring that the request cannot be sent without the required data.
276impl<M> TranscriptionRequestBuilder<M, Provided<Vec<u8>>>
277where
278    M: TranscriptionModel,
279{
280    /// Builds the transcription request
281    /// Panics if data is empty.
282    pub fn build(self) -> TranscriptionRequest {
283        TranscriptionRequest {
284            data: self.data.0,
285            filename: self.filename.unwrap_or("file".to_string()),
286            language: self.language,
287            prompt: self.prompt,
288            temperature: self.temperature,
289            additional_params: self.additional_params,
290        }
291    }
292
293    /// Sends the transcription request to the transcription model provider and returns the transcription response
294    pub async fn send(self) -> Result<TranscriptionResponse<M::Response>, TranscriptionError> {
295        let model = self.model.clone();
296        model.transcription(self.build()).await
297    }
298}