rig_core/transcription.rs
1//! This module provides functionality for working with audio transcription models.
2//! It provides traits, structs, and enums for generating audio transcription requests,
3//! handling transcription responses, and defining transcription models.
4use crate::markers::{Missing, Provided};
5use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
6use crate::{http_client, json_utils};
7use std::io;
8use std::{fs, path::Path};
9use thiserror::Error;
10
11// Errors
12#[derive(Debug, Error)]
13#[non_exhaustive]
14pub enum TranscriptionError {
15 /// Http error (e.g.: connection error, timeout, etc.)
16 #[error("HttpError: {0}")]
17 HttpError(#[from] http_client::Error),
18
19 /// Json error (e.g.: serialization, deserialization)
20 #[error("JsonError: {0}")]
21 JsonError(#[from] serde_json::Error),
22
23 #[cfg(not(target_family = "wasm"))]
24 /// Error building the transcription request
25 #[error("RequestError: {0}")]
26 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
27
28 #[cfg(target_family = "wasm")]
29 /// Error building the transcription request
30 #[error("RequestError: {0}")]
31 RequestError(#[from] Box<dyn std::error::Error + 'static>),
32
33 /// Error parsing the transcription response
34 #[error("ResponseError: {0}")]
35 ResponseError(String),
36
37 /// Error returned by the transcription model provider
38 #[error("ProviderError: {0}")]
39 ProviderError(String),
40}
41
42/// Trait defining a low-level LLM transcription interface
43pub trait Transcription<M>
44where
45 M: TranscriptionModel,
46{
47 /// Generates a transcription request builder for the given `file`.
48 /// This function is meant to be called by the user to further customize the
49 /// request at transcription time before sending it.
50 ///
51 /// ❗IMPORTANT: The type that implements this trait might have already
52 /// populated fields in the builder (the exact fields depend on the type).
53 /// For fields that have already been set by the model, calling the corresponding
54 /// method on the builder will overwrite the value set by the model.
55 fn transcription(
56 &self,
57 filename: &str,
58 data: &[u8],
59 ) -> impl std::future::Future<
60 Output = Result<TranscriptionRequestBuilder<M, Provided<Vec<u8>>>, TranscriptionError>,
61 > + WasmCompatSend;
62}
63
64/// General transcription response struct that contains the transcription text
65/// and the raw response.
66pub struct TranscriptionResponse<T> {
67 pub text: String,
68 pub response: T,
69}
70
71/// Trait defining a transcription model that can be used to generate transcription requests.
72/// This trait is meant to be implemented by the user to define a custom transcription model,
73/// either from a third-party provider (e.g: OpenAI) or a local model.
74pub trait TranscriptionModel: Clone + WasmCompatSend + WasmCompatSync {
75 /// The raw response type returned by the underlying model.
76 type Response: WasmCompatSend + WasmCompatSync;
77 type Client;
78
79 fn make(client: &Self::Client, model: impl Into<String>) -> Self;
80
81 /// Generates a completion response for the given transcription model
82 fn transcription(
83 &self,
84 request: TranscriptionRequest,
85 ) -> impl std::future::Future<
86 Output = Result<TranscriptionResponse<Self::Response>, TranscriptionError>,
87 > + WasmCompatSend;
88
89 /// Generates a transcription request builder for the given `file`
90 fn transcription_request(&self) -> TranscriptionRequestBuilder<Self, Missing> {
91 TranscriptionRequestBuilder::new(self.clone())
92 }
93}
94/// Struct representing a general transcription request that can be sent to a transcription model provider.
95pub struct TranscriptionRequest {
96 /// The file data to be sent to the transcription model provider
97 pub data: Vec<u8>,
98 /// The file name to be used in the request
99 pub filename: String,
100 /// The language used in the response from the transcription model provider
101 pub language: Option<String>,
102 /// The prompt to be sent to the transcription model provider
103 pub prompt: Option<String>,
104 /// The temperature sent to the transcription model provider
105 pub temperature: Option<f64>,
106 /// Additional parameters to be sent to the transcription model provider
107 pub additional_params: Option<serde_json::Value>,
108}
109
110/// Builder struct for a transcription request
111///
112/// Example usage:
113/// ```no_run
114/// use rig_core::{
115/// prelude::TranscriptionClient,
116/// providers::openai::{Client, self},
117/// transcription::{TranscriptionModel, TranscriptionRequestBuilder},
118/// };
119///
120/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
121/// let openai = Client::new("your-openai-api-key")?;
122/// let model = openai.transcription_model(openai::WHISPER_1);
123///
124/// // Create the transcription request and execute it separately.
125/// let request = TranscriptionRequestBuilder::new(model.clone())
126/// .data(vec![0; 16])
127/// .filename(Some("audio.mp3".to_string()))
128/// .temperature(0.5)
129/// .build();
130///
131/// let response = model.transcription(request).await?;
132/// # Ok(())
133/// # }
134/// ```
135///
136/// Alternatively, you can execute the transcription request directly from the builder:
137/// ```no_run
138/// use rig_core::{
139/// prelude::TranscriptionClient,
140/// providers::openai::{Client, self},
141/// transcription::TranscriptionRequestBuilder,
142/// };
143///
144/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
145/// let openai = Client::new("your-openai-api-key")?;
146/// let model = openai.transcription_model(openai::WHISPER_1);
147///
148/// // Create the transcription request and execute it directly.
149/// let response = TranscriptionRequestBuilder::new(model)
150/// .data(vec![0; 16])
151/// .filename(Some("audio.mp3".to_string()))
152/// .temperature(0.5)
153/// .send()
154/// .await?;
155/// # Ok(())
156/// # }
157/// ```
158///
159/// Note: It is usually unnecessary to create a completion request builder directly.
160/// Instead, use the [TranscriptionModel::transcription_request] method.
161pub struct TranscriptionRequestBuilder<M, D>
162where
163 M: TranscriptionModel,
164{
165 model: M,
166 data: D, // starts Missing, becomes Provided<Vec<u8>> after data is set or load_file is called
167 filename: Option<String>,
168 language: Option<String>,
169 prompt: Option<String>,
170 temperature: Option<f64>,
171 additional_params: Option<serde_json::Value>,
172}
173
174impl<M> TranscriptionRequestBuilder<M, Missing>
175where
176 M: TranscriptionModel,
177{
178 pub fn new(model: M) -> Self {
179 TranscriptionRequestBuilder {
180 model,
181 data: Missing,
182 filename: None,
183 language: None,
184 prompt: None,
185 temperature: None,
186 additional_params: None,
187 }
188 }
189}
190
191impl<M, D> TranscriptionRequestBuilder<M, D>
192where
193 M: TranscriptionModel,
194{
195 pub fn filename(mut self, filename: Option<String>) -> Self {
196 self.filename = filename;
197 self
198 }
199
200 /// Sets the data for the request and transitions the builder to the next state where data is provided.
201 pub fn data(self, data: Vec<u8>) -> TranscriptionRequestBuilder<M, Provided<Vec<u8>>> {
202 TranscriptionRequestBuilder {
203 model: self.model,
204 data: Provided(data),
205 filename: self.filename,
206 language: self.language,
207 prompt: self.prompt,
208 temperature: self.temperature,
209 additional_params: self.additional_params,
210 }
211 }
212
213 /// Load the specified file into data and transitions the builder to the next state where data is provided.
214 pub fn load_file<P>(
215 self,
216 path: P,
217 ) -> io::Result<TranscriptionRequestBuilder<M, Provided<Vec<u8>>>>
218 where
219 P: AsRef<Path>,
220 {
221 let path = path.as_ref();
222 let data = fs::read(path)?;
223
224 let filename = path.file_name().map(|n| n.to_string_lossy().into_owned());
225
226 Ok(TranscriptionRequestBuilder {
227 model: self.model,
228 data: Provided(data),
229 filename: filename.or(self.filename),
230 language: self.language,
231 prompt: self.prompt,
232 temperature: self.temperature,
233 additional_params: self.additional_params,
234 })
235 }
236
237 /// Sets the output language for the transcription request
238 pub fn language(mut self, language: String) -> Self {
239 self.language = Some(language);
240 self
241 }
242
243 /// Sets the prompt to be sent in the transcription request
244 pub fn prompt(mut self, prompt: String) -> Self {
245 self.prompt = Some(prompt);
246 self
247 }
248
249 /// Set the temperature to be sent in the transcription request
250 pub fn temperature(mut self, temperature: f64) -> Self {
251 self.temperature = Some(temperature);
252 self
253 }
254
255 /// Adds additional parameters to the transcription request.
256 pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
257 match self.additional_params {
258 Some(params) => {
259 self.additional_params = Some(json_utils::merge(params, additional_params));
260 }
261 None => {
262 self.additional_params = Some(additional_params);
263 }
264 }
265 self
266 }
267
268 /// Sets the additional parameters for the transcription request.
269 pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
270 self.additional_params = additional_params;
271 self
272 }
273}
274
275/// The build and send methods are only available when data is provided, ensuring that the request cannot be sent without the required data.
276impl<M> TranscriptionRequestBuilder<M, Provided<Vec<u8>>>
277where
278 M: TranscriptionModel,
279{
280 /// Builds the transcription request
281 /// Panics if data is empty.
282 pub fn build(self) -> TranscriptionRequest {
283 TranscriptionRequest {
284 data: self.data.0,
285 filename: self.filename.unwrap_or("file".to_string()),
286 language: self.language,
287 prompt: self.prompt,
288 temperature: self.temperature,
289 additional_params: self.additional_params,
290 }
291 }
292
293 /// Sends the transcription request to the transcription model provider and returns the transcription response
294 pub async fn send(self) -> Result<TranscriptionResponse<M::Response>, TranscriptionError> {
295 let model = self.model.clone();
296 model.transcription(self.build()).await
297 }
298}