Skip to main content

clawft_plugin/voice/
cloud_tts.rs

1//! Cloud-based text-to-speech providers.
2//!
3//! Defines the [`CloudTtsProvider`] trait and provides implementations
4//! for OpenAI TTS ([`OpenAiTtsProvider`]) and ElevenLabs
5//! ([`ElevenLabsTtsProvider`]).
6
7use async_trait::async_trait;
8
9use crate::PluginError;
10
11/// Result from a cloud TTS synthesis.
12#[derive(Debug, Clone)]
13pub struct CloudTtsResult {
14    /// Raw audio data.
15    pub audio_data: Vec<u8>,
16    /// MIME type of the audio (e.g., "audio/mp3", "audio/mpeg").
17    pub mime_type: String,
18    /// Duration of the synthesized audio in milliseconds, if known.
19    pub duration_ms: Option<u64>,
20}
21
22/// Information about an available voice.
23#[derive(Debug, Clone, serde::Serialize)]
24pub struct VoiceInfo {
25    /// Provider-specific voice ID.
26    pub id: String,
27    /// Human-readable voice name.
28    pub name: String,
29    /// Primary language code (BCP-47).
30    pub language: String,
31}
32
33/// Trait for cloud-based text-to-speech providers.
34#[async_trait]
35pub trait CloudTtsProvider: Send + Sync {
36    /// Provider name (e.g., "openai-tts", "elevenlabs").
37    fn name(&self) -> &str;
38
39    /// List available voices for this provider.
40    fn available_voices(&self) -> Vec<VoiceInfo>;
41
42    /// Synthesize text to audio bytes.
43    ///
44    /// * `text` - The text to synthesize.
45    /// * `voice_id` - Provider-specific voice identifier.
46    async fn synthesize(
47        &self,
48        text: &str,
49        voice_id: &str,
50    ) -> Result<CloudTtsResult, PluginError>;
51}
52
53// ---------------------------------------------------------------------------
54// OpenAI TTS
55// ---------------------------------------------------------------------------
56
57/// OpenAI TTS API implementation.
58///
59/// Posts to `https://api.openai.com/v1/audio/speech` with model "tts-1".
60/// Available voices: alloy, echo, fable, onyx, nova, shimmer.
61pub struct OpenAiTtsProvider {
62    api_key: String,
63    model: String,
64    client: reqwest::Client,
65}
66
67impl OpenAiTtsProvider {
68    /// Create a new OpenAI TTS provider with the given API key.
69    pub fn new(api_key: String) -> Self {
70        Self {
71            api_key,
72            model: "tts-1".to_string(),
73            client: reqwest::Client::new(),
74        }
75    }
76
77    /// Override the TTS model (default: "tts-1").
78    pub fn with_model(mut self, model: impl Into<String>) -> Self {
79        self.model = model.into();
80        self
81    }
82}
83
84#[async_trait]
85impl CloudTtsProvider for OpenAiTtsProvider {
86    fn name(&self) -> &str {
87        "openai-tts"
88    }
89
90    fn available_voices(&self) -> Vec<VoiceInfo> {
91        ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
92            .iter()
93            .map(|v| VoiceInfo {
94                id: v.to_string(),
95                name: v.to_string(),
96                language: "en".to_string(),
97            })
98            .collect()
99    }
100
101    async fn synthesize(
102        &self,
103        text: &str,
104        voice_id: &str,
105    ) -> Result<CloudTtsResult, PluginError> {
106        let body = serde_json::json!({
107            "model": self.model,
108            "input": text,
109            "voice": voice_id,
110            "response_format": "mp3",
111        });
112
113        let resp = self
114            .client
115            .post("https://api.openai.com/v1/audio/speech")
116            .bearer_auth(&self.api_key)
117            .json(&body)
118            .send()
119            .await
120            .map_err(|e| {
121                PluginError::ExecutionFailed(format!("OpenAI TTS request failed: {e}"))
122            })?;
123
124        if !resp.status().is_success() {
125            let status = resp.status();
126            let err_body = resp.text().await.unwrap_or_default();
127            return Err(PluginError::ExecutionFailed(format!(
128                "OpenAI TTS returned {status}: {err_body}"
129            )));
130        }
131
132        let audio_data = resp
133            .bytes()
134            .await
135            .map_err(|e| PluginError::ExecutionFailed(format!("TTS response read error: {e}")))?
136            .to_vec();
137
138        Ok(CloudTtsResult {
139            audio_data,
140            mime_type: "audio/mp3".to_string(),
141            duration_ms: None,
142        })
143    }
144}
145
146// ---------------------------------------------------------------------------
147// ElevenLabs TTS
148// ---------------------------------------------------------------------------
149
150/// ElevenLabs TTS API implementation.
151///
152/// Posts to `https://api.elevenlabs.io/v1/text-to-speech/{voice_id}`
153/// with `xi-api-key` header authentication.
154pub struct ElevenLabsTtsProvider {
155    api_key: String,
156    client: reqwest::Client,
157}
158
159impl ElevenLabsTtsProvider {
160    /// Create a new ElevenLabs TTS provider with the given API key.
161    pub fn new(api_key: String) -> Self {
162        Self {
163            api_key,
164            client: reqwest::Client::new(),
165        }
166    }
167}
168
169#[async_trait]
170impl CloudTtsProvider for ElevenLabsTtsProvider {
171    fn name(&self) -> &str {
172        "elevenlabs"
173    }
174
175    fn available_voices(&self) -> Vec<VoiceInfo> {
176        vec![
177            VoiceInfo {
178                id: "21m00Tcm4TlvDq8ikWAM".into(),
179                name: "Rachel".into(),
180                language: "en".into(),
181            },
182            VoiceInfo {
183                id: "AZnzlk1XvdvUeBnXmlld".into(),
184                name: "Domi".into(),
185                language: "en".into(),
186            },
187            VoiceInfo {
188                id: "EXAVITQu4vr4xnSDxMaL".into(),
189                name: "Bella".into(),
190                language: "en".into(),
191            },
192            VoiceInfo {
193                id: "ErXwobaYiN019PkySvjV".into(),
194                name: "Antoni".into(),
195                language: "en".into(),
196            },
197        ]
198    }
199
200    async fn synthesize(
201        &self,
202        text: &str,
203        voice_id: &str,
204    ) -> Result<CloudTtsResult, PluginError> {
205        let url = format!("https://api.elevenlabs.io/v1/text-to-speech/{voice_id}");
206        let body = serde_json::json!({
207            "text": text,
208            "model_id": "eleven_monolingual_v1",
209            "voice_settings": {
210                "stability": 0.5,
211                "similarity_boost": 0.75,
212            },
213        });
214
215        let resp = self
216            .client
217            .post(&url)
218            .header("xi-api-key", &self.api_key)
219            .header("Content-Type", "application/json")
220            .header("Accept", "audio/mpeg")
221            .json(&body)
222            .send()
223            .await
224            .map_err(|e| {
225                PluginError::ExecutionFailed(format!("ElevenLabs request failed: {e}"))
226            })?;
227
228        if !resp.status().is_success() {
229            let status = resp.status();
230            let err_body = resp.text().await.unwrap_or_default();
231            return Err(PluginError::ExecutionFailed(format!(
232                "ElevenLabs returned {status}: {err_body}"
233            )));
234        }
235
236        let audio_data = resp
237            .bytes()
238            .await
239            .map_err(|e| {
240                PluginError::ExecutionFailed(format!("ElevenLabs read error: {e}"))
241            })?
242            .to_vec();
243
244        Ok(CloudTtsResult {
245            audio_data,
246            mime_type: "audio/mpeg".to_string(),
247            duration_ms: None,
248        })
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    // -- OpenAI TTS tests --
257
258    #[test]
259    fn openai_tts_provider_name() {
260        let provider = OpenAiTtsProvider::new("test-key".into());
261        assert_eq!(provider.name(), "openai-tts");
262    }
263
264    #[test]
265    fn openai_tts_available_voices() {
266        let provider = OpenAiTtsProvider::new("test-key".into());
267        let voices = provider.available_voices();
268        assert_eq!(voices.len(), 6);
269        let ids: Vec<&str> = voices.iter().map(|v| v.id.as_str()).collect();
270        assert!(ids.contains(&"alloy"));
271        assert!(ids.contains(&"echo"));
272        assert!(ids.contains(&"fable"));
273        assert!(ids.contains(&"onyx"));
274        assert!(ids.contains(&"nova"));
275        assert!(ids.contains(&"shimmer"));
276    }
277
278    #[test]
279    fn openai_tts_with_model_builder() {
280        let provider = OpenAiTtsProvider::new("test-key".into()).with_model("tts-1-hd");
281        assert_eq!(provider.model, "tts-1-hd");
282    }
283
284    #[tokio::test]
285    async fn openai_tts_synthesize_invalid_key_errors() {
286        let provider = OpenAiTtsProvider::new("invalid-key".into());
287        let result = provider.synthesize("hello", "alloy").await;
288        assert!(result.is_err());
289    }
290
291    // -- ElevenLabs TTS tests --
292
293    #[test]
294    fn elevenlabs_provider_name() {
295        let provider = ElevenLabsTtsProvider::new("test-key".into());
296        assert_eq!(provider.name(), "elevenlabs");
297    }
298
299    #[test]
300    fn elevenlabs_available_voices() {
301        let provider = ElevenLabsTtsProvider::new("test-key".into());
302        let voices = provider.available_voices();
303        assert_eq!(voices.len(), 4);
304        let names: Vec<&str> = voices.iter().map(|v| v.name.as_str()).collect();
305        assert!(names.contains(&"Rachel"));
306        assert!(names.contains(&"Domi"));
307        assert!(names.contains(&"Bella"));
308        assert!(names.contains(&"Antoni"));
309    }
310
311    #[test]
312    fn cloud_tts_result_fields() {
313        let result = CloudTtsResult {
314            audio_data: vec![1, 2, 3],
315            mime_type: "audio/mp3".into(),
316            duration_ms: Some(1500),
317        };
318        assert_eq!(result.audio_data, vec![1, 2, 3]);
319        assert_eq!(result.mime_type, "audio/mp3");
320        assert_eq!(result.duration_ms, Some(1500));
321    }
322
323    #[tokio::test]
324    async fn elevenlabs_synthesize_invalid_key_errors() {
325        let provider = ElevenLabsTtsProvider::new("invalid-key".into());
326        let result = provider.synthesize("hello", "21m00Tcm4TlvDq8ikWAM").await;
327        assert!(result.is_err());
328    }
329
330    #[test]
331    fn voice_info_serializable() {
332        let info = VoiceInfo {
333            id: "alloy".into(),
334            name: "Alloy".into(),
335            language: "en".into(),
336        };
337        let json = serde_json::to_string(&info).unwrap();
338        assert!(json.contains("\"id\":\"alloy\""));
339        assert!(json.contains("\"name\":\"Alloy\""));
340        assert!(json.contains("\"language\":\"en\""));
341    }
342}