use crate::{
AudioChunk, AudioFormat, ModelInfo, STTResult, SpeechRequest, SpeechResponse, TTSResult,
TextChunk, TranscriptionRequest, TranscriptionResponse,
};
use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
pub trait TTSProvider: TTSSpeechProvider + TTSModelsProvider + Send + Sync {}
#[async_trait]
pub trait TTSSpeechProvider: Send + Sync {
async fn generate_speech(&self, request: SpeechRequest) -> TTSResult<SpeechResponse>;
async fn generate_speech_stream<'a>(
&'a self,
_request: SpeechRequest,
) -> TTSResult<Pin<Box<dyn Stream<Item = TTSResult<AudioChunk>> + Send + 'a>>> {
Err(crate::error::TTSError::StreamingNotSupported(
"Not Supported".to_string(),
))
}
fn supports_streaming(&self) -> bool {
false
}
fn supported_formats(&self) -> Vec<AudioFormat> {
vec![AudioFormat::Wav]
}
fn default_sample_rate(&self) -> u32 {
24000
}
}
#[async_trait]
pub trait TTSModelsProvider: Send + Sync {
async fn list_models(&self) -> TTSResult<Vec<ModelInfo>> {
Ok(vec![])
}
fn get_current_model(&self) -> ModelInfo;
fn supported_languages(&self) -> Vec<String> {
vec!["en".to_string()]
}
}
pub trait STTProvider: STTSpeechProvider + STTModelsProvider + Send + Sync {}
#[async_trait]
pub trait STTSpeechProvider: Send + Sync {
async fn transcribe(&self, request: TranscriptionRequest) -> STTResult<TranscriptionResponse>;
async fn transcribe_stream<'a>(
&'a self,
_request: TranscriptionRequest,
) -> STTResult<Pin<Box<dyn Stream<Item = STTResult<TextChunk>> + Send + 'a>>> {
Err(crate::error::STTError::StreamingNotSupported(
"Not Supported".to_string(),
))
}
fn supports_streaming(&self) -> bool {
false
}
fn supported_sample_rate(&self) -> u32 {
16000
}
fn supported_channels(&self) -> u16 {
1
}
fn supports_timestamps(&self) -> bool {
false
}
}
#[async_trait]
pub trait STTModelsProvider: Send + Sync {
async fn list_models(&self) -> STTResult<Vec<ModelInfo>> {
Ok(vec![])
}
fn get_current_model(&self) -> ModelInfo;
fn supported_languages(&self) -> Vec<String> {
vec!["en".to_string()]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
AudioData, AudioFormat, ModelInfo, SharedAudioData, SpeechRequest, SpeechResponse,
TranscriptionRequest, TranscriptionResponse, VoiceIdentifier,
};
use async_trait::async_trait;
#[derive(Debug)]
struct DummyProvider;
#[async_trait]
impl TTSSpeechProvider for DummyProvider {
async fn generate_speech(&self, request: SpeechRequest) -> TTSResult<SpeechResponse> {
Ok(SpeechResponse {
audio: AudioData {
samples: vec![0.0],
channels: 1,
sample_rate: request.sample_rate.unwrap_or(24000),
},
text: request.text,
duration_ms: 0,
})
}
}
#[async_trait]
impl TTSModelsProvider for DummyProvider {
fn get_current_model(&self) -> ModelInfo {
ModelInfo {
id: "dummy".to_string(),
name: "Dummy".to_string(),
description: None,
languages: vec!["en".to_string()],
}
}
}
impl TTSProvider for DummyProvider {}
#[tokio::test]
async fn test_default_streaming_not_supported() {
let provider = DummyProvider;
let request = SpeechRequest {
text: "hello".to_string(),
voice: VoiceIdentifier::new("test"),
format: AudioFormat::Wav,
sample_rate: None,
};
let err = match provider.generate_speech_stream(request).await {
Ok(_) => panic!("expected streaming not supported"),
Err(err) => err,
};
assert!(matches!(
err,
crate::error::TTSError::StreamingNotSupported(_)
));
assert!(!provider.supports_streaming());
}
#[test]
fn test_default_provider_formats_and_languages() {
let provider = DummyProvider;
assert_eq!(provider.supported_formats(), vec![AudioFormat::Wav]);
assert_eq!(provider.default_sample_rate(), 24000);
assert_eq!(provider.supported_languages(), vec!["en".to_string()]);
}
#[derive(Debug)]
struct DummySTTProvider;
#[async_trait]
impl STTSpeechProvider for DummySTTProvider {
async fn transcribe(
&self,
request: TranscriptionRequest,
) -> STTResult<TranscriptionResponse> {
Ok(TranscriptionResponse {
text: format!("Transcribed {} samples", request.audio.samples.len()),
timestamps: None,
duration_ms: 0,
})
}
}
#[async_trait]
impl STTModelsProvider for DummySTTProvider {
fn get_current_model(&self) -> ModelInfo {
ModelInfo {
id: "dummy".to_string(),
name: "Dummy STT".to_string(),
description: None,
languages: vec!["en".to_string()],
}
}
}
impl STTProvider for DummySTTProvider {}
#[tokio::test]
async fn test_stt_default_streaming_not_supported() {
let provider = DummySTTProvider;
let request = TranscriptionRequest {
audio: SharedAudioData::new(AudioData {
samples: vec![0.0; 16000],
sample_rate: 16000,
channels: 1,
}),
language: None,
include_timestamps: false,
};
let err = match provider.transcribe_stream(request).await {
Ok(_) => panic!("expected streaming not supported"),
Err(err) => err,
};
assert!(matches!(
err,
crate::error::STTError::StreamingNotSupported(_)
));
assert!(!provider.supports_streaming());
}
#[test]
fn test_stt_default_provider_settings() {
let provider = DummySTTProvider;
assert_eq!(provider.supported_sample_rate(), 16000);
assert_eq!(provider.supported_channels(), 1);
assert_eq!(provider.supported_languages(), vec!["en".to_string()]);
assert!(!provider.supports_timestamps());
}
}