Skip to main content

clawft_plugin/voice/
fallback.rs

1//! STT/TTS fallback chains: local-first with cloud fallback.
2//!
3//! [`SttFallbackChain`] tries local STT first, falling back to a cloud
4//! provider when the local result has low confidence or errors.
5//! [`TtsFallbackChain`] tries local TTS first, falling back on error.
6
7use crate::PluginError;
8
9use super::cloud_stt::CloudSttProvider;
10use super::cloud_tts::CloudTtsProvider;
11
12/// Minimum confidence score to accept a local STT result without
13/// attempting cloud fallback.
14const DEFAULT_CONFIDENCE_THRESHOLD: f32 = 0.60;
15
16// ---------------------------------------------------------------------------
17// Local engine traits
18// ---------------------------------------------------------------------------
19
20/// Trait for the local STT engine (e.g., sherpa-rs based).
21#[async_trait::async_trait]
22pub trait LocalSttEngine: Send + Sync {
23    /// Transcribe audio data using the local engine.
24    async fn transcribe(
25        &self,
26        audio_data: &[u8],
27        language: Option<&str>,
28    ) -> Result<LocalSttResult, PluginError>;
29}
30
31/// Result from a local STT transcription.
32#[derive(Debug, Clone)]
33pub struct LocalSttResult {
34    /// Transcribed text.
35    pub text: String,
36    /// Confidence score (0.0-1.0).
37    pub confidence: f32,
38}
39
40/// Trait for the local TTS engine (e.g., sherpa-rs piper).
41#[async_trait::async_trait]
42pub trait LocalTtsEngine: Send + Sync {
43    /// Synthesize text to audio. Returns `(audio_data, mime_type)`.
44    async fn synthesize(&self, text: &str) -> Result<(Vec<u8>, String), PluginError>;
45}
46
47// ---------------------------------------------------------------------------
48// STT fallback
49// ---------------------------------------------------------------------------
50
51/// Source of an STT transcription result.
52#[non_exhaustive]
53#[derive(Debug, Clone)]
54pub enum SttSource {
55    /// Result came from the local STT engine.
56    Local,
57    /// Result came from a cloud provider (name stored).
58    Cloud(String),
59}
60
61/// Combined result from the STT fallback chain.
62#[derive(Debug, Clone)]
63pub struct SttFallbackResult {
64    /// Transcribed text.
65    pub text: String,
66    /// Confidence score (0.0-1.0).
67    pub confidence: f32,
68    /// Which engine produced the result.
69    pub source: SttSource,
70    /// Language code.
71    pub language: String,
72}
73
74/// Fallback chain for STT: local -> cloud on failure or low confidence.
75///
76/// Decision logic:
77/// 1. Try local engine.
78/// 2. If local succeeds with confidence >= threshold: return local result.
79/// 3. If local succeeds with low confidence: try cloud, return whichever
80///    has higher confidence. On cloud error, return the local result.
81/// 4. If local fails: try cloud. On cloud error too, propagate the local
82///    error (unless no cloud provider is configured).
83pub struct SttFallbackChain {
84    local: Box<dyn LocalSttEngine>,
85    cloud: Option<Box<dyn CloudSttProvider>>,
86    confidence_threshold: f32,
87}
88
89impl SttFallbackChain {
90    /// Create a new chain with only a local engine.
91    pub fn new(local: Box<dyn LocalSttEngine>) -> Self {
92        Self {
93            local,
94            cloud: None,
95            confidence_threshold: DEFAULT_CONFIDENCE_THRESHOLD,
96        }
97    }
98
99    /// Add a cloud provider for fallback.
100    pub fn with_cloud(mut self, provider: Box<dyn CloudSttProvider>) -> Self {
101        self.cloud = Some(provider);
102        self
103    }
104
105    /// Override the confidence threshold (default: 0.60).
106    pub fn with_confidence_threshold(mut self, threshold: f32) -> Self {
107        self.confidence_threshold = threshold;
108        self
109    }
110
111    /// Transcribe audio, falling back to cloud if local fails or
112    /// confidence is below the threshold.
113    pub async fn transcribe(
114        &self,
115        audio_data: &[u8],
116        mime_type: &str,
117        language: Option<&str>,
118    ) -> Result<SttFallbackResult, PluginError> {
119        match self.local.transcribe(audio_data, language).await {
120            Ok(local_result) if local_result.confidence >= self.confidence_threshold => {
121                // Local succeeded with good confidence.
122                Ok(SttFallbackResult {
123                    text: local_result.text,
124                    confidence: local_result.confidence,
125                    source: SttSource::Local,
126                    language: language.unwrap_or("en").to_string(),
127                })
128            }
129            Ok(low_confidence) => {
130                // Local succeeded but confidence is too low -- try cloud.
131                if let Some(cloud) = &self.cloud {
132                    match cloud.transcribe(audio_data, mime_type, language).await {
133                        Ok(cloud_result) if cloud_result.confidence > low_confidence.confidence => {
134                            Ok(SttFallbackResult {
135                                text: cloud_result.text,
136                                confidence: cloud_result.confidence,
137                                source: SttSource::Cloud(cloud.name().to_string()),
138                                language: cloud_result.language,
139                            })
140                        }
141                        Ok(_) => {
142                            // Cloud confidence was not better; keep local.
143                            Ok(SttFallbackResult {
144                                text: low_confidence.text,
145                                confidence: low_confidence.confidence,
146                                source: SttSource::Local,
147                                language: language.unwrap_or("en").to_string(),
148                            })
149                        }
150                        Err(_) => {
151                            // Cloud also failed -- return low-confidence local.
152                            Ok(SttFallbackResult {
153                                text: low_confidence.text,
154                                confidence: low_confidence.confidence,
155                                source: SttSource::Local,
156                                language: language.unwrap_or("en").to_string(),
157                            })
158                        }
159                    }
160                } else {
161                    // No cloud provider configured.
162                    Ok(SttFallbackResult {
163                        text: low_confidence.text,
164                        confidence: low_confidence.confidence,
165                        source: SttSource::Local,
166                        language: language.unwrap_or("en").to_string(),
167                    })
168                }
169            }
170            Err(local_err) => {
171                // Local failed entirely -- try cloud.
172                if let Some(cloud) = &self.cloud {
173                    let cloud_result = cloud.transcribe(audio_data, mime_type, language).await?;
174                    Ok(SttFallbackResult {
175                        text: cloud_result.text,
176                        confidence: cloud_result.confidence,
177                        source: SttSource::Cloud(cloud.name().to_string()),
178                        language: cloud_result.language,
179                    })
180                } else {
181                    Err(local_err)
182                }
183            }
184        }
185    }
186}
187
188// ---------------------------------------------------------------------------
189// TTS fallback
190// ---------------------------------------------------------------------------
191
192/// Source of a TTS synthesis result.
193#[non_exhaustive]
194#[derive(Debug, Clone)]
195pub enum TtsSource {
196    /// Result came from the local TTS engine.
197    Local,
198    /// Result came from a cloud provider (name stored).
199    Cloud(String),
200}
201
202/// Combined result from the TTS fallback chain.
203#[derive(Debug, Clone)]
204pub struct TtsFallbackResult {
205    /// Raw audio data.
206    pub audio_data: Vec<u8>,
207    /// MIME type of the audio.
208    pub mime_type: String,
209    /// Which engine produced the result.
210    pub source: TtsSource,
211}
212
213/// Fallback chain for TTS: local -> cloud on error.
214pub struct TtsFallbackChain {
215    local: Box<dyn LocalTtsEngine>,
216    cloud: Option<Box<dyn CloudTtsProvider>>,
217}
218
219impl TtsFallbackChain {
220    /// Create a new chain with only a local engine.
221    pub fn new(local: Box<dyn LocalTtsEngine>) -> Self {
222        Self { local, cloud: None }
223    }
224
225    /// Add a cloud provider for fallback.
226    pub fn with_cloud(mut self, provider: Box<dyn CloudTtsProvider>) -> Self {
227        self.cloud = Some(provider);
228        self
229    }
230
231    /// Synthesize text to speech, falling back to cloud on local error.
232    pub async fn synthesize(
233        &self,
234        text: &str,
235        voice_id: Option<&str>,
236    ) -> Result<TtsFallbackResult, PluginError> {
237        match self.local.synthesize(text).await {
238            Ok((audio, mime)) => Ok(TtsFallbackResult {
239                audio_data: audio,
240                mime_type: mime,
241                source: TtsSource::Local,
242            }),
243            Err(local_err) => {
244                if let Some(cloud) = &self.cloud {
245                    let voice = voice_id.unwrap_or("alloy");
246                    let result = cloud.synthesize(text, voice).await?;
247                    Ok(TtsFallbackResult {
248                        audio_data: result.audio_data,
249                        mime_type: result.mime_type,
250                        source: TtsSource::Cloud(cloud.name().to_string()),
251                    })
252                } else {
253                    Err(local_err)
254                }
255            }
256        }
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    // -- Mock implementations --
265
266    struct MockLocalStt {
267        text: String,
268        confidence: f32,
269        should_fail: bool,
270    }
271
272    #[async_trait::async_trait]
273    impl LocalSttEngine for MockLocalStt {
274        async fn transcribe(
275            &self,
276            _audio_data: &[u8],
277            _language: Option<&str>,
278        ) -> Result<LocalSttResult, PluginError> {
279            if self.should_fail {
280                return Err(PluginError::ExecutionFailed("local STT failed".into()));
281            }
282            Ok(LocalSttResult {
283                text: self.text.clone(),
284                confidence: self.confidence,
285            })
286        }
287    }
288
289    struct MockCloudStt {
290        text: String,
291        confidence: f32,
292        should_fail: bool,
293    }
294
295    #[async_trait::async_trait]
296    impl CloudSttProvider for MockCloudStt {
297        fn name(&self) -> &str {
298            "mock-cloud"
299        }
300
301        async fn transcribe(
302            &self,
303            _audio_data: &[u8],
304            _mime_type: &str,
305            _language: Option<&str>,
306        ) -> Result<super::super::cloud_stt::CloudSttResult, PluginError> {
307            if self.should_fail {
308                return Err(PluginError::ExecutionFailed("cloud STT failed".into()));
309            }
310            Ok(super::super::cloud_stt::CloudSttResult {
311                text: self.text.clone(),
312                confidence: self.confidence,
313                language: "en".into(),
314                duration_ms: 1000,
315            })
316        }
317    }
318
319    struct MockLocalTts {
320        should_fail: bool,
321    }
322
323    #[async_trait::async_trait]
324    impl LocalTtsEngine for MockLocalTts {
325        async fn synthesize(&self, _text: &str) -> Result<(Vec<u8>, String), PluginError> {
326            if self.should_fail {
327                return Err(PluginError::ExecutionFailed("local TTS failed".into()));
328            }
329            Ok((vec![1, 2, 3], "audio/wav".into()))
330        }
331    }
332
333    struct MockCloudTts {
334        should_fail: bool,
335    }
336
337    #[async_trait::async_trait]
338    impl CloudTtsProvider for MockCloudTts {
339        fn name(&self) -> &str {
340            "mock-cloud-tts"
341        }
342
343        fn available_voices(&self) -> Vec<super::super::cloud_tts::VoiceInfo> {
344            vec![]
345        }
346
347        async fn synthesize(
348            &self,
349            _text: &str,
350            _voice_id: &str,
351        ) -> Result<super::super::cloud_tts::CloudTtsResult, PluginError> {
352            if self.should_fail {
353                return Err(PluginError::ExecutionFailed("cloud TTS failed".into()));
354            }
355            Ok(super::super::cloud_tts::CloudTtsResult {
356                audio_data: vec![4, 5, 6],
357                mime_type: "audio/mp3".into(),
358                duration_ms: Some(1000),
359            })
360        }
361    }
362
363    // -- STT fallback tests --
364
365    #[tokio::test]
366    async fn stt_local_success_high_confidence() {
367        let chain = SttFallbackChain::new(Box::new(MockLocalStt {
368            text: "hello local".into(),
369            confidence: 0.90,
370            should_fail: false,
371        }));
372        let result = chain.transcribe(b"audio", "audio/wav", None).await.unwrap();
373        assert_eq!(result.text, "hello local");
374        assert!(matches!(result.source, SttSource::Local));
375        assert!((result.confidence - 0.90).abs() < f32::EPSILON);
376    }
377
378    #[tokio::test]
379    async fn stt_local_low_confidence_cloud_fallback() {
380        let chain = SttFallbackChain::new(Box::new(MockLocalStt {
381            text: "helo lcal".into(),
382            confidence: 0.30,
383            should_fail: false,
384        }))
385        .with_cloud(Box::new(MockCloudStt {
386            text: "hello local".into(),
387            confidence: 0.95,
388            should_fail: false,
389        }));
390
391        let result = chain.transcribe(b"audio", "audio/wav", None).await.unwrap();
392        assert_eq!(result.text, "hello local");
393        assert!(matches!(result.source, SttSource::Cloud(_)));
394        assert!((result.confidence - 0.95).abs() < f32::EPSILON);
395    }
396
397    #[tokio::test]
398    async fn stt_local_low_confidence_cloud_worse_keeps_local() {
399        let chain = SttFallbackChain::new(Box::new(MockLocalStt {
400            text: "helo".into(),
401            confidence: 0.50,
402            should_fail: false,
403        }))
404        .with_cloud(Box::new(MockCloudStt {
405            text: "hello".into(),
406            confidence: 0.40, // Worse than local
407            should_fail: false,
408        }));
409
410        let result = chain.transcribe(b"audio", "audio/wav", None).await.unwrap();
411        assert_eq!(result.text, "helo");
412        assert!(matches!(result.source, SttSource::Local));
413    }
414
415    #[tokio::test]
416    async fn stt_local_error_cloud_fallback() {
417        let chain = SttFallbackChain::new(Box::new(MockLocalStt {
418            text: "".into(),
419            confidence: 0.0,
420            should_fail: true,
421        }))
422        .with_cloud(Box::new(MockCloudStt {
423            text: "cloud result".into(),
424            confidence: 0.90,
425            should_fail: false,
426        }));
427
428        let result = chain.transcribe(b"audio", "audio/wav", None).await.unwrap();
429        assert_eq!(result.text, "cloud result");
430        assert!(matches!(result.source, SttSource::Cloud(_)));
431    }
432
433    #[tokio::test]
434    async fn stt_both_fail_returns_error() {
435        let chain = SttFallbackChain::new(Box::new(MockLocalStt {
436            text: "".into(),
437            confidence: 0.0,
438            should_fail: true,
439        }))
440        .with_cloud(Box::new(MockCloudStt {
441            text: "".into(),
442            confidence: 0.0,
443            should_fail: true,
444        }));
445
446        let result = chain.transcribe(b"audio", "audio/wav", None).await;
447        assert!(result.is_err());
448    }
449
450    #[tokio::test]
451    async fn stt_local_low_confidence_cloud_also_fails_returns_local() {
452        let chain = SttFallbackChain::new(Box::new(MockLocalStt {
453            text: "low conf".into(),
454            confidence: 0.30,
455            should_fail: false,
456        }))
457        .with_cloud(Box::new(MockCloudStt {
458            text: "".into(),
459            confidence: 0.0,
460            should_fail: true,
461        }));
462
463        let result = chain.transcribe(b"audio", "audio/wav", None).await.unwrap();
464        assert_eq!(result.text, "low conf");
465        assert!(matches!(result.source, SttSource::Local));
466    }
467
468    #[tokio::test]
469    async fn stt_custom_threshold() {
470        let chain = SttFallbackChain::new(Box::new(MockLocalStt {
471            text: "local".into(),
472            confidence: 0.80,
473            should_fail: false,
474        }))
475        .with_confidence_threshold(0.90) // Higher threshold
476        .with_cloud(Box::new(MockCloudStt {
477            text: "cloud".into(),
478            confidence: 0.95,
479            should_fail: false,
480        }));
481
482        let result = chain.transcribe(b"audio", "audio/wav", None).await.unwrap();
483        // 0.80 < 0.90 threshold, so cloud should be tried
484        assert_eq!(result.text, "cloud");
485        assert!(matches!(result.source, SttSource::Cloud(_)));
486    }
487
488    // -- TTS fallback tests --
489
490    #[tokio::test]
491    async fn tts_local_success() {
492        let chain = TtsFallbackChain::new(Box::new(MockLocalTts { should_fail: false }));
493        let result = chain.synthesize("hello", None).await.unwrap();
494        assert_eq!(result.audio_data, vec![1, 2, 3]);
495        assert!(matches!(result.source, TtsSource::Local));
496    }
497
498    #[tokio::test]
499    async fn tts_local_error_cloud_fallback() {
500        let chain = TtsFallbackChain::new(Box::new(MockLocalTts { should_fail: true }))
501            .with_cloud(Box::new(MockCloudTts { should_fail: false }));
502        let result = chain.synthesize("hello", None).await.unwrap();
503        assert_eq!(result.audio_data, vec![4, 5, 6]);
504        assert!(matches!(result.source, TtsSource::Cloud(_)));
505    }
506
507    #[tokio::test]
508    async fn tts_both_fail_returns_error() {
509        let chain = TtsFallbackChain::new(Box::new(MockLocalTts { should_fail: true }))
510            .with_cloud(Box::new(MockCloudTts { should_fail: true }));
511        let result = chain.synthesize("hello", None).await;
512        assert!(result.is_err());
513    }
514}