1use std::pin::Pin;
9
10use async_trait::async_trait;
11use bytes::Bytes;
12use futures::Stream;
13
14use crate::error::{AudioError, AudioResult};
15use crate::frame::AudioFrame;
16use crate::providers::tts::CloudTtsConfig;
17use crate::traits::{TtsProvider, TtsRequest, Voice};
18
19#[allow(dead_code)]
21pub mod models {
22 pub const GEMINI_3_1_FLASH_TTS: &str = "gemini-3.1-flash-tts-preview";
24 pub const GEMINI_2_5_FLASH_TTS: &str = "gemini-2.5-flash-preview-tts";
26 pub const GEMINI_2_5_PRO_TTS: &str = "gemini-2.5-pro-preview-tts";
28}
29
30#[derive(Debug, Clone)]
32pub struct SpeakerConfig {
33 pub name: String,
35 pub voice: String,
37}
38
39impl SpeakerConfig {
40 pub fn new(name: impl Into<String>, voice: impl Into<String>) -> Self {
42 Self { name: name.into(), voice: voice.into() }
43 }
44}
45
46pub struct GeminiTts {
66 config: CloudTtsConfig,
67 client: reqwest::Client,
68 model: String,
69 voices: Vec<Voice>,
70 speakers: Option<Vec<SpeakerConfig>>,
71}
72
73impl GeminiTts {
74 pub fn from_env() -> AudioResult<Self> {
76 let api_key = std::env::var("GEMINI_API_KEY")
77 .or_else(|_| std::env::var("GOOGLE_API_KEY"))
78 .map_err(|_| AudioError::Tts {
79 provider: "gemini".into(),
80 message: "GEMINI_API_KEY or GOOGLE_API_KEY not set".into(),
81 })?;
82 Ok(Self::new(CloudTtsConfig::new(api_key)))
83 }
84
85 pub fn new(config: CloudTtsConfig) -> Self {
87 Self {
88 config,
89 client: reqwest::Client::new(),
90 model: models::GEMINI_3_1_FLASH_TTS.into(),
91 voices: build_voice_catalog(),
92 speakers: None,
93 }
94 }
95
96 pub fn with_model(mut self, model: impl Into<String>) -> Self {
98 self.model = model.into();
99 self
100 }
101
102 pub fn with_speakers(mut self, speakers: Vec<SpeakerConfig>) -> Self {
107 self.speakers = Some(speakers);
108 self
109 }
110
111 fn base_url(&self) -> String {
112 self.config.base_url.clone().unwrap_or_else(|| {
113 format!(
114 "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent",
115 self.model
116 )
117 })
118 }
119
120 fn build_speech_config(&self, voice: &str) -> serde_json::Value {
121 match &self.speakers {
122 Some(speakers) if !speakers.is_empty() => {
123 let speaker_configs: Vec<serde_json::Value> = speakers
124 .iter()
125 .map(|s| {
126 serde_json::json!({
127 "speaker": s.name,
128 "voiceConfig": {
129 "prebuiltVoiceConfig": {
130 "voiceName": s.voice
131 }
132 }
133 })
134 })
135 .collect();
136 serde_json::json!({
137 "multiSpeakerVoiceConfig": {
138 "speakerVoiceConfigs": speaker_configs
139 }
140 })
141 }
142 _ => {
143 let voice_name = if voice.is_empty() { "Kore" } else { voice };
144 serde_json::json!({
145 "voiceConfig": {
146 "prebuiltVoiceConfig": {
147 "voiceName": voice_name
148 }
149 }
150 })
151 }
152 }
153 }
154}
155
156#[async_trait]
157impl TtsProvider for GeminiTts {
158 async fn synthesize(&self, request: &TtsRequest) -> AudioResult<AudioFrame> {
159 let url = self.base_url();
160 let speech_config = self.build_speech_config(&request.voice);
161
162 let body = serde_json::json!({
163 "contents": [{"parts": [{"text": request.text}]}],
164 "generationConfig": {
165 "response_modalities": ["AUDIO"],
166 "speech_config": speech_config
167 }
168 });
169
170 let resp = self
171 .client
172 .post(&url)
173 .header("x-goog-api-key", &self.config.api_key)
174 .json(&body)
175 .send()
176 .await
177 .map_err(|e| AudioError::Tts { provider: "gemini".into(), message: e.to_string() })?;
178
179 if !resp.status().is_success() {
180 let status = resp.status();
181 let body = resp.text().await.unwrap_or_default();
182 return Err(AudioError::Tts {
183 provider: "gemini".into(),
184 message: format!("HTTP {status}: {body}"),
185 });
186 }
187
188 let json: serde_json::Value = resp
189 .json()
190 .await
191 .map_err(|e| AudioError::Tts { provider: "gemini".into(), message: e.to_string() })?;
192
193 let audio_b64 = json["candidates"][0]["content"]["parts"][0]["inlineData"]["data"]
194 .as_str()
195 .ok_or_else(|| AudioError::Tts {
196 provider: "gemini".into(),
197 message: "no audio data in response".into(),
198 })?;
199
200 use base64::Engine;
201 let pcm = base64::engine::general_purpose::STANDARD.decode(audio_b64).map_err(|e| {
202 AudioError::Tts {
203 provider: "gemini".into(),
204 message: format!("base64 decode failed: {e}"),
205 }
206 })?;
207
208 Ok(AudioFrame::new(Bytes::from(pcm), 24000, 1))
209 }
210
211 async fn synthesize_stream(
212 &self,
213 request: &TtsRequest,
214 ) -> AudioResult<Pin<Box<dyn Stream<Item = AudioResult<AudioFrame>> + Send>>> {
215 let frame = self.synthesize(request).await?;
217 Ok(Box::pin(futures::stream::once(async { Ok(frame) })))
218 }
219
220 fn voice_catalog(&self) -> &[Voice] {
221 &self.voices
222 }
223}
224
225fn build_voice_catalog() -> Vec<Voice> {
227 let voices = [
228 ("Zephyr", "Bright"),
229 ("Puck", "Upbeat"),
230 ("Charon", "Informative"),
231 ("Kore", "Firm"),
232 ("Fenrir", "Excitable"),
233 ("Leda", "Youthful"),
234 ("Orus", "Firm"),
235 ("Aoede", "Breezy"),
236 ("Callirrhoe", "Easy-going"),
237 ("Autonoe", "Bright"),
238 ("Enceladus", "Breathy"),
239 ("Iapetus", "Clear"),
240 ("Umbriel", "Easy-going"),
241 ("Algieba", "Smooth"),
242 ("Despina", "Smooth"),
243 ("Erinome", "Clear"),
244 ("Algenib", "Gravelly"),
245 ("Rasalgethi", "Informative"),
246 ("Laomedeia", "Upbeat"),
247 ("Achernar", "Soft"),
248 ("Alnilam", "Firm"),
249 ("Schedar", "Even"),
250 ("Gacrux", "Mature"),
251 ("Pulcherrima", "Forward"),
252 ("Achird", "Friendly"),
253 ("Zubenelgenubi", "Casual"),
254 ("Vindemiatrix", "Gentle"),
255 ("Sadachbia", "Lively"),
256 ("Sadaltager", "Knowledgeable"),
257 ("Sulafat", "Warm"),
258 ];
259
260 voices
261 .iter()
262 .map(|(name, style)| Voice {
263 id: name.to_string(),
264 name: format!("{name} — {style}"),
265 language: "multilingual".into(),
266 gender: None,
267 })
268 .collect()
269}