1use async_trait::async_trait;
8
9use crate::PluginError;
10
11#[derive(Debug, Clone)]
13pub struct CloudTtsResult {
14 pub audio_data: Vec<u8>,
16 pub mime_type: String,
18 pub duration_ms: Option<u64>,
20}
21
22#[derive(Debug, Clone, serde::Serialize)]
24pub struct VoiceInfo {
25 pub id: String,
27 pub name: String,
29 pub language: String,
31}
32
33#[async_trait]
35pub trait CloudTtsProvider: Send + Sync {
36 fn name(&self) -> &str;
38
39 fn available_voices(&self) -> Vec<VoiceInfo>;
41
42 async fn synthesize(
47 &self,
48 text: &str,
49 voice_id: &str,
50 ) -> Result<CloudTtsResult, PluginError>;
51}
52
53pub struct OpenAiTtsProvider {
62 api_key: String,
63 model: String,
64 client: reqwest::Client,
65}
66
67impl OpenAiTtsProvider {
68 pub fn new(api_key: String) -> Self {
70 Self {
71 api_key,
72 model: "tts-1".to_string(),
73 client: reqwest::Client::new(),
74 }
75 }
76
77 pub fn with_model(mut self, model: impl Into<String>) -> Self {
79 self.model = model.into();
80 self
81 }
82}
83
84#[async_trait]
85impl CloudTtsProvider for OpenAiTtsProvider {
86 fn name(&self) -> &str {
87 "openai-tts"
88 }
89
90 fn available_voices(&self) -> Vec<VoiceInfo> {
91 ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
92 .iter()
93 .map(|v| VoiceInfo {
94 id: v.to_string(),
95 name: v.to_string(),
96 language: "en".to_string(),
97 })
98 .collect()
99 }
100
101 async fn synthesize(
102 &self,
103 text: &str,
104 voice_id: &str,
105 ) -> Result<CloudTtsResult, PluginError> {
106 let body = serde_json::json!({
107 "model": self.model,
108 "input": text,
109 "voice": voice_id,
110 "response_format": "mp3",
111 });
112
113 let resp = self
114 .client
115 .post("https://api.openai.com/v1/audio/speech")
116 .bearer_auth(&self.api_key)
117 .json(&body)
118 .send()
119 .await
120 .map_err(|e| {
121 PluginError::ExecutionFailed(format!("OpenAI TTS request failed: {e}"))
122 })?;
123
124 if !resp.status().is_success() {
125 let status = resp.status();
126 let err_body = resp.text().await.unwrap_or_default();
127 return Err(PluginError::ExecutionFailed(format!(
128 "OpenAI TTS returned {status}: {err_body}"
129 )));
130 }
131
132 let audio_data = resp
133 .bytes()
134 .await
135 .map_err(|e| PluginError::ExecutionFailed(format!("TTS response read error: {e}")))?
136 .to_vec();
137
138 Ok(CloudTtsResult {
139 audio_data,
140 mime_type: "audio/mp3".to_string(),
141 duration_ms: None,
142 })
143 }
144}
145
146pub struct ElevenLabsTtsProvider {
155 api_key: String,
156 client: reqwest::Client,
157}
158
159impl ElevenLabsTtsProvider {
160 pub fn new(api_key: String) -> Self {
162 Self {
163 api_key,
164 client: reqwest::Client::new(),
165 }
166 }
167}
168
169#[async_trait]
170impl CloudTtsProvider for ElevenLabsTtsProvider {
171 fn name(&self) -> &str {
172 "elevenlabs"
173 }
174
175 fn available_voices(&self) -> Vec<VoiceInfo> {
176 vec![
177 VoiceInfo {
178 id: "21m00Tcm4TlvDq8ikWAM".into(),
179 name: "Rachel".into(),
180 language: "en".into(),
181 },
182 VoiceInfo {
183 id: "AZnzlk1XvdvUeBnXmlld".into(),
184 name: "Domi".into(),
185 language: "en".into(),
186 },
187 VoiceInfo {
188 id: "EXAVITQu4vr4xnSDxMaL".into(),
189 name: "Bella".into(),
190 language: "en".into(),
191 },
192 VoiceInfo {
193 id: "ErXwobaYiN019PkySvjV".into(),
194 name: "Antoni".into(),
195 language: "en".into(),
196 },
197 ]
198 }
199
200 async fn synthesize(
201 &self,
202 text: &str,
203 voice_id: &str,
204 ) -> Result<CloudTtsResult, PluginError> {
205 let url = format!("https://api.elevenlabs.io/v1/text-to-speech/{voice_id}");
206 let body = serde_json::json!({
207 "text": text,
208 "model_id": "eleven_monolingual_v1",
209 "voice_settings": {
210 "stability": 0.5,
211 "similarity_boost": 0.75,
212 },
213 });
214
215 let resp = self
216 .client
217 .post(&url)
218 .header("xi-api-key", &self.api_key)
219 .header("Content-Type", "application/json")
220 .header("Accept", "audio/mpeg")
221 .json(&body)
222 .send()
223 .await
224 .map_err(|e| {
225 PluginError::ExecutionFailed(format!("ElevenLabs request failed: {e}"))
226 })?;
227
228 if !resp.status().is_success() {
229 let status = resp.status();
230 let err_body = resp.text().await.unwrap_or_default();
231 return Err(PluginError::ExecutionFailed(format!(
232 "ElevenLabs returned {status}: {err_body}"
233 )));
234 }
235
236 let audio_data = resp
237 .bytes()
238 .await
239 .map_err(|e| {
240 PluginError::ExecutionFailed(format!("ElevenLabs read error: {e}"))
241 })?
242 .to_vec();
243
244 Ok(CloudTtsResult {
245 audio_data,
246 mime_type: "audio/mpeg".to_string(),
247 duration_ms: None,
248 })
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
259 fn openai_tts_provider_name() {
260 let provider = OpenAiTtsProvider::new("test-key".into());
261 assert_eq!(provider.name(), "openai-tts");
262 }
263
264 #[test]
265 fn openai_tts_available_voices() {
266 let provider = OpenAiTtsProvider::new("test-key".into());
267 let voices = provider.available_voices();
268 assert_eq!(voices.len(), 6);
269 let ids: Vec<&str> = voices.iter().map(|v| v.id.as_str()).collect();
270 assert!(ids.contains(&"alloy"));
271 assert!(ids.contains(&"echo"));
272 assert!(ids.contains(&"fable"));
273 assert!(ids.contains(&"onyx"));
274 assert!(ids.contains(&"nova"));
275 assert!(ids.contains(&"shimmer"));
276 }
277
278 #[test]
279 fn openai_tts_with_model_builder() {
280 let provider = OpenAiTtsProvider::new("test-key".into()).with_model("tts-1-hd");
281 assert_eq!(provider.model, "tts-1-hd");
282 }
283
284 #[tokio::test]
285 async fn openai_tts_synthesize_invalid_key_errors() {
286 let provider = OpenAiTtsProvider::new("invalid-key".into());
287 let result = provider.synthesize("hello", "alloy").await;
288 assert!(result.is_err());
289 }
290
291 #[test]
294 fn elevenlabs_provider_name() {
295 let provider = ElevenLabsTtsProvider::new("test-key".into());
296 assert_eq!(provider.name(), "elevenlabs");
297 }
298
299 #[test]
300 fn elevenlabs_available_voices() {
301 let provider = ElevenLabsTtsProvider::new("test-key".into());
302 let voices = provider.available_voices();
303 assert_eq!(voices.len(), 4);
304 let names: Vec<&str> = voices.iter().map(|v| v.name.as_str()).collect();
305 assert!(names.contains(&"Rachel"));
306 assert!(names.contains(&"Domi"));
307 assert!(names.contains(&"Bella"));
308 assert!(names.contains(&"Antoni"));
309 }
310
311 #[test]
312 fn cloud_tts_result_fields() {
313 let result = CloudTtsResult {
314 audio_data: vec![1, 2, 3],
315 mime_type: "audio/mp3".into(),
316 duration_ms: Some(1500),
317 };
318 assert_eq!(result.audio_data, vec![1, 2, 3]);
319 assert_eq!(result.mime_type, "audio/mp3");
320 assert_eq!(result.duration_ms, Some(1500));
321 }
322
323 #[tokio::test]
324 async fn elevenlabs_synthesize_invalid_key_errors() {
325 let provider = ElevenLabsTtsProvider::new("invalid-key".into());
326 let result = provider.synthesize("hello", "21m00Tcm4TlvDq8ikWAM").await;
327 assert!(result.is_err());
328 }
329
330 #[test]
331 fn voice_info_serializable() {
332 let info = VoiceInfo {
333 id: "alloy".into(),
334 name: "Alloy".into(),
335 language: "en".into(),
336 };
337 let json = serde_json::to_string(&info).unwrap();
338 assert!(json.contains("\"id\":\"alloy\""));
339 assert!(json.contains("\"name\":\"Alloy\""));
340 assert!(json.contains("\"language\":\"en\""));
341 }
342}