Skip to main content

adk_audio/providers/tts/
gemini.rs

1//! Gemini native audio TTS provider.
2//!
3//! Supports all Gemini TTS models:
4//! - `gemini-3.1-flash-tts-preview` — expressive, audio tags, multi-speaker (default)
5//! - `gemini-2.5-flash-preview-tts` — fast, multi-speaker
6//! - `gemini-2.5-pro-preview-tts` — high-fidelity, multi-speaker
7
8use std::pin::Pin;
9
10use async_trait::async_trait;
11use bytes::Bytes;
12use futures::Stream;
13
14use crate::error::{AudioError, AudioResult};
15use crate::frame::AudioFrame;
16use crate::providers::tts::CloudTtsConfig;
17use crate::traits::{TtsProvider, TtsRequest, Voice};
18
19/// Available Gemini TTS model IDs.
20#[allow(dead_code)]
21pub mod models {
22    /// Gemini 3.1 Flash TTS — expressive audio tags, multi-speaker, low-latency.
23    pub const GEMINI_3_1_FLASH_TTS: &str = "gemini-3.1-flash-tts-preview";
24    /// Gemini 2.5 Flash TTS — fast, multi-speaker.
25    pub const GEMINI_2_5_FLASH_TTS: &str = "gemini-2.5-flash-preview-tts";
26    /// Gemini 2.5 Pro TTS — high-fidelity, multi-speaker.
27    pub const GEMINI_2_5_PRO_TTS: &str = "gemini-2.5-pro-preview-tts";
28}
29
30/// Speaker configuration for multi-speaker TTS.
31#[derive(Debug, Clone)]
32pub struct SpeakerConfig {
33    /// Speaker name (must match the name used in the transcript).
34    pub name: String,
35    /// Voice name from the 30 available voices.
36    pub voice: String,
37}
38
39impl SpeakerConfig {
40    /// Create a new speaker configuration.
41    pub fn new(name: impl Into<String>, voice: impl Into<String>) -> Self {
42        Self { name: name.into(), voice: voice.into() }
43    }
44}
45
46/// Gemini TTS provider using `generateContent` with audio response modality.
47///
48/// # Example
49///
50/// ```rust,ignore
51/// use adk_audio::GeminiTts;
52///
53/// // Default: gemini-3.1-flash-tts-preview
54/// let tts = GeminiTts::from_env()?;
55///
56/// // Specific model
57/// let tts = GeminiTts::from_env()?.with_model("gemini-2.5-pro-preview-tts");
58///
59/// // Multi-speaker
60/// let tts = GeminiTts::from_env()?.with_speakers(vec![
61///     SpeakerConfig::new("Alice", "Kore"),
62///     SpeakerConfig::new("Bob", "Puck"),
63/// ]);
64/// ```
65pub struct GeminiTts {
66    config: CloudTtsConfig,
67    client: reqwest::Client,
68    model: String,
69    voices: Vec<Voice>,
70    speakers: Option<Vec<SpeakerConfig>>,
71}
72
73impl GeminiTts {
74    /// Create from environment variable `GEMINI_API_KEY`.
75    pub fn from_env() -> AudioResult<Self> {
76        let api_key = std::env::var("GEMINI_API_KEY")
77            .or_else(|_| std::env::var("GOOGLE_API_KEY"))
78            .map_err(|_| AudioError::Tts {
79                provider: "gemini".into(),
80                message: "GEMINI_API_KEY or GOOGLE_API_KEY not set".into(),
81            })?;
82        Ok(Self::new(CloudTtsConfig::new(api_key)))
83    }
84
85    /// Create with explicit config.
86    pub fn new(config: CloudTtsConfig) -> Self {
87        Self {
88            config,
89            client: reqwest::Client::new(),
90            model: models::GEMINI_3_1_FLASH_TTS.into(),
91            voices: build_voice_catalog(),
92            speakers: None,
93        }
94    }
95
96    /// Set the TTS model.
97    pub fn with_model(mut self, model: impl Into<String>) -> Self {
98        self.model = model.into();
99        self
100    }
101
102    /// Configure multi-speaker synthesis.
103    ///
104    /// Speaker names must match the names used in the transcript text.
105    /// Up to 2 speakers are supported.
106    pub fn with_speakers(mut self, speakers: Vec<SpeakerConfig>) -> Self {
107        self.speakers = Some(speakers);
108        self
109    }
110
111    fn base_url(&self) -> String {
112        self.config.base_url.clone().unwrap_or_else(|| {
113            format!(
114                "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent",
115                self.model
116            )
117        })
118    }
119
120    fn build_speech_config(&self, voice: &str) -> serde_json::Value {
121        match &self.speakers {
122            Some(speakers) if !speakers.is_empty() => {
123                let speaker_configs: Vec<serde_json::Value> = speakers
124                    .iter()
125                    .map(|s| {
126                        serde_json::json!({
127                            "speaker": s.name,
128                            "voiceConfig": {
129                                "prebuiltVoiceConfig": {
130                                    "voiceName": s.voice
131                                }
132                            }
133                        })
134                    })
135                    .collect();
136                serde_json::json!({
137                    "multiSpeakerVoiceConfig": {
138                        "speakerVoiceConfigs": speaker_configs
139                    }
140                })
141            }
142            _ => {
143                let voice_name = if voice.is_empty() { "Kore" } else { voice };
144                serde_json::json!({
145                    "voiceConfig": {
146                        "prebuiltVoiceConfig": {
147                            "voiceName": voice_name
148                        }
149                    }
150                })
151            }
152        }
153    }
154}
155
156#[async_trait]
157impl TtsProvider for GeminiTts {
158    async fn synthesize(&self, request: &TtsRequest) -> AudioResult<AudioFrame> {
159        let url = self.base_url();
160        let speech_config = self.build_speech_config(&request.voice);
161
162        let body = serde_json::json!({
163            "contents": [{"parts": [{"text": request.text}]}],
164            "generationConfig": {
165                "response_modalities": ["AUDIO"],
166                "speech_config": speech_config
167            }
168        });
169
170        let resp = self
171            .client
172            .post(&url)
173            .header("x-goog-api-key", &self.config.api_key)
174            .json(&body)
175            .send()
176            .await
177            .map_err(|e| AudioError::Tts { provider: "gemini".into(), message: e.to_string() })?;
178
179        if !resp.status().is_success() {
180            let status = resp.status();
181            let body = resp.text().await.unwrap_or_default();
182            return Err(AudioError::Tts {
183                provider: "gemini".into(),
184                message: format!("HTTP {status}: {body}"),
185            });
186        }
187
188        let json: serde_json::Value = resp
189            .json()
190            .await
191            .map_err(|e| AudioError::Tts { provider: "gemini".into(), message: e.to_string() })?;
192
193        let audio_b64 = json["candidates"][0]["content"]["parts"][0]["inlineData"]["data"]
194            .as_str()
195            .ok_or_else(|| AudioError::Tts {
196                provider: "gemini".into(),
197                message: "no audio data in response".into(),
198            })?;
199
200        use base64::Engine;
201        let pcm = base64::engine::general_purpose::STANDARD.decode(audio_b64).map_err(|e| {
202            AudioError::Tts {
203                provider: "gemini".into(),
204                message: format!("base64 decode failed: {e}"),
205            }
206        })?;
207
208        Ok(AudioFrame::new(Bytes::from(pcm), 24000, 1))
209    }
210
211    async fn synthesize_stream(
212        &self,
213        request: &TtsRequest,
214    ) -> AudioResult<Pin<Box<dyn Stream<Item = AudioResult<AudioFrame>> + Send>>> {
215        // Gemini TTS does not support streaming — return single frame
216        let frame = self.synthesize(request).await?;
217        Ok(Box::pin(futures::stream::once(async { Ok(frame) })))
218    }
219
220    fn voice_catalog(&self) -> &[Voice] {
221        &self.voices
222    }
223}
224
225/// Build the full 30-voice catalog.
226fn build_voice_catalog() -> Vec<Voice> {
227    let voices = [
228        ("Zephyr", "Bright"),
229        ("Puck", "Upbeat"),
230        ("Charon", "Informative"),
231        ("Kore", "Firm"),
232        ("Fenrir", "Excitable"),
233        ("Leda", "Youthful"),
234        ("Orus", "Firm"),
235        ("Aoede", "Breezy"),
236        ("Callirrhoe", "Easy-going"),
237        ("Autonoe", "Bright"),
238        ("Enceladus", "Breathy"),
239        ("Iapetus", "Clear"),
240        ("Umbriel", "Easy-going"),
241        ("Algieba", "Smooth"),
242        ("Despina", "Smooth"),
243        ("Erinome", "Clear"),
244        ("Algenib", "Gravelly"),
245        ("Rasalgethi", "Informative"),
246        ("Laomedeia", "Upbeat"),
247        ("Achernar", "Soft"),
248        ("Alnilam", "Firm"),
249        ("Schedar", "Even"),
250        ("Gacrux", "Mature"),
251        ("Pulcherrima", "Forward"),
252        ("Achird", "Friendly"),
253        ("Zubenelgenubi", "Casual"),
254        ("Vindemiatrix", "Gentle"),
255        ("Sadachbia", "Lively"),
256        ("Sadaltager", "Knowledgeable"),
257        ("Sulafat", "Warm"),
258    ];
259
260    voices
261        .iter()
262        .map(|(name, style)| Voice {
263            id: name.to_string(),
264            name: format!("{name} — {style}"),
265            language: "multilingual".into(),
266            gender: None,
267        })
268        .collect()
269}