1use anyhow::{Context, Result};
7use async_openai::{
8 config::OpenAIConfig,
9 types::{
10 AudioResponseFormat, CreateSpeechRequest, CreateTranscriptionRequestArgs, SpeechModel,
11 SpeechResponseFormat, Voice,
12 },
13 Client,
14};
15use serde::{Deserialize, Serialize};
16use std::sync::Arc;
17use tokio::sync::Mutex;
18use tracing::{debug, info, warn};
19
20pub struct VoiceInterface {
22 config: VoiceConfig,
23 stt_provider: Arc<Mutex<dyn SpeechToTextProvider>>,
24 tts_provider: Arc<Mutex<dyn TextToSpeechProvider>>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct VoiceConfig {
30 pub enable_stt: bool,
32 pub enable_tts: bool,
34 pub stt_provider: SttProviderType,
36 pub tts_provider: TtsProviderType,
38 pub sample_rate: u32,
40 pub channels: u16,
42 pub max_duration_secs: u64,
44 pub language: String,
46 pub voice: String,
48}
49
50impl Default for VoiceConfig {
51 fn default() -> Self {
52 Self {
53 enable_stt: true,
54 enable_tts: true,
55 stt_provider: SttProviderType::OpenAI,
56 tts_provider: TtsProviderType::OpenAI,
57 sample_rate: 16000,
58 channels: 1,
59 max_duration_secs: 300, language: "en-US".to_string(),
61 voice: "alloy".to_string(),
62 }
63 }
64}
65
66#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
68#[serde(rename_all = "snake_case")]
69pub enum SttProviderType {
70 OpenAI,
72 Google,
74 Azure,
76 LocalWhisper,
78}
79
80#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
82#[serde(rename_all = "snake_case")]
83pub enum TtsProviderType {
84 OpenAI,
86 Google,
88 Azure,
90 LocalEngine,
92}
93
94#[async_trait::async_trait]
96pub trait SpeechToTextProvider: Send + Sync {
97 async fn transcribe(&self, audio_data: &[u8], _config: &VoiceConfig) -> Result<SttResult>;
99
100 async fn transcribe_stream(
102 &self,
103 audio_stream: tokio::sync::mpsc::Receiver<Vec<u8>>,
104 config: &VoiceConfig,
105 ) -> Result<tokio::sync::mpsc::Receiver<SttStreamResult>>;
106}
107
108#[async_trait::async_trait]
110pub trait TextToSpeechProvider: Send + Sync {
111 async fn synthesize(&self, text: &str, _config: &VoiceConfig) -> Result<TtsResult>;
113
114 async fn synthesize_stream(
116 &self,
117 _text: &str,
118 _config: &VoiceConfig,
119 ) -> Result<tokio::sync::mpsc::Receiver<Vec<u8>>>;
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct SttResult {
125 pub text: String,
127 pub confidence: f32,
129 pub language: Option<String>,
131 pub duration_ms: u64,
133 pub word_timestamps: Vec<WordTimestamp>,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct SttStreamResult {
140 pub text: String,
142 pub is_final: bool,
144 pub confidence: f32,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct WordTimestamp {
151 pub word: String,
152 pub start_ms: u64,
153 pub end_ms: u64,
154}
155
156#[derive(Debug, Clone)]
158pub struct TtsResult {
159 pub audio_data: Vec<u8>,
161 pub format: AudioFormat,
163 pub duration_ms: u64,
165}
166
167#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
169#[serde(rename_all = "snake_case")]
170pub enum AudioFormat {
171 Wav,
172 Mp3,
173 Opus,
174 Pcm,
175}
176
177impl VoiceInterface {
178 pub fn new(config: VoiceConfig) -> Self {
180 let stt_provider: Arc<Mutex<dyn SpeechToTextProvider>> = match config.stt_provider {
181 SttProviderType::OpenAI => Arc::new(Mutex::new(OpenAISttProvider::new(config.clone()))),
182 SttProviderType::Google => Arc::new(Mutex::new(GoogleSttProvider::new(config.clone()))),
183 SttProviderType::Azure => Arc::new(Mutex::new(AzureSttProvider::new(config.clone()))),
184 SttProviderType::LocalWhisper => {
185 Arc::new(Mutex::new(LocalWhisperProvider::new(config.clone())))
186 }
187 };
188
189 let tts_provider: Arc<Mutex<dyn TextToSpeechProvider>> = match config.tts_provider {
190 TtsProviderType::OpenAI => Arc::new(Mutex::new(OpenAITtsProvider::new(config.clone()))),
191 TtsProviderType::Google => Arc::new(Mutex::new(GoogleTtsProvider::new(config.clone()))),
192 TtsProviderType::Azure => Arc::new(Mutex::new(AzureTtsProvider::new(config.clone()))),
193 TtsProviderType::LocalEngine => {
194 Arc::new(Mutex::new(LocalTtsEngine::new(config.clone())))
195 }
196 };
197
198 Self {
199 config,
200 stt_provider,
201 tts_provider,
202 }
203 }
204
205 pub async fn transcribe(&self, audio_data: &[u8]) -> Result<SttResult> {
207 if !self.config.enable_stt {
208 anyhow::bail!("Speech-to-text is disabled");
209 }
210
211 let provider = self.stt_provider.lock().await;
212 provider.transcribe(audio_data, &self.config).await
213 }
214
215 pub async fn synthesize(&self, text: &str) -> Result<TtsResult> {
217 if !self.config.enable_tts {
218 anyhow::bail!("Text-to-speech is disabled");
219 }
220
221 let provider = self.tts_provider.lock().await;
222 provider.synthesize(text, &self.config).await
223 }
224}
225
226struct OpenAISttProvider {
230 config: VoiceConfig,
231 client: Client<OpenAIConfig>,
232}
233
234impl OpenAISttProvider {
235 fn new(config: VoiceConfig) -> Self {
236 let client = Client::new();
237 Self { config, client }
238 }
239}
240
241#[async_trait::async_trait]
242impl SpeechToTextProvider for OpenAISttProvider {
243 async fn transcribe(&self, audio_data: &[u8], config: &VoiceConfig) -> Result<SttResult> {
244 info!(
245 "Transcribing audio with OpenAI Whisper (size: {} bytes)",
246 audio_data.len()
247 );
248
249 let start_time = std::time::Instant::now();
250
251 let request = CreateTranscriptionRequestArgs::default()
253 .file(async_openai::types::AudioInput {
254 source: async_openai::types::InputSource::Bytes {
255 filename: "audio.mp3".to_string(),
256 bytes: audio_data.to_vec().into(),
257 },
258 })
259 .model("whisper-1")
260 .language(&config.language[..2]) .response_format(AudioResponseFormat::VerboseJson)
262 .build()
263 .context("Failed to build transcription request")?;
264
265 let response = self
267 .client
268 .audio()
269 .transcribe(request)
270 .await
271 .context("Failed to transcribe audio with OpenAI Whisper")?;
272
273 let duration_ms = start_time.elapsed().as_millis() as u64;
274
275 debug!(
276 "OpenAI Whisper transcription completed: '{}' (duration: {}ms)",
277 response.text, duration_ms
278 );
279
280 let word_timestamps = vec![]; Ok(SttResult {
284 text: response.text,
285 confidence: 0.95, language: Some(config.language.clone()),
287 duration_ms,
288 word_timestamps,
289 })
290 }
291
292 async fn transcribe_stream(
293 &self,
294 mut audio_stream: tokio::sync::mpsc::Receiver<Vec<u8>>,
295 config: &VoiceConfig,
296 ) -> Result<tokio::sync::mpsc::Receiver<SttStreamResult>> {
297 let (tx, rx) = tokio::sync::mpsc::channel(100);
298
299 let client = self.client.clone();
300 let language = config.language.clone();
301
302 tokio::spawn(async move {
305 let mut accumulated_audio = Vec::new();
306
307 while let Some(audio_chunk) = audio_stream.recv().await {
308 accumulated_audio.extend_from_slice(&audio_chunk);
309
310 if accumulated_audio.len() >= 160_000 {
313 match CreateTranscriptionRequestArgs::default()
315 .file(async_openai::types::AudioInput {
316 source: async_openai::types::InputSource::Bytes {
317 filename: "audio_chunk.mp3".to_string(),
318 bytes: accumulated_audio.clone().into(),
319 },
320 })
321 .model("whisper-1")
322 .language(&language[..2])
323 .response_format(AudioResponseFormat::Json)
324 .build()
325 {
326 Ok(request) => {
327 if let Ok(response) = client.audio().transcribe(request).await {
328 let _ = tx
329 .send(SttStreamResult {
330 text: response.text,
331 is_final: false,
332 confidence: 0.95,
333 })
334 .await;
335 }
336 }
337 Err(e) => {
338 warn!("Failed to create transcription request: {}", e);
339 }
340 }
341
342 accumulated_audio.clear();
344 }
345 }
346
347 if !accumulated_audio.is_empty() {
349 if let Ok(request) = CreateTranscriptionRequestArgs::default()
350 .file(async_openai::types::AudioInput {
351 source: async_openai::types::InputSource::Bytes {
352 filename: "audio_final.mp3".to_string(),
353 bytes: accumulated_audio.into(),
354 },
355 })
356 .model("whisper-1")
357 .language(&language[..2])
358 .response_format(AudioResponseFormat::Json)
359 .build()
360 {
361 if let Ok(response) = client.audio().transcribe(request).await {
362 let _ = tx
363 .send(SttStreamResult {
364 text: response.text,
365 is_final: true,
366 confidence: 0.95,
367 })
368 .await;
369 }
370 }
371 }
372 });
373
374 Ok(rx)
375 }
376}
377
378struct GoogleSttProvider {
380 config: VoiceConfig,
381}
382
383impl GoogleSttProvider {
384 fn new(config: VoiceConfig) -> Self {
385 Self { config }
386 }
387}
388
389#[async_trait::async_trait]
390impl SpeechToTextProvider for GoogleSttProvider {
391 async fn transcribe(&self, _audio_data: &[u8], _config: &VoiceConfig) -> Result<SttResult> {
392 warn!("Google STT integration not yet implemented");
393 Ok(SttResult {
394 text: "[Google STT placeholder]".to_string(),
395 confidence: 0.90,
396 language: Some("en-US".to_string()),
397 duration_ms: 1000,
398 word_timestamps: vec![],
399 })
400 }
401
402 async fn transcribe_stream(
403 &self,
404 _audio_stream: tokio::sync::mpsc::Receiver<Vec<u8>>,
405 _config: &VoiceConfig,
406 ) -> Result<tokio::sync::mpsc::Receiver<SttStreamResult>> {
407 let (_tx, rx) = tokio::sync::mpsc::channel(100);
408 Ok(rx)
409 }
410}
411
412struct AzureSttProvider {
414 config: VoiceConfig,
415}
416
417impl AzureSttProvider {
418 fn new(config: VoiceConfig) -> Self {
419 Self { config }
420 }
421}
422
423#[async_trait::async_trait]
424impl SpeechToTextProvider for AzureSttProvider {
425 async fn transcribe(&self, _audio_data: &[u8], _config: &VoiceConfig) -> Result<SttResult> {
426 warn!("Azure STT integration not yet implemented");
427 Ok(SttResult {
428 text: "[Azure STT placeholder]".to_string(),
429 confidence: 0.92,
430 language: Some("en-US".to_string()),
431 duration_ms: 1000,
432 word_timestamps: vec![],
433 })
434 }
435
436 async fn transcribe_stream(
437 &self,
438 _audio_stream: tokio::sync::mpsc::Receiver<Vec<u8>>,
439 _config: &VoiceConfig,
440 ) -> Result<tokio::sync::mpsc::Receiver<SttStreamResult>> {
441 let (_tx, rx) = tokio::sync::mpsc::channel(100);
442 Ok(rx)
443 }
444}
445
446struct LocalWhisperProvider {
448 config: VoiceConfig,
449}
450
451impl LocalWhisperProvider {
452 fn new(config: VoiceConfig) -> Self {
453 Self { config }
454 }
455}
456
457#[async_trait::async_trait]
458impl SpeechToTextProvider for LocalWhisperProvider {
459 async fn transcribe(&self, _audio_data: &[u8], _config: &VoiceConfig) -> Result<SttResult> {
460 warn!("Local Whisper integration not yet implemented");
461 Ok(SttResult {
462 text: "[Local Whisper placeholder]".to_string(),
463 confidence: 0.88,
464 language: Some("en-US".to_string()),
465 duration_ms: 1000,
466 word_timestamps: vec![],
467 })
468 }
469
470 async fn transcribe_stream(
471 &self,
472 _audio_stream: tokio::sync::mpsc::Receiver<Vec<u8>>,
473 _config: &VoiceConfig,
474 ) -> Result<tokio::sync::mpsc::Receiver<SttStreamResult>> {
475 let (_tx, rx) = tokio::sync::mpsc::channel(100);
476 Ok(rx)
477 }
478}
479
480struct OpenAITtsProvider {
482 config: VoiceConfig,
483 client: Client<OpenAIConfig>,
484}
485
486impl OpenAITtsProvider {
487 fn new(config: VoiceConfig) -> Self {
488 let client = Client::new();
489 Self { config, client }
490 }
491}
492
493#[async_trait::async_trait]
494impl TextToSpeechProvider for OpenAITtsProvider {
495 async fn synthesize(&self, text: &str, config: &VoiceConfig) -> Result<TtsResult> {
496 info!(
497 "Synthesizing speech with OpenAI TTS (text length: {} chars)",
498 text.len()
499 );
500
501 let start_time = std::time::Instant::now();
502
503 let voice = match config.voice.as_str() {
505 "alloy" => Voice::Alloy,
506 "echo" => Voice::Echo,
507 "fable" => Voice::Fable,
508 "onyx" => Voice::Onyx,
509 "nova" => Voice::Nova,
510 "shimmer" => Voice::Shimmer,
511 _ => Voice::Alloy, };
513
514 let request = CreateSpeechRequest {
516 model: SpeechModel::Tts1,
517 input: text.to_string(),
518 voice,
519 response_format: Some(SpeechResponseFormat::Mp3),
520 speed: Some(1.0),
521 };
522
523 let response = self
525 .client
526 .audio()
527 .speech(request)
528 .await
529 .context("Failed to synthesize speech with OpenAI TTS")?;
530
531 let duration_ms = start_time.elapsed().as_millis() as u64;
532
533 let audio_data = response.bytes.to_vec();
535
536 debug!(
537 "OpenAI TTS synthesis completed: {} bytes (duration: {}ms)",
538 audio_data.len(),
539 duration_ms
540 );
541
542 Ok(TtsResult {
543 audio_data,
544 format: AudioFormat::Mp3,
545 duration_ms,
546 })
547 }
548
549 async fn synthesize_stream(
550 &self,
551 text: &str,
552 config: &VoiceConfig,
553 ) -> Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
554 let (tx, rx) = tokio::sync::mpsc::channel(100);
555
556 let client = self.client.clone();
557 let text = text.to_string();
558 let voice_str = config.voice.clone();
559
560 tokio::spawn(async move {
562 let voice = match voice_str.as_str() {
564 "alloy" => Voice::Alloy,
565 "echo" => Voice::Echo,
566 "fable" => Voice::Fable,
567 "onyx" => Voice::Onyx,
568 "nova" => Voice::Nova,
569 "shimmer" => Voice::Shimmer,
570 _ => Voice::Alloy,
571 };
572
573 let sentences: Vec<&str> = text
575 .split(['.', '!', '?'])
576 .filter(|s| !s.trim().is_empty())
577 .collect();
578
579 for sentence in sentences {
580 let request = CreateSpeechRequest {
581 model: SpeechModel::Tts1,
582 input: sentence.trim().to_string(),
583 voice: voice.clone(),
584 response_format: Some(SpeechResponseFormat::Mp3),
585 speed: Some(1.0),
586 };
587
588 match client.audio().speech(request).await {
589 Ok(response) => {
590 let audio_chunk = response.bytes.to_vec();
591 if tx.send(audio_chunk).await.is_err() {
592 break; }
594 }
595 Err(e) => {
596 warn!("Failed to synthesize sentence in streaming mode: {}", e);
597 break;
598 }
599 }
600 }
601 });
602
603 Ok(rx)
604 }
605}
606
607struct GoogleTtsProvider {
609 config: VoiceConfig,
610}
611
612impl GoogleTtsProvider {
613 fn new(config: VoiceConfig) -> Self {
614 Self { config }
615 }
616}
617
618#[async_trait::async_trait]
619impl TextToSpeechProvider for GoogleTtsProvider {
620 async fn synthesize(&self, text: &str, _config: &VoiceConfig) -> Result<TtsResult> {
621 warn!("Google TTS integration not yet implemented");
622 Ok(TtsResult {
623 audio_data: vec![],
624 format: AudioFormat::Mp3,
625 duration_ms: (text.len() as u64) * 100,
626 })
627 }
628
629 async fn synthesize_stream(
630 &self,
631 _text: &str,
632 _config: &VoiceConfig,
633 ) -> Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
634 let (_tx, rx) = tokio::sync::mpsc::channel(100);
635 Ok(rx)
636 }
637}
638
639struct AzureTtsProvider {
641 config: VoiceConfig,
642}
643
644impl AzureTtsProvider {
645 fn new(config: VoiceConfig) -> Self {
646 Self { config }
647 }
648}
649
650#[async_trait::async_trait]
651impl TextToSpeechProvider for AzureTtsProvider {
652 async fn synthesize(&self, text: &str, _config: &VoiceConfig) -> Result<TtsResult> {
653 warn!("Azure TTS integration not yet implemented");
654 Ok(TtsResult {
655 audio_data: vec![],
656 format: AudioFormat::Wav,
657 duration_ms: (text.len() as u64) * 100,
658 })
659 }
660
661 async fn synthesize_stream(
662 &self,
663 _text: &str,
664 _config: &VoiceConfig,
665 ) -> Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
666 let (_tx, rx) = tokio::sync::mpsc::channel(100);
667 Ok(rx)
668 }
669}
670
671struct LocalTtsEngine {
673 config: VoiceConfig,
674}
675
676impl LocalTtsEngine {
677 fn new(config: VoiceConfig) -> Self {
678 Self { config }
679 }
680}
681
682#[async_trait::async_trait]
683impl TextToSpeechProvider for LocalTtsEngine {
684 async fn synthesize(&self, text: &str, _config: &VoiceConfig) -> Result<TtsResult> {
685 warn!("Local TTS engine not yet implemented");
686 Ok(TtsResult {
687 audio_data: vec![],
688 format: AudioFormat::Wav,
689 duration_ms: (text.len() as u64) * 100,
690 })
691 }
692
693 async fn synthesize_stream(
694 &self,
695 _text: &str,
696 _config: &VoiceConfig,
697 ) -> Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
698 let (_tx, rx) = tokio::sync::mpsc::channel(100);
699 Ok(rx)
700 }
701}
702
703#[cfg(test)]
704mod tests {
705 use super::*;
706
707 #[tokio::test]
708 async fn test_voice_interface_creation() {
709 let config = VoiceConfig::default();
710 let interface = VoiceInterface::new(config);
711 assert!(interface.config.enable_stt);
712 assert!(interface.config.enable_tts);
713 }
714
715 #[tokio::test]
716 async fn test_transcribe_disabled() {
717 let config = VoiceConfig {
718 enable_stt: false,
719 ..Default::default()
720 };
721
722 let interface = VoiceInterface::new(config);
723
724 let audio_data = vec![0u8; 1000];
725 let result = interface.transcribe(&audio_data).await;
726 assert!(result.is_err());
727 assert!(result.unwrap_err().to_string().contains("disabled"));
728 }
729
730 #[tokio::test]
731 async fn test_synthesize_disabled() {
732 let config = VoiceConfig {
733 enable_tts: false,
734 ..Default::default()
735 };
736
737 let interface = VoiceInterface::new(config);
738
739 let text = "Hello, world!";
740 let result = interface.synthesize(text).await;
741 assert!(result.is_err());
742 assert!(result.unwrap_err().to_string().contains("disabled"));
743 }
744
745 #[test]
746 fn test_voice_config_custom() {
747 let config = VoiceConfig {
748 enable_stt: true,
749 enable_tts: false,
750 stt_provider: SttProviderType::Google,
751 tts_provider: TtsProviderType::Azure,
752 sample_rate: 48000,
753 channels: 2,
754 max_duration_secs: 600,
755 language: "ja-JP".to_string(),
756 voice: "echo".to_string(),
757 };
758
759 assert_eq!(config.sample_rate, 48000);
760 assert_eq!(config.channels, 2);
761 assert_eq!(config.language, "ja-JP");
762 assert_eq!(config.voice, "echo");
763 assert!(!config.enable_tts);
764 }
765
766 #[test]
767 fn test_audio_format_variants() {
768 assert!(matches!(AudioFormat::Wav, AudioFormat::Wav));
769 assert!(matches!(AudioFormat::Mp3, AudioFormat::Mp3));
770 }
771
772 #[test]
773 fn test_stt_provider_serialization() {
774 let provider = SttProviderType::OpenAI;
775 let serialized = serde_json::to_string(&provider).unwrap();
776 assert_eq!(serialized, "\"open_a_i\""); let deserialized: SttProviderType = serde_json::from_str(&serialized).unwrap();
779 assert_eq!(deserialized, SttProviderType::OpenAI);
780 }
781
782 #[test]
783 fn test_tts_provider_serialization() {
784 let provider = TtsProviderType::Google;
785 let serialized = serde_json::to_string(&provider).unwrap();
786 assert_eq!(serialized, "\"google\"");
787
788 let deserialized: TtsProviderType = serde_json::from_str(&serialized).unwrap();
789 assert_eq!(deserialized, TtsProviderType::Google);
790 }
791}