1use adk_rust_mcp_common::auth::AuthProvider;
7use adk_rust_mcp_common::config::Config;
8use adk_rust_mcp_common::error::Error;
9use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12use std::path::Path;
13use tracing::{debug, info, instrument};
14
15pub const DEFAULT_VOICE: &str = "en-US-Chirp3-HD-Achernar";
17
18pub const DEFAULT_LANGUAGE_CODE: &str = "en-US";
20
21pub const DEFAULT_SPEAKING_RATE: f32 = 1.0;
23
24pub const MIN_SPEAKING_RATE: f32 = 0.25;
26
27pub const MAX_SPEAKING_RATE: f32 = 4.0;
29
30pub const DEFAULT_PITCH: f32 = 0.0;
32
33pub const MIN_PITCH: f32 = -20.0;
35
36pub const MAX_PITCH: f32 = 20.0;
38
39pub const VALID_ALPHABETS: &[&str] = &["ipa", "x-sampa"];
41
42
43#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
47pub struct Pronunciation {
48 pub word: String,
50
51 pub phonetic: String,
53
54 pub alphabet: String,
56}
57
58impl Pronunciation {
59 pub fn validate(&self) -> Result<(), ValidationError> {
61 if self.word.trim().is_empty() {
62 return Err(ValidationError {
63 field: "word".to_string(),
64 message: "Word cannot be empty".to_string(),
65 });
66 }
67
68 if self.phonetic.trim().is_empty() {
69 return Err(ValidationError {
70 field: "phonetic".to_string(),
71 message: "Phonetic representation cannot be empty".to_string(),
72 });
73 }
74
75 let alphabet_lower = self.alphabet.to_lowercase();
76 if !VALID_ALPHABETS.contains(&alphabet_lower.as_str()) {
77 return Err(ValidationError {
78 field: "alphabet".to_string(),
79 message: format!(
80 "Invalid alphabet '{}'. Must be one of: {}",
81 self.alphabet,
82 VALID_ALPHABETS.join(", ")
83 ),
84 });
85 }
86
87 Ok(())
88 }
89
90 pub fn to_ssml(&self) -> String {
92 let alphabet = self.alphabet.to_lowercase();
93 format!(
94 r#"<phoneme alphabet="{}" ph="{}">{}</phoneme>"#,
95 alphabet, self.phonetic, self.word
96 )
97 }
98}
99
100#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
104pub struct SpeechSynthesizeParams {
105 pub text: String,
107
108 #[serde(default, skip_serializing_if = "Option::is_none")]
110 pub voice: Option<String>,
111
112 #[serde(default = "default_language_code")]
114 pub language_code: String,
115
116 #[serde(default = "default_speaking_rate")]
118 pub speaking_rate: f32,
119
120 #[serde(default)]
122 pub pitch: f32,
123
124 #[serde(default, skip_serializing_if = "Option::is_none")]
126 pub pronunciations: Option<Vec<Pronunciation>>,
127
128 #[serde(default, skip_serializing_if = "Option::is_none")]
131 pub output_file: Option<String>,
132}
133
134fn default_language_code() -> String {
135 DEFAULT_LANGUAGE_CODE.to_string()
136}
137
138fn default_speaking_rate() -> f32 {
139 DEFAULT_SPEAKING_RATE
140}
141
142
143#[derive(Debug, Clone)]
145pub struct ValidationError {
146 pub field: String,
148 pub message: String,
150}
151
152impl std::fmt::Display for ValidationError {
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 write!(f, "{}: {}", self.field, self.message)
155 }
156}
157
158impl SpeechSynthesizeParams {
159 pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
165 let mut errors = Vec::new();
166
167 if self.text.trim().is_empty() {
169 errors.push(ValidationError {
170 field: "text".to_string(),
171 message: "Text cannot be empty".to_string(),
172 });
173 }
174
175 if self.speaking_rate < MIN_SPEAKING_RATE || self.speaking_rate > MAX_SPEAKING_RATE {
177 errors.push(ValidationError {
178 field: "speaking_rate".to_string(),
179 message: format!(
180 "speaking_rate must be between {} and {}, got {}",
181 MIN_SPEAKING_RATE, MAX_SPEAKING_RATE, self.speaking_rate
182 ),
183 });
184 }
185
186 if self.pitch < MIN_PITCH || self.pitch > MAX_PITCH {
188 errors.push(ValidationError {
189 field: "pitch".to_string(),
190 message: format!(
191 "pitch must be between {} and {} semitones, got {}",
192 MIN_PITCH, MAX_PITCH, self.pitch
193 ),
194 });
195 }
196
197 if let Some(ref pronunciations) = self.pronunciations {
199 for (i, pron) in pronunciations.iter().enumerate() {
200 if let Err(e) = pron.validate() {
201 errors.push(ValidationError {
202 field: format!("pronunciations[{}].{}", i, e.field),
203 message: e.message,
204 });
205 }
206 }
207 }
208
209 if errors.is_empty() {
210 Ok(())
211 } else {
212 Err(errors)
213 }
214 }
215
216 pub fn get_voice(&self) -> &str {
218 self.voice.as_deref().unwrap_or(DEFAULT_VOICE)
219 }
220
221 pub fn build_ssml(&self) -> String {
223 let mut text = self.text.clone();
224
225 if let Some(ref pronunciations) = self.pronunciations {
227 for pron in pronunciations {
228 text = text.replace(&pron.word, &pron.to_ssml());
230 }
231 }
232
233 format!(r#"<speak>{}</speak>"#, text)
235 }
236}
237
238
239pub struct SpeechHandler {
243 pub config: Config,
245 pub http: reqwest::Client,
247 pub auth: AuthProvider,
249}
250
251impl SpeechHandler {
252 #[instrument(level = "debug", name = "speech_handler_new", skip_all)]
257 pub async fn new(config: Config) -> Result<Self, Error> {
258 debug!("Initializing SpeechHandler");
259
260 let auth = AuthProvider::new().await?;
261 let http = reqwest::Client::new();
262
263 Ok(Self { config, http, auth })
264 }
265
266 #[cfg(test)]
268 pub fn with_deps(config: Config, http: reqwest::Client, auth: AuthProvider) -> Self {
269 Self { config, http, auth }
270 }
271
272 pub fn get_endpoint(&self) -> String {
274 if self.config.is_gemini() {
275 "https://texttospeech.googleapis.com/v1/text:synthesize".to_string()
276 } else {
277 "https://texttospeech.googleapis.com/v1/text:synthesize".to_string()
278 }
279 }
280
281 pub fn get_voices_endpoint(&self) -> String {
283 "https://texttospeech.googleapis.com/v1/voices".to_string()
284 }
285
286 async fn add_auth(&self, builder: reqwest::RequestBuilder) -> Result<reqwest::RequestBuilder, Error> {
288 if self.config.is_gemini() {
289 let key = self.config.gemini_api_key.as_deref().unwrap_or_default();
290 Ok(builder.header("x-goog-api-key", key))
291 } else {
292 let token = self.auth.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
293 Ok(builder
294 .header("Authorization", format!("Bearer {}", token))
295 .header("x-goog-user-project", &self.config.project_id))
296 }
297 }
298
299 #[instrument(level = "info", name = "synthesize_speech", skip(self, params))]
308 pub async fn synthesize(&self, params: SpeechSynthesizeParams) -> Result<SpeechSynthesizeResult, Error> {
309 params.validate().map_err(|errors| {
311 let messages: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
312 Error::validation(messages.join("; "))
313 })?;
314
315 info!(voice = %params.get_voice(), "Synthesizing speech with Cloud TTS API");
316
317 let (input, use_ssml) = if params.pronunciations.is_some() {
319 (params.build_ssml(), true)
320 } else {
321 (params.text.clone(), false)
322 };
323
324 let request = TtsRequest {
326 input: TtsInput {
327 text: if use_ssml { None } else { Some(input.clone()) },
328 ssml: if use_ssml { Some(input) } else { None },
329 },
330 voice: TtsVoice {
331 language_code: params.language_code.clone(),
332 name: params.get_voice().to_string(),
333 },
334 audio_config: TtsAudioConfig {
335 audio_encoding: "LINEAR16".to_string(),
336 speaking_rate: Some(params.speaking_rate),
337 pitch: Some(params.pitch),
338 sample_rate_hertz: Some(24000),
339 },
340 };
341
342 let endpoint = self.get_endpoint();
344 debug!(endpoint = %endpoint, "Calling Cloud TTS API");
345
346 let builder = self.http
347 .post(&endpoint)
348 .header("Content-Type", "application/json")
349 .json(&request);
350 let builder = self.add_auth(builder).await?;
351
352 let response = builder
353 .send()
354 .await
355 .map_err(|e| Error::api(&endpoint, 0, format!("Request failed: {}", e)))?;
356
357 let status = response.status();
358 if !status.is_success() {
359 let body = response.text().await.unwrap_or_default();
360 return Err(Error::api(&endpoint, status.as_u16(), body));
361 }
362
363 let api_response: TtsResponse = response.json().await.map_err(|e| {
365 Error::api(
366 &endpoint,
367 status.as_u16(),
368 format!("Failed to parse response: {}", e),
369 )
370 })?;
371
372 let audio_data = api_response.audio_content;
373 if audio_data.is_empty() {
374 return Err(Error::api(&endpoint, 200, "No audio content returned from API"));
375 }
376
377 info!("Received audio data from Cloud TTS API");
378
379 let audio = GeneratedAudio {
380 data: audio_data,
381 mime_type: "audio/wav".to_string(),
382 };
383
384 self.handle_output(audio, ¶ms).await
386 }
387
388
389 #[instrument(level = "info", name = "list_voices", skip(self))]
395 pub async fn list_voices(&self) -> Result<Vec<VoiceInfo>, Error> {
396 info!("Listing available voices from Cloud TTS API");
397
398 let endpoint = self.get_voices_endpoint();
400 debug!(endpoint = %endpoint, "Calling Cloud TTS voices API");
401
402 let builder = self.http.get(&endpoint);
403 let builder = self.add_auth(builder).await?;
404
405 let response = builder
406 .send()
407 .await
408 .map_err(|e| Error::api(&endpoint, 0, format!("Request failed: {}", e)))?;
409
410 let status = response.status();
411 if !status.is_success() {
412 let body = response.text().await.unwrap_or_default();
413 return Err(Error::api(&endpoint, status.as_u16(), body));
414 }
415
416 let api_response: VoicesResponse = response.json().await.map_err(|e| {
418 Error::api(
419 &endpoint,
420 status.as_u16(),
421 format!("Failed to parse response: {}", e),
422 )
423 })?;
424
425 let chirp3_voices: Vec<VoiceInfo> = api_response
427 .voices
428 .into_iter()
429 .filter(|v| v.name.contains("Chirp3-HD"))
430 .map(|v| VoiceInfo {
431 name: v.name,
432 language_codes: v.language_codes,
433 ssml_gender: v.ssml_gender,
434 natural_sample_rate_hertz: v.natural_sample_rate_hertz,
435 })
436 .collect();
437
438 info!(count = chirp3_voices.len(), "Found Chirp3-HD voices");
439 Ok(chirp3_voices)
440 }
441
442 async fn handle_output(
444 &self,
445 audio: GeneratedAudio,
446 params: &SpeechSynthesizeParams,
447 ) -> Result<SpeechSynthesizeResult, Error> {
448 if let Some(output_file) = ¶ms.output_file {
450 return self.save_to_file(audio, output_file).await;
451 }
452
453 Ok(SpeechSynthesizeResult::Base64(audio))
455 }
456
457 async fn save_to_file(
459 &self,
460 audio: GeneratedAudio,
461 output_file: &str,
462 ) -> Result<SpeechSynthesizeResult, Error> {
463 let data = BASE64.decode(&audio.data).map_err(|e| {
465 Error::validation(format!("Invalid base64 data: {}", e))
466 })?;
467
468 if let Some(parent) = Path::new(output_file).parent() {
470 if !parent.as_os_str().is_empty() {
471 tokio::fs::create_dir_all(parent).await?;
472 }
473 }
474
475 tokio::fs::write(output_file, &data).await?;
477
478 info!(path = %output_file, "Saved audio to local file");
479 Ok(SpeechSynthesizeResult::LocalFile(output_file.to_string()))
480 }
481}
482
483
484#[derive(Debug, Serialize)]
490pub struct TtsRequest {
491 pub input: TtsInput,
493 pub voice: TtsVoice,
495 #[serde(rename = "audioConfig")]
497 pub audio_config: TtsAudioConfig,
498}
499
500#[derive(Debug, Serialize)]
502pub struct TtsInput {
503 #[serde(skip_serializing_if = "Option::is_none")]
505 pub text: Option<String>,
506 #[serde(skip_serializing_if = "Option::is_none")]
508 pub ssml: Option<String>,
509}
510
511#[derive(Debug, Serialize)]
513#[serde(rename_all = "camelCase")]
514pub struct TtsVoice {
515 pub language_code: String,
517 pub name: String,
519}
520
521#[derive(Debug, Serialize)]
523#[serde(rename_all = "camelCase")]
524pub struct TtsAudioConfig {
525 pub audio_encoding: String,
527 #[serde(skip_serializing_if = "Option::is_none")]
529 pub speaking_rate: Option<f32>,
530 #[serde(skip_serializing_if = "Option::is_none")]
532 pub pitch: Option<f32>,
533 #[serde(skip_serializing_if = "Option::is_none")]
535 pub sample_rate_hertz: Option<u32>,
536}
537
538#[derive(Debug, Deserialize)]
540#[serde(rename_all = "camelCase")]
541pub struct TtsResponse {
542 pub audio_content: String,
544}
545
546#[derive(Debug, Deserialize)]
548pub struct VoicesResponse {
549 pub voices: Vec<ApiVoiceInfo>,
551}
552
553#[derive(Debug, Deserialize)]
555#[serde(rename_all = "camelCase")]
556pub struct ApiVoiceInfo {
557 pub name: String,
559 pub language_codes: Vec<String>,
561 pub ssml_gender: Option<String>,
563 pub natural_sample_rate_hertz: Option<u32>,
565}
566
567#[derive(Debug, Clone)]
573pub struct GeneratedAudio {
574 pub data: String,
576 pub mime_type: String,
578}
579
580#[derive(Debug, Clone, Serialize)]
582pub struct VoiceInfo {
583 pub name: String,
585 pub language_codes: Vec<String>,
587 pub ssml_gender: Option<String>,
589 pub natural_sample_rate_hertz: Option<u32>,
591}
592
593#[derive(Debug)]
595pub enum SpeechSynthesizeResult {
596 Base64(GeneratedAudio),
598 LocalFile(String),
600}
601
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606
607 #[test]
608 fn test_default_params() {
609 let params: SpeechSynthesizeParams =
610 serde_json::from_str(r#"{"text": "Hello world"}"#).unwrap();
611 assert_eq!(params.language_code, DEFAULT_LANGUAGE_CODE);
612 assert_eq!(params.speaking_rate, DEFAULT_SPEAKING_RATE);
613 assert_eq!(params.pitch, DEFAULT_PITCH);
614 assert!(params.voice.is_none());
615 assert!(params.pronunciations.is_none());
616 assert!(params.output_file.is_none());
617 }
618
619 #[test]
620 fn test_valid_params() {
621 let params = SpeechSynthesizeParams {
622 text: "Hello world".to_string(),
623 voice: Some("en-US-Chirp3-HD-Achernar".to_string()),
624 language_code: "en-US".to_string(),
625 speaking_rate: 1.5,
626 pitch: 2.0,
627 pronunciations: None,
628 output_file: None,
629 };
630
631 assert!(params.validate().is_ok());
632 }
633
634 #[test]
635 fn test_empty_text() {
636 let params = SpeechSynthesizeParams {
637 text: " ".to_string(),
638 voice: None,
639 language_code: "en-US".to_string(),
640 speaking_rate: 1.0,
641 pitch: 0.0,
642 pronunciations: None,
643 output_file: None,
644 };
645
646 let result = params.validate();
647 assert!(result.is_err());
648 let errors = result.unwrap_err();
649 assert!(errors.iter().any(|e| e.field == "text"));
650 }
651
652 #[test]
653 fn test_speaking_rate_too_low() {
654 let params = SpeechSynthesizeParams {
655 text: "Hello".to_string(),
656 voice: None,
657 language_code: "en-US".to_string(),
658 speaking_rate: 0.1,
659 pitch: 0.0,
660 pronunciations: None,
661 output_file: None,
662 };
663
664 let result = params.validate();
665 assert!(result.is_err());
666 let errors = result.unwrap_err();
667 assert!(errors.iter().any(|e| e.field == "speaking_rate"));
668 }
669
670 #[test]
671 fn test_speaking_rate_too_high() {
672 let params = SpeechSynthesizeParams {
673 text: "Hello".to_string(),
674 voice: None,
675 language_code: "en-US".to_string(),
676 speaking_rate: 5.0,
677 pitch: 0.0,
678 pronunciations: None,
679 output_file: None,
680 };
681
682 let result = params.validate();
683 assert!(result.is_err());
684 let errors = result.unwrap_err();
685 assert!(errors.iter().any(|e| e.field == "speaking_rate"));
686 }
687
688 #[test]
689 fn test_pitch_too_low() {
690 let params = SpeechSynthesizeParams {
691 text: "Hello".to_string(),
692 voice: None,
693 language_code: "en-US".to_string(),
694 speaking_rate: 1.0,
695 pitch: -25.0,
696 pronunciations: None,
697 output_file: None,
698 };
699
700 let result = params.validate();
701 assert!(result.is_err());
702 let errors = result.unwrap_err();
703 assert!(errors.iter().any(|e| e.field == "pitch"));
704 }
705
706 #[test]
707 fn test_pitch_too_high() {
708 let params = SpeechSynthesizeParams {
709 text: "Hello".to_string(),
710 voice: None,
711 language_code: "en-US".to_string(),
712 speaking_rate: 1.0,
713 pitch: 25.0,
714 pronunciations: None,
715 output_file: None,
716 };
717
718 let result = params.validate();
719 assert!(result.is_err());
720 let errors = result.unwrap_err();
721 assert!(errors.iter().any(|e| e.field == "pitch"));
722 }
723
724 #[test]
725 fn test_valid_speaking_rate_boundaries() {
726 let params = SpeechSynthesizeParams {
728 text: "Hello".to_string(),
729 voice: None,
730 language_code: "en-US".to_string(),
731 speaking_rate: MIN_SPEAKING_RATE,
732 pitch: 0.0,
733 pronunciations: None,
734 output_file: None,
735 };
736 assert!(params.validate().is_ok());
737
738 let params = SpeechSynthesizeParams {
740 text: "Hello".to_string(),
741 voice: None,
742 language_code: "en-US".to_string(),
743 speaking_rate: MAX_SPEAKING_RATE,
744 pitch: 0.0,
745 pronunciations: None,
746 output_file: None,
747 };
748 assert!(params.validate().is_ok());
749 }
750
751 #[test]
752 fn test_valid_pitch_boundaries() {
753 let params = SpeechSynthesizeParams {
755 text: "Hello".to_string(),
756 voice: None,
757 language_code: "en-US".to_string(),
758 speaking_rate: 1.0,
759 pitch: MIN_PITCH,
760 pronunciations: None,
761 output_file: None,
762 };
763 assert!(params.validate().is_ok());
764
765 let params = SpeechSynthesizeParams {
767 text: "Hello".to_string(),
768 voice: None,
769 language_code: "en-US".to_string(),
770 speaking_rate: 1.0,
771 pitch: MAX_PITCH,
772 pronunciations: None,
773 output_file: None,
774 };
775 assert!(params.validate().is_ok());
776 }
777
778 #[test]
779 fn test_pronunciation_valid_ipa() {
780 let pron = Pronunciation {
781 word: "tomato".to_string(),
782 phonetic: "təˈmeɪtoʊ".to_string(),
783 alphabet: "ipa".to_string(),
784 };
785 assert!(pron.validate().is_ok());
786 }
787
788 #[test]
789 fn test_pronunciation_valid_xsampa() {
790 let pron = Pronunciation {
791 word: "tomato".to_string(),
792 phonetic: "t@\"meItoU".to_string(),
793 alphabet: "x-sampa".to_string(),
794 };
795 assert!(pron.validate().is_ok());
796 }
797
798 #[test]
799 fn test_pronunciation_invalid_alphabet() {
800 let pron = Pronunciation {
801 word: "tomato".to_string(),
802 phonetic: "tomato".to_string(),
803 alphabet: "invalid".to_string(),
804 };
805 let result = pron.validate();
806 assert!(result.is_err());
807 assert!(result.unwrap_err().field == "alphabet");
808 }
809
810 #[test]
811 fn test_pronunciation_empty_word() {
812 let pron = Pronunciation {
813 word: "".to_string(),
814 phonetic: "test".to_string(),
815 alphabet: "ipa".to_string(),
816 };
817 let result = pron.validate();
818 assert!(result.is_err());
819 assert!(result.unwrap_err().field == "word");
820 }
821
822 #[test]
823 fn test_pronunciation_empty_phonetic() {
824 let pron = Pronunciation {
825 word: "test".to_string(),
826 phonetic: "".to_string(),
827 alphabet: "ipa".to_string(),
828 };
829 let result = pron.validate();
830 assert!(result.is_err());
831 assert!(result.unwrap_err().field == "phonetic");
832 }
833
834 #[test]
835 fn test_pronunciation_to_ssml() {
836 let pron = Pronunciation {
837 word: "tomato".to_string(),
838 phonetic: "təˈmeɪtoʊ".to_string(),
839 alphabet: "ipa".to_string(),
840 };
841 let ssml = pron.to_ssml();
842 assert!(ssml.contains("phoneme"));
843 assert!(ssml.contains("ipa"));
844 assert!(ssml.contains("təˈmeɪtoʊ"));
845 assert!(ssml.contains("tomato"));
846 }
847
848 #[test]
849 fn test_build_ssml_with_pronunciations() {
850 let params = SpeechSynthesizeParams {
851 text: "I like tomato".to_string(),
852 voice: None,
853 language_code: "en-US".to_string(),
854 speaking_rate: 1.0,
855 pitch: 0.0,
856 pronunciations: Some(vec![Pronunciation {
857 word: "tomato".to_string(),
858 phonetic: "təˈmeɪtoʊ".to_string(),
859 alphabet: "ipa".to_string(),
860 }]),
861 output_file: None,
862 };
863
864 let ssml = params.build_ssml();
865 assert!(ssml.starts_with("<speak>"));
866 assert!(ssml.ends_with("</speak>"));
867 assert!(ssml.contains("phoneme"));
868 assert!(!ssml.contains("tomato</speak>")); }
870
871 #[test]
872 fn test_build_ssml_without_pronunciations() {
873 let params = SpeechSynthesizeParams {
874 text: "Hello world".to_string(),
875 voice: None,
876 language_code: "en-US".to_string(),
877 speaking_rate: 1.0,
878 pitch: 0.0,
879 pronunciations: None,
880 output_file: None,
881 };
882
883 let ssml = params.build_ssml();
884 assert_eq!(ssml, "<speak>Hello world</speak>");
885 }
886
887 #[test]
888 fn test_get_voice_default() {
889 let params = SpeechSynthesizeParams {
890 text: "Hello".to_string(),
891 voice: None,
892 language_code: "en-US".to_string(),
893 speaking_rate: 1.0,
894 pitch: 0.0,
895 pronunciations: None,
896 output_file: None,
897 };
898
899 assert_eq!(params.get_voice(), DEFAULT_VOICE);
900 }
901
902 #[test]
903 fn test_get_voice_custom() {
904 let params = SpeechSynthesizeParams {
905 text: "Hello".to_string(),
906 voice: Some("custom-voice".to_string()),
907 language_code: "en-US".to_string(),
908 speaking_rate: 1.0,
909 pitch: 0.0,
910 pronunciations: None,
911 output_file: None,
912 };
913
914 assert_eq!(params.get_voice(), "custom-voice");
915 }
916
917 #[test]
918 fn test_params_with_invalid_pronunciation() {
919 let params = SpeechSynthesizeParams {
920 text: "Hello".to_string(),
921 voice: None,
922 language_code: "en-US".to_string(),
923 speaking_rate: 1.0,
924 pitch: 0.0,
925 pronunciations: Some(vec![Pronunciation {
926 word: "test".to_string(),
927 phonetic: "test".to_string(),
928 alphabet: "invalid".to_string(),
929 }]),
930 output_file: None,
931 };
932
933 let result = params.validate();
934 assert!(result.is_err());
935 let errors = result.unwrap_err();
936 assert!(errors.iter().any(|e| e.field.contains("pronunciations")));
937 }
938
939 #[test]
940 fn test_serialization_roundtrip() {
941 let params = SpeechSynthesizeParams {
942 text: "Hello world".to_string(),
943 voice: Some("en-US-Chirp3-HD-Achernar".to_string()),
944 language_code: "en-US".to_string(),
945 speaking_rate: 1.5,
946 pitch: 2.0,
947 pronunciations: Some(vec![Pronunciation {
948 word: "hello".to_string(),
949 phonetic: "həˈloʊ".to_string(),
950 alphabet: "ipa".to_string(),
951 }]),
952 output_file: Some("/tmp/output.wav".to_string()),
953 };
954
955 let json = serde_json::to_string(¶ms).unwrap();
956 let deserialized: SpeechSynthesizeParams = serde_json::from_str(&json).unwrap();
957
958 assert_eq!(params.text, deserialized.text);
959 assert_eq!(params.voice, deserialized.voice);
960 assert_eq!(params.language_code, deserialized.language_code);
961 assert_eq!(params.speaking_rate, deserialized.speaking_rate);
962 assert_eq!(params.pitch, deserialized.pitch);
963 assert_eq!(params.output_file, deserialized.output_file);
964 }
965}
966
967
968#[cfg(test)]
969mod property_tests {
970 use super::*;
971 use proptest::prelude::*;
972
973 fn valid_speaking_rate_strategy() -> impl Strategy<Value = f32> {
981 (MIN_SPEAKING_RATE..=MAX_SPEAKING_RATE).prop_map(|x| (x * 100.0).round() / 100.0)
982 }
983
984 fn invalid_speaking_rate_strategy() -> impl Strategy<Value = f32> {
986 prop_oneof![
987 (0.0f32..0.24f32).prop_map(|x| (x * 100.0).round() / 100.0),
989 (4.01f32..10.0f32).prop_map(|x| (x * 100.0).round() / 100.0),
991 ]
992 }
993
994 fn valid_pitch_strategy() -> impl Strategy<Value = f32> {
996 (MIN_PITCH..=MAX_PITCH).prop_map(|x| (x * 10.0).round() / 10.0)
997 }
998
999 fn invalid_pitch_strategy() -> impl Strategy<Value = f32> {
1001 prop_oneof![
1002 (-50.0f32..MIN_PITCH).prop_map(|x| (x * 10.0).round() / 10.0),
1003 (MAX_PITCH + 0.1..50.0f32).prop_map(|x| (x * 10.0).round() / 10.0),
1004 ]
1005 }
1006
1007 fn valid_text_strategy() -> impl Strategy<Value = String> {
1009 "[a-zA-Z0-9 ]{1,100}"
1010 .prop_map(|s| s.trim().to_string())
1011 .prop_filter("Must not be empty", |s| !s.trim().is_empty())
1012 }
1013
1014 proptest! {
1015 #[test]
1017 fn valid_speaking_rate_passes_validation(
1018 rate in valid_speaking_rate_strategy(),
1019 text in valid_text_strategy(),
1020 ) {
1021 let params = SpeechSynthesizeParams {
1022 text,
1023 voice: None,
1024 language_code: "en-US".to_string(),
1025 speaking_rate: rate,
1026 pitch: 0.0,
1027 pronunciations: None,
1028 output_file: None,
1029 };
1030
1031 let result = params.validate();
1032 prop_assert!(
1033 result.is_ok(),
1034 "speaking_rate {} should be valid, but got errors: {:?}",
1035 rate,
1036 result.err()
1037 );
1038 }
1039
1040 #[test]
1042 fn invalid_speaking_rate_fails_validation(
1043 rate in invalid_speaking_rate_strategy(),
1044 text in valid_text_strategy(),
1045 ) {
1046 let params = SpeechSynthesizeParams {
1047 text,
1048 voice: None,
1049 language_code: "en-US".to_string(),
1050 speaking_rate: rate,
1051 pitch: 0.0,
1052 pronunciations: None,
1053 output_file: None,
1054 };
1055
1056 let result = params.validate();
1057 prop_assert!(
1058 result.is_err(),
1059 "speaking_rate {} should be invalid",
1060 rate
1061 );
1062
1063 let errors = result.unwrap_err();
1064 prop_assert!(
1065 errors.iter().any(|e| e.field == "speaking_rate"),
1066 "Should have a speaking_rate validation error for value {}",
1067 rate
1068 );
1069 }
1070
1071 #[test]
1073 fn valid_pitch_passes_validation(
1074 pitch in valid_pitch_strategy(),
1075 text in valid_text_strategy(),
1076 ) {
1077 let params = SpeechSynthesizeParams {
1078 text,
1079 voice: None,
1080 language_code: "en-US".to_string(),
1081 speaking_rate: 1.0,
1082 pitch,
1083 pronunciations: None,
1084 output_file: None,
1085 };
1086
1087 let result = params.validate();
1088 prop_assert!(
1089 result.is_ok(),
1090 "pitch {} should be valid, but got errors: {:?}",
1091 pitch,
1092 result.err()
1093 );
1094 }
1095
1096 #[test]
1098 fn invalid_pitch_fails_validation(
1099 pitch in invalid_pitch_strategy(),
1100 text in valid_text_strategy(),
1101 ) {
1102 let params = SpeechSynthesizeParams {
1103 text,
1104 voice: None,
1105 language_code: "en-US".to_string(),
1106 speaking_rate: 1.0,
1107 pitch,
1108 pronunciations: None,
1109 output_file: None,
1110 };
1111
1112 let result = params.validate();
1113 prop_assert!(
1114 result.is_err(),
1115 "pitch {} should be invalid",
1116 pitch
1117 );
1118
1119 let errors = result.unwrap_err();
1120 prop_assert!(
1121 errors.iter().any(|e| e.field == "pitch"),
1122 "Should have a pitch validation error for value {}",
1123 pitch
1124 );
1125 }
1126
1127 #[test]
1129 fn valid_speaking_rate_and_pitch_passes_validation(
1130 rate in valid_speaking_rate_strategy(),
1131 pitch in valid_pitch_strategy(),
1132 text in valid_text_strategy(),
1133 ) {
1134 let params = SpeechSynthesizeParams {
1135 text,
1136 voice: None,
1137 language_code: "en-US".to_string(),
1138 speaking_rate: rate,
1139 pitch,
1140 pronunciations: None,
1141 output_file: None,
1142 };
1143
1144 let result = params.validate();
1145 prop_assert!(
1146 result.is_ok(),
1147 "speaking_rate {} and pitch {} should be valid, but got errors: {:?}",
1148 rate,
1149 pitch,
1150 result.err()
1151 );
1152 }
1153 }
1154
1155 fn valid_alphabet_strategy() -> impl Strategy<Value = String> {
1163 prop_oneof![Just("ipa".to_string()), Just("x-sampa".to_string()),]
1164 }
1165
1166 fn invalid_alphabet_strategy() -> impl Strategy<Value = String> {
1168 "[a-z]{1,10}"
1169 .prop_filter("Must not be valid alphabet", |s| {
1170 let lower = s.to_lowercase();
1171 lower != "ipa" && lower != "x-sampa"
1172 })
1173 }
1174
1175 fn valid_word_strategy() -> impl Strategy<Value = String> {
1177 "[a-zA-Z]{1,20}".prop_filter("Must not be empty", |s| !s.trim().is_empty())
1178 }
1179
1180 fn valid_phonetic_strategy() -> impl Strategy<Value = String> {
1182 "[a-zA-Zəˈɪʊæɑɔɛʌ]{1,30}".prop_filter("Must not be empty", |s| !s.trim().is_empty())
1183 }
1184
1185 proptest! {
1186 #[test]
1188 fn valid_alphabet_passes_validation(
1189 alphabet in valid_alphabet_strategy(),
1190 word in valid_word_strategy(),
1191 phonetic in valid_phonetic_strategy(),
1192 ) {
1193 let pron = Pronunciation {
1194 word,
1195 phonetic,
1196 alphabet: alphabet.clone(),
1197 };
1198
1199 let result = pron.validate();
1200 prop_assert!(
1201 result.is_ok(),
1202 "alphabet '{}' should be valid, but got error: {:?}",
1203 alphabet,
1204 result.err()
1205 );
1206 }
1207
1208 #[test]
1210 fn invalid_alphabet_fails_validation(
1211 alphabet in invalid_alphabet_strategy(),
1212 word in valid_word_strategy(),
1213 phonetic in valid_phonetic_strategy(),
1214 ) {
1215 let pron = Pronunciation {
1216 word,
1217 phonetic,
1218 alphabet: alphabet.clone(),
1219 };
1220
1221 let result = pron.validate();
1222 prop_assert!(
1223 result.is_err(),
1224 "alphabet '{}' should be invalid",
1225 alphabet
1226 );
1227
1228 let error = result.unwrap_err();
1229 prop_assert!(
1230 error.field == "alphabet",
1231 "Should have an alphabet validation error for value '{}'",
1232 alphabet
1233 );
1234 }
1235
1236 #[test]
1238 fn params_with_valid_pronunciation_passes_validation(
1239 alphabet in valid_alphabet_strategy(),
1240 word in valid_word_strategy(),
1241 phonetic in valid_phonetic_strategy(),
1242 text in valid_text_strategy(),
1243 ) {
1244 let params = SpeechSynthesizeParams {
1245 text,
1246 voice: None,
1247 language_code: "en-US".to_string(),
1248 speaking_rate: 1.0,
1249 pitch: 0.0,
1250 pronunciations: Some(vec![Pronunciation {
1251 word,
1252 phonetic,
1253 alphabet: alphabet.clone(),
1254 }]),
1255 output_file: None,
1256 };
1257
1258 let result = params.validate();
1259 prop_assert!(
1260 result.is_ok(),
1261 "params with alphabet '{}' should be valid, but got errors: {:?}",
1262 alphabet,
1263 result.err()
1264 );
1265 }
1266
1267 #[test]
1269 fn params_with_invalid_pronunciation_fails_validation(
1270 alphabet in invalid_alphabet_strategy(),
1271 word in valid_word_strategy(),
1272 phonetic in valid_phonetic_strategy(),
1273 text in valid_text_strategy(),
1274 ) {
1275 let params = SpeechSynthesizeParams {
1276 text,
1277 voice: None,
1278 language_code: "en-US".to_string(),
1279 speaking_rate: 1.0,
1280 pitch: 0.0,
1281 pronunciations: Some(vec![Pronunciation {
1282 word,
1283 phonetic,
1284 alphabet: alphabet.clone(),
1285 }]),
1286 output_file: None,
1287 };
1288
1289 let result = params.validate();
1290 prop_assert!(
1291 result.is_err(),
1292 "params with alphabet '{}' should be invalid",
1293 alphabet
1294 );
1295
1296 let errors = result.unwrap_err();
1297 prop_assert!(
1298 errors.iter().any(|e| e.field.contains("pronunciations") && e.field.contains("alphabet")),
1299 "Should have a pronunciations.alphabet validation error for value '{}'",
1300 alphabet
1301 );
1302 }
1303
1304 #[test]
1306 fn empty_text_fails_validation(
1307 rate in valid_speaking_rate_strategy(),
1308 pitch in valid_pitch_strategy(),
1309 ) {
1310 let params = SpeechSynthesizeParams {
1311 text: " ".to_string(),
1312 voice: None,
1313 language_code: "en-US".to_string(),
1314 speaking_rate: rate,
1315 pitch,
1316 pronunciations: None,
1317 output_file: None,
1318 };
1319
1320 let result = params.validate();
1321 prop_assert!(result.is_err());
1322
1323 let errors = result.unwrap_err();
1324 prop_assert!(
1325 errors.iter().any(|e| e.field == "text"),
1326 "Should have a text validation error"
1327 );
1328 }
1329 }
1330}