rig/transcription.rs
1//! This module provides functionality for working with audio transcription models.
2//! It provides traits, structs, and enums for generating audio transcription requests,
3//! handling transcription responses, and defining transcription models.
4use crate::markers::{Missing, Provided};
5use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
6use crate::{http_client, json_utils};
7use std::io;
8use std::{fs, path::Path};
9use thiserror::Error;
10
11// Errors
12#[derive(Debug, Error)]
13#[non_exhaustive]
14pub enum TranscriptionError {
15 /// Http error (e.g.: connection error, timeout, etc.)
16 #[error("HttpError: {0}")]
17 HttpError(#[from] http_client::Error),
18
19 /// Json error (e.g.: serialization, deserialization)
20 #[error("JsonError: {0}")]
21 JsonError(#[from] serde_json::Error),
22
23 #[cfg(not(target_family = "wasm"))]
24 /// Error building the transcription request
25 #[error("RequestError: {0}")]
26 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
27
28 #[cfg(target_family = "wasm")]
29 /// Error building the transcription request
30 #[error("RequestError: {0}")]
31 RequestError(#[from] Box<dyn std::error::Error + 'static>),
32
33 /// Error parsing the transcription response
34 #[error("ResponseError: {0}")]
35 ResponseError(String),
36
37 /// Error returned by the transcription model provider
38 #[error("ProviderError: {0}")]
39 ProviderError(String),
40}
41
42/// Trait defining a low-level LLM transcription interface
43pub trait Transcription<M>
44where
45 M: TranscriptionModel,
46{
47 /// Generates a transcription request builder for the given `file`.
48 /// This function is meant to be called by the user to further customize the
49 /// request at transcription time before sending it.
50 ///
51 /// ❗IMPORTANT: The type that implements this trait might have already
52 /// populated fields in the builder (the exact fields depend on the type).
53 /// For fields that have already been set by the model, calling the corresponding
54 /// method on the builder will overwrite the value set by the model.
55 fn transcription(
56 &self,
57 filename: &str,
58 data: &[u8],
59 ) -> impl std::future::Future<
60 Output = Result<TranscriptionRequestBuilder<M, Provided<Vec<u8>>>, TranscriptionError>,
61 > + WasmCompatSend;
62}
63
64/// General transcription response struct that contains the transcription text
65/// and the raw response.
66pub struct TranscriptionResponse<T> {
67 pub text: String,
68 pub response: T,
69}
70
71/// Trait defining a transcription model that can be used to generate transcription requests.
72/// This trait is meant to be implemented by the user to define a custom transcription model,
73/// either from a third-party provider (e.g: OpenAI) or a local model.
74pub trait TranscriptionModel: Clone + WasmCompatSend + WasmCompatSync {
75 /// The raw response type returned by the underlying model.
76 type Response: WasmCompatSend + WasmCompatSync;
77 type Client;
78
79 fn make(client: &Self::Client, model: impl Into<String>) -> Self;
80
81 /// Generates a completion response for the given transcription model
82 fn transcription(
83 &self,
84 request: TranscriptionRequest,
85 ) -> impl std::future::Future<
86 Output = Result<TranscriptionResponse<Self::Response>, TranscriptionError>,
87 > + WasmCompatSend;
88
89 /// Generates a transcription request builder for the given `file`
90 fn transcription_request(&self) -> TranscriptionRequestBuilder<Self, Missing> {
91 TranscriptionRequestBuilder::new(self.clone())
92 }
93}
94/// Struct representing a general transcription request that can be sent to a transcription model provider.
95pub struct TranscriptionRequest {
96 /// The file data to be sent to the transcription model provider
97 pub data: Vec<u8>,
98 /// The file name to be used in the request
99 pub filename: String,
100 /// The language used in the response from the transcription model provider
101 pub language: Option<String>,
102 /// The prompt to be sent to the transcription model provider
103 pub prompt: Option<String>,
104 /// The temperature sent to the transcription model provider
105 pub temperature: Option<f64>,
106 /// Additional parameters to be sent to the transcription model provider
107 pub additional_params: Option<serde_json::Value>,
108}
109
110/// Builder struct for a transcription request
111///
112/// Example usage:
113/// ```rust
114/// use rig::{
115/// providers::openai::{Client, self},
116/// transcription::TranscriptionRequestBuilder,
117/// };
118///
119/// let openai = Client::new("your-openai-api-key");
120/// let model = openai.transcription_model(openai::WHISPER_1).build();
121///
122/// // Create the completion request and execute it separately
123/// let request = TranscriptionRequestBuilder::new(model, "~/audio.mp3".to_string())
124/// .temperature(0.5)
125/// .build();
126///
127/// let response = model.transcription(request)
128/// .await
129/// .expect("Failed to get transcription response");
130/// ```
131///
132/// Alternatively, you can execute the transcription request directly from the builder:
133/// ```rust
134/// use rig::{
135/// providers::openai::{Client, self},
136/// transcription::TranscriptionRequestBuilder,
137/// };
138///
139/// let openai = Client::new("your-openai-api-key");
140/// let model = openai.transcription_model(openai::WHISPER_1).build();
141///
142/// // Create the completion request and execute it directly
143/// let response = TranscriptionRequestBuilder::new(model, "~/audio.mp3".to_string())
144/// .temperature(0.5)
145/// .send()
146/// .await
147/// .expect("Failed to get transcription response");
148/// ```
149///
150/// Note: It is usually unnecessary to create a completion request builder directly.
151/// Instead, use the [TranscriptionModel::transcription_request] method.
152pub struct TranscriptionRequestBuilder<M, D>
153where
154 M: TranscriptionModel,
155{
156 model: M,
157 data: D, // starts Missing, becomes Provided<Vec<u8>> after data is set or load_file is called
158 filename: Option<String>,
159 language: Option<String>,
160 prompt: Option<String>,
161 temperature: Option<f64>,
162 additional_params: Option<serde_json::Value>,
163}
164
165impl<M> TranscriptionRequestBuilder<M, Missing>
166where
167 M: TranscriptionModel,
168{
169 pub fn new(model: M) -> Self {
170 TranscriptionRequestBuilder {
171 model,
172 data: Missing,
173 filename: None,
174 language: None,
175 prompt: None,
176 temperature: None,
177 additional_params: None,
178 }
179 }
180}
181
182impl<M, D> TranscriptionRequestBuilder<M, D>
183where
184 M: TranscriptionModel,
185{
186 pub fn filename(mut self, filename: Option<String>) -> Self {
187 self.filename = filename;
188 self
189 }
190
191 /// Sets the data for the request and transitions the builder to the next state where data is provided.
192 pub fn data(self, data: Vec<u8>) -> TranscriptionRequestBuilder<M, Provided<Vec<u8>>> {
193 TranscriptionRequestBuilder {
194 model: self.model,
195 data: Provided(data),
196 filename: self.filename,
197 language: self.language,
198 prompt: self.prompt,
199 temperature: self.temperature,
200 additional_params: self.additional_params,
201 }
202 }
203
204 /// Load the specified file into data and transitions the builder to the next state where data is provided.
205 pub fn load_file<P>(
206 self,
207 path: P,
208 ) -> io::Result<TranscriptionRequestBuilder<M, Provided<Vec<u8>>>>
209 where
210 P: AsRef<Path>,
211 {
212 let path = path.as_ref();
213 let data = fs::read(path)?;
214
215 let filename = path.file_name().map(|n| n.to_string_lossy().into_owned());
216
217 Ok(TranscriptionRequestBuilder {
218 model: self.model,
219 data: Provided(data),
220 filename: filename.or(self.filename),
221 language: self.language,
222 prompt: self.prompt,
223 temperature: self.temperature,
224 additional_params: self.additional_params,
225 })
226 }
227
228 /// Sets the output language for the transcription request
229 pub fn language(mut self, language: String) -> Self {
230 self.language = Some(language);
231 self
232 }
233
234 /// Sets the prompt to be sent in the transcription request
235 pub fn prompt(mut self, prompt: String) -> Self {
236 self.prompt = Some(prompt);
237 self
238 }
239
240 /// Set the temperature to be sent in the transcription request
241 pub fn temperature(mut self, temperature: f64) -> Self {
242 self.temperature = Some(temperature);
243 self
244 }
245
246 /// Adds additional parameters to the transcription request.
247 pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
248 match self.additional_params {
249 Some(params) => {
250 self.additional_params = Some(json_utils::merge(params, additional_params));
251 }
252 None => {
253 self.additional_params = Some(additional_params);
254 }
255 }
256 self
257 }
258
259 /// Sets the additional parameters for the transcription request.
260 pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
261 self.additional_params = additional_params;
262 self
263 }
264}
265
266/// The build and send methods are only available when data is provided, ensuring that the request cannot be sent without the required data.
267impl<M> TranscriptionRequestBuilder<M, Provided<Vec<u8>>>
268where
269 M: TranscriptionModel,
270{
271 /// Builds the transcription request
272 /// Panics if data is empty.
273 pub fn build(self) -> TranscriptionRequest {
274 TranscriptionRequest {
275 data: self.data.0,
276 filename: self.filename.unwrap_or("file".to_string()),
277 language: self.language,
278 prompt: self.prompt,
279 temperature: self.temperature,
280 additional_params: self.additional_params,
281 }
282 }
283
284 /// Sends the transcription request to the transcription model provider and returns the transcription response
285 pub async fn send(self) -> Result<TranscriptionResponse<M::Response>, TranscriptionError> {
286 let model = self.model.clone();
287 model.transcription(self.build()).await
288 }
289}