llmservice_flows/
audio.rs

1use reqwest::multipart;
2use serde::Deserialize;
3
4use crate::LLMApi;
5use crate::Retry;
6
7pub struct TranscribeInput {
8    pub audio: Vec<u8>,
9    pub audio_format: String,
10    pub language: String,
11    pub max_len: Option<u64>,
12    pub max_context: Option<i32>,
13    pub split_on_word: Option<bool>,
14}
15
16impl LLMApi for TranscribeInput {
17    type Output = TranscriptionOutput;
18    async fn api(&self, endpoint: &str, api_key: &str) -> Retry<Self::Output> {
19        transcribe_inner(endpoint, api_key, &self).await
20    }
21}
22
23#[derive(Debug, Deserialize)]
24pub struct TranscriptionOutput {
25    pub text: String,
26}
27pub struct TranslateInput {
28    pub audio: Vec<u8>,
29    pub audio_format: String,
30    pub language: String,
31    pub max_len: Option<u64>,
32    pub max_context: Option<i32>,
33    pub split_on_word: Option<bool>,
34}
35
36impl LLMApi for TranslateInput {
37    type Output = TranslationOutput;
38    async fn api(&self, endpoint: &str, api_key: &str) -> Retry<Self::Output> {
39        translate_inner(endpoint, api_key, &self).await
40    }
41}
42
43#[derive(Debug, Deserialize)]
44pub struct TranslationOutput {
45    pub text: String,
46}
47
48impl<'a> crate::LLMServiceFlows<'a> {
49    /// Transcribe audio into the input language.
50    ///
51    /// `input` is an [TranscribeInput] object.
52    ///
53    ///```rust
54    ///   // This code snippet transcribe input audio into English, the audio is collected in previous step.
55    ///   // Prepare the TranscribeInput struct.
56    ///   let input = TranscribeInput {
57    ///      audio: audio,
58    ///      audio_format: "wav".to_string(),
59    ///      language: "en".to_string(),
60    ///      max_len: Some(0),
61    ///      max_context: Some(-1),
62    ///      split_on_word: Some(false),
63    ///   };
64    ///   // Call the transcribe function.
65    ///   let transcription = match llm.transcribe(input).await {
66    ///       Ok(r) => r.text,
67    ///       Err(e) => {your error handling},
68    ///   };
69    /// ```
70    pub async fn transcribe(&self, input: TranscribeInput) -> Result<TranscriptionOutput, String> {
71        self.keep_trying(input).await
72    }
73
74    /// Translate audio into English.
75    ///
76    /// `input` is an [TranslateInput] object.
77    ///
78    ///```rust
79    ///   // This code snippet translate input audio into English, the audio is collected in previous step.
80    ///   // Prepare the TranslateInput struct.
81    ///   let input = TranslateInput {
82    ///      audio: audio,
83    ///      audio_format: "wav".to_string(),
84    ///      language: "zh".to_string(),
85    ///      max_len: Some(0),
86    ///      max_context: Some(-1),
87    ///      split_on_word: Some(false),
88    ///   };
89    ///   // Call the translate function.
90    ///   let translation = match llm.translate(input).await {
91    ///       Ok(r) => r.text,
92    ///       Err(e) => {your error handling},
93    ///   };
94    /// ```
95    pub async fn translate(&self, input: TranslateInput) -> Result<TranslationOutput, String> {
96        self.keep_trying(input).await
97    }
98}
99
100async fn transcribe_inner(
101    endpoint: &str,
102    _api_key: &str,
103    input: &TranscribeInput,
104) -> Retry<TranscriptionOutput> {
105    let uri = format!("{}/audio/transcriptions", endpoint);
106
107    let mut form = multipart::Form::new()
108        .part(
109            "file",
110            multipart::Part::bytes(input.audio.clone())
111                .file_name(format!("audio.{}", input.audio_format)),
112        )
113        .part("language", multipart::Part::text(input.language.clone()));
114    if input.max_len.is_some() {
115        form = form.part(
116            "max_len",
117            multipart::Part::text(input.max_len.unwrap().to_string()),
118        );
119    }
120    if input.max_context.is_some() {
121        form = form.part(
122            "max_context",
123            multipart::Part::text(input.max_context.unwrap().to_string()),
124        );
125    }
126    if input.split_on_word.is_some() {
127        form = form.part(
128            "split_on_word",
129            multipart::Part::text(input.split_on_word.unwrap().to_string()),
130        );
131    }
132
133    match reqwest::Client::new()
134        .post(uri)
135        .multipart(form)
136        .send()
137        .await
138    {
139        Ok(res) => {
140            let status = res.status();
141            let body = res.bytes().await.unwrap();
142            match status.is_success() {
143                true => Retry::No(
144                    serde_json::from_slice::<TranscriptionOutput>(&body.as_ref())
145                        .or(Err(String::from("Unexpected error"))),
146                ),
147                false => {
148                    match status.into() {
149                        409 | 429 | 503 => {
150                            // 409 TryAgain 429 RateLimitError
151                            // 503 ServiceUnavailable
152                            Retry::Yes(String::from_utf8_lossy(&body.as_ref()).into_owned())
153                        }
154                        _ => Retry::No(Err(String::from_utf8_lossy(&body.as_ref()).into_owned())),
155                    }
156                }
157            }
158        }
159        Err(e) => Retry::No(Err(e.to_string())),
160    }
161}
162
163async fn translate_inner(
164    endpoint: &str,
165    _api_key: &str,
166    input: &TranslateInput,
167) -> Retry<TranslationOutput> {
168    let uri = format!("{}/audio/translations", endpoint);
169
170    let mut form = multipart::Form::new()
171        .part(
172            "file",
173            multipart::Part::bytes(input.audio.clone())
174                .file_name(format!("audio.{}", input.audio_format)),
175        )
176        .part("language", multipart::Part::text(input.language.clone()));
177    if input.max_len.is_some() {
178        form = form.part(
179            "max_len",
180            multipart::Part::text(input.max_len.unwrap().to_string()),
181        );
182    }
183    if input.max_context.is_some() {
184        form = form.part(
185            "max_context",
186            multipart::Part::text(input.max_context.unwrap().to_string()),
187        );
188    }
189    if input.split_on_word.is_some() {
190        form = form.part(
191            "split_on_word",
192            multipart::Part::text(input.split_on_word.unwrap().to_string()),
193        );
194    }
195
196    match reqwest::Client::new()
197        .post(uri)
198        .multipart(form)
199        .send()
200        .await
201    {
202        Ok(res) => {
203            let status = res.status();
204            let body = res.bytes().await.unwrap();
205            match status.is_success() {
206                true => Retry::No(
207                    serde_json::from_slice::<TranslationOutput>(&body.as_ref())
208                        .or(Err(String::from("Unexpected error"))),
209                ),
210                false => {
211                    match status.into() {
212                        409 | 429 | 503 => {
213                            // 409 TryAgain 429 RateLimitError
214                            // 503 ServiceUnavailable
215                            Retry::Yes(String::from_utf8_lossy(&body.as_ref()).into_owned())
216                        }
217                        _ => Retry::No(Err(String::from_utf8_lossy(&body.as_ref()).into_owned())),
218                    }
219                }
220            }
221        }
222        Err(e) => Retry::No(Err(e.to_string())),
223    }
224}