Skip to main content

adk_rust_mcp_speech/
handler.rs

1//! Speech synthesis handler for the MCP Speech server.
2//!
3//! This module provides the `SpeechHandler` struct and parameter types for
4//! text-to-speech synthesis using Google's Cloud TTS Chirp3-HD API.
5
6use adk_rust_mcp_common::auth::AuthProvider;
7use adk_rust_mcp_common::config::Config;
8use adk_rust_mcp_common::error::Error;
9use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12use std::path::Path;
13use tracing::{debug, info, instrument};
14
15/// Default voice for speech synthesis.
16pub const DEFAULT_VOICE: &str = "en-US-Chirp3-HD-Achernar";
17
18/// Default language code.
19pub const DEFAULT_LANGUAGE_CODE: &str = "en-US";
20
21/// Default speaking rate.
22pub const DEFAULT_SPEAKING_RATE: f32 = 1.0;
23
24/// Minimum speaking rate.
25pub const MIN_SPEAKING_RATE: f32 = 0.25;
26
27/// Maximum speaking rate.
28pub const MAX_SPEAKING_RATE: f32 = 4.0;
29
30/// Default pitch.
31pub const DEFAULT_PITCH: f32 = 0.0;
32
33/// Minimum pitch (semitones).
34pub const MIN_PITCH: f32 = -20.0;
35
36/// Maximum pitch (semitones).
37pub const MAX_PITCH: f32 = 20.0;
38
39/// Valid pronunciation alphabets.
40pub const VALID_ALPHABETS: &[&str] = &["ipa", "x-sampa"];
41
42
43/// Custom pronunciation for a word.
44///
45/// Allows specifying phonetic pronunciation using IPA or X-SAMPA alphabets.
46#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
47pub struct Pronunciation {
48    /// The word to apply custom pronunciation to.
49    pub word: String,
50
51    /// The phonetic representation of the word.
52    pub phonetic: String,
53
54    /// The phonetic alphabet used: "ipa" or "x-sampa".
55    pub alphabet: String,
56}
57
58impl Pronunciation {
59    /// Validate the pronunciation entry.
60    pub fn validate(&self) -> Result<(), ValidationError> {
61        if self.word.trim().is_empty() {
62            return Err(ValidationError {
63                field: "word".to_string(),
64                message: "Word cannot be empty".to_string(),
65            });
66        }
67
68        if self.phonetic.trim().is_empty() {
69            return Err(ValidationError {
70                field: "phonetic".to_string(),
71                message: "Phonetic representation cannot be empty".to_string(),
72            });
73        }
74
75        let alphabet_lower = self.alphabet.to_lowercase();
76        if !VALID_ALPHABETS.contains(&alphabet_lower.as_str()) {
77            return Err(ValidationError {
78                field: "alphabet".to_string(),
79                message: format!(
80                    "Invalid alphabet '{}'. Must be one of: {}",
81                    self.alphabet,
82                    VALID_ALPHABETS.join(", ")
83                ),
84            });
85        }
86
87        Ok(())
88    }
89
90    /// Convert to SSML phoneme element.
91    pub fn to_ssml(&self) -> String {
92        let alphabet = self.alphabet.to_lowercase();
93        format!(
94            r#"<phoneme alphabet="{}" ph="{}">{}</phoneme>"#,
95            alphabet, self.phonetic, self.word
96        )
97    }
98}
99
100/// Speech synthesis parameters.
101///
102/// These parameters control the text-to-speech synthesis via the Cloud TTS API.
103#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
104pub struct SpeechSynthesizeParams {
105    /// Text to synthesize into speech.
106    pub text: String,
107
108    /// Voice name to use (Chirp3-HD voice).
109    #[serde(default, skip_serializing_if = "Option::is_none")]
110    pub voice: Option<String>,
111
112    /// Language code (e.g., "en-US", "es-ES").
113    #[serde(default = "default_language_code")]
114    pub language_code: String,
115
116    /// Speaking rate (0.25-4.0, default 1.0).
117    #[serde(default = "default_speaking_rate")]
118    pub speaking_rate: f32,
119
120    /// Pitch adjustment in semitones (-20.0 to 20.0, default 0.0).
121    #[serde(default)]
122    pub pitch: f32,
123
124    /// Custom pronunciations for specific words.
125    #[serde(default, skip_serializing_if = "Option::is_none")]
126    pub pronunciations: Option<Vec<Pronunciation>>,
127
128    /// Output file path for saving the WAV locally.
129    /// If not specified, returns base64-encoded data.
130    #[serde(default, skip_serializing_if = "Option::is_none")]
131    pub output_file: Option<String>,
132}
133
134fn default_language_code() -> String {
135    DEFAULT_LANGUAGE_CODE.to_string()
136}
137
138fn default_speaking_rate() -> f32 {
139    DEFAULT_SPEAKING_RATE
140}
141
142
143/// Validation error details for speech synthesis parameters.
144#[derive(Debug, Clone)]
145pub struct ValidationError {
146    /// The field that failed validation.
147    pub field: String,
148    /// Description of the validation failure.
149    pub message: String,
150}
151
152impl std::fmt::Display for ValidationError {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        write!(f, "{}: {}", self.field, self.message)
155    }
156}
157
158impl SpeechSynthesizeParams {
159    /// Validate the parameters.
160    ///
161    /// # Returns
162    /// - `Ok(())` if all parameters are valid
163    /// - `Err(Vec<ValidationError>)` with all validation errors
164    pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
165        let mut errors = Vec::new();
166
167        // Validate text is not empty
168        if self.text.trim().is_empty() {
169            errors.push(ValidationError {
170                field: "text".to_string(),
171                message: "Text cannot be empty".to_string(),
172            });
173        }
174
175        // Validate speaking_rate range
176        if self.speaking_rate < MIN_SPEAKING_RATE || self.speaking_rate > MAX_SPEAKING_RATE {
177            errors.push(ValidationError {
178                field: "speaking_rate".to_string(),
179                message: format!(
180                    "speaking_rate must be between {} and {}, got {}",
181                    MIN_SPEAKING_RATE, MAX_SPEAKING_RATE, self.speaking_rate
182                ),
183            });
184        }
185
186        // Validate pitch range
187        if self.pitch < MIN_PITCH || self.pitch > MAX_PITCH {
188            errors.push(ValidationError {
189                field: "pitch".to_string(),
190                message: format!(
191                    "pitch must be between {} and {} semitones, got {}",
192                    MIN_PITCH, MAX_PITCH, self.pitch
193                ),
194            });
195        }
196
197        // Validate pronunciations if provided
198        if let Some(ref pronunciations) = self.pronunciations {
199            for (i, pron) in pronunciations.iter().enumerate() {
200                if let Err(e) = pron.validate() {
201                    errors.push(ValidationError {
202                        field: format!("pronunciations[{}].{}", i, e.field),
203                        message: e.message,
204                    });
205                }
206            }
207        }
208
209        if errors.is_empty() {
210            Ok(())
211        } else {
212            Err(errors)
213        }
214    }
215
216    /// Get the voice name to use, defaulting if not specified.
217    pub fn get_voice(&self) -> &str {
218        self.voice.as_deref().unwrap_or(DEFAULT_VOICE)
219    }
220
221    /// Build SSML text with pronunciations applied.
222    pub fn build_ssml(&self) -> String {
223        let mut text = self.text.clone();
224
225        // Apply pronunciations if provided
226        if let Some(ref pronunciations) = self.pronunciations {
227            for pron in pronunciations {
228                // Replace word with SSML phoneme
229                text = text.replace(&pron.word, &pron.to_ssml());
230            }
231        }
232
233        // Wrap in SSML speak element
234        format!(r#"<speak>{}</speak>"#, text)
235    }
236}
237
238
239/// Speech synthesis handler.
240///
241/// Handles text-to-speech requests using the Cloud TTS Chirp3-HD API.
242pub struct SpeechHandler {
243    /// Application configuration.
244    pub config: Config,
245    /// HTTP client for API requests.
246    pub http: reqwest::Client,
247    /// Authentication provider.
248    pub auth: AuthProvider,
249}
250
251impl SpeechHandler {
252    /// Create a new SpeechHandler with the given configuration.
253    ///
254    /// # Errors
255    /// Returns an error if auth provider initialization fails.
256    #[instrument(level = "debug", name = "speech_handler_new", skip_all)]
257    pub async fn new(config: Config) -> Result<Self, Error> {
258        debug!("Initializing SpeechHandler");
259
260        let auth = AuthProvider::new().await?;
261        let http = reqwest::Client::new();
262
263        Ok(Self { config, http, auth })
264    }
265
266    /// Create a new SpeechHandler with provided dependencies (for testing).
267    #[cfg(test)]
268    pub fn with_deps(config: Config, http: reqwest::Client, auth: AuthProvider) -> Self {
269        Self { config, http, auth }
270    }
271
272    /// Get the Cloud TTS API endpoint.
273    pub fn get_endpoint(&self) -> String {
274        if self.config.is_gemini() {
275            "https://texttospeech.googleapis.com/v1/text:synthesize".to_string()
276        } else {
277            "https://texttospeech.googleapis.com/v1/text:synthesize".to_string()
278        }
279    }
280
281    /// Get the Cloud TTS voices list endpoint.
282    pub fn get_voices_endpoint(&self) -> String {
283        "https://texttospeech.googleapis.com/v1/voices".to_string()
284    }
285
286    /// Add auth headers based on provider.
287    async fn add_auth(&self, builder: reqwest::RequestBuilder) -> Result<reqwest::RequestBuilder, Error> {
288        if self.config.is_gemini() {
289            let key = self.config.gemini_api_key.as_deref().unwrap_or_default();
290            Ok(builder.header("x-goog-api-key", key))
291        } else {
292            let token = self.auth.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
293            Ok(builder
294                .header("Authorization", format!("Bearer {}", token))
295                .header("x-goog-user-project", &self.config.project_id))
296        }
297    }
298
299    /// Synthesize speech from text.
300    ///
301    /// # Arguments
302    /// * `params` - Speech synthesis parameters
303    ///
304    /// # Returns
305    /// * `Ok(SpeechSynthesizeResult)` - Generated audio with data or path
306    /// * `Err(Error)` - If validation fails, API call fails, or output handling fails
307    #[instrument(level = "info", name = "synthesize_speech", skip(self, params))]
308    pub async fn synthesize(&self, params: SpeechSynthesizeParams) -> Result<SpeechSynthesizeResult, Error> {
309        // Validate parameters
310        params.validate().map_err(|errors| {
311            let messages: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
312            Error::validation(messages.join("; "))
313        })?;
314
315        info!(voice = %params.get_voice(), "Synthesizing speech with Cloud TTS API");
316
317        // Determine if we need SSML (for pronunciations)
318        let (input, use_ssml) = if params.pronunciations.is_some() {
319            (params.build_ssml(), true)
320        } else {
321            (params.text.clone(), false)
322        };
323
324        // Build the API request
325        let request = TtsRequest {
326            input: TtsInput {
327                text: if use_ssml { None } else { Some(input.clone()) },
328                ssml: if use_ssml { Some(input) } else { None },
329            },
330            voice: TtsVoice {
331                language_code: params.language_code.clone(),
332                name: params.get_voice().to_string(),
333            },
334            audio_config: TtsAudioConfig {
335                audio_encoding: "LINEAR16".to_string(),
336                speaking_rate: Some(params.speaking_rate),
337                pitch: Some(params.pitch),
338                sample_rate_hertz: Some(24000),
339            },
340        };
341
342        // Make API request
343        let endpoint = self.get_endpoint();
344        debug!(endpoint = %endpoint, "Calling Cloud TTS API");
345
346        let builder = self.http
347            .post(&endpoint)
348            .header("Content-Type", "application/json")
349            .json(&request);
350        let builder = self.add_auth(builder).await?;
351
352        let response = builder
353            .send()
354            .await
355            .map_err(|e| Error::api(&endpoint, 0, format!("Request failed: {}", e)))?;
356
357        let status = response.status();
358        if !status.is_success() {
359            let body = response.text().await.unwrap_or_default();
360            return Err(Error::api(&endpoint, status.as_u16(), body));
361        }
362
363        // Parse response
364        let api_response: TtsResponse = response.json().await.map_err(|e| {
365            Error::api(
366                &endpoint,
367                status.as_u16(),
368                format!("Failed to parse response: {}", e),
369            )
370        })?;
371
372        let audio_data = api_response.audio_content;
373        if audio_data.is_empty() {
374            return Err(Error::api(&endpoint, 200, "No audio content returned from API"));
375        }
376
377        info!("Received audio data from Cloud TTS API");
378
379        let audio = GeneratedAudio {
380            data: audio_data,
381            mime_type: "audio/wav".to_string(),
382        };
383
384        // Handle output based on params
385        self.handle_output(audio, &params).await
386    }
387
388
389    /// List available voices.
390    ///
391    /// # Returns
392    /// * `Ok(Vec<VoiceInfo>)` - List of available voices
393    /// * `Err(Error)` - If API call fails
394    #[instrument(level = "info", name = "list_voices", skip(self))]
395    pub async fn list_voices(&self) -> Result<Vec<VoiceInfo>, Error> {
396        info!("Listing available voices from Cloud TTS API");
397
398        // Make API request
399        let endpoint = self.get_voices_endpoint();
400        debug!(endpoint = %endpoint, "Calling Cloud TTS voices API");
401
402        let builder = self.http.get(&endpoint);
403        let builder = self.add_auth(builder).await?;
404
405        let response = builder
406            .send()
407            .await
408            .map_err(|e| Error::api(&endpoint, 0, format!("Request failed: {}", e)))?;
409
410        let status = response.status();
411        if !status.is_success() {
412            let body = response.text().await.unwrap_or_default();
413            return Err(Error::api(&endpoint, status.as_u16(), body));
414        }
415
416        // Parse response
417        let api_response: VoicesResponse = response.json().await.map_err(|e| {
418            Error::api(
419                &endpoint,
420                status.as_u16(),
421                format!("Failed to parse response: {}", e),
422            )
423        })?;
424
425        // Filter for Chirp3-HD voices
426        let chirp3_voices: Vec<VoiceInfo> = api_response
427            .voices
428            .into_iter()
429            .filter(|v| v.name.contains("Chirp3-HD"))
430            .map(|v| VoiceInfo {
431                name: v.name,
432                language_codes: v.language_codes,
433                ssml_gender: v.ssml_gender,
434                natural_sample_rate_hertz: v.natural_sample_rate_hertz,
435            })
436            .collect();
437
438        info!(count = chirp3_voices.len(), "Found Chirp3-HD voices");
439        Ok(chirp3_voices)
440    }
441
442    /// Handle output of generated audio based on params.
443    async fn handle_output(
444        &self,
445        audio: GeneratedAudio,
446        params: &SpeechSynthesizeParams,
447    ) -> Result<SpeechSynthesizeResult, Error> {
448        // If output_file is specified, save to local file
449        if let Some(output_file) = &params.output_file {
450            return self.save_to_file(audio, output_file).await;
451        }
452
453        // Otherwise, return base64-encoded data
454        Ok(SpeechSynthesizeResult::Base64(audio))
455    }
456
457    /// Save audio to local file.
458    async fn save_to_file(
459        &self,
460        audio: GeneratedAudio,
461        output_file: &str,
462    ) -> Result<SpeechSynthesizeResult, Error> {
463        // Decode base64 data
464        let data = BASE64.decode(&audio.data).map_err(|e| {
465            Error::validation(format!("Invalid base64 data: {}", e))
466        })?;
467
468        // Ensure parent directory exists
469        if let Some(parent) = Path::new(output_file).parent() {
470            if !parent.as_os_str().is_empty() {
471                tokio::fs::create_dir_all(parent).await?;
472            }
473        }
474
475        // Write to file
476        tokio::fs::write(output_file, &data).await?;
477
478        info!(path = %output_file, "Saved audio to local file");
479        Ok(SpeechSynthesizeResult::LocalFile(output_file.to_string()))
480    }
481}
482
483
484// =============================================================================
485// API Request/Response Types
486// =============================================================================
487
488/// Cloud TTS API request.
489#[derive(Debug, Serialize)]
490pub struct TtsRequest {
491    /// Input text or SSML
492    pub input: TtsInput,
493    /// Voice configuration
494    pub voice: TtsVoice,
495    /// Audio configuration
496    #[serde(rename = "audioConfig")]
497    pub audio_config: TtsAudioConfig,
498}
499
500/// TTS input (text or SSML).
501#[derive(Debug, Serialize)]
502pub struct TtsInput {
503    /// Plain text input
504    #[serde(skip_serializing_if = "Option::is_none")]
505    pub text: Option<String>,
506    /// SSML input
507    #[serde(skip_serializing_if = "Option::is_none")]
508    pub ssml: Option<String>,
509}
510
511/// TTS voice configuration.
512#[derive(Debug, Serialize)]
513#[serde(rename_all = "camelCase")]
514pub struct TtsVoice {
515    /// Language code (e.g., "en-US")
516    pub language_code: String,
517    /// Voice name
518    pub name: String,
519}
520
521/// TTS audio configuration.
522#[derive(Debug, Serialize)]
523#[serde(rename_all = "camelCase")]
524pub struct TtsAudioConfig {
525    /// Audio encoding format
526    pub audio_encoding: String,
527    /// Speaking rate
528    #[serde(skip_serializing_if = "Option::is_none")]
529    pub speaking_rate: Option<f32>,
530    /// Pitch adjustment
531    #[serde(skip_serializing_if = "Option::is_none")]
532    pub pitch: Option<f32>,
533    /// Sample rate in Hz
534    #[serde(skip_serializing_if = "Option::is_none")]
535    pub sample_rate_hertz: Option<u32>,
536}
537
538/// Cloud TTS API response.
539#[derive(Debug, Deserialize)]
540#[serde(rename_all = "camelCase")]
541pub struct TtsResponse {
542    /// Base64-encoded audio content
543    pub audio_content: String,
544}
545
546/// Cloud TTS voices list response.
547#[derive(Debug, Deserialize)]
548pub struct VoicesResponse {
549    /// List of available voices
550    pub voices: Vec<ApiVoiceInfo>,
551}
552
553/// Voice information from API.
554#[derive(Debug, Deserialize)]
555#[serde(rename_all = "camelCase")]
556pub struct ApiVoiceInfo {
557    /// Voice name
558    pub name: String,
559    /// Supported language codes
560    pub language_codes: Vec<String>,
561    /// SSML gender
562    pub ssml_gender: Option<String>,
563    /// Natural sample rate
564    pub natural_sample_rate_hertz: Option<u32>,
565}
566
567// =============================================================================
568// Result Types
569// =============================================================================
570
571/// Generated audio data.
572#[derive(Debug, Clone)]
573pub struct GeneratedAudio {
574    /// Base64-encoded audio data
575    pub data: String,
576    /// MIME type of the audio
577    pub mime_type: String,
578}
579
580/// Voice information.
581#[derive(Debug, Clone, Serialize)]
582pub struct VoiceInfo {
583    /// Voice name
584    pub name: String,
585    /// Supported language codes
586    pub language_codes: Vec<String>,
587    /// SSML gender
588    pub ssml_gender: Option<String>,
589    /// Natural sample rate
590    pub natural_sample_rate_hertz: Option<u32>,
591}
592
593/// Result of speech synthesis.
594#[derive(Debug)]
595pub enum SpeechSynthesizeResult {
596    /// Base64-encoded audio data (when no output specified)
597    Base64(GeneratedAudio),
598    /// Local file path (when output_file specified)
599    LocalFile(String),
600}
601
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606
607    #[test]
608    fn test_default_params() {
609        let params: SpeechSynthesizeParams =
610            serde_json::from_str(r#"{"text": "Hello world"}"#).unwrap();
611        assert_eq!(params.language_code, DEFAULT_LANGUAGE_CODE);
612        assert_eq!(params.speaking_rate, DEFAULT_SPEAKING_RATE);
613        assert_eq!(params.pitch, DEFAULT_PITCH);
614        assert!(params.voice.is_none());
615        assert!(params.pronunciations.is_none());
616        assert!(params.output_file.is_none());
617    }
618
619    #[test]
620    fn test_valid_params() {
621        let params = SpeechSynthesizeParams {
622            text: "Hello world".to_string(),
623            voice: Some("en-US-Chirp3-HD-Achernar".to_string()),
624            language_code: "en-US".to_string(),
625            speaking_rate: 1.5,
626            pitch: 2.0,
627            pronunciations: None,
628            output_file: None,
629        };
630
631        assert!(params.validate().is_ok());
632    }
633
634    #[test]
635    fn test_empty_text() {
636        let params = SpeechSynthesizeParams {
637            text: "   ".to_string(),
638            voice: None,
639            language_code: "en-US".to_string(),
640            speaking_rate: 1.0,
641            pitch: 0.0,
642            pronunciations: None,
643            output_file: None,
644        };
645
646        let result = params.validate();
647        assert!(result.is_err());
648        let errors = result.unwrap_err();
649        assert!(errors.iter().any(|e| e.field == "text"));
650    }
651
652    #[test]
653    fn test_speaking_rate_too_low() {
654        let params = SpeechSynthesizeParams {
655            text: "Hello".to_string(),
656            voice: None,
657            language_code: "en-US".to_string(),
658            speaking_rate: 0.1,
659            pitch: 0.0,
660            pronunciations: None,
661            output_file: None,
662        };
663
664        let result = params.validate();
665        assert!(result.is_err());
666        let errors = result.unwrap_err();
667        assert!(errors.iter().any(|e| e.field == "speaking_rate"));
668    }
669
670    #[test]
671    fn test_speaking_rate_too_high() {
672        let params = SpeechSynthesizeParams {
673            text: "Hello".to_string(),
674            voice: None,
675            language_code: "en-US".to_string(),
676            speaking_rate: 5.0,
677            pitch: 0.0,
678            pronunciations: None,
679            output_file: None,
680        };
681
682        let result = params.validate();
683        assert!(result.is_err());
684        let errors = result.unwrap_err();
685        assert!(errors.iter().any(|e| e.field == "speaking_rate"));
686    }
687
688    #[test]
689    fn test_pitch_too_low() {
690        let params = SpeechSynthesizeParams {
691            text: "Hello".to_string(),
692            voice: None,
693            language_code: "en-US".to_string(),
694            speaking_rate: 1.0,
695            pitch: -25.0,
696            pronunciations: None,
697            output_file: None,
698        };
699
700        let result = params.validate();
701        assert!(result.is_err());
702        let errors = result.unwrap_err();
703        assert!(errors.iter().any(|e| e.field == "pitch"));
704    }
705
706    #[test]
707    fn test_pitch_too_high() {
708        let params = SpeechSynthesizeParams {
709            text: "Hello".to_string(),
710            voice: None,
711            language_code: "en-US".to_string(),
712            speaking_rate: 1.0,
713            pitch: 25.0,
714            pronunciations: None,
715            output_file: None,
716        };
717
718        let result = params.validate();
719        assert!(result.is_err());
720        let errors = result.unwrap_err();
721        assert!(errors.iter().any(|e| e.field == "pitch"));
722    }
723
724    #[test]
725    fn test_valid_speaking_rate_boundaries() {
726        // Test minimum valid speaking rate
727        let params = SpeechSynthesizeParams {
728            text: "Hello".to_string(),
729            voice: None,
730            language_code: "en-US".to_string(),
731            speaking_rate: MIN_SPEAKING_RATE,
732            pitch: 0.0,
733            pronunciations: None,
734            output_file: None,
735        };
736        assert!(params.validate().is_ok());
737
738        // Test maximum valid speaking rate
739        let params = SpeechSynthesizeParams {
740            text: "Hello".to_string(),
741            voice: None,
742            language_code: "en-US".to_string(),
743            speaking_rate: MAX_SPEAKING_RATE,
744            pitch: 0.0,
745            pronunciations: None,
746            output_file: None,
747        };
748        assert!(params.validate().is_ok());
749    }
750
751    #[test]
752    fn test_valid_pitch_boundaries() {
753        // Test minimum valid pitch
754        let params = SpeechSynthesizeParams {
755            text: "Hello".to_string(),
756            voice: None,
757            language_code: "en-US".to_string(),
758            speaking_rate: 1.0,
759            pitch: MIN_PITCH,
760            pronunciations: None,
761            output_file: None,
762        };
763        assert!(params.validate().is_ok());
764
765        // Test maximum valid pitch
766        let params = SpeechSynthesizeParams {
767            text: "Hello".to_string(),
768            voice: None,
769            language_code: "en-US".to_string(),
770            speaking_rate: 1.0,
771            pitch: MAX_PITCH,
772            pronunciations: None,
773            output_file: None,
774        };
775        assert!(params.validate().is_ok());
776    }
777
778    #[test]
779    fn test_pronunciation_valid_ipa() {
780        let pron = Pronunciation {
781            word: "tomato".to_string(),
782            phonetic: "təˈmeɪtoʊ".to_string(),
783            alphabet: "ipa".to_string(),
784        };
785        assert!(pron.validate().is_ok());
786    }
787
788    #[test]
789    fn test_pronunciation_valid_xsampa() {
790        let pron = Pronunciation {
791            word: "tomato".to_string(),
792            phonetic: "t@\"meItoU".to_string(),
793            alphabet: "x-sampa".to_string(),
794        };
795        assert!(pron.validate().is_ok());
796    }
797
798    #[test]
799    fn test_pronunciation_invalid_alphabet() {
800        let pron = Pronunciation {
801            word: "tomato".to_string(),
802            phonetic: "tomato".to_string(),
803            alphabet: "invalid".to_string(),
804        };
805        let result = pron.validate();
806        assert!(result.is_err());
807        assert!(result.unwrap_err().field == "alphabet");
808    }
809
810    #[test]
811    fn test_pronunciation_empty_word() {
812        let pron = Pronunciation {
813            word: "".to_string(),
814            phonetic: "test".to_string(),
815            alphabet: "ipa".to_string(),
816        };
817        let result = pron.validate();
818        assert!(result.is_err());
819        assert!(result.unwrap_err().field == "word");
820    }
821
822    #[test]
823    fn test_pronunciation_empty_phonetic() {
824        let pron = Pronunciation {
825            word: "test".to_string(),
826            phonetic: "".to_string(),
827            alphabet: "ipa".to_string(),
828        };
829        let result = pron.validate();
830        assert!(result.is_err());
831        assert!(result.unwrap_err().field == "phonetic");
832    }
833
834    #[test]
835    fn test_pronunciation_to_ssml() {
836        let pron = Pronunciation {
837            word: "tomato".to_string(),
838            phonetic: "təˈmeɪtoʊ".to_string(),
839            alphabet: "ipa".to_string(),
840        };
841        let ssml = pron.to_ssml();
842        assert!(ssml.contains("phoneme"));
843        assert!(ssml.contains("ipa"));
844        assert!(ssml.contains("təˈmeɪtoʊ"));
845        assert!(ssml.contains("tomato"));
846    }
847
848    #[test]
849    fn test_build_ssml_with_pronunciations() {
850        let params = SpeechSynthesizeParams {
851            text: "I like tomato".to_string(),
852            voice: None,
853            language_code: "en-US".to_string(),
854            speaking_rate: 1.0,
855            pitch: 0.0,
856            pronunciations: Some(vec![Pronunciation {
857                word: "tomato".to_string(),
858                phonetic: "təˈmeɪtoʊ".to_string(),
859                alphabet: "ipa".to_string(),
860            }]),
861            output_file: None,
862        };
863
864        let ssml = params.build_ssml();
865        assert!(ssml.starts_with("<speak>"));
866        assert!(ssml.ends_with("</speak>"));
867        assert!(ssml.contains("phoneme"));
868        assert!(!ssml.contains("tomato</speak>")); // tomato should be wrapped in phoneme
869    }
870
871    #[test]
872    fn test_build_ssml_without_pronunciations() {
873        let params = SpeechSynthesizeParams {
874            text: "Hello world".to_string(),
875            voice: None,
876            language_code: "en-US".to_string(),
877            speaking_rate: 1.0,
878            pitch: 0.0,
879            pronunciations: None,
880            output_file: None,
881        };
882
883        let ssml = params.build_ssml();
884        assert_eq!(ssml, "<speak>Hello world</speak>");
885    }
886
887    #[test]
888    fn test_get_voice_default() {
889        let params = SpeechSynthesizeParams {
890            text: "Hello".to_string(),
891            voice: None,
892            language_code: "en-US".to_string(),
893            speaking_rate: 1.0,
894            pitch: 0.0,
895            pronunciations: None,
896            output_file: None,
897        };
898
899        assert_eq!(params.get_voice(), DEFAULT_VOICE);
900    }
901
902    #[test]
903    fn test_get_voice_custom() {
904        let params = SpeechSynthesizeParams {
905            text: "Hello".to_string(),
906            voice: Some("custom-voice".to_string()),
907            language_code: "en-US".to_string(),
908            speaking_rate: 1.0,
909            pitch: 0.0,
910            pronunciations: None,
911            output_file: None,
912        };
913
914        assert_eq!(params.get_voice(), "custom-voice");
915    }
916
917    #[test]
918    fn test_params_with_invalid_pronunciation() {
919        let params = SpeechSynthesizeParams {
920            text: "Hello".to_string(),
921            voice: None,
922            language_code: "en-US".to_string(),
923            speaking_rate: 1.0,
924            pitch: 0.0,
925            pronunciations: Some(vec![Pronunciation {
926                word: "test".to_string(),
927                phonetic: "test".to_string(),
928                alphabet: "invalid".to_string(),
929            }]),
930            output_file: None,
931        };
932
933        let result = params.validate();
934        assert!(result.is_err());
935        let errors = result.unwrap_err();
936        assert!(errors.iter().any(|e| e.field.contains("pronunciations")));
937    }
938
939    #[test]
940    fn test_serialization_roundtrip() {
941        let params = SpeechSynthesizeParams {
942            text: "Hello world".to_string(),
943            voice: Some("en-US-Chirp3-HD-Achernar".to_string()),
944            language_code: "en-US".to_string(),
945            speaking_rate: 1.5,
946            pitch: 2.0,
947            pronunciations: Some(vec![Pronunciation {
948                word: "hello".to_string(),
949                phonetic: "həˈloʊ".to_string(),
950                alphabet: "ipa".to_string(),
951            }]),
952            output_file: Some("/tmp/output.wav".to_string()),
953        };
954
955        let json = serde_json::to_string(&params).unwrap();
956        let deserialized: SpeechSynthesizeParams = serde_json::from_str(&json).unwrap();
957
958        assert_eq!(params.text, deserialized.text);
959        assert_eq!(params.voice, deserialized.voice);
960        assert_eq!(params.language_code, deserialized.language_code);
961        assert_eq!(params.speaking_rate, deserialized.speaking_rate);
962        assert_eq!(params.pitch, deserialized.pitch);
963        assert_eq!(params.output_file, deserialized.output_file);
964    }
965}
966
967
968#[cfg(test)]
969mod property_tests {
970    use super::*;
971    use proptest::prelude::*;
972
973    // Feature: rust-mcp-genmedia, Property 8: Numeric Parameter Range Validation (speaking_rate, pitch)
974    // **Validates: Requirements 7.6, 7.7**
975    //
976    // For any numeric parameter with defined bounds (speaking_rate 0.25-4.0, pitch -20.0 to 20.0),
977    // values outside the valid range SHALL be rejected with a validation error.
978
979    /// Strategy to generate valid speaking_rate values (0.25-4.0)
980    fn valid_speaking_rate_strategy() -> impl Strategy<Value = f32> {
981        (MIN_SPEAKING_RATE..=MAX_SPEAKING_RATE).prop_map(|x| (x * 100.0).round() / 100.0)
982    }
983
984    /// Strategy to generate invalid speaking_rate values (< 0.25 or > 4.0)
985    fn invalid_speaking_rate_strategy() -> impl Strategy<Value = f32> {
986        prop_oneof![
987            // Values below minimum (exclusive of MIN_SPEAKING_RATE)
988            (0.0f32..0.24f32).prop_map(|x| (x * 100.0).round() / 100.0),
989            // Values above maximum (exclusive of MAX_SPEAKING_RATE)
990            (4.01f32..10.0f32).prop_map(|x| (x * 100.0).round() / 100.0),
991        ]
992    }
993
994    /// Strategy to generate valid pitch values (-20.0 to 20.0)
995    fn valid_pitch_strategy() -> impl Strategy<Value = f32> {
996        (MIN_PITCH..=MAX_PITCH).prop_map(|x| (x * 10.0).round() / 10.0)
997    }
998
999    /// Strategy to generate invalid pitch values (< -20.0 or > 20.0)
1000    fn invalid_pitch_strategy() -> impl Strategy<Value = f32> {
1001        prop_oneof![
1002            (-50.0f32..MIN_PITCH).prop_map(|x| (x * 10.0).round() / 10.0),
1003            (MAX_PITCH + 0.1..50.0f32).prop_map(|x| (x * 10.0).round() / 10.0),
1004        ]
1005    }
1006
1007    /// Strategy to generate valid text (non-empty)
1008    fn valid_text_strategy() -> impl Strategy<Value = String> {
1009        "[a-zA-Z0-9 ]{1,100}"
1010            .prop_map(|s| s.trim().to_string())
1011            .prop_filter("Must not be empty", |s| !s.trim().is_empty())
1012    }
1013
1014    proptest! {
1015        /// Property 8: Valid speaking_rate values (0.25-4.0) should pass validation
1016        #[test]
1017        fn valid_speaking_rate_passes_validation(
1018            rate in valid_speaking_rate_strategy(),
1019            text in valid_text_strategy(),
1020        ) {
1021            let params = SpeechSynthesizeParams {
1022                text,
1023                voice: None,
1024                language_code: "en-US".to_string(),
1025                speaking_rate: rate,
1026                pitch: 0.0,
1027                pronunciations: None,
1028                output_file: None,
1029            };
1030
1031            let result = params.validate();
1032            prop_assert!(
1033                result.is_ok(),
1034                "speaking_rate {} should be valid, but got errors: {:?}",
1035                rate,
1036                result.err()
1037            );
1038        }
1039
1040        /// Property 8: Invalid speaking_rate values (< 0.25 or > 4.0) should fail validation
1041        #[test]
1042        fn invalid_speaking_rate_fails_validation(
1043            rate in invalid_speaking_rate_strategy(),
1044            text in valid_text_strategy(),
1045        ) {
1046            let params = SpeechSynthesizeParams {
1047                text,
1048                voice: None,
1049                language_code: "en-US".to_string(),
1050                speaking_rate: rate,
1051                pitch: 0.0,
1052                pronunciations: None,
1053                output_file: None,
1054            };
1055
1056            let result = params.validate();
1057            prop_assert!(
1058                result.is_err(),
1059                "speaking_rate {} should be invalid",
1060                rate
1061            );
1062
1063            let errors = result.unwrap_err();
1064            prop_assert!(
1065                errors.iter().any(|e| e.field == "speaking_rate"),
1066                "Should have a speaking_rate validation error for value {}",
1067                rate
1068            );
1069        }
1070
1071        /// Property 8: Valid pitch values (-20.0 to 20.0) should pass validation
1072        #[test]
1073        fn valid_pitch_passes_validation(
1074            pitch in valid_pitch_strategy(),
1075            text in valid_text_strategy(),
1076        ) {
1077            let params = SpeechSynthesizeParams {
1078                text,
1079                voice: None,
1080                language_code: "en-US".to_string(),
1081                speaking_rate: 1.0,
1082                pitch,
1083                pronunciations: None,
1084                output_file: None,
1085            };
1086
1087            let result = params.validate();
1088            prop_assert!(
1089                result.is_ok(),
1090                "pitch {} should be valid, but got errors: {:?}",
1091                pitch,
1092                result.err()
1093            );
1094        }
1095
1096        /// Property 8: Invalid pitch values (< -20.0 or > 20.0) should fail validation
1097        #[test]
1098        fn invalid_pitch_fails_validation(
1099            pitch in invalid_pitch_strategy(),
1100            text in valid_text_strategy(),
1101        ) {
1102            let params = SpeechSynthesizeParams {
1103                text,
1104                voice: None,
1105                language_code: "en-US".to_string(),
1106                speaking_rate: 1.0,
1107                pitch,
1108                pronunciations: None,
1109                output_file: None,
1110            };
1111
1112            let result = params.validate();
1113            prop_assert!(
1114                result.is_err(),
1115                "pitch {} should be invalid",
1116                pitch
1117            );
1118
1119            let errors = result.unwrap_err();
1120            prop_assert!(
1121                errors.iter().any(|e| e.field == "pitch"),
1122                "Should have a pitch validation error for value {}",
1123                pitch
1124            );
1125        }
1126
1127        /// Property: Combined valid speaking_rate and pitch should pass validation
1128        #[test]
1129        fn valid_speaking_rate_and_pitch_passes_validation(
1130            rate in valid_speaking_rate_strategy(),
1131            pitch in valid_pitch_strategy(),
1132            text in valid_text_strategy(),
1133        ) {
1134            let params = SpeechSynthesizeParams {
1135                text,
1136                voice: None,
1137                language_code: "en-US".to_string(),
1138                speaking_rate: rate,
1139                pitch,
1140                pronunciations: None,
1141                output_file: None,
1142            };
1143
1144            let result = params.validate();
1145            prop_assert!(
1146                result.is_ok(),
1147                "speaking_rate {} and pitch {} should be valid, but got errors: {:?}",
1148                rate,
1149                pitch,
1150                result.err()
1151            );
1152        }
1153    }
1154
1155    // Feature: rust-mcp-genmedia, Property 12: Pronunciation Alphabet Validation
1156    // **Validates: Requirements 7.9**
1157    //
1158    // For any pronunciation entry in speech_synthesize, the alphabet field SHALL be
1159    // either "ipa" or "x-sampa". Other values SHALL be rejected with a validation error.
1160
1161    /// Strategy to generate valid alphabet values
1162    fn valid_alphabet_strategy() -> impl Strategy<Value = String> {
1163        prop_oneof![Just("ipa".to_string()), Just("x-sampa".to_string()),]
1164    }
1165
1166    /// Strategy to generate invalid alphabet values
1167    fn invalid_alphabet_strategy() -> impl Strategy<Value = String> {
1168        "[a-z]{1,10}"
1169            .prop_filter("Must not be valid alphabet", |s| {
1170                let lower = s.to_lowercase();
1171                lower != "ipa" && lower != "x-sampa"
1172            })
1173    }
1174
1175    /// Strategy to generate valid word (non-empty)
1176    fn valid_word_strategy() -> impl Strategy<Value = String> {
1177        "[a-zA-Z]{1,20}".prop_filter("Must not be empty", |s| !s.trim().is_empty())
1178    }
1179
1180    /// Strategy to generate valid phonetic (non-empty)
1181    fn valid_phonetic_strategy() -> impl Strategy<Value = String> {
1182        "[a-zA-Zəˈɪʊæɑɔɛʌ]{1,30}".prop_filter("Must not be empty", |s| !s.trim().is_empty())
1183    }
1184
1185    proptest! {
1186        /// Property 12: Valid alphabet values ("ipa", "x-sampa") should pass validation
1187        #[test]
1188        fn valid_alphabet_passes_validation(
1189            alphabet in valid_alphabet_strategy(),
1190            word in valid_word_strategy(),
1191            phonetic in valid_phonetic_strategy(),
1192        ) {
1193            let pron = Pronunciation {
1194                word,
1195                phonetic,
1196                alphabet: alphabet.clone(),
1197            };
1198
1199            let result = pron.validate();
1200            prop_assert!(
1201                result.is_ok(),
1202                "alphabet '{}' should be valid, but got error: {:?}",
1203                alphabet,
1204                result.err()
1205            );
1206        }
1207
1208        /// Property 12: Invalid alphabet values should fail validation
1209        #[test]
1210        fn invalid_alphabet_fails_validation(
1211            alphabet in invalid_alphabet_strategy(),
1212            word in valid_word_strategy(),
1213            phonetic in valid_phonetic_strategy(),
1214        ) {
1215            let pron = Pronunciation {
1216                word,
1217                phonetic,
1218                alphabet: alphabet.clone(),
1219            };
1220
1221            let result = pron.validate();
1222            prop_assert!(
1223                result.is_err(),
1224                "alphabet '{}' should be invalid",
1225                alphabet
1226            );
1227
1228            let error = result.unwrap_err();
1229            prop_assert!(
1230                error.field == "alphabet",
1231                "Should have an alphabet validation error for value '{}'",
1232                alphabet
1233            );
1234        }
1235
1236        /// Property 12: Pronunciation with valid alphabet in params should pass validation
1237        #[test]
1238        fn params_with_valid_pronunciation_passes_validation(
1239            alphabet in valid_alphabet_strategy(),
1240            word in valid_word_strategy(),
1241            phonetic in valid_phonetic_strategy(),
1242            text in valid_text_strategy(),
1243        ) {
1244            let params = SpeechSynthesizeParams {
1245                text,
1246                voice: None,
1247                language_code: "en-US".to_string(),
1248                speaking_rate: 1.0,
1249                pitch: 0.0,
1250                pronunciations: Some(vec![Pronunciation {
1251                    word,
1252                    phonetic,
1253                    alphabet: alphabet.clone(),
1254                }]),
1255                output_file: None,
1256            };
1257
1258            let result = params.validate();
1259            prop_assert!(
1260                result.is_ok(),
1261                "params with alphabet '{}' should be valid, but got errors: {:?}",
1262                alphabet,
1263                result.err()
1264            );
1265        }
1266
1267        /// Property 12: Pronunciation with invalid alphabet in params should fail validation
1268        #[test]
1269        fn params_with_invalid_pronunciation_fails_validation(
1270            alphabet in invalid_alphabet_strategy(),
1271            word in valid_word_strategy(),
1272            phonetic in valid_phonetic_strategy(),
1273            text in valid_text_strategy(),
1274        ) {
1275            let params = SpeechSynthesizeParams {
1276                text,
1277                voice: None,
1278                language_code: "en-US".to_string(),
1279                speaking_rate: 1.0,
1280                pitch: 0.0,
1281                pronunciations: Some(vec![Pronunciation {
1282                    word,
1283                    phonetic,
1284                    alphabet: alphabet.clone(),
1285                }]),
1286                output_file: None,
1287            };
1288
1289            let result = params.validate();
1290            prop_assert!(
1291                result.is_err(),
1292                "params with alphabet '{}' should be invalid",
1293                alphabet
1294            );
1295
1296            let errors = result.unwrap_err();
1297            prop_assert!(
1298                errors.iter().any(|e| e.field.contains("pronunciations") && e.field.contains("alphabet")),
1299                "Should have a pronunciations.alphabet validation error for value '{}'",
1300                alphabet
1301            );
1302        }
1303
1304        /// Property: Empty text should always fail validation regardless of other params
1305        #[test]
1306        fn empty_text_fails_validation(
1307            rate in valid_speaking_rate_strategy(),
1308            pitch in valid_pitch_strategy(),
1309        ) {
1310            let params = SpeechSynthesizeParams {
1311                text: "   ".to_string(),
1312                voice: None,
1313                language_code: "en-US".to_string(),
1314                speaking_rate: rate,
1315                pitch,
1316                pronunciations: None,
1317                output_file: None,
1318            };
1319
1320            let result = params.validate();
1321            prop_assert!(result.is_err());
1322
1323            let errors = result.unwrap_err();
1324            prop_assert!(
1325                errors.iter().any(|e| e.field == "text"),
1326                "Should have a text validation error"
1327            );
1328        }
1329    }
1330}