use anyhow::{Context, Result};
use async_openai::{
config::OpenAIConfig,
types::audio::{
AudioResponseFormat, CreateSpeechRequest, CreateTranscriptionRequestArgs, SpeechModel,
SpeechResponseFormat, Voice,
},
Client,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, info, warn};
pub struct VoiceInterface {
config: VoiceConfig,
stt_provider: Arc<Mutex<dyn SpeechToTextProvider>>,
tts_provider: Arc<Mutex<dyn TextToSpeechProvider>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VoiceConfig {
pub enable_stt: bool,
pub enable_tts: bool,
pub stt_provider: SttProviderType,
pub tts_provider: TtsProviderType,
pub sample_rate: u32,
pub channels: u16,
pub max_duration_secs: u64,
pub language: String,
pub voice: String,
}
impl Default for VoiceConfig {
fn default() -> Self {
Self {
enable_stt: true,
enable_tts: true,
stt_provider: SttProviderType::OpenAI,
tts_provider: TtsProviderType::OpenAI,
sample_rate: 16000,
channels: 1,
max_duration_secs: 300, language: "en-US".to_string(),
voice: "alloy".to_string(),
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SttProviderType {
OpenAI,
Google,
Azure,
LocalWhisper,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum TtsProviderType {
OpenAI,
Google,
Azure,
LocalEngine,
}
#[async_trait::async_trait]
pub trait SpeechToTextProvider: Send + Sync {
async fn transcribe(&self, audio_data: &[u8], _config: &VoiceConfig) -> Result<SttResult>;
async fn transcribe_stream(
&self,
audio_stream: tokio::sync::mpsc::Receiver<Vec<u8>>,
config: &VoiceConfig,
) -> Result<tokio::sync::mpsc::Receiver<SttStreamResult>>;
}
#[async_trait::async_trait]
pub trait TextToSpeechProvider: Send + Sync {
async fn synthesize(&self, text: &str, _config: &VoiceConfig) -> Result<TtsResult>;
async fn synthesize_stream(
&self,
_text: &str,
_config: &VoiceConfig,
) -> Result<tokio::sync::mpsc::Receiver<Vec<u8>>>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SttResult {
pub text: String,
pub confidence: f32,
pub language: Option<String>,
pub duration_ms: u64,
pub word_timestamps: Vec<WordTimestamp>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SttStreamResult {
pub text: String,
pub is_final: bool,
pub confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WordTimestamp {
pub word: String,
pub start_ms: u64,
pub end_ms: u64,
}
#[derive(Debug, Clone)]
pub struct TtsResult {
pub audio_data: Vec<u8>,
pub format: AudioFormat,
pub duration_ms: u64,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum AudioFormat {
Wav,
Mp3,
Opus,
Pcm,
}
impl VoiceInterface {
pub fn new(config: VoiceConfig) -> Self {
let stt_provider: Arc<Mutex<dyn SpeechToTextProvider>> = match config.stt_provider {
SttProviderType::OpenAI => Arc::new(Mutex::new(OpenAISttProvider::new(config.clone()))),
SttProviderType::Google => Arc::new(Mutex::new(GoogleSttProvider::new(config.clone()))),
SttProviderType::Azure => Arc::new(Mutex::new(AzureSttProvider::new(config.clone()))),
SttProviderType::LocalWhisper => {
Arc::new(Mutex::new(LocalWhisperProvider::new(config.clone())))
}
};
let tts_provider: Arc<Mutex<dyn TextToSpeechProvider>> = match config.tts_provider {
TtsProviderType::OpenAI => Arc::new(Mutex::new(OpenAITtsProvider::new(config.clone()))),
TtsProviderType::Google => Arc::new(Mutex::new(GoogleTtsProvider::new(config.clone()))),
TtsProviderType::Azure => Arc::new(Mutex::new(AzureTtsProvider::new(config.clone()))),
TtsProviderType::LocalEngine => {
Arc::new(Mutex::new(LocalTtsEngine::new(config.clone())))
}
};
Self {
config,
stt_provider,
tts_provider,
}
}
pub async fn transcribe(&self, audio_data: &[u8]) -> Result<SttResult> {
if !self.config.enable_stt {
anyhow::bail!("Speech-to-text is disabled");
}
let provider = self.stt_provider.lock().await;
provider.transcribe(audio_data, &self.config).await
}
pub async fn synthesize(&self, text: &str) -> Result<TtsResult> {
if !self.config.enable_tts {
anyhow::bail!("Text-to-speech is disabled");
}
let provider = self.tts_provider.lock().await;
provider.synthesize(text, &self.config).await
}
}
struct OpenAISttProvider {
config: VoiceConfig,
client: Client<OpenAIConfig>,
}
impl OpenAISttProvider {
fn new(config: VoiceConfig) -> Self {
let client = Client::new();
Self { config, client }
}
}
#[async_trait::async_trait]
impl SpeechToTextProvider for OpenAISttProvider {
async fn transcribe(&self, audio_data: &[u8], config: &VoiceConfig) -> Result<SttResult> {
info!(
"Transcribing audio with OpenAI Whisper (size: {} bytes)",
audio_data.len()
);
let start_time = std::time::Instant::now();
let request = CreateTranscriptionRequestArgs::default()
.file(async_openai::types::audio::AudioInput {
source: async_openai::types::InputSource::Bytes {
filename: "audio.mp3".to_string(),
bytes: audio_data.to_vec().into(),
},
})
.model("whisper-1")
.language(&config.language[..2]) .response_format(AudioResponseFormat::VerboseJson)
.build()
.context("Failed to build transcription request")?;
let response = self
.client
.audio()
.transcription()
.create(request)
.await
.context("Failed to transcribe audio with OpenAI Whisper")?;
let duration_ms = start_time.elapsed().as_millis() as u64;
debug!(
"OpenAI Whisper transcription completed: '{}' (duration: {}ms)",
response.text, duration_ms
);
let word_timestamps = vec![];
Ok(SttResult {
text: response.text,
confidence: 0.95, language: Some(config.language.clone()),
duration_ms,
word_timestamps,
})
}
async fn transcribe_stream(
&self,
mut audio_stream: tokio::sync::mpsc::Receiver<Vec<u8>>,
config: &VoiceConfig,
) -> Result<tokio::sync::mpsc::Receiver<SttStreamResult>> {
let (tx, rx) = tokio::sync::mpsc::channel(100);
let client = self.client.clone();
let language = config.language.clone();
tokio::spawn(async move {
let mut accumulated_audio = Vec::new();
while let Some(audio_chunk) = audio_stream.recv().await {
accumulated_audio.extend_from_slice(&audio_chunk);
if accumulated_audio.len() >= 160_000 {
match CreateTranscriptionRequestArgs::default()
.file(async_openai::types::audio::AudioInput {
source: async_openai::types::InputSource::Bytes {
filename: "audio_chunk.mp3".to_string(),
bytes: accumulated_audio.clone().into(),
},
})
.model("whisper-1")
.language(&language[..2])
.response_format(AudioResponseFormat::Json)
.build()
{
Ok(request) => {
if let Ok(response) =
client.audio().transcription().create(request).await
{
let _ = tx
.send(SttStreamResult {
text: response.text,
is_final: false,
confidence: 0.95,
})
.await;
}
}
Err(e) => {
warn!("Failed to create transcription request: {}", e);
}
}
accumulated_audio.clear();
}
}
if !accumulated_audio.is_empty() {
if let Ok(request) = CreateTranscriptionRequestArgs::default()
.file(async_openai::types::audio::AudioInput {
source: async_openai::types::InputSource::Bytes {
filename: "audio_final.mp3".to_string(),
bytes: accumulated_audio.into(),
},
})
.model("whisper-1")
.language(&language[..2])
.response_format(AudioResponseFormat::Json)
.build()
{
if let Ok(response) = client.audio().transcription().create(request).await {
let _ = tx
.send(SttStreamResult {
text: response.text,
is_final: true,
confidence: 0.95,
})
.await;
}
}
}
});
Ok(rx)
}
}
struct GoogleSttProvider {
config: VoiceConfig,
}
impl GoogleSttProvider {
fn new(config: VoiceConfig) -> Self {
Self { config }
}
}
#[async_trait::async_trait]
impl SpeechToTextProvider for GoogleSttProvider {
async fn transcribe(&self, _audio_data: &[u8], _config: &VoiceConfig) -> Result<SttResult> {
warn!("Google STT integration not yet implemented");
Ok(SttResult {
text: "[Google STT placeholder]".to_string(),
confidence: 0.90,
language: Some("en-US".to_string()),
duration_ms: 1000,
word_timestamps: vec![],
})
}
async fn transcribe_stream(
&self,
_audio_stream: tokio::sync::mpsc::Receiver<Vec<u8>>,
_config: &VoiceConfig,
) -> Result<tokio::sync::mpsc::Receiver<SttStreamResult>> {
let (_tx, rx) = tokio::sync::mpsc::channel(100);
Ok(rx)
}
}
struct AzureSttProvider {
config: VoiceConfig,
}
impl AzureSttProvider {
fn new(config: VoiceConfig) -> Self {
Self { config }
}
}
#[async_trait::async_trait]
impl SpeechToTextProvider for AzureSttProvider {
async fn transcribe(&self, _audio_data: &[u8], _config: &VoiceConfig) -> Result<SttResult> {
warn!("Azure STT integration not yet implemented");
Ok(SttResult {
text: "[Azure STT placeholder]".to_string(),
confidence: 0.92,
language: Some("en-US".to_string()),
duration_ms: 1000,
word_timestamps: vec![],
})
}
async fn transcribe_stream(
&self,
_audio_stream: tokio::sync::mpsc::Receiver<Vec<u8>>,
_config: &VoiceConfig,
) -> Result<tokio::sync::mpsc::Receiver<SttStreamResult>> {
let (_tx, rx) = tokio::sync::mpsc::channel(100);
Ok(rx)
}
}
struct LocalWhisperProvider {
config: VoiceConfig,
}
impl LocalWhisperProvider {
fn new(config: VoiceConfig) -> Self {
Self { config }
}
}
#[async_trait::async_trait]
impl SpeechToTextProvider for LocalWhisperProvider {
async fn transcribe(&self, _audio_data: &[u8], _config: &VoiceConfig) -> Result<SttResult> {
warn!("Local Whisper integration not yet implemented");
Ok(SttResult {
text: "[Local Whisper placeholder]".to_string(),
confidence: 0.88,
language: Some("en-US".to_string()),
duration_ms: 1000,
word_timestamps: vec![],
})
}
async fn transcribe_stream(
&self,
_audio_stream: tokio::sync::mpsc::Receiver<Vec<u8>>,
_config: &VoiceConfig,
) -> Result<tokio::sync::mpsc::Receiver<SttStreamResult>> {
let (_tx, rx) = tokio::sync::mpsc::channel(100);
Ok(rx)
}
}
struct OpenAITtsProvider {
config: VoiceConfig,
client: Client<OpenAIConfig>,
}
impl OpenAITtsProvider {
fn new(config: VoiceConfig) -> Self {
let client = Client::new();
Self { config, client }
}
}
#[async_trait::async_trait]
impl TextToSpeechProvider for OpenAITtsProvider {
async fn synthesize(&self, text: &str, config: &VoiceConfig) -> Result<TtsResult> {
info!(
"Synthesizing speech with OpenAI TTS (text length: {} chars)",
text.len()
);
let start_time = std::time::Instant::now();
let voice = match config.voice.as_str() {
"alloy" => Voice::Alloy,
"echo" => Voice::Echo,
"fable" => Voice::Fable,
"onyx" => Voice::Onyx,
"nova" => Voice::Nova,
"shimmer" => Voice::Shimmer,
_ => Voice::Alloy, };
let request = CreateSpeechRequest {
model: SpeechModel::Tts1,
input: text.to_string(),
voice,
instructions: None,
response_format: Some(SpeechResponseFormat::Mp3),
speed: Some(1.0),
stream_format: None,
};
let response = self
.client
.audio()
.speech()
.create(request)
.await
.context("Failed to synthesize speech with OpenAI TTS")?;
let duration_ms = start_time.elapsed().as_millis() as u64;
let audio_data = response.bytes.to_vec();
debug!(
"OpenAI TTS synthesis completed: {} bytes (duration: {}ms)",
audio_data.len(),
duration_ms
);
Ok(TtsResult {
audio_data,
format: AudioFormat::Mp3,
duration_ms,
})
}
async fn synthesize_stream(
&self,
text: &str,
config: &VoiceConfig,
) -> Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
let (tx, rx) = tokio::sync::mpsc::channel(100);
let client = self.client.clone();
let text = text.to_string();
let voice_str = config.voice.clone();
tokio::spawn(async move {
let voice = match voice_str.as_str() {
"alloy" => Voice::Alloy,
"echo" => Voice::Echo,
"fable" => Voice::Fable,
"onyx" => Voice::Onyx,
"nova" => Voice::Nova,
"shimmer" => Voice::Shimmer,
_ => Voice::Alloy,
};
let sentences: Vec<&str> = text
.split(['.', '!', '?'])
.filter(|s| !s.trim().is_empty())
.collect();
for sentence in sentences {
let request = CreateSpeechRequest {
model: SpeechModel::Tts1,
input: sentence.trim().to_string(),
voice: voice.clone(),
instructions: None,
response_format: Some(SpeechResponseFormat::Mp3),
speed: Some(1.0),
stream_format: None,
};
match client.audio().speech().create(request).await {
Ok(response) => {
let audio_chunk = response.bytes.to_vec();
if tx.send(audio_chunk).await.is_err() {
break; }
}
Err(e) => {
warn!("Failed to synthesize sentence in streaming mode: {}", e);
break;
}
}
}
});
Ok(rx)
}
}
struct GoogleTtsProvider {
config: VoiceConfig,
}
impl GoogleTtsProvider {
fn new(config: VoiceConfig) -> Self {
Self { config }
}
}
#[async_trait::async_trait]
impl TextToSpeechProvider for GoogleTtsProvider {
async fn synthesize(&self, text: &str, _config: &VoiceConfig) -> Result<TtsResult> {
warn!("Google TTS integration not yet implemented");
Ok(TtsResult {
audio_data: vec![],
format: AudioFormat::Mp3,
duration_ms: (text.len() as u64) * 100,
})
}
async fn synthesize_stream(
&self,
_text: &str,
_config: &VoiceConfig,
) -> Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
let (_tx, rx) = tokio::sync::mpsc::channel(100);
Ok(rx)
}
}
struct AzureTtsProvider {
config: VoiceConfig,
}
impl AzureTtsProvider {
fn new(config: VoiceConfig) -> Self {
Self { config }
}
}
#[async_trait::async_trait]
impl TextToSpeechProvider for AzureTtsProvider {
async fn synthesize(&self, text: &str, _config: &VoiceConfig) -> Result<TtsResult> {
warn!("Azure TTS integration not yet implemented");
Ok(TtsResult {
audio_data: vec![],
format: AudioFormat::Wav,
duration_ms: (text.len() as u64) * 100,
})
}
async fn synthesize_stream(
&self,
_text: &str,
_config: &VoiceConfig,
) -> Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
let (_tx, rx) = tokio::sync::mpsc::channel(100);
Ok(rx)
}
}
struct LocalTtsEngine {
config: VoiceConfig,
}
impl LocalTtsEngine {
fn new(config: VoiceConfig) -> Self {
Self { config }
}
}
#[async_trait::async_trait]
impl TextToSpeechProvider for LocalTtsEngine {
async fn synthesize(&self, text: &str, _config: &VoiceConfig) -> Result<TtsResult> {
warn!("Local TTS engine not yet implemented");
Ok(TtsResult {
audio_data: vec![],
format: AudioFormat::Wav,
duration_ms: (text.len() as u64) * 100,
})
}
async fn synthesize_stream(
&self,
_text: &str,
_config: &VoiceConfig,
) -> Result<tokio::sync::mpsc::Receiver<Vec<u8>>> {
let (_tx, rx) = tokio::sync::mpsc::channel(100);
Ok(rx)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_voice_interface_creation() {
let config = VoiceConfig::default();
let interface = VoiceInterface::new(config);
assert!(interface.config.enable_stt);
assert!(interface.config.enable_tts);
}
#[tokio::test]
async fn test_transcribe_disabled() {
let config = VoiceConfig {
enable_stt: false,
..Default::default()
};
let interface = VoiceInterface::new(config);
let audio_data = vec![0u8; 1000];
let result = interface.transcribe(&audio_data).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("disabled"));
}
#[tokio::test]
async fn test_synthesize_disabled() {
let config = VoiceConfig {
enable_tts: false,
..Default::default()
};
let interface = VoiceInterface::new(config);
let text = "Hello, world!";
let result = interface.synthesize(text).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("disabled"));
}
#[test]
fn test_voice_config_custom() {
let config = VoiceConfig {
enable_stt: true,
enable_tts: false,
stt_provider: SttProviderType::Google,
tts_provider: TtsProviderType::Azure,
sample_rate: 48000,
channels: 2,
max_duration_secs: 600,
language: "ja-JP".to_string(),
voice: "echo".to_string(),
};
assert_eq!(config.sample_rate, 48000);
assert_eq!(config.channels, 2);
assert_eq!(config.language, "ja-JP");
assert_eq!(config.voice, "echo");
assert!(!config.enable_tts);
}
#[test]
fn test_audio_format_variants() {
assert!(matches!(AudioFormat::Wav, AudioFormat::Wav));
assert!(matches!(AudioFormat::Mp3, AudioFormat::Mp3));
}
#[test]
fn test_stt_provider_serialization() {
let provider = SttProviderType::OpenAI;
let serialized = serde_json::to_string(&provider).expect("should succeed");
assert_eq!(serialized, "\"open_a_i\"");
let deserialized: SttProviderType =
serde_json::from_str(&serialized).expect("should succeed");
assert_eq!(deserialized, SttProviderType::OpenAI);
}
#[test]
fn test_tts_provider_serialization() {
let provider = TtsProviderType::Google;
let serialized = serde_json::to_string(&provider).expect("should succeed");
assert_eq!(serialized, "\"google\"");
let deserialized: TtsProviderType =
serde_json::from_str(&serialized).expect("should succeed");
assert_eq!(deserialized, TtsProviderType::Google);
}
}