use anyhow::{Context, Result};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::rate_limiter::RateLimiter;
const GOOGLE_TTS_API_BASE: &str = "https://texttospeech.googleapis.com/v1";
pub struct GoogleTtsClient {
api_key: String,
base_url: String,
http_client: Client,
rate_limiter: Option<std::sync::Arc<RateLimiter>>,
}
impl GoogleTtsClient {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
base_url: GOOGLE_TTS_API_BASE.to_string(),
http_client: Client::new(),
rate_limiter: None,
}
}
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
pub fn with_rate_limit(mut self, requests_per_minute: u32) -> Self {
self.rate_limiter = Some(std::sync::Arc::new(RateLimiter::new(requests_per_minute)));
self
}
async fn acquire_rate_limit(&self) {
if let Some(ref limiter) = self.rate_limiter {
limiter.acquire().await;
}
}
pub async fn synthesize(
&self,
req: &GoogleTtsSynthesizeRequest,
) -> Result<GoogleTtsSynthesizeResponse> {
self.acquire_rate_limit().await;
let url = format!("{}/text:synthesize", self.base_url);
let response = self
.http_client
.post(&url)
.header("X-Goog-Api-Key", &self.api_key)
.header("Content-Type", "application/json")
.json(req)
.send()
.await
.context("Failed to send Google TTS request")?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("Google TTS API error ({}): {}", status, body);
}
response
.json()
.await
.context("Failed to parse Google TTS response")
}
pub async fn list_voices(
&self,
language_code: Option<&str>,
) -> Result<GoogleTtsVoicesResponse> {
self.acquire_rate_limit().await;
let mut url = format!("{}/voices", self.base_url);
if let Some(lang) = language_code {
url = format!("{}?languageCode={}", url, lang);
}
let response = self
.http_client
.get(&url)
.header("X-Goog-Api-Key", &self.api_key)
.send()
.await
.context("Failed to list Google TTS voices")?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("Google TTS voices API error ({}): {}", status, body);
}
response
.json()
.await
.context("Failed to parse Google TTS voices response")
}
}
#[derive(Debug, Clone, Serialize)]
pub struct GoogleTtsSynthesizeRequest {
pub input: GoogleTtsInput,
pub voice: GoogleTtsVoiceSelection,
#[serde(rename = "audioConfig")]
pub audio_config: GoogleTtsAudioConfig,
}
#[derive(Debug, Clone, Serialize)]
pub struct GoogleTtsInput {
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ssml: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GoogleTtsVoiceSelection {
#[serde(rename = "languageCode")]
pub language_code: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(rename = "ssmlGender", skip_serializing_if = "Option::is_none")]
pub ssml_gender: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct GoogleTtsAudioConfig {
#[serde(rename = "audioEncoding")]
pub audio_encoding: String,
#[serde(rename = "speakingRate", skip_serializing_if = "Option::is_none")]
pub speaking_rate: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pitch: Option<f32>,
#[serde(rename = "sampleRateHertz", skip_serializing_if = "Option::is_none")]
pub sample_rate_hertz: Option<u32>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct GoogleTtsSynthesizeResponse {
#[serde(rename = "audioContent")]
pub audio_content: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct GoogleTtsVoicesResponse {
#[serde(default)]
pub voices: Vec<GoogleTtsVoiceEntry>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct GoogleTtsVoiceEntry {
#[serde(rename = "languageCodes", default)]
pub language_codes: Vec<String>,
pub name: String,
#[serde(rename = "ssmlGender")]
pub ssml_gender: Option<String>,
#[serde(rename = "naturalSampleRateHertz")]
pub natural_sample_rate_hertz: Option<u32>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = GoogleTtsClient::new("test-key");
assert_eq!(client.base_url, GOOGLE_TTS_API_BASE);
}
#[test]
fn test_synthesize_request_serialization() {
let req = GoogleTtsSynthesizeRequest {
input: GoogleTtsInput {
text: Some("Hello world".to_string()),
ssml: None,
},
voice: GoogleTtsVoiceSelection {
language_code: "en-US".to_string(),
name: Some("en-US-Neural2-A".to_string()),
ssml_gender: None,
},
audio_config: GoogleTtsAudioConfig {
audio_encoding: "LINEAR16".to_string(),
speaking_rate: None,
pitch: None,
sample_rate_hertz: Some(24000),
},
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["input"]["text"], "Hello world");
assert_eq!(json["voice"]["languageCode"], "en-US");
}
}