1use anyhow::{Context, Result};
7use async_openai::{
8 config::OpenAIConfig,
9 types::audio::{
10 AudioResponseFormat, CreateSpeechRequest, CreateTranscriptionRequestArgs, SpeechModel,
11 SpeechResponseFormat, Voice,
12 },
13 Client,
14};
15use serde::{Deserialize, Serialize};
16use std::sync::Arc;
17use tokio::sync::Mutex;
18use tracing::{debug, info, warn};
19
20pub struct VoiceInterface {
22 config: VoiceConfig,
23 stt_provider: Arc<Mutex<dyn SpeechToTextProvider>>,
24 tts_provider: Arc<Mutex<dyn TextToSpeechProvider>>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct VoiceConfig {
30 pub enable_stt: bool,
32 pub enable_tts: bool,
34 pub stt_provider: SttProviderType,
36 pub tts_provider: TtsProviderType,
38 pub sample_rate: u32,
40 pub channels: u16,
42 pub max_duration_secs: u64,
44 pub language: String,
46 pub voice: String,
48}
49
50impl Default for VoiceConfig {
51 fn default() -> Self {
52 Self {
53 enable_stt: true,
54 enable_tts: true,
55 stt_provider: SttProviderType::OpenAI,
56 tts_provider: TtsProviderType::OpenAI,
57 sample_rate: 16000,
58 channels: 1,
59 max_duration_secs: 300, language: "en-US".to_string(),
61 voice: "alloy".to_string(),
62 }
63 }
64}
65
66#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
68#[serde(rename_all = "snake_case")]
69pub enum SttProviderType {
70 OpenAI,
72 Google,
74 Azure,
76 LocalWhisper,
78}
79
80#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
82#[serde(rename_all = "snake_case")]
83pub enum TtsProviderType {
84 OpenAI,
86 Google,
88 Azure,
90 LocalEngine,
92}
93
94#[async_trait::async_trait]
96pub trait SpeechToTextProvider: Send + Sync {
97 async fn transcribe(&self, audio_data: &[u8], _config: &VoiceConfig) -> Result<SttResult>;
99
100 async fn transcribe_stream(
102 &self,
103 audio_stream: tokio::sync::mpsc::Receiver<Vec<u8>>,
104 config: &VoiceConfig,
105 ) -> Result<tokio::sync::mpsc::Receiver<SttStreamResult>>;
106}
107
108#[async_trait::async_trait]
110pub trait TextToSpeechProvider: Send + Sync {
111 async fn synthesize(&self, text: &str, _config: &VoiceConfig) -> Result<TtsResult>;
113
114 async fn synthesize_stream(
116 &self,
117 _text: &str,
118 _config: &VoiceConfig,
119 ) -> Result<tokio::sync::mpsc::Receiver<Vec<u8>>>;
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct SttResult {
125 pub text: String,
127 pub confidence: f32,
129 pub language: Option<String>,
131 pub duration_ms: u64,
133 pub word_timestamps: Vec<WordTimestamp>,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct SttStreamResult {
140 pub text: String,
142 pub is_final: bool,
144 pub confidence: f32,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct WordTimestamp {
151 pub word: String,
152 pub start_ms: u64,
153 pub end_ms: u64,
154}
155
156#[derive(Debug, Clone)]
158pub struct TtsResult {
159 pub audio_data: Vec<u8>,
161 pub format: AudioFormat,
163 pub duration_ms: u64,
165}
166
167#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
169#[serde(rename_all = "snake_case")]
170pub enum AudioFormat {
171 Wav,
172 Mp3,
173 Opus,
174 Pcm,
175}
176
177impl VoiceInterface {
178 pub fn new(config: VoiceConfig) -> Self {
180 let stt_provider: Arc<Mutex<dyn SpeechToTextProvider>> = match config.stt_provider {
181 SttProviderType::OpenAI => Arc::new(Mutex::new(OpenAISttProvider::new(config.clone()))),
182 SttProviderType::Google => Arc::new(Mutex::new(GoogleSttProvider::new(config.clone()))),
183 SttProviderType::Azure => Arc::new(Mutex::new(AzureSttProvider::new(config.clone()))),
184 SttProviderType::LocalWhisper => {
185 Arc::new(Mutex::new(LocalWhisperProvider::new(config.clone())))
186 }
187 };
188
189 let tts_provider: Arc<Mutex<dyn TextToSpeechProvider>> = match config.tts_provider {
190 TtsProviderType::OpenAI => Arc::new(Mutex::new(OpenAITtsProvider::new(config.clone()))),
191 TtsProviderType::Google => Arc::new(Mutex::new(GoogleTtsProvider::new(config.clone()))),
192 TtsProviderType::Azure => Arc::new(Mutex::new(AzureTtsProvider::new(config.clone()))),
193 TtsProviderType::LocalEngine => {
194 Arc::new(Mutex::new(LocalTtsEngine::new(config.clone())))
195 }
196 };
197
198 Self {
199 config,
200 stt_provider,
201 tts_provider,
202 }
203 }
204
205 pub async fn transcribe(&self, audio_data: &[u8]) -> Result<SttResult> {
207 if !self.config.enable_stt {
208 anyhow::bail!("Speech-to-text is disabled");
209 }
210
211 let provider = self.stt_provider.lock().await;
212 provider.transcribe(audio_data, &self.config).await
213 }
214
215 pub async fn synthesize(&self, text: &str) -> Result<TtsResult> {
217 if !self.config.enable_tts {
218 anyhow::bail!("Text-to-speech is disabled");
219 }
220
221 let provider = self.tts_provider.lock().await;
222 provider.synthesize(text, &self.config).await
223 }
224}
225
226struct OpenAISttProvider {
230 config: VoiceConfig,
231 client: Client<OpenAIConfig>,
232}
233
234impl OpenAISttProvider {
235 fn new(config: VoiceConfig) -> Self {
236 let client = Client::new();
237 Self { config, client }
238 }
239}
240
241#[async_trait::async_trait]
242impl SpeechToTextProvider for OpenAISttProvider {
243 async fn transcribe(&self, audio_data: &[u8], config: &VoiceConfig) -> Result<SttResult> {
244 info!(
245 "Transcribing audio with OpenAI Whisper (size: {} bytes)",
246 audio_data.len()
247 );
248
249 let start_time = std::time::Instant::now();
250
251 let request = CreateTranscriptionRequestArgs::default()
253 .file(async_openai::types::audio::AudioInput {
254 source: async_openai::types::InputSource::Bytes {
255 filename: "audio.mp3".to_string(),
256 bytes: audio_data.to_vec().into(),
257 },
258 })
259 .model("whisper-1")
260 .language(&config.language[..2]) .response_format(AudioResponseFormat::VerboseJson)
262 .build()
263 .context("Failed to build transcription request")?;
264
265 let response = self
267 .client
268 .audio()
269 .transcription()
270 .create(request)
271 .await
272 .context("Failed to transcribe audio with OpenAI Whisper")?;
273
274 let duration_ms = start_time.elapsed().as_millis() as u64;
275
276 debug!(
277 "OpenAI Whisper transcription completed: '{}' (duration: {}ms)",
278 response.text, duration_ms
279 );
280
281 let word_timestamps = vec![]; Ok(SttResult {
285 text: response.text,
286 confidence: 0.95, language: Some(config.language.clone()),
288 duration_ms,
289 word_timestamps,
290 })
291 }
292
293 async fn transcribe_stream(
294 &self,
295 mut audio_stream: tokio::sync::mpsc::Receiver<Vec<u8>>,
296 config: &VoiceConfig,
297 ) -> Result<tokio::sync::mpsc::Receiver<SttStreamResult>> {
298 let (tx, rx) = tokio::sync::mpsc::channel(100);
299
300 let client = self.client.clone();
301 let language = config.language.clone();
302
303 tokio::spawn(async move {
306 let mut accumulated_audio = Vec::new();
307
308 while let Some(audio_chunk) = audio_stream.recv().await {
309 accumulated_audio.extend_from_slice(&audio_chunk);
310
311 if accumulated_audio.len() >= 160_000 {
314 match CreateTranscriptionRequestArgs::default()
316 .file(async_openai::types::audio::AudioInput {
317 source: async_openai::types::InputSource::Bytes {
318 filename: "audio_chunk.mp3".to_string(),
319 bytes: accumulated_audio.clone().into(),
320 },
321 })
322 .model("whisper-1")
323 .language(&language[..2])
324 .response_format(AudioResponseFormat::Json)
325 .build()
326 {
327 Ok(request) => {
328 if let Ok(response) =
329 client.audio().transcription().create(request).await
330 {
331 let _ = tx
332 .send(SttStreamResult {
333 text: response.text,
334 is_final: false,
335 confidence: 0.95,
336 })
337 .await;
338 }
339 }
340 Err(e) => {
341 warn!("Failed to create transcription request: {}", e);
342 }
343 }
344
345 accumulated_audio.clear();
347 }
348 }
349
350 if !accumulated_audio.is_empty() {
352 if let Ok(request) = CreateTranscriptionRequestArgs::default()
353 .file(async_openai::types::audio::AudioInput {
354 source: async_openai::types::InputSource::Bytes {
355 filename: "audio_final.mp3".to_string(),
356 bytes: accumulated_audio.into(),
357 },
358 })
359 .model("whisper-1")
360 .language(&language[..2])
361 .response_format(AudioResponseFormat::Json)
362 .build()
363 {
364 if let Ok(response) = client.audio().transcription().create(request).await {
365 let _ = tx
366 .send(SttStreamResult {
367 text: response.text,
368 is_final: true,
369 confidence: 0.95,
370 })
371 .await;
372 }
373 }
374 }
375 });
376
377 Ok(rx)
378 }
379}
380
381struct GoogleSttProvider {
383 config: VoiceConfig,
384}
385
386impl GoogleSttProvider {
387 fn new(config: VoiceConfig) -> Self {
388 Self { config }
389 }
390}
391
392#[async_trait::async_trait]
393impl SpeechToTextProvider for GoogleSttProvider {
394 async fn transcribe(&self, _audio_data: &[u8], _config: &VoiceConfig) -> Result<SttResult> {
395 warn!("Google STT integration not yet implemented");
396 Ok(SttResult {
397 text: "[Google STT placeholder]".to_string(),
398 confidence: 0.90,
399 language: Some("en-US".to_string()),
400 duration_ms: 1000,
401 word_timestamps: vec![],
402 })
403 }
404
405 async fn transcribe_stream(
406 &self,
407 _audio_stream: tokio::sync::mpsc::Receiver<Vec<u8>>,
408 _config: &VoiceConfig,
409 ) -> Result<tokio::sync::mpsc::Receiver<SttStreamResult>> {
410 let (_tx, rx) = tokio::sync::mpsc::channel(100);
411 Ok(rx)
412 }
413}
414
415struct AzureSttProvider {
417 config: VoiceConfig,
418}
419
420impl AzureSttProvider {
421 fn new(config: VoiceConfig) -> Self {
422 Self { config }
423 }
424}
425
426#[async_trait::async_trait]
427impl SpeechToTextProvider for AzureSttProvider {
428 async fn transcribe(&self, _audio_data: &[u8], _config: &VoiceConfig) -> Result<SttResult> {
429 warn!("Azure STT integration not yet implemented");
430 Ok(SttResult {
431 text: "[Azure STT placeholder]".to_string(),
432 confidence: 0.92,
433 language: Some("en-US".to_string()),
434 duration_ms: 1000,
435 word_timestamps: vec![],
436 })
437 }
438
439 async fn transcribe_stream(
440 &self,
441 _audio_stream: tokio::sync::mpsc::Receiver<Vec<u8>>,
442 _config: &VoiceConfig,
443 ) -> Result<tokio::sync::mpsc::Receiver<SttStreamResult>> {
444 let (_tx, rx) = tokio::sync::mpsc::channel(100);
445 Ok(rx)
446 }
447}
448
449struct LocalWhisperProvider {
451 config: VoiceConfig,
452}
453
454impl LocalWhisperProvider {
455 fn new(config: VoiceConfig) -> Self {
456 Self { config }
457 }
458}
459
460#[async_trait::async_trait]
461impl SpeechToTextProvider for LocalWhisperProvider {
462 async fn transcribe(&self, _audio_data: &[u8], _config: &VoiceConfig) -> Result<SttResult> {
463 warn!("Local Whisper integration not yet implemented");
464 Ok(SttResult {
465 text: "[Local Whisper placeholder]".to_string(),
466 confidence: 0.88,
467 language: Some("en-US".to_string()),
468 duration_ms: 1000,
469 word_timestamps: vec![],
470 })
471 }
472
473 async fn transcribe_stream(
474 &self,
475 _audio_stream: tokio::sync::mpsc::Receiver<Vec<u8>>,
476 _config: &VoiceConfig,
477 ) -> Result<tokio::sync::mpsc::Receiver<SttStreamResult>> {
478 let (_tx, rx) = tokio::sync::mpsc::channel(100);
479 Ok(rx)
480 }
481}
482
483struct OpenAITtsProvider {
485 config: VoiceConfig,
486 client: Client<OpenAIConfig>,
487}
488
489impl OpenAITtsProvider {
490 fn new(config: VoiceConfig) -> Self {
491 let client = Client::new();
492 Self { config, client }
493 }
494}
495
496#[async_trait::async_trait]
497impl TextToSpeechProvider for OpenAITtsProvider {
498 async fn synthesize(&self, text: &str, config: &VoiceConfig) -> Result<TtsResult> {
499 info!(
500 "Synthesizing speech with OpenAI TTS (text length: {} chars)",
501 text.len()
502 );
503
504 let start_time = std::time::Instant::now();
505
506 let voice = match config.voice.as_str() {
508 "alloy" => Voice::Alloy,
509 "echo" => Voice::Echo,
510 "fable" => Voice::Fable,
511 "onyx" => Voice::Onyx,
512 "nova" => Voice::Nova,
513 "shimmer" => Voice::Shimmer,
514 _ => Voice::Alloy, };
516
517 let request = CreateSpeechRequest {
519 model: SpeechModel::Tts1,
520 input: text.to_string(),
521 voice,
522 instructions: None,
523 response_format: Some(SpeechResponseFormat::Mp3),
524 speed: Some(1.0),
525 stream_format: None,
526 };
527
528 let response = self
530 .client
531 .audio()
532 .speech()
533 .create(request)
534 .await
535 .context("Failed to synthesize speech with OpenAI TTS")?;
536
537 let duration_ms = start_time.elapsed().as_millis() as u64;
538
539 let audio_data = response.bytes.to_vec();
541
542 debug!(
543 "OpenAI TTS synthesis completed: {} bytes (duration: {}ms)",
544 audio_data.len(),
545 duration_ms
546 );
547
548 Ok(TtsResult {
549 audio_data,
550 format: AudioFormat::Mp3,
551 duration_ms,
552 })
553 }
554
555 async fn synthesize_stream(
556 &self,
557 text: &str,
558 config: &VoiceConfig,
559 ) -> Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
560 let (tx, rx) = tokio::sync::mpsc::channel(100);
561
562 let client = self.client.clone();
563 let text = text.to_string();
564 let voice_str = config.voice.clone();
565
566 tokio::spawn(async move {
568 let voice = match voice_str.as_str() {
570 "alloy" => Voice::Alloy,
571 "echo" => Voice::Echo,
572 "fable" => Voice::Fable,
573 "onyx" => Voice::Onyx,
574 "nova" => Voice::Nova,
575 "shimmer" => Voice::Shimmer,
576 _ => Voice::Alloy,
577 };
578
579 let sentences: Vec<&str> = text
581 .split(['.', '!', '?'])
582 .filter(|s| !s.trim().is_empty())
583 .collect();
584
585 for sentence in sentences {
586 let request = CreateSpeechRequest {
587 model: SpeechModel::Tts1,
588 input: sentence.trim().to_string(),
589 voice: voice.clone(),
590 instructions: None,
591 response_format: Some(SpeechResponseFormat::Mp3),
592 speed: Some(1.0),
593 stream_format: None,
594 };
595
596 match client.audio().speech().create(request).await {
597 Ok(response) => {
598 let audio_chunk = response.bytes.to_vec();
599 if tx.send(audio_chunk).await.is_err() {
600 break; }
602 }
603 Err(e) => {
604 warn!("Failed to synthesize sentence in streaming mode: {}", e);
605 break;
606 }
607 }
608 }
609 });
610
611 Ok(rx)
612 }
613}
614
615struct GoogleTtsProvider {
617 config: VoiceConfig,
618}
619
620impl GoogleTtsProvider {
621 fn new(config: VoiceConfig) -> Self {
622 Self { config }
623 }
624}
625
626#[async_trait::async_trait]
627impl TextToSpeechProvider for GoogleTtsProvider {
628 async fn synthesize(&self, text: &str, _config: &VoiceConfig) -> Result<TtsResult> {
629 warn!("Google TTS integration not yet implemented");
630 Ok(TtsResult {
631 audio_data: vec![],
632 format: AudioFormat::Mp3,
633 duration_ms: (text.len() as u64) * 100,
634 })
635 }
636
637 async fn synthesize_stream(
638 &self,
639 _text: &str,
640 _config: &VoiceConfig,
641 ) -> Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
642 let (_tx, rx) = tokio::sync::mpsc::channel(100);
643 Ok(rx)
644 }
645}
646
647struct AzureTtsProvider {
649 config: VoiceConfig,
650}
651
652impl AzureTtsProvider {
653 fn new(config: VoiceConfig) -> Self {
654 Self { config }
655 }
656}
657
658#[async_trait::async_trait]
659impl TextToSpeechProvider for AzureTtsProvider {
660 async fn synthesize(&self, text: &str, _config: &VoiceConfig) -> Result<TtsResult> {
661 warn!("Azure TTS integration not yet implemented");
662 Ok(TtsResult {
663 audio_data: vec![],
664 format: AudioFormat::Wav,
665 duration_ms: (text.len() as u64) * 100,
666 })
667 }
668
669 async fn synthesize_stream(
670 &self,
671 _text: &str,
672 _config: &VoiceConfig,
673 ) -> Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
674 let (_tx, rx) = tokio::sync::mpsc::channel(100);
675 Ok(rx)
676 }
677}
678
679struct LocalTtsEngine {
681 config: VoiceConfig,
682}
683
684impl LocalTtsEngine {
685 fn new(config: VoiceConfig) -> Self {
686 Self { config }
687 }
688}
689
690#[async_trait::async_trait]
691impl TextToSpeechProvider for LocalTtsEngine {
692 async fn synthesize(&self, text: &str, _config: &VoiceConfig) -> Result<TtsResult> {
693 warn!("Local TTS engine not yet implemented");
694 Ok(TtsResult {
695 audio_data: vec![],
696 format: AudioFormat::Wav,
697 duration_ms: (text.len() as u64) * 100,
698 })
699 }
700
701 async fn synthesize_stream(
702 &self,
703 _text: &str,
704 _config: &VoiceConfig,
705 ) -> Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
706 let (_tx, rx) = tokio::sync::mpsc::channel(100);
707 Ok(rx)
708 }
709}
710
711#[cfg(test)]
712mod tests {
713 use super::*;
714
715 #[tokio::test]
716 async fn test_voice_interface_creation() {
717 let config = VoiceConfig::default();
718 let interface = VoiceInterface::new(config);
719 assert!(interface.config.enable_stt);
720 assert!(interface.config.enable_tts);
721 }
722
723 #[tokio::test]
724 async fn test_transcribe_disabled() {
725 let config = VoiceConfig {
726 enable_stt: false,
727 ..Default::default()
728 };
729
730 let interface = VoiceInterface::new(config);
731
732 let audio_data = vec![0u8; 1000];
733 let result = interface.transcribe(&audio_data).await;
734 assert!(result.is_err());
735 assert!(result.unwrap_err().to_string().contains("disabled"));
736 }
737
738 #[tokio::test]
739 async fn test_synthesize_disabled() {
740 let config = VoiceConfig {
741 enable_tts: false,
742 ..Default::default()
743 };
744
745 let interface = VoiceInterface::new(config);
746
747 let text = "Hello, world!";
748 let result = interface.synthesize(text).await;
749 assert!(result.is_err());
750 assert!(result.unwrap_err().to_string().contains("disabled"));
751 }
752
753 #[test]
754 fn test_voice_config_custom() {
755 let config = VoiceConfig {
756 enable_stt: true,
757 enable_tts: false,
758 stt_provider: SttProviderType::Google,
759 tts_provider: TtsProviderType::Azure,
760 sample_rate: 48000,
761 channels: 2,
762 max_duration_secs: 600,
763 language: "ja-JP".to_string(),
764 voice: "echo".to_string(),
765 };
766
767 assert_eq!(config.sample_rate, 48000);
768 assert_eq!(config.channels, 2);
769 assert_eq!(config.language, "ja-JP");
770 assert_eq!(config.voice, "echo");
771 assert!(!config.enable_tts);
772 }
773
774 #[test]
775 fn test_audio_format_variants() {
776 assert!(matches!(AudioFormat::Wav, AudioFormat::Wav));
777 assert!(matches!(AudioFormat::Mp3, AudioFormat::Mp3));
778 }
779
780 #[test]
781 fn test_stt_provider_serialization() {
782 let provider = SttProviderType::OpenAI;
783 let serialized = serde_json::to_string(&provider).expect("should succeed");
784 assert_eq!(serialized, "\"open_a_i\""); let deserialized: SttProviderType =
787 serde_json::from_str(&serialized).expect("should succeed");
788 assert_eq!(deserialized, SttProviderType::OpenAI);
789 }
790
791 #[test]
792 fn test_tts_provider_serialization() {
793 let provider = TtsProviderType::Google;
794 let serialized = serde_json::to_string(&provider).expect("should succeed");
795 assert_eq!(serialized, "\"google\"");
796
797 let deserialized: TtsProviderType =
798 serde_json::from_str(&serialized).expect("should succeed");
799 assert_eq!(deserialized, TtsProviderType::Google);
800 }
801}