use async_trait::async_trait;
use crate::PluginError;
#[derive(Debug, Clone)]
pub struct CloudTtsResult {
pub audio_data: Vec<u8>,
pub mime_type: String,
pub duration_ms: Option<u64>,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct VoiceInfo {
pub id: String,
pub name: String,
pub language: String,
}
#[async_trait]
pub trait CloudTtsProvider: Send + Sync {
fn name(&self) -> &str;
fn available_voices(&self) -> Vec<VoiceInfo>;
async fn synthesize(
&self,
text: &str,
voice_id: &str,
) -> Result<CloudTtsResult, PluginError>;
}
pub struct OpenAiTtsProvider {
api_key: String,
model: String,
client: reqwest::Client,
}
impl OpenAiTtsProvider {
pub fn new(api_key: String) -> Self {
Self {
api_key,
model: "tts-1".to_string(),
client: reqwest::Client::new(),
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
}
#[async_trait]
impl CloudTtsProvider for OpenAiTtsProvider {
fn name(&self) -> &str {
"openai-tts"
}
fn available_voices(&self) -> Vec<VoiceInfo> {
["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
.iter()
.map(|v| VoiceInfo {
id: v.to_string(),
name: v.to_string(),
language: "en".to_string(),
})
.collect()
}
async fn synthesize(
&self,
text: &str,
voice_id: &str,
) -> Result<CloudTtsResult, PluginError> {
let body = serde_json::json!({
"model": self.model,
"input": text,
"voice": voice_id,
"response_format": "mp3",
});
let resp = self
.client
.post("https://api.openai.com/v1/audio/speech")
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await
.map_err(|e| {
PluginError::ExecutionFailed(format!("OpenAI TTS request failed: {e}"))
})?;
if !resp.status().is_success() {
let status = resp.status();
let err_body = resp.text().await.unwrap_or_default();
return Err(PluginError::ExecutionFailed(format!(
"OpenAI TTS returned {status}: {err_body}"
)));
}
let audio_data = resp
.bytes()
.await
.map_err(|e| PluginError::ExecutionFailed(format!("TTS response read error: {e}")))?
.to_vec();
Ok(CloudTtsResult {
audio_data,
mime_type: "audio/mp3".to_string(),
duration_ms: None,
})
}
}
pub struct ElevenLabsTtsProvider {
api_key: String,
client: reqwest::Client,
}
impl ElevenLabsTtsProvider {
pub fn new(api_key: String) -> Self {
Self {
api_key,
client: reqwest::Client::new(),
}
}
}
#[async_trait]
impl CloudTtsProvider for ElevenLabsTtsProvider {
fn name(&self) -> &str {
"elevenlabs"
}
fn available_voices(&self) -> Vec<VoiceInfo> {
vec![
VoiceInfo {
id: "21m00Tcm4TlvDq8ikWAM".into(),
name: "Rachel".into(),
language: "en".into(),
},
VoiceInfo {
id: "AZnzlk1XvdvUeBnXmlld".into(),
name: "Domi".into(),
language: "en".into(),
},
VoiceInfo {
id: "EXAVITQu4vr4xnSDxMaL".into(),
name: "Bella".into(),
language: "en".into(),
},
VoiceInfo {
id: "ErXwobaYiN019PkySvjV".into(),
name: "Antoni".into(),
language: "en".into(),
},
]
}
async fn synthesize(
&self,
text: &str,
voice_id: &str,
) -> Result<CloudTtsResult, PluginError> {
let url = format!("https://api.elevenlabs.io/v1/text-to-speech/{voice_id}");
let body = serde_json::json!({
"text": text,
"model_id": "eleven_monolingual_v1",
"voice_settings": {
"stability": 0.5,
"similarity_boost": 0.75,
},
});
let resp = self
.client
.post(&url)
.header("xi-api-key", &self.api_key)
.header("Content-Type", "application/json")
.header("Accept", "audio/mpeg")
.json(&body)
.send()
.await
.map_err(|e| {
PluginError::ExecutionFailed(format!("ElevenLabs request failed: {e}"))
})?;
if !resp.status().is_success() {
let status = resp.status();
let err_body = resp.text().await.unwrap_or_default();
return Err(PluginError::ExecutionFailed(format!(
"ElevenLabs returned {status}: {err_body}"
)));
}
let audio_data = resp
.bytes()
.await
.map_err(|e| {
PluginError::ExecutionFailed(format!("ElevenLabs read error: {e}"))
})?
.to_vec();
Ok(CloudTtsResult {
audio_data,
mime_type: "audio/mpeg".to_string(),
duration_ms: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn openai_tts_provider_name() {
let provider = OpenAiTtsProvider::new("test-key".into());
assert_eq!(provider.name(), "openai-tts");
}
#[test]
fn openai_tts_available_voices() {
let provider = OpenAiTtsProvider::new("test-key".into());
let voices = provider.available_voices();
assert_eq!(voices.len(), 6);
let ids: Vec<&str> = voices.iter().map(|v| v.id.as_str()).collect();
assert!(ids.contains(&"alloy"));
assert!(ids.contains(&"echo"));
assert!(ids.contains(&"fable"));
assert!(ids.contains(&"onyx"));
assert!(ids.contains(&"nova"));
assert!(ids.contains(&"shimmer"));
}
#[test]
fn openai_tts_with_model_builder() {
let provider = OpenAiTtsProvider::new("test-key".into()).with_model("tts-1-hd");
assert_eq!(provider.model, "tts-1-hd");
}
#[tokio::test]
async fn openai_tts_synthesize_invalid_key_errors() {
let provider = OpenAiTtsProvider::new("invalid-key".into());
let result = provider.synthesize("hello", "alloy").await;
assert!(result.is_err());
}
#[test]
fn elevenlabs_provider_name() {
let provider = ElevenLabsTtsProvider::new("test-key".into());
assert_eq!(provider.name(), "elevenlabs");
}
#[test]
fn elevenlabs_available_voices() {
let provider = ElevenLabsTtsProvider::new("test-key".into());
let voices = provider.available_voices();
assert_eq!(voices.len(), 4);
let names: Vec<&str> = voices.iter().map(|v| v.name.as_str()).collect();
assert!(names.contains(&"Rachel"));
assert!(names.contains(&"Domi"));
assert!(names.contains(&"Bella"));
assert!(names.contains(&"Antoni"));
}
#[test]
fn cloud_tts_result_fields() {
let result = CloudTtsResult {
audio_data: vec![1, 2, 3],
mime_type: "audio/mp3".into(),
duration_ms: Some(1500),
};
assert_eq!(result.audio_data, vec![1, 2, 3]);
assert_eq!(result.mime_type, "audio/mp3");
assert_eq!(result.duration_ms, Some(1500));
}
#[tokio::test]
async fn elevenlabs_synthesize_invalid_key_errors() {
let provider = ElevenLabsTtsProvider::new("invalid-key".into());
let result = provider.synthesize("hello", "21m00Tcm4TlvDq8ikWAM").await;
assert!(result.is_err());
}
#[test]
fn voice_info_serializable() {
let info = VoiceInfo {
id: "alloy".into(),
name: "Alloy".into(),
language: "en".into(),
};
let json = serde_json::to_string(&info).unwrap();
assert!(json.contains("\"id\":\"alloy\""));
assert!(json.contains("\"name\":\"Alloy\""));
assert!(json.contains("\"language\":\"en\""));
}
}