1#[allow(deprecated)]
5use crate::client::transcription::TranscriptionModelHandle;
6use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
7use crate::{http_client, json_utils};
8use std::sync::Arc;
9use std::{fs, path::Path};
10use thiserror::Error;
11
12#[derive(Debug, Error)]
14#[non_exhaustive]
15pub enum TranscriptionError {
16 #[error("HttpError: {0}")]
18 HttpError(#[from] http_client::Error),
19
20 #[error("JsonError: {0}")]
22 JsonError(#[from] serde_json::Error),
23
24 #[cfg(not(target_family = "wasm"))]
25 #[error("RequestError: {0}")]
27 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
28
29 #[cfg(target_family = "wasm")]
30 #[error("RequestError: {0}")]
32 RequestError(#[from] Box<dyn std::error::Error + 'static>),
33
34 #[error("ResponseError: {0}")]
36 ResponseError(String),
37
38 #[error("ProviderError: {0}")]
40 ProviderError(String),
41}
42
43pub trait Transcription<M>
45where
46 M: TranscriptionModel,
47{
48 fn transcription(
57 &self,
58 filename: &str,
59 data: &[u8],
60 ) -> impl std::future::Future<
61 Output = Result<TranscriptionRequestBuilder<M>, TranscriptionError>,
62 > + WasmCompatSend;
63}
64
65pub struct TranscriptionResponse<T> {
68 pub text: String,
69 pub response: T,
70}
71
72pub trait TranscriptionModel: Clone + WasmCompatSend + WasmCompatSync {
76 type Response: WasmCompatSend + WasmCompatSync;
78 type Client;
79
80 fn make(client: &Self::Client, model: impl Into<String>) -> Self;
81
82 fn transcription(
84 &self,
85 request: TranscriptionRequest,
86 ) -> impl std::future::Future<
87 Output = Result<TranscriptionResponse<Self::Response>, TranscriptionError>,
88 > + WasmCompatSend;
89
90 fn transcription_request(&self) -> TranscriptionRequestBuilder<Self> {
92 TranscriptionRequestBuilder::new(self.clone())
93 }
94}
95
96#[allow(deprecated)]
97#[deprecated(
98 since = "0.25.0",
99 note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `TranscriptionModel` instead."
100)]
101pub trait TranscriptionModelDyn: WasmCompatSend + WasmCompatSync {
102 fn transcription(
103 &self,
104 request: TranscriptionRequest,
105 ) -> WasmBoxedFuture<'_, Result<TranscriptionResponse<()>, TranscriptionError>>;
106
107 fn transcription_request(&self) -> TranscriptionRequestBuilder<TranscriptionModelHandle<'_>>;
108}
109
110#[allow(deprecated)]
111impl<T> TranscriptionModelDyn for T
112where
113 T: TranscriptionModel,
114{
115 fn transcription(
116 &self,
117 request: TranscriptionRequest,
118 ) -> WasmBoxedFuture<'_, Result<TranscriptionResponse<()>, TranscriptionError>> {
119 Box::pin(async move {
120 let resp = self.transcription(request).await?;
121
122 Ok(TranscriptionResponse {
123 text: resp.text,
124 response: (),
125 })
126 })
127 }
128
129 fn transcription_request(&self) -> TranscriptionRequestBuilder<TranscriptionModelHandle<'_>> {
130 TranscriptionRequestBuilder::new(TranscriptionModelHandle {
131 inner: Arc::new(self.clone()),
132 })
133 }
134}
135
136pub struct TranscriptionRequest {
138 pub data: Vec<u8>,
140 pub filename: String,
142 pub language: Option<String>,
144 pub prompt: Option<String>,
146 pub temperature: Option<f64>,
148 pub additional_params: Option<serde_json::Value>,
150}
151
152pub struct TranscriptionRequestBuilder<M>
195where
196 M: TranscriptionModel,
197{
198 model: M,
199 data: Vec<u8>,
200 filename: Option<String>,
201 language: Option<String>,
202 prompt: Option<String>,
203 temperature: Option<f64>,
204 additional_params: Option<serde_json::Value>,
205}
206
207impl<M> TranscriptionRequestBuilder<M>
208where
209 M: TranscriptionModel,
210{
211 pub fn new(model: M) -> Self {
212 TranscriptionRequestBuilder {
213 model,
214 data: vec![],
215 filename: None,
216 language: None,
217 prompt: None,
218 temperature: None,
219 additional_params: None,
220 }
221 }
222
223 pub fn filename(mut self, filename: Option<String>) -> Self {
224 self.filename = filename;
225 self
226 }
227
228 pub fn data(mut self, data: Vec<u8>) -> Self {
230 self.data = data;
231 self
232 }
233
234 pub fn load_file<P>(self, path: P) -> Self
236 where
237 P: AsRef<Path>,
238 {
239 let path = path.as_ref();
240 let data = fs::read(path).expect("Failed to load audio file, file did not exist");
241
242 self.filename(Some(
243 path.file_name()
244 .expect("Path was not a file")
245 .to_str()
246 .expect("Failed to convert filename to ascii")
247 .to_string(),
248 ))
249 .data(data)
250 }
251
252 pub fn language(mut self, language: String) -> Self {
254 self.language = Some(language);
255 self
256 }
257
258 pub fn prompt(mut self, prompt: String) -> Self {
260 self.prompt = Some(prompt);
261 self
262 }
263
264 pub fn temperature(mut self, temperature: f64) -> Self {
266 self.temperature = Some(temperature);
267 self
268 }
269
270 pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
272 match self.additional_params {
273 Some(params) => {
274 self.additional_params = Some(json_utils::merge(params, additional_params));
275 }
276 None => {
277 self.additional_params = Some(additional_params);
278 }
279 }
280 self
281 }
282
283 pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
285 self.additional_params = additional_params;
286 self
287 }
288
289 pub fn build(self) -> TranscriptionRequest {
292 if self.data.is_empty() {
293 panic!("Data cannot be empty!")
294 }
295
296 TranscriptionRequest {
297 data: self.data,
298 filename: self.filename.unwrap_or("file".to_string()),
299 language: self.language,
300 prompt: self.prompt,
301 temperature: self.temperature,
302 additional_params: self.additional_params,
303 }
304 }
305
306 pub async fn send(self) -> Result<TranscriptionResponse<M::Response>, TranscriptionError> {
308 let model = self.model.clone();
309
310 model.transcription(self.build()).await
311 }
312}