Skip to main content

autoagents_speech/
provider.rs

1use crate::{
2    AudioChunk, AudioFormat, ModelInfo, STTResult, SpeechRequest, SpeechResponse, TTSResult,
3    TextChunk, TranscriptionRequest, TranscriptionResponse,
4};
5use async_trait::async_trait;
6use futures::Stream;
7use std::pin::Pin;
8
9/// Marker Trait for TTS providers
10///
11/// This trait combines all TTS capabilities into a single provider interface.
12/// Providers should implement this marker trait along with the specific capability traits.
13pub trait TTSProvider: TTSSpeechProvider + TTSModelsProvider + Send + Sync {}
14
15/// Trait for TTS speech generation capabilities
16#[async_trait]
17pub trait TTSSpeechProvider: Send + Sync {
18    /// Generate speech from text (required)
19    ///
20    /// # Arguments
21    /// * `request` - Speech generation request with text, voice, and format
22    ///
23    /// # Returns
24    /// Speech response with audio data and metadata
25    async fn generate_speech(&self, request: SpeechRequest) -> TTSResult<SpeechResponse>;
26
27    /// Generate speech as a stream (optional)
28    ///
29    /// # Arguments
30    /// * `request` - Speech generation request
31    ///
32    /// # Returns
33    /// Stream of audio chunks
34    async fn generate_speech_stream<'a>(
35        &'a self,
36        _request: SpeechRequest,
37    ) -> TTSResult<Pin<Box<dyn Stream<Item = TTSResult<AudioChunk>> + Send + 'a>>> {
38        // Default implementation: not supported
39        Err(crate::error::TTSError::StreamingNotSupported(
40            "Not Supported".to_string(),
41        ))
42    }
43
44    /// Check if streaming is supported (default: false)
45    fn supports_streaming(&self) -> bool {
46        false
47    }
48
49    /// Get supported audio formats (default: WAV only)
50    fn supported_formats(&self) -> Vec<AudioFormat> {
51        vec![AudioFormat::Wav]
52    }
53
54    /// Get default sample rate
55    fn default_sample_rate(&self) -> u32 {
56        24000
57    }
58}
59
60/// Trait for TTS model management capabilities
61#[async_trait]
62pub trait TTSModelsProvider: Send + Sync {
63    /// List available models (optional)
64    ///
65    /// # Returns
66    /// List of available model information
67    async fn list_models(&self) -> TTSResult<Vec<ModelInfo>> {
68        Ok(vec![])
69    }
70
71    /// Get current model information (required)
72    ///
73    /// # Returns
74    /// Current model information
75    fn get_current_model(&self) -> ModelInfo;
76
77    /// Get supported languages
78    fn supported_languages(&self) -> Vec<String> {
79        vec!["en".to_string()]
80    }
81}
82
83/// Marker trait for STT providers
84///
85/// This trait combines all STT capabilities into a single provider interface.
86/// Providers should implement this marker trait along with the specific capability traits.
87pub trait STTProvider: STTSpeechProvider + STTModelsProvider + Send + Sync {}
88
89/// Trait for STT transcription capabilities
90#[async_trait]
91pub trait STTSpeechProvider: Send + Sync {
92    /// Transcribe audio to text (required)
93    ///
94    /// # Arguments
95    /// * `request` - Transcription request with audio and options
96    ///
97    /// # Returns
98    /// Transcription response with text and optional timestamps
99    async fn transcribe(&self, request: TranscriptionRequest) -> STTResult<TranscriptionResponse>;
100
101    /// Transcribe audio as a stream (optional)
102    ///
103    /// # Arguments
104    /// * `request` - Transcription request
105    ///
106    /// # Returns
107    /// Stream of text chunks
108    async fn transcribe_stream<'a>(
109        &'a self,
110        _request: TranscriptionRequest,
111    ) -> STTResult<Pin<Box<dyn Stream<Item = STTResult<TextChunk>> + Send + 'a>>> {
112        Err(crate::error::STTError::StreamingNotSupported(
113            "Not Supported".to_string(),
114        ))
115    }
116
117    /// Check if streaming is supported (default: false)
118    fn supports_streaming(&self) -> bool {
119        false
120    }
121
122    /// Get supported sample rate (default: 16000Hz)
123    fn supported_sample_rate(&self) -> u32 {
124        16000
125    }
126
127    /// Get supported number of channels (default: 1 for mono)
128    fn supported_channels(&self) -> u16 {
129        1
130    }
131
132    /// Check if timestamps are supported (default: false)
133    fn supports_timestamps(&self) -> bool {
134        false
135    }
136}
137
138/// Trait for STT model management capabilities
139#[async_trait]
140pub trait STTModelsProvider: Send + Sync {
141    /// List available models (optional)
142    ///
143    /// # Returns
144    /// List of available model information
145    async fn list_models(&self) -> STTResult<Vec<ModelInfo>> {
146        Ok(vec![])
147    }
148
149    /// Get current model information (required)
150    ///
151    /// # Returns
152    /// Current model information
153    fn get_current_model(&self) -> ModelInfo;
154
155    /// Get supported languages
156    fn supported_languages(&self) -> Vec<String> {
157        vec!["en".to_string()]
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use crate::{
165        AudioData, AudioFormat, ModelInfo, SharedAudioData, SpeechRequest, SpeechResponse,
166        TranscriptionRequest, TranscriptionResponse, VoiceIdentifier,
167    };
168    use async_trait::async_trait;
169
170    #[derive(Debug)]
171    struct DummyProvider;
172
173    #[async_trait]
174    impl TTSSpeechProvider for DummyProvider {
175        async fn generate_speech(&self, request: SpeechRequest) -> TTSResult<SpeechResponse> {
176            Ok(SpeechResponse {
177                audio: AudioData {
178                    samples: vec![0.0],
179                    channels: 1,
180                    sample_rate: request.sample_rate.unwrap_or(24000),
181                },
182                text: request.text,
183                duration_ms: 0,
184            })
185        }
186    }
187
188    #[async_trait]
189    impl TTSModelsProvider for DummyProvider {
190        fn get_current_model(&self) -> ModelInfo {
191            ModelInfo {
192                id: "dummy".to_string(),
193                name: "Dummy".to_string(),
194                description: None,
195                languages: vec!["en".to_string()],
196            }
197        }
198    }
199
200    impl TTSProvider for DummyProvider {}
201
202    #[tokio::test]
203    async fn test_default_streaming_not_supported() {
204        let provider = DummyProvider;
205        let request = SpeechRequest {
206            text: "hello".to_string(),
207            voice: VoiceIdentifier::new("test"),
208            format: AudioFormat::Wav,
209            sample_rate: None,
210        };
211
212        let err = match provider.generate_speech_stream(request).await {
213            Ok(_) => panic!("expected streaming not supported"),
214            Err(err) => err,
215        };
216        assert!(matches!(
217            err,
218            crate::error::TTSError::StreamingNotSupported(_)
219        ));
220        assert!(!provider.supports_streaming());
221    }
222
223    #[test]
224    fn test_default_provider_formats_and_languages() {
225        let provider = DummyProvider;
226        assert_eq!(provider.supported_formats(), vec![AudioFormat::Wav]);
227        assert_eq!(provider.default_sample_rate(), 24000);
228        assert_eq!(provider.supported_languages(), vec!["en".to_string()]);
229    }
230
231    #[derive(Debug)]
232    struct DummySTTProvider;
233
234    #[async_trait]
235    impl STTSpeechProvider for DummySTTProvider {
236        async fn transcribe(
237            &self,
238            request: TranscriptionRequest,
239        ) -> STTResult<TranscriptionResponse> {
240            Ok(TranscriptionResponse {
241                text: format!("Transcribed {} samples", request.audio.samples.len()),
242                timestamps: None,
243                duration_ms: 0,
244            })
245        }
246    }
247
248    #[async_trait]
249    impl STTModelsProvider for DummySTTProvider {
250        fn get_current_model(&self) -> ModelInfo {
251            ModelInfo {
252                id: "dummy".to_string(),
253                name: "Dummy STT".to_string(),
254                description: None,
255                languages: vec!["en".to_string()],
256            }
257        }
258    }
259
260    impl STTProvider for DummySTTProvider {}
261
262    #[tokio::test]
263    async fn test_stt_default_streaming_not_supported() {
264        let provider = DummySTTProvider;
265        let request = TranscriptionRequest {
266            audio: SharedAudioData::new(AudioData {
267                samples: vec![0.0; 16000],
268                sample_rate: 16000,
269                channels: 1,
270            }),
271            language: None,
272            include_timestamps: false,
273        };
274
275        let err = match provider.transcribe_stream(request).await {
276            Ok(_) => panic!("expected streaming not supported"),
277            Err(err) => err,
278        };
279        assert!(matches!(
280            err,
281            crate::error::STTError::StreamingNotSupported(_)
282        ));
283        assert!(!provider.supports_streaming());
284    }
285
286    #[test]
287    fn test_stt_default_provider_settings() {
288        let provider = DummySTTProvider;
289        assert_eq!(provider.supported_sample_rate(), 16000);
290        assert_eq!(provider.supported_channels(), 1);
291        assert_eq!(provider.supported_languages(), vec!["en".to_string()]);
292        assert!(!provider.supports_timestamps());
293    }
294}