autoagents_speech/
provider.rs1use crate::{
2 AudioChunk, AudioFormat, ModelInfo, STTResult, SpeechRequest, SpeechResponse, TTSResult,
3 TextChunk, TranscriptionRequest, TranscriptionResponse,
4};
5use async_trait::async_trait;
6use futures::Stream;
7use std::pin::Pin;
8
9pub trait TTSProvider: TTSSpeechProvider + TTSModelsProvider + Send + Sync {}
14
15#[async_trait]
17pub trait TTSSpeechProvider: Send + Sync {
18 async fn generate_speech(&self, request: SpeechRequest) -> TTSResult<SpeechResponse>;
26
27 async fn generate_speech_stream<'a>(
35 &'a self,
36 _request: SpeechRequest,
37 ) -> TTSResult<Pin<Box<dyn Stream<Item = TTSResult<AudioChunk>> + Send + 'a>>> {
38 Err(crate::error::TTSError::StreamingNotSupported(
40 "Not Supported".to_string(),
41 ))
42 }
43
44 fn supports_streaming(&self) -> bool {
46 false
47 }
48
49 fn supported_formats(&self) -> Vec<AudioFormat> {
51 vec![AudioFormat::Wav]
52 }
53
54 fn default_sample_rate(&self) -> u32 {
56 24000
57 }
58}
59
60#[async_trait]
62pub trait TTSModelsProvider: Send + Sync {
63 async fn list_models(&self) -> TTSResult<Vec<ModelInfo>> {
68 Ok(vec![])
69 }
70
71 fn get_current_model(&self) -> ModelInfo;
76
77 fn supported_languages(&self) -> Vec<String> {
79 vec!["en".to_string()]
80 }
81}
82
83pub trait STTProvider: STTSpeechProvider + STTModelsProvider + Send + Sync {}
88
89#[async_trait]
91pub trait STTSpeechProvider: Send + Sync {
92 async fn transcribe(&self, request: TranscriptionRequest) -> STTResult<TranscriptionResponse>;
100
101 async fn transcribe_stream<'a>(
109 &'a self,
110 _request: TranscriptionRequest,
111 ) -> STTResult<Pin<Box<dyn Stream<Item = STTResult<TextChunk>> + Send + 'a>>> {
112 Err(crate::error::STTError::StreamingNotSupported(
113 "Not Supported".to_string(),
114 ))
115 }
116
117 fn supports_streaming(&self) -> bool {
119 false
120 }
121
122 fn supported_sample_rate(&self) -> u32 {
124 16000
125 }
126
127 fn supported_channels(&self) -> u16 {
129 1
130 }
131
132 fn supports_timestamps(&self) -> bool {
134 false
135 }
136}
137
138#[async_trait]
140pub trait STTModelsProvider: Send + Sync {
141 async fn list_models(&self) -> STTResult<Vec<ModelInfo>> {
146 Ok(vec![])
147 }
148
149 fn get_current_model(&self) -> ModelInfo;
154
155 fn supported_languages(&self) -> Vec<String> {
157 vec!["en".to_string()]
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use crate::{
165 AudioData, AudioFormat, ModelInfo, SharedAudioData, SpeechRequest, SpeechResponse,
166 TranscriptionRequest, TranscriptionResponse, VoiceIdentifier,
167 };
168 use async_trait::async_trait;
169
170 #[derive(Debug)]
171 struct DummyProvider;
172
173 #[async_trait]
174 impl TTSSpeechProvider for DummyProvider {
175 async fn generate_speech(&self, request: SpeechRequest) -> TTSResult<SpeechResponse> {
176 Ok(SpeechResponse {
177 audio: AudioData {
178 samples: vec![0.0],
179 channels: 1,
180 sample_rate: request.sample_rate.unwrap_or(24000),
181 },
182 text: request.text,
183 duration_ms: 0,
184 })
185 }
186 }
187
188 #[async_trait]
189 impl TTSModelsProvider for DummyProvider {
190 fn get_current_model(&self) -> ModelInfo {
191 ModelInfo {
192 id: "dummy".to_string(),
193 name: "Dummy".to_string(),
194 description: None,
195 languages: vec!["en".to_string()],
196 }
197 }
198 }
199
200 impl TTSProvider for DummyProvider {}
201
202 #[tokio::test]
203 async fn test_default_streaming_not_supported() {
204 let provider = DummyProvider;
205 let request = SpeechRequest {
206 text: "hello".to_string(),
207 voice: VoiceIdentifier::new("test"),
208 format: AudioFormat::Wav,
209 sample_rate: None,
210 };
211
212 let err = match provider.generate_speech_stream(request).await {
213 Ok(_) => panic!("expected streaming not supported"),
214 Err(err) => err,
215 };
216 assert!(matches!(
217 err,
218 crate::error::TTSError::StreamingNotSupported(_)
219 ));
220 assert!(!provider.supports_streaming());
221 }
222
223 #[test]
224 fn test_default_provider_formats_and_languages() {
225 let provider = DummyProvider;
226 assert_eq!(provider.supported_formats(), vec![AudioFormat::Wav]);
227 assert_eq!(provider.default_sample_rate(), 24000);
228 assert_eq!(provider.supported_languages(), vec!["en".to_string()]);
229 }
230
231 #[derive(Debug)]
232 struct DummySTTProvider;
233
234 #[async_trait]
235 impl STTSpeechProvider for DummySTTProvider {
236 async fn transcribe(
237 &self,
238 request: TranscriptionRequest,
239 ) -> STTResult<TranscriptionResponse> {
240 Ok(TranscriptionResponse {
241 text: format!("Transcribed {} samples", request.audio.samples.len()),
242 timestamps: None,
243 duration_ms: 0,
244 })
245 }
246 }
247
248 #[async_trait]
249 impl STTModelsProvider for DummySTTProvider {
250 fn get_current_model(&self) -> ModelInfo {
251 ModelInfo {
252 id: "dummy".to_string(),
253 name: "Dummy STT".to_string(),
254 description: None,
255 languages: vec!["en".to_string()],
256 }
257 }
258 }
259
260 impl STTProvider for DummySTTProvider {}
261
262 #[tokio::test]
263 async fn test_stt_default_streaming_not_supported() {
264 let provider = DummySTTProvider;
265 let request = TranscriptionRequest {
266 audio: SharedAudioData::new(AudioData {
267 samples: vec![0.0; 16000],
268 sample_rate: 16000,
269 channels: 1,
270 }),
271 language: None,
272 include_timestamps: false,
273 };
274
275 let err = match provider.transcribe_stream(request).await {
276 Ok(_) => panic!("expected streaming not supported"),
277 Err(err) => err,
278 };
279 assert!(matches!(
280 err,
281 crate::error::STTError::StreamingNotSupported(_)
282 ));
283 assert!(!provider.supports_streaming());
284 }
285
286 #[test]
287 fn test_stt_default_provider_settings() {
288 let provider = DummySTTProvider;
289 assert_eq!(provider.supported_sample_rate(), 16000);
290 assert_eq!(provider.supported_channels(), 1);
291 assert_eq!(provider.supported_languages(), vec!["en".to_string()]);
292 assert!(!provider.supports_timestamps());
293 }
294}