1use crate::PluginError;
8
9use super::cloud_stt::CloudSttProvider;
10use super::cloud_tts::CloudTtsProvider;
11
12const DEFAULT_CONFIDENCE_THRESHOLD: f32 = 0.60;
15
16#[async_trait::async_trait]
22pub trait LocalSttEngine: Send + Sync {
23 async fn transcribe(
25 &self,
26 audio_data: &[u8],
27 language: Option<&str>,
28 ) -> Result<LocalSttResult, PluginError>;
29}
30
31#[derive(Debug, Clone)]
33pub struct LocalSttResult {
34 pub text: String,
36 pub confidence: f32,
38}
39
40#[async_trait::async_trait]
42pub trait LocalTtsEngine: Send + Sync {
43 async fn synthesize(&self, text: &str) -> Result<(Vec<u8>, String), PluginError>;
45}
46
47#[non_exhaustive]
53#[derive(Debug, Clone)]
54pub enum SttSource {
55 Local,
57 Cloud(String),
59}
60
61#[derive(Debug, Clone)]
63pub struct SttFallbackResult {
64 pub text: String,
66 pub confidence: f32,
68 pub source: SttSource,
70 pub language: String,
72}
73
74pub struct SttFallbackChain {
84 local: Box<dyn LocalSttEngine>,
85 cloud: Option<Box<dyn CloudSttProvider>>,
86 confidence_threshold: f32,
87}
88
89impl SttFallbackChain {
90 pub fn new(local: Box<dyn LocalSttEngine>) -> Self {
92 Self {
93 local,
94 cloud: None,
95 confidence_threshold: DEFAULT_CONFIDENCE_THRESHOLD,
96 }
97 }
98
99 pub fn with_cloud(mut self, provider: Box<dyn CloudSttProvider>) -> Self {
101 self.cloud = Some(provider);
102 self
103 }
104
105 pub fn with_confidence_threshold(mut self, threshold: f32) -> Self {
107 self.confidence_threshold = threshold;
108 self
109 }
110
111 pub async fn transcribe(
114 &self,
115 audio_data: &[u8],
116 mime_type: &str,
117 language: Option<&str>,
118 ) -> Result<SttFallbackResult, PluginError> {
119 match self.local.transcribe(audio_data, language).await {
120 Ok(local_result) if local_result.confidence >= self.confidence_threshold => {
121 Ok(SttFallbackResult {
123 text: local_result.text,
124 confidence: local_result.confidence,
125 source: SttSource::Local,
126 language: language.unwrap_or("en").to_string(),
127 })
128 }
129 Ok(low_confidence) => {
130 if let Some(cloud) = &self.cloud {
132 match cloud.transcribe(audio_data, mime_type, language).await {
133 Ok(cloud_result) if cloud_result.confidence > low_confidence.confidence => {
134 Ok(SttFallbackResult {
135 text: cloud_result.text,
136 confidence: cloud_result.confidence,
137 source: SttSource::Cloud(cloud.name().to_string()),
138 language: cloud_result.language,
139 })
140 }
141 Ok(_) => {
142 Ok(SttFallbackResult {
144 text: low_confidence.text,
145 confidence: low_confidence.confidence,
146 source: SttSource::Local,
147 language: language.unwrap_or("en").to_string(),
148 })
149 }
150 Err(_) => {
151 Ok(SttFallbackResult {
153 text: low_confidence.text,
154 confidence: low_confidence.confidence,
155 source: SttSource::Local,
156 language: language.unwrap_or("en").to_string(),
157 })
158 }
159 }
160 } else {
161 Ok(SttFallbackResult {
163 text: low_confidence.text,
164 confidence: low_confidence.confidence,
165 source: SttSource::Local,
166 language: language.unwrap_or("en").to_string(),
167 })
168 }
169 }
170 Err(local_err) => {
171 if let Some(cloud) = &self.cloud {
173 let cloud_result = cloud.transcribe(audio_data, mime_type, language).await?;
174 Ok(SttFallbackResult {
175 text: cloud_result.text,
176 confidence: cloud_result.confidence,
177 source: SttSource::Cloud(cloud.name().to_string()),
178 language: cloud_result.language,
179 })
180 } else {
181 Err(local_err)
182 }
183 }
184 }
185 }
186}
187
188#[non_exhaustive]
194#[derive(Debug, Clone)]
195pub enum TtsSource {
196 Local,
198 Cloud(String),
200}
201
202#[derive(Debug, Clone)]
204pub struct TtsFallbackResult {
205 pub audio_data: Vec<u8>,
207 pub mime_type: String,
209 pub source: TtsSource,
211}
212
213pub struct TtsFallbackChain {
215 local: Box<dyn LocalTtsEngine>,
216 cloud: Option<Box<dyn CloudTtsProvider>>,
217}
218
219impl TtsFallbackChain {
220 pub fn new(local: Box<dyn LocalTtsEngine>) -> Self {
222 Self { local, cloud: None }
223 }
224
225 pub fn with_cloud(mut self, provider: Box<dyn CloudTtsProvider>) -> Self {
227 self.cloud = Some(provider);
228 self
229 }
230
231 pub async fn synthesize(
233 &self,
234 text: &str,
235 voice_id: Option<&str>,
236 ) -> Result<TtsFallbackResult, PluginError> {
237 match self.local.synthesize(text).await {
238 Ok((audio, mime)) => Ok(TtsFallbackResult {
239 audio_data: audio,
240 mime_type: mime,
241 source: TtsSource::Local,
242 }),
243 Err(local_err) => {
244 if let Some(cloud) = &self.cloud {
245 let voice = voice_id.unwrap_or("alloy");
246 let result = cloud.synthesize(text, voice).await?;
247 Ok(TtsFallbackResult {
248 audio_data: result.audio_data,
249 mime_type: result.mime_type,
250 source: TtsSource::Cloud(cloud.name().to_string()),
251 })
252 } else {
253 Err(local_err)
254 }
255 }
256 }
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 struct MockLocalStt {
267 text: String,
268 confidence: f32,
269 should_fail: bool,
270 }
271
272 #[async_trait::async_trait]
273 impl LocalSttEngine for MockLocalStt {
274 async fn transcribe(
275 &self,
276 _audio_data: &[u8],
277 _language: Option<&str>,
278 ) -> Result<LocalSttResult, PluginError> {
279 if self.should_fail {
280 return Err(PluginError::ExecutionFailed("local STT failed".into()));
281 }
282 Ok(LocalSttResult {
283 text: self.text.clone(),
284 confidence: self.confidence,
285 })
286 }
287 }
288
289 struct MockCloudStt {
290 text: String,
291 confidence: f32,
292 should_fail: bool,
293 }
294
295 #[async_trait::async_trait]
296 impl CloudSttProvider for MockCloudStt {
297 fn name(&self) -> &str {
298 "mock-cloud"
299 }
300
301 async fn transcribe(
302 &self,
303 _audio_data: &[u8],
304 _mime_type: &str,
305 _language: Option<&str>,
306 ) -> Result<super::super::cloud_stt::CloudSttResult, PluginError> {
307 if self.should_fail {
308 return Err(PluginError::ExecutionFailed("cloud STT failed".into()));
309 }
310 Ok(super::super::cloud_stt::CloudSttResult {
311 text: self.text.clone(),
312 confidence: self.confidence,
313 language: "en".into(),
314 duration_ms: 1000,
315 })
316 }
317 }
318
319 struct MockLocalTts {
320 should_fail: bool,
321 }
322
323 #[async_trait::async_trait]
324 impl LocalTtsEngine for MockLocalTts {
325 async fn synthesize(&self, _text: &str) -> Result<(Vec<u8>, String), PluginError> {
326 if self.should_fail {
327 return Err(PluginError::ExecutionFailed("local TTS failed".into()));
328 }
329 Ok((vec![1, 2, 3], "audio/wav".into()))
330 }
331 }
332
333 struct MockCloudTts {
334 should_fail: bool,
335 }
336
337 #[async_trait::async_trait]
338 impl CloudTtsProvider for MockCloudTts {
339 fn name(&self) -> &str {
340 "mock-cloud-tts"
341 }
342
343 fn available_voices(&self) -> Vec<super::super::cloud_tts::VoiceInfo> {
344 vec![]
345 }
346
347 async fn synthesize(
348 &self,
349 _text: &str,
350 _voice_id: &str,
351 ) -> Result<super::super::cloud_tts::CloudTtsResult, PluginError> {
352 if self.should_fail {
353 return Err(PluginError::ExecutionFailed("cloud TTS failed".into()));
354 }
355 Ok(super::super::cloud_tts::CloudTtsResult {
356 audio_data: vec![4, 5, 6],
357 mime_type: "audio/mp3".into(),
358 duration_ms: Some(1000),
359 })
360 }
361 }
362
363 #[tokio::test]
366 async fn stt_local_success_high_confidence() {
367 let chain = SttFallbackChain::new(Box::new(MockLocalStt {
368 text: "hello local".into(),
369 confidence: 0.90,
370 should_fail: false,
371 }));
372 let result = chain.transcribe(b"audio", "audio/wav", None).await.unwrap();
373 assert_eq!(result.text, "hello local");
374 assert!(matches!(result.source, SttSource::Local));
375 assert!((result.confidence - 0.90).abs() < f32::EPSILON);
376 }
377
378 #[tokio::test]
379 async fn stt_local_low_confidence_cloud_fallback() {
380 let chain = SttFallbackChain::new(Box::new(MockLocalStt {
381 text: "helo lcal".into(),
382 confidence: 0.30,
383 should_fail: false,
384 }))
385 .with_cloud(Box::new(MockCloudStt {
386 text: "hello local".into(),
387 confidence: 0.95,
388 should_fail: false,
389 }));
390
391 let result = chain.transcribe(b"audio", "audio/wav", None).await.unwrap();
392 assert_eq!(result.text, "hello local");
393 assert!(matches!(result.source, SttSource::Cloud(_)));
394 assert!((result.confidence - 0.95).abs() < f32::EPSILON);
395 }
396
397 #[tokio::test]
398 async fn stt_local_low_confidence_cloud_worse_keeps_local() {
399 let chain = SttFallbackChain::new(Box::new(MockLocalStt {
400 text: "helo".into(),
401 confidence: 0.50,
402 should_fail: false,
403 }))
404 .with_cloud(Box::new(MockCloudStt {
405 text: "hello".into(),
406 confidence: 0.40, should_fail: false,
408 }));
409
410 let result = chain.transcribe(b"audio", "audio/wav", None).await.unwrap();
411 assert_eq!(result.text, "helo");
412 assert!(matches!(result.source, SttSource::Local));
413 }
414
415 #[tokio::test]
416 async fn stt_local_error_cloud_fallback() {
417 let chain = SttFallbackChain::new(Box::new(MockLocalStt {
418 text: "".into(),
419 confidence: 0.0,
420 should_fail: true,
421 }))
422 .with_cloud(Box::new(MockCloudStt {
423 text: "cloud result".into(),
424 confidence: 0.90,
425 should_fail: false,
426 }));
427
428 let result = chain.transcribe(b"audio", "audio/wav", None).await.unwrap();
429 assert_eq!(result.text, "cloud result");
430 assert!(matches!(result.source, SttSource::Cloud(_)));
431 }
432
433 #[tokio::test]
434 async fn stt_both_fail_returns_error() {
435 let chain = SttFallbackChain::new(Box::new(MockLocalStt {
436 text: "".into(),
437 confidence: 0.0,
438 should_fail: true,
439 }))
440 .with_cloud(Box::new(MockCloudStt {
441 text: "".into(),
442 confidence: 0.0,
443 should_fail: true,
444 }));
445
446 let result = chain.transcribe(b"audio", "audio/wav", None).await;
447 assert!(result.is_err());
448 }
449
450 #[tokio::test]
451 async fn stt_local_low_confidence_cloud_also_fails_returns_local() {
452 let chain = SttFallbackChain::new(Box::new(MockLocalStt {
453 text: "low conf".into(),
454 confidence: 0.30,
455 should_fail: false,
456 }))
457 .with_cloud(Box::new(MockCloudStt {
458 text: "".into(),
459 confidence: 0.0,
460 should_fail: true,
461 }));
462
463 let result = chain.transcribe(b"audio", "audio/wav", None).await.unwrap();
464 assert_eq!(result.text, "low conf");
465 assert!(matches!(result.source, SttSource::Local));
466 }
467
468 #[tokio::test]
469 async fn stt_custom_threshold() {
470 let chain = SttFallbackChain::new(Box::new(MockLocalStt {
471 text: "local".into(),
472 confidence: 0.80,
473 should_fail: false,
474 }))
475 .with_confidence_threshold(0.90) .with_cloud(Box::new(MockCloudStt {
477 text: "cloud".into(),
478 confidence: 0.95,
479 should_fail: false,
480 }));
481
482 let result = chain.transcribe(b"audio", "audio/wav", None).await.unwrap();
483 assert_eq!(result.text, "cloud");
485 assert!(matches!(result.source, SttSource::Cloud(_)));
486 }
487
488 #[tokio::test]
491 async fn tts_local_success() {
492 let chain = TtsFallbackChain::new(Box::new(MockLocalTts { should_fail: false }));
493 let result = chain.synthesize("hello", None).await.unwrap();
494 assert_eq!(result.audio_data, vec![1, 2, 3]);
495 assert!(matches!(result.source, TtsSource::Local));
496 }
497
498 #[tokio::test]
499 async fn tts_local_error_cloud_fallback() {
500 let chain = TtsFallbackChain::new(Box::new(MockLocalTts { should_fail: true }))
501 .with_cloud(Box::new(MockCloudTts { should_fail: false }));
502 let result = chain.synthesize("hello", None).await.unwrap();
503 assert_eq!(result.audio_data, vec![4, 5, 6]);
504 assert!(matches!(result.source, TtsSource::Cloud(_)));
505 }
506
507 #[tokio::test]
508 async fn tts_both_fail_returns_error() {
509 let chain = TtsFallbackChain::new(Box::new(MockLocalTts { should_fail: true }))
510 .with_cloud(Box::new(MockCloudTts { should_fail: true }));
511 let result = chain.synthesize("hello", None).await;
512 assert!(result.is_err());
513 }
514}