use adk_rust_mcp_common::auth::AuthProvider;
use adk_rust_mcp_common::config::Config;
use adk_rust_mcp_common::error::Error;
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::path::Path;
use tracing::{debug, info, instrument};
pub const DEFAULT_VOICE: &str = "en-US-Chirp3-HD-Achernar";
pub const DEFAULT_LANGUAGE_CODE: &str = "en-US";
pub const DEFAULT_SPEAKING_RATE: f32 = 1.0;
pub const MIN_SPEAKING_RATE: f32 = 0.25;
pub const MAX_SPEAKING_RATE: f32 = 4.0;
pub const DEFAULT_PITCH: f32 = 0.0;
pub const MIN_PITCH: f32 = -20.0;
pub const MAX_PITCH: f32 = 20.0;
pub const VALID_ALPHABETS: &[&str] = &["ipa", "x-sampa"];
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
pub struct Pronunciation {
pub word: String,
pub phonetic: String,
pub alphabet: String,
}
impl Pronunciation {
pub fn validate(&self) -> Result<(), ValidationError> {
if self.word.trim().is_empty() {
return Err(ValidationError {
field: "word".to_string(),
message: "Word cannot be empty".to_string(),
});
}
if self.phonetic.trim().is_empty() {
return Err(ValidationError {
field: "phonetic".to_string(),
message: "Phonetic representation cannot be empty".to_string(),
});
}
let alphabet_lower = self.alphabet.to_lowercase();
if !VALID_ALPHABETS.contains(&alphabet_lower.as_str()) {
return Err(ValidationError {
field: "alphabet".to_string(),
message: format!(
"Invalid alphabet '{}'. Must be one of: {}",
self.alphabet,
VALID_ALPHABETS.join(", ")
),
});
}
Ok(())
}
pub fn to_ssml(&self) -> String {
let alphabet = self.alphabet.to_lowercase();
format!(
r#"<phoneme alphabet="{}" ph="{}">{}</phoneme>"#,
alphabet, self.phonetic, self.word
)
}
}
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
pub struct SpeechSynthesizeParams {
pub text: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub voice: Option<String>,
#[serde(default = "default_language_code")]
pub language_code: String,
#[serde(default = "default_speaking_rate")]
pub speaking_rate: f32,
#[serde(default)]
pub pitch: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub pronunciations: Option<Vec<Pronunciation>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_file: Option<String>,
}
fn default_language_code() -> String {
DEFAULT_LANGUAGE_CODE.to_string()
}
fn default_speaking_rate() -> f32 {
DEFAULT_SPEAKING_RATE
}
#[derive(Debug, Clone)]
pub struct ValidationError {
pub field: String,
pub message: String,
}
impl std::fmt::Display for ValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.field, self.message)
}
}
impl SpeechSynthesizeParams {
pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
let mut errors = Vec::new();
if self.text.trim().is_empty() {
errors.push(ValidationError {
field: "text".to_string(),
message: "Text cannot be empty".to_string(),
});
}
if self.speaking_rate < MIN_SPEAKING_RATE || self.speaking_rate > MAX_SPEAKING_RATE {
errors.push(ValidationError {
field: "speaking_rate".to_string(),
message: format!(
"speaking_rate must be between {} and {}, got {}",
MIN_SPEAKING_RATE, MAX_SPEAKING_RATE, self.speaking_rate
),
});
}
if self.pitch < MIN_PITCH || self.pitch > MAX_PITCH {
errors.push(ValidationError {
field: "pitch".to_string(),
message: format!(
"pitch must be between {} and {} semitones, got {}",
MIN_PITCH, MAX_PITCH, self.pitch
),
});
}
if let Some(ref pronunciations) = self.pronunciations {
for (i, pron) in pronunciations.iter().enumerate() {
if let Err(e) = pron.validate() {
errors.push(ValidationError {
field: format!("pronunciations[{}].{}", i, e.field),
message: e.message,
});
}
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
pub fn get_voice(&self) -> &str {
self.voice.as_deref().unwrap_or(DEFAULT_VOICE)
}
pub fn build_ssml(&self) -> String {
let mut text = self.text.clone();
if let Some(ref pronunciations) = self.pronunciations {
for pron in pronunciations {
text = text.replace(&pron.word, &pron.to_ssml());
}
}
format!(r#"<speak>{}</speak>"#, text)
}
}
pub struct SpeechHandler {
pub config: Config,
pub http: reqwest::Client,
pub auth: AuthProvider,
}
impl SpeechHandler {
#[instrument(level = "debug", name = "speech_handler_new", skip_all)]
pub async fn new(config: Config) -> Result<Self, Error> {
debug!("Initializing SpeechHandler");
let auth = AuthProvider::new().await?;
let http = reqwest::Client::new();
Ok(Self { config, http, auth })
}
#[cfg(test)]
pub fn with_deps(config: Config, http: reqwest::Client, auth: AuthProvider) -> Self {
Self { config, http, auth }
}
pub fn get_endpoint(&self) -> String {
if self.config.is_gemini() {
"https://texttospeech.googleapis.com/v1/text:synthesize".to_string()
} else {
"https://texttospeech.googleapis.com/v1/text:synthesize".to_string()
}
}
pub fn get_voices_endpoint(&self) -> String {
"https://texttospeech.googleapis.com/v1/voices".to_string()
}
async fn add_auth(&self, builder: reqwest::RequestBuilder) -> Result<reqwest::RequestBuilder, Error> {
if self.config.is_gemini() {
let key = self.config.gemini_api_key.as_deref().unwrap_or_default();
Ok(builder.header("x-goog-api-key", key))
} else {
let token = self.auth.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
Ok(builder
.header("Authorization", format!("Bearer {}", token))
.header("x-goog-user-project", &self.config.project_id))
}
}
#[instrument(level = "info", name = "synthesize_speech", skip(self, params))]
pub async fn synthesize(&self, params: SpeechSynthesizeParams) -> Result<SpeechSynthesizeResult, Error> {
params.validate().map_err(|errors| {
let messages: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
Error::validation(messages.join("; "))
})?;
info!(voice = %params.get_voice(), "Synthesizing speech with Cloud TTS API");
let (input, use_ssml) = if params.pronunciations.is_some() {
(params.build_ssml(), true)
} else {
(params.text.clone(), false)
};
let request = TtsRequest {
input: TtsInput {
text: if use_ssml { None } else { Some(input.clone()) },
ssml: if use_ssml { Some(input) } else { None },
},
voice: TtsVoice {
language_code: params.language_code.clone(),
name: params.get_voice().to_string(),
},
audio_config: TtsAudioConfig {
audio_encoding: "LINEAR16".to_string(),
speaking_rate: Some(params.speaking_rate),
pitch: Some(params.pitch),
sample_rate_hertz: Some(24000),
},
};
let endpoint = self.get_endpoint();
debug!(endpoint = %endpoint, "Calling Cloud TTS API");
let builder = self.http
.post(&endpoint)
.header("Content-Type", "application/json")
.json(&request);
let builder = self.add_auth(builder).await?;
let response = builder
.send()
.await
.map_err(|e| Error::api(&endpoint, 0, format!("Request failed: {}", e)))?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(Error::api(&endpoint, status.as_u16(), body));
}
let api_response: TtsResponse = response.json().await.map_err(|e| {
Error::api(
&endpoint,
status.as_u16(),
format!("Failed to parse response: {}", e),
)
})?;
let audio_data = api_response.audio_content;
if audio_data.is_empty() {
return Err(Error::api(&endpoint, 200, "No audio content returned from API"));
}
info!("Received audio data from Cloud TTS API");
let audio = GeneratedAudio {
data: audio_data,
mime_type: "audio/wav".to_string(),
};
self.handle_output(audio, ¶ms).await
}
#[instrument(level = "info", name = "list_voices", skip(self))]
pub async fn list_voices(&self) -> Result<Vec<VoiceInfo>, Error> {
info!("Listing available voices from Cloud TTS API");
let endpoint = self.get_voices_endpoint();
debug!(endpoint = %endpoint, "Calling Cloud TTS voices API");
let builder = self.http.get(&endpoint);
let builder = self.add_auth(builder).await?;
let response = builder
.send()
.await
.map_err(|e| Error::api(&endpoint, 0, format!("Request failed: {}", e)))?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(Error::api(&endpoint, status.as_u16(), body));
}
let api_response: VoicesResponse = response.json().await.map_err(|e| {
Error::api(
&endpoint,
status.as_u16(),
format!("Failed to parse response: {}", e),
)
})?;
let chirp3_voices: Vec<VoiceInfo> = api_response
.voices
.into_iter()
.filter(|v| v.name.contains("Chirp3-HD"))
.map(|v| VoiceInfo {
name: v.name,
language_codes: v.language_codes,
ssml_gender: v.ssml_gender,
natural_sample_rate_hertz: v.natural_sample_rate_hertz,
})
.collect();
info!(count = chirp3_voices.len(), "Found Chirp3-HD voices");
Ok(chirp3_voices)
}
async fn handle_output(
&self,
audio: GeneratedAudio,
params: &SpeechSynthesizeParams,
) -> Result<SpeechSynthesizeResult, Error> {
if let Some(output_file) = ¶ms.output_file {
return self.save_to_file(audio, output_file).await;
}
Ok(SpeechSynthesizeResult::Base64(audio))
}
async fn save_to_file(
&self,
audio: GeneratedAudio,
output_file: &str,
) -> Result<SpeechSynthesizeResult, Error> {
let data = BASE64.decode(&audio.data).map_err(|e| {
Error::validation(format!("Invalid base64 data: {}", e))
})?;
if let Some(parent) = Path::new(output_file).parent() {
if !parent.as_os_str().is_empty() {
tokio::fs::create_dir_all(parent).await?;
}
}
tokio::fs::write(output_file, &data).await?;
info!(path = %output_file, "Saved audio to local file");
Ok(SpeechSynthesizeResult::LocalFile(output_file.to_string()))
}
}
#[derive(Debug, Serialize)]
pub struct TtsRequest {
pub input: TtsInput,
pub voice: TtsVoice,
#[serde(rename = "audioConfig")]
pub audio_config: TtsAudioConfig,
}
#[derive(Debug, Serialize)]
pub struct TtsInput {
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ssml: Option<String>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct TtsVoice {
pub language_code: String,
pub name: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct TtsAudioConfig {
pub audio_encoding: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub speaking_rate: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pitch: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sample_rate_hertz: Option<u32>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TtsResponse {
pub audio_content: String,
}
#[derive(Debug, Deserialize)]
pub struct VoicesResponse {
pub voices: Vec<ApiVoiceInfo>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiVoiceInfo {
pub name: String,
pub language_codes: Vec<String>,
pub ssml_gender: Option<String>,
pub natural_sample_rate_hertz: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct GeneratedAudio {
pub data: String,
pub mime_type: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct VoiceInfo {
pub name: String,
pub language_codes: Vec<String>,
pub ssml_gender: Option<String>,
pub natural_sample_rate_hertz: Option<u32>,
}
#[derive(Debug)]
pub enum SpeechSynthesizeResult {
Base64(GeneratedAudio),
LocalFile(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_params() {
let params: SpeechSynthesizeParams =
serde_json::from_str(r#"{"text": "Hello world"}"#).unwrap();
assert_eq!(params.language_code, DEFAULT_LANGUAGE_CODE);
assert_eq!(params.speaking_rate, DEFAULT_SPEAKING_RATE);
assert_eq!(params.pitch, DEFAULT_PITCH);
assert!(params.voice.is_none());
assert!(params.pronunciations.is_none());
assert!(params.output_file.is_none());
}
#[test]
fn test_valid_params() {
let params = SpeechSynthesizeParams {
text: "Hello world".to_string(),
voice: Some("en-US-Chirp3-HD-Achernar".to_string()),
language_code: "en-US".to_string(),
speaking_rate: 1.5,
pitch: 2.0,
pronunciations: None,
output_file: None,
};
assert!(params.validate().is_ok());
}
#[test]
fn test_empty_text() {
let params = SpeechSynthesizeParams {
text: " ".to_string(),
voice: None,
language_code: "en-US".to_string(),
speaking_rate: 1.0,
pitch: 0.0,
pronunciations: None,
output_file: None,
};
let result = params.validate();
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(errors.iter().any(|e| e.field == "text"));
}
#[test]
fn test_speaking_rate_too_low() {
let params = SpeechSynthesizeParams {
text: "Hello".to_string(),
voice: None,
language_code: "en-US".to_string(),
speaking_rate: 0.1,
pitch: 0.0,
pronunciations: None,
output_file: None,
};
let result = params.validate();
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(errors.iter().any(|e| e.field == "speaking_rate"));
}
#[test]
fn test_speaking_rate_too_high() {
let params = SpeechSynthesizeParams {
text: "Hello".to_string(),
voice: None,
language_code: "en-US".to_string(),
speaking_rate: 5.0,
pitch: 0.0,
pronunciations: None,
output_file: None,
};
let result = params.validate();
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(errors.iter().any(|e| e.field == "speaking_rate"));
}
#[test]
fn test_pitch_too_low() {
let params = SpeechSynthesizeParams {
text: "Hello".to_string(),
voice: None,
language_code: "en-US".to_string(),
speaking_rate: 1.0,
pitch: -25.0,
pronunciations: None,
output_file: None,
};
let result = params.validate();
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(errors.iter().any(|e| e.field == "pitch"));
}
#[test]
fn test_pitch_too_high() {
let params = SpeechSynthesizeParams {
text: "Hello".to_string(),
voice: None,
language_code: "en-US".to_string(),
speaking_rate: 1.0,
pitch: 25.0,
pronunciations: None,
output_file: None,
};
let result = params.validate();
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(errors.iter().any(|e| e.field == "pitch"));
}
#[test]
fn test_valid_speaking_rate_boundaries() {
let params = SpeechSynthesizeParams {
text: "Hello".to_string(),
voice: None,
language_code: "en-US".to_string(),
speaking_rate: MIN_SPEAKING_RATE,
pitch: 0.0,
pronunciations: None,
output_file: None,
};
assert!(params.validate().is_ok());
let params = SpeechSynthesizeParams {
text: "Hello".to_string(),
voice: None,
language_code: "en-US".to_string(),
speaking_rate: MAX_SPEAKING_RATE,
pitch: 0.0,
pronunciations: None,
output_file: None,
};
assert!(params.validate().is_ok());
}
#[test]
fn test_valid_pitch_boundaries() {
let params = SpeechSynthesizeParams {
text: "Hello".to_string(),
voice: None,
language_code: "en-US".to_string(),
speaking_rate: 1.0,
pitch: MIN_PITCH,
pronunciations: None,
output_file: None,
};
assert!(params.validate().is_ok());
let params = SpeechSynthesizeParams {
text: "Hello".to_string(),
voice: None,
language_code: "en-US".to_string(),
speaking_rate: 1.0,
pitch: MAX_PITCH,
pronunciations: None,
output_file: None,
};
assert!(params.validate().is_ok());
}
#[test]
fn test_pronunciation_valid_ipa() {
let pron = Pronunciation {
word: "tomato".to_string(),
phonetic: "təˈmeɪtoʊ".to_string(),
alphabet: "ipa".to_string(),
};
assert!(pron.validate().is_ok());
}
#[test]
fn test_pronunciation_valid_xsampa() {
let pron = Pronunciation {
word: "tomato".to_string(),
phonetic: "t@\"meItoU".to_string(),
alphabet: "x-sampa".to_string(),
};
assert!(pron.validate().is_ok());
}
#[test]
fn test_pronunciation_invalid_alphabet() {
let pron = Pronunciation {
word: "tomato".to_string(),
phonetic: "tomato".to_string(),
alphabet: "invalid".to_string(),
};
let result = pron.validate();
assert!(result.is_err());
assert!(result.unwrap_err().field == "alphabet");
}
#[test]
fn test_pronunciation_empty_word() {
let pron = Pronunciation {
word: "".to_string(),
phonetic: "test".to_string(),
alphabet: "ipa".to_string(),
};
let result = pron.validate();
assert!(result.is_err());
assert!(result.unwrap_err().field == "word");
}
#[test]
fn test_pronunciation_empty_phonetic() {
let pron = Pronunciation {
word: "test".to_string(),
phonetic: "".to_string(),
alphabet: "ipa".to_string(),
};
let result = pron.validate();
assert!(result.is_err());
assert!(result.unwrap_err().field == "phonetic");
}
#[test]
fn test_pronunciation_to_ssml() {
let pron = Pronunciation {
word: "tomato".to_string(),
phonetic: "təˈmeɪtoʊ".to_string(),
alphabet: "ipa".to_string(),
};
let ssml = pron.to_ssml();
assert!(ssml.contains("phoneme"));
assert!(ssml.contains("ipa"));
assert!(ssml.contains("təˈmeɪtoʊ"));
assert!(ssml.contains("tomato"));
}
#[test]
fn test_build_ssml_with_pronunciations() {
let params = SpeechSynthesizeParams {
text: "I like tomato".to_string(),
voice: None,
language_code: "en-US".to_string(),
speaking_rate: 1.0,
pitch: 0.0,
pronunciations: Some(vec![Pronunciation {
word: "tomato".to_string(),
phonetic: "təˈmeɪtoʊ".to_string(),
alphabet: "ipa".to_string(),
}]),
output_file: None,
};
let ssml = params.build_ssml();
assert!(ssml.starts_with("<speak>"));
assert!(ssml.ends_with("</speak>"));
assert!(ssml.contains("phoneme"));
assert!(!ssml.contains("tomato</speak>")); }
#[test]
fn test_build_ssml_without_pronunciations() {
let params = SpeechSynthesizeParams {
text: "Hello world".to_string(),
voice: None,
language_code: "en-US".to_string(),
speaking_rate: 1.0,
pitch: 0.0,
pronunciations: None,
output_file: None,
};
let ssml = params.build_ssml();
assert_eq!(ssml, "<speak>Hello world</speak>");
}
#[test]
fn test_get_voice_default() {
let params = SpeechSynthesizeParams {
text: "Hello".to_string(),
voice: None,
language_code: "en-US".to_string(),
speaking_rate: 1.0,
pitch: 0.0,
pronunciations: None,
output_file: None,
};
assert_eq!(params.get_voice(), DEFAULT_VOICE);
}
#[test]
fn test_get_voice_custom() {
let params = SpeechSynthesizeParams {
text: "Hello".to_string(),
voice: Some("custom-voice".to_string()),
language_code: "en-US".to_string(),
speaking_rate: 1.0,
pitch: 0.0,
pronunciations: None,
output_file: None,
};
assert_eq!(params.get_voice(), "custom-voice");
}
#[test]
fn test_params_with_invalid_pronunciation() {
let params = SpeechSynthesizeParams {
text: "Hello".to_string(),
voice: None,
language_code: "en-US".to_string(),
speaking_rate: 1.0,
pitch: 0.0,
pronunciations: Some(vec![Pronunciation {
word: "test".to_string(),
phonetic: "test".to_string(),
alphabet: "invalid".to_string(),
}]),
output_file: None,
};
let result = params.validate();
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(errors.iter().any(|e| e.field.contains("pronunciations")));
}
#[test]
fn test_serialization_roundtrip() {
let params = SpeechSynthesizeParams {
text: "Hello world".to_string(),
voice: Some("en-US-Chirp3-HD-Achernar".to_string()),
language_code: "en-US".to_string(),
speaking_rate: 1.5,
pitch: 2.0,
pronunciations: Some(vec![Pronunciation {
word: "hello".to_string(),
phonetic: "həˈloʊ".to_string(),
alphabet: "ipa".to_string(),
}]),
output_file: Some("/tmp/output.wav".to_string()),
};
let json = serde_json::to_string(¶ms).unwrap();
let deserialized: SpeechSynthesizeParams = serde_json::from_str(&json).unwrap();
assert_eq!(params.text, deserialized.text);
assert_eq!(params.voice, deserialized.voice);
assert_eq!(params.language_code, deserialized.language_code);
assert_eq!(params.speaking_rate, deserialized.speaking_rate);
assert_eq!(params.pitch, deserialized.pitch);
assert_eq!(params.output_file, deserialized.output_file);
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
fn valid_speaking_rate_strategy() -> impl Strategy<Value = f32> {
(MIN_SPEAKING_RATE..=MAX_SPEAKING_RATE).prop_map(|x| (x * 100.0).round() / 100.0)
}
fn invalid_speaking_rate_strategy() -> impl Strategy<Value = f32> {
prop_oneof![
(0.0f32..0.24f32).prop_map(|x| (x * 100.0).round() / 100.0),
(4.01f32..10.0f32).prop_map(|x| (x * 100.0).round() / 100.0),
]
}
fn valid_pitch_strategy() -> impl Strategy<Value = f32> {
(MIN_PITCH..=MAX_PITCH).prop_map(|x| (x * 10.0).round() / 10.0)
}
fn invalid_pitch_strategy() -> impl Strategy<Value = f32> {
prop_oneof![
(-50.0f32..MIN_PITCH).prop_map(|x| (x * 10.0).round() / 10.0),
(MAX_PITCH + 0.1..50.0f32).prop_map(|x| (x * 10.0).round() / 10.0),
]
}
fn valid_text_strategy() -> impl Strategy<Value = String> {
"[a-zA-Z0-9 ]{1,100}"
.prop_map(|s| s.trim().to_string())
.prop_filter("Must not be empty", |s| !s.trim().is_empty())
}
proptest! {
#[test]
fn valid_speaking_rate_passes_validation(
rate in valid_speaking_rate_strategy(),
text in valid_text_strategy(),
) {
let params = SpeechSynthesizeParams {
text,
voice: None,
language_code: "en-US".to_string(),
speaking_rate: rate,
pitch: 0.0,
pronunciations: None,
output_file: None,
};
let result = params.validate();
prop_assert!(
result.is_ok(),
"speaking_rate {} should be valid, but got errors: {:?}",
rate,
result.err()
);
}
#[test]
fn invalid_speaking_rate_fails_validation(
rate in invalid_speaking_rate_strategy(),
text in valid_text_strategy(),
) {
let params = SpeechSynthesizeParams {
text,
voice: None,
language_code: "en-US".to_string(),
speaking_rate: rate,
pitch: 0.0,
pronunciations: None,
output_file: None,
};
let result = params.validate();
prop_assert!(
result.is_err(),
"speaking_rate {} should be invalid",
rate
);
let errors = result.unwrap_err();
prop_assert!(
errors.iter().any(|e| e.field == "speaking_rate"),
"Should have a speaking_rate validation error for value {}",
rate
);
}
#[test]
fn valid_pitch_passes_validation(
pitch in valid_pitch_strategy(),
text in valid_text_strategy(),
) {
let params = SpeechSynthesizeParams {
text,
voice: None,
language_code: "en-US".to_string(),
speaking_rate: 1.0,
pitch,
pronunciations: None,
output_file: None,
};
let result = params.validate();
prop_assert!(
result.is_ok(),
"pitch {} should be valid, but got errors: {:?}",
pitch,
result.err()
);
}
#[test]
fn invalid_pitch_fails_validation(
pitch in invalid_pitch_strategy(),
text in valid_text_strategy(),
) {
let params = SpeechSynthesizeParams {
text,
voice: None,
language_code: "en-US".to_string(),
speaking_rate: 1.0,
pitch,
pronunciations: None,
output_file: None,
};
let result = params.validate();
prop_assert!(
result.is_err(),
"pitch {} should be invalid",
pitch
);
let errors = result.unwrap_err();
prop_assert!(
errors.iter().any(|e| e.field == "pitch"),
"Should have a pitch validation error for value {}",
pitch
);
}
#[test]
fn valid_speaking_rate_and_pitch_passes_validation(
rate in valid_speaking_rate_strategy(),
pitch in valid_pitch_strategy(),
text in valid_text_strategy(),
) {
let params = SpeechSynthesizeParams {
text,
voice: None,
language_code: "en-US".to_string(),
speaking_rate: rate,
pitch,
pronunciations: None,
output_file: None,
};
let result = params.validate();
prop_assert!(
result.is_ok(),
"speaking_rate {} and pitch {} should be valid, but got errors: {:?}",
rate,
pitch,
result.err()
);
}
}
fn valid_alphabet_strategy() -> impl Strategy<Value = String> {
prop_oneof![Just("ipa".to_string()), Just("x-sampa".to_string()),]
}
fn invalid_alphabet_strategy() -> impl Strategy<Value = String> {
"[a-z]{1,10}"
.prop_filter("Must not be valid alphabet", |s| {
let lower = s.to_lowercase();
lower != "ipa" && lower != "x-sampa"
})
}
fn valid_word_strategy() -> impl Strategy<Value = String> {
"[a-zA-Z]{1,20}".prop_filter("Must not be empty", |s| !s.trim().is_empty())
}
fn valid_phonetic_strategy() -> impl Strategy<Value = String> {
"[a-zA-Zəˈɪʊæɑɔɛʌ]{1,30}".prop_filter("Must not be empty", |s| !s.trim().is_empty())
}
proptest! {
#[test]
fn valid_alphabet_passes_validation(
alphabet in valid_alphabet_strategy(),
word in valid_word_strategy(),
phonetic in valid_phonetic_strategy(),
) {
let pron = Pronunciation {
word,
phonetic,
alphabet: alphabet.clone(),
};
let result = pron.validate();
prop_assert!(
result.is_ok(),
"alphabet '{}' should be valid, but got error: {:?}",
alphabet,
result.err()
);
}
#[test]
fn invalid_alphabet_fails_validation(
alphabet in invalid_alphabet_strategy(),
word in valid_word_strategy(),
phonetic in valid_phonetic_strategy(),
) {
let pron = Pronunciation {
word,
phonetic,
alphabet: alphabet.clone(),
};
let result = pron.validate();
prop_assert!(
result.is_err(),
"alphabet '{}' should be invalid",
alphabet
);
let error = result.unwrap_err();
prop_assert!(
error.field == "alphabet",
"Should have an alphabet validation error for value '{}'",
alphabet
);
}
#[test]
fn params_with_valid_pronunciation_passes_validation(
alphabet in valid_alphabet_strategy(),
word in valid_word_strategy(),
phonetic in valid_phonetic_strategy(),
text in valid_text_strategy(),
) {
let params = SpeechSynthesizeParams {
text,
voice: None,
language_code: "en-US".to_string(),
speaking_rate: 1.0,
pitch: 0.0,
pronunciations: Some(vec![Pronunciation {
word,
phonetic,
alphabet: alphabet.clone(),
}]),
output_file: None,
};
let result = params.validate();
prop_assert!(
result.is_ok(),
"params with alphabet '{}' should be valid, but got errors: {:?}",
alphabet,
result.err()
);
}
#[test]
fn params_with_invalid_pronunciation_fails_validation(
alphabet in invalid_alphabet_strategy(),
word in valid_word_strategy(),
phonetic in valid_phonetic_strategy(),
text in valid_text_strategy(),
) {
let params = SpeechSynthesizeParams {
text,
voice: None,
language_code: "en-US".to_string(),
speaking_rate: 1.0,
pitch: 0.0,
pronunciations: Some(vec![Pronunciation {
word,
phonetic,
alphabet: alphabet.clone(),
}]),
output_file: None,
};
let result = params.validate();
prop_assert!(
result.is_err(),
"params with alphabet '{}' should be invalid",
alphabet
);
let errors = result.unwrap_err();
prop_assert!(
errors.iter().any(|e| e.field.contains("pronunciations") && e.field.contains("alphabet")),
"Should have a pronunciations.alphabet validation error for value '{}'",
alphabet
);
}
#[test]
fn empty_text_fails_validation(
rate in valid_speaking_rate_strategy(),
pitch in valid_pitch_strategy(),
) {
let params = SpeechSynthesizeParams {
text: " ".to_string(),
voice: None,
language_code: "en-US".to_string(),
speaking_rate: rate,
pitch,
pronunciations: None,
output_file: None,
};
let result = params.validate();
prop_assert!(result.is_err());
let errors = result.unwrap_err();
prop_assert!(
errors.iter().any(|e| e.field == "text"),
"Should have a text validation error"
);
}
}
}