1use crate::client::transcription::TranscriptionModelHandle;
6use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
7use crate::{http_client, json_utils};
8use std::sync::Arc;
9use std::{fs, path::Path};
10use thiserror::Error;
11
12#[derive(Debug, Error)]
14#[non_exhaustive]
15pub enum TranscriptionError {
16 #[error("HttpError: {0}")]
18 HttpError(#[from] http_client::Error),
19
20 #[error("JsonError: {0}")]
22 JsonError(#[from] serde_json::Error),
23
24 #[cfg(not(target_family = "wasm"))]
25 #[error("RequestError: {0}")]
27 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
28
29 #[cfg(target_family = "wasm")]
30 #[error("RequestError: {0}")]
32 RequestError(#[from] Box<dyn std::error::Error + 'static>),
33
34 #[error("ResponseError: {0}")]
36 ResponseError(String),
37
38 #[error("ProviderError: {0}")]
40 ProviderError(String),
41}
42
43pub trait Transcription<M>
45where
46 M: TranscriptionModel,
47{
48 fn transcription(
57 &self,
58 filename: &str,
59 data: &[u8],
60 ) -> impl std::future::Future<
61 Output = Result<TranscriptionRequestBuilder<M>, TranscriptionError>,
62 > + WasmCompatSend;
63}
64
65pub struct TranscriptionResponse<T> {
68 pub text: String,
69 pub response: T,
70}
71
72pub trait TranscriptionModel: Clone + WasmCompatSend + WasmCompatSync {
76 type Response: WasmCompatSend + WasmCompatSync;
78
79 fn transcription(
81 &self,
82 request: TranscriptionRequest,
83 ) -> impl std::future::Future<
84 Output = Result<TranscriptionResponse<Self::Response>, TranscriptionError>,
85 > + WasmCompatSend;
86
87 fn transcription_request(&self) -> TranscriptionRequestBuilder<Self> {
89 TranscriptionRequestBuilder::new(self.clone())
90 }
91}
92
93pub trait TranscriptionModelDyn: WasmCompatSend + WasmCompatSync {
94 fn transcription(
95 &self,
96 request: TranscriptionRequest,
97 ) -> WasmBoxedFuture<'_, Result<TranscriptionResponse<()>, TranscriptionError>>;
98
99 fn transcription_request(&self) -> TranscriptionRequestBuilder<TranscriptionModelHandle<'_>>;
100}
101
102impl<T> TranscriptionModelDyn for T
103where
104 T: TranscriptionModel,
105{
106 fn transcription(
107 &self,
108 request: TranscriptionRequest,
109 ) -> WasmBoxedFuture<'_, Result<TranscriptionResponse<()>, TranscriptionError>> {
110 Box::pin(async move {
111 let resp = self.transcription(request).await?;
112
113 Ok(TranscriptionResponse {
114 text: resp.text,
115 response: (),
116 })
117 })
118 }
119
120 fn transcription_request(&self) -> TranscriptionRequestBuilder<TranscriptionModelHandle<'_>> {
121 TranscriptionRequestBuilder::new(TranscriptionModelHandle {
122 inner: Arc::new(self.clone()),
123 })
124 }
125}
126
127pub struct TranscriptionRequest {
129 pub data: Vec<u8>,
131 pub filename: String,
133 pub language: String,
135 pub prompt: Option<String>,
137 pub temperature: Option<f64>,
139 pub additional_params: Option<serde_json::Value>,
141}
142
143pub struct TranscriptionRequestBuilder<M>
186where
187 M: TranscriptionModel,
188{
189 model: M,
190 data: Vec<u8>,
191 filename: Option<String>,
192 language: String,
193 prompt: Option<String>,
194 temperature: Option<f64>,
195 additional_params: Option<serde_json::Value>,
196}
197
198impl<M> TranscriptionRequestBuilder<M>
199where
200 M: TranscriptionModel,
201{
202 pub fn new(model: M) -> Self {
203 TranscriptionRequestBuilder {
204 model,
205 data: vec![],
206 filename: None,
207 language: "en".to_string(),
208 prompt: None,
209 temperature: None,
210 additional_params: None,
211 }
212 }
213
214 pub fn filename(mut self, filename: Option<String>) -> Self {
215 self.filename = filename;
216 self
217 }
218
219 pub fn data(mut self, data: Vec<u8>) -> Self {
221 self.data = data;
222 self
223 }
224
225 pub fn load_file<P>(self, path: P) -> Self
227 where
228 P: AsRef<Path>,
229 {
230 let path = path.as_ref();
231 let data = fs::read(path).expect("Failed to load audio file, file did not exist");
232
233 self.filename(Some(
234 path.file_name()
235 .expect("Path was not a file")
236 .to_str()
237 .expect("Failed to convert filename to ascii")
238 .to_string(),
239 ))
240 .data(data)
241 }
242
243 pub fn language(mut self, language: String) -> Self {
245 self.language = language;
246 self
247 }
248
249 pub fn prompt(mut self, prompt: String) -> Self {
251 self.prompt = Some(prompt);
252 self
253 }
254
255 pub fn temperature(mut self, temperature: f64) -> Self {
257 self.temperature = Some(temperature);
258 self
259 }
260
261 pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
263 match self.additional_params {
264 Some(params) => {
265 self.additional_params = Some(json_utils::merge(params, additional_params));
266 }
267 None => {
268 self.additional_params = Some(additional_params);
269 }
270 }
271 self
272 }
273
274 pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
276 self.additional_params = additional_params;
277 self
278 }
279
280 pub fn build(self) -> TranscriptionRequest {
283 if self.data.is_empty() {
284 panic!("Data cannot be empty!")
285 }
286
287 TranscriptionRequest {
288 data: self.data,
289 filename: self.filename.unwrap_or("file".to_string()),
290 language: self.language,
291 prompt: self.prompt,
292 temperature: self.temperature,
293 additional_params: self.additional_params,
294 }
295 }
296
297 pub async fn send(self) -> Result<TranscriptionResponse<M::Response>, TranscriptionError> {
299 let model = self.model.clone();
300
301 model.transcription(self.build()).await
302 }
303}