1use adk_rust_mcp_common::auth::AuthProvider;
7use adk_rust_mcp_common::config::Config;
8use adk_rust_mcp_common::error::Error;
9use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12use std::path::Path;
13use tracing::{debug, info, instrument};
14
15pub const DEFAULT_IMAGE_MODEL: &str = "gemini-2.5-flash-image";
17
18pub const DEFAULT_TTS_MODEL: &str = "gemini-2.5-flash-preview-tts";
20
21pub const DEFAULT_VOICE: &str = "Kore";
23
24pub const AVAILABLE_VOICES: &[&str] = &[
26 "Zephyr", "Puck", "Charon", "Kore", "Fenrir", "Leda", "Orus", "Aoede",
27 "Callirrhoe", "Autonoe", "Enceladus", "Iapetus", "Umbriel", "Algieba",
28 "Despina", "Erinome", "Algenib", "Rasalgethi", "Laomedeia", "Achernar",
29 "Alnilam", "Schedar", "Gacrux", "Pulcherrima", "Achird", "Zubenelgenubi",
30 "Vindemiatrix", "Sadachbia", "Sadaltager", "Sulafat",
31];
32
33pub const AVAILABLE_STYLES: &[&str] = &[
35 "neutral", "cheerful", "sad", "angry", "fearful", "surprised", "calm",
36];
37
38pub const SUPPORTED_LANGUAGE_CODES: &[(&str, &str)] = &[
40 ("en", "English"), ("ar", "Arabic"), ("bn", "Bangla"), ("nl", "Dutch"),
41 ("fr", "French"), ("de", "German"), ("hi", "Hindi"), ("id", "Indonesian"),
42 ("it", "Italian"), ("ja", "Japanese"), ("ko", "Korean"), ("mr", "Marathi"),
43 ("pl", "Polish"), ("pt", "Portuguese"), ("ro", "Romanian"), ("ru", "Russian"),
44 ("es", "Spanish"), ("ta", "Tamil"), ("te", "Telugu"), ("th", "Thai"),
45 ("tr", "Turkish"), ("uk", "Ukrainian"), ("vi", "Vietnamese"), ("fil", "Filipino"),
46 ("fi", "Finnish"), ("el", "Greek"), ("gu", "Gujarati"), ("he", "Hebrew"),
47 ("hu", "Hungarian"), ("sv", "Swedish"), ("zh", "Chinese (Mandarin)"),
48 ("cs", "Czech"), ("da", "Danish"), ("nb", "Norwegian"),
49];
50
51#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
55pub struct MultimodalImageParams {
56 pub prompt: String,
58
59 #[serde(default = "default_image_model")]
61 pub model: String,
62
63 #[serde(default, skip_serializing_if = "Option::is_none")]
66 pub output_file: Option<String>,
67}
68
69fn default_image_model() -> String {
70 DEFAULT_IMAGE_MODEL.to_string()
71}
72
73#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
77pub struct MultimodalTtsParams {
78 pub text: String,
80
81 #[serde(default, skip_serializing_if = "Option::is_none")]
83 pub voice: Option<String>,
84
85 #[serde(default, skip_serializing_if = "Option::is_none")]
87 pub style: Option<String>,
88
89 #[serde(default = "default_tts_model")]
91 pub model: String,
92
93 #[serde(default, skip_serializing_if = "Option::is_none")]
96 pub output_file: Option<String>,
97}
98
99fn default_tts_model() -> String {
100 DEFAULT_TTS_MODEL.to_string()
101}
102
103#[derive(Debug, Clone)]
105pub struct ValidationError {
106 pub field: String,
108 pub message: String,
110}
111
112impl std::fmt::Display for ValidationError {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 write!(f, "{}: {}", self.field, self.message)
115 }
116}
117
118impl MultimodalImageParams {
119 pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
121 let mut errors = Vec::new();
122
123 if self.prompt.trim().is_empty() {
125 errors.push(ValidationError {
126 field: "prompt".to_string(),
127 message: "Prompt cannot be empty".to_string(),
128 });
129 }
130
131 if errors.is_empty() {
132 Ok(())
133 } else {
134 Err(errors)
135 }
136 }
137}
138
139impl MultimodalTtsParams {
140 pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
142 let mut errors = Vec::new();
143
144 if self.text.trim().is_empty() {
146 errors.push(ValidationError {
147 field: "text".to_string(),
148 message: "Text cannot be empty".to_string(),
149 });
150 }
151
152 if let Some(ref voice) = self.voice {
154 if !AVAILABLE_VOICES.contains(&voice.as_str()) {
155 errors.push(ValidationError {
156 field: "voice".to_string(),
157 message: format!(
158 "Invalid voice '{}'. Available voices: {}",
159 voice,
160 AVAILABLE_VOICES.join(", ")
161 ),
162 });
163 }
164 }
165
166 if let Some(ref style) = self.style {
168 if !AVAILABLE_STYLES.contains(&style.as_str()) {
169 errors.push(ValidationError {
170 field: "style".to_string(),
171 message: format!(
172 "Invalid style '{}'. Available styles: {}",
173 style,
174 AVAILABLE_STYLES.join(", ")
175 ),
176 });
177 }
178 }
179
180 if errors.is_empty() {
181 Ok(())
182 } else {
183 Err(errors)
184 }
185 }
186
187 pub fn get_voice(&self) -> &str {
189 self.voice.as_deref().unwrap_or(DEFAULT_VOICE)
190 }
191}
192
193pub struct MultimodalHandler {
197 pub config: Config,
199 pub http: reqwest::Client,
201 pub auth: AuthProvider,
203}
204
205impl MultimodalHandler {
206 #[instrument(level = "debug", name = "multimodal_handler_new", skip_all)]
211 pub async fn new(config: Config) -> Result<Self, Error> {
212 debug!("Initializing MultimodalHandler");
213
214 let auth = AuthProvider::new().await?;
215 let http = reqwest::Client::new();
216
217 Ok(Self { config, http, auth })
218 }
219
220 #[cfg(test)]
222 pub fn with_deps(config: Config, http: reqwest::Client, auth: AuthProvider) -> Self {
223 Self { config, http, auth }
224 }
225
226 pub fn get_image_endpoint(&self, model: &str) -> String {
228 if self.config.is_gemini() {
229 format!("{}/models/{}:generateContent", self.config.gemini_base_url(), model)
230 } else {
231 format!(
232 "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:generateContent",
233 self.config.location, self.config.project_id, self.config.location, model
234 )
235 }
236 }
237
238 pub fn get_tts_endpoint(&self, model: &str) -> String {
240 if self.config.is_gemini() {
241 format!("{}/models/{}:generateContent", self.config.gemini_base_url(), model)
242 } else {
243 format!(
244 "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:generateContent",
245 self.config.location, self.config.project_id, self.config.location, model
246 )
247 }
248 }
249
250 async fn add_auth(&self, builder: reqwest::RequestBuilder) -> Result<reqwest::RequestBuilder, Error> {
252 if self.config.is_gemini() {
253 let key = self.config.gemini_api_key.as_deref().unwrap_or_default();
254 Ok(builder.header("x-goog-api-key", key))
255 } else {
256 let token = self.auth.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
257 Ok(builder.header("Authorization", format!("Bearer {}", token)))
258 }
259 }
260
261
262 #[instrument(level = "info", name = "multimodal_generate_image", skip(self, params))]
271 pub async fn generate_image(
272 &self,
273 params: MultimodalImageParams,
274 ) -> Result<ImageGenerateResult, Error> {
275 params.validate().map_err(|errors| {
277 let messages: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
278 Error::validation(messages.join("; "))
279 })?;
280
281 info!(model = %params.model, "Generating image with Gemini API");
282
283 let request = GeminiImageRequest {
285 contents: vec![GeminiContent {
286 role: "user".to_string(),
287 parts: vec![GeminiPart::Text {
288 text: format!("Generate an image of: {}", params.prompt),
289 }],
290 }],
291 generation_config: GeminiGenerationConfig {
292 response_modalities: vec!["TEXT".to_string(), "IMAGE".to_string()],
293 image_config: Some(GeminiImageConfig {
294 aspect_ratio: "1:1".to_string(),
295 }),
296 temperature: None,
297 max_output_tokens: None,
298 },
299 };
300
301 let endpoint = self.get_image_endpoint(¶ms.model);
303 debug!(endpoint = %endpoint, "Calling Gemini API for image generation");
304
305 let builder = self.http
306 .post(&endpoint)
307 .header("Content-Type", "application/json")
308 .json(&request);
309 let builder = self.add_auth(builder).await?;
310
311 let response = builder
312 .send()
313 .await
314 .map_err(|e| Error::api(&endpoint, 0, format!("Request failed: {}", e)))?;
315
316 let status = response.status();
317 if !status.is_success() {
318 let body = response.text().await.unwrap_or_default();
319 return Err(Error::api(&endpoint, status.as_u16(), body));
320 }
321
322 let response_text = response.text().await.map_err(|e| {
324 Error::api(&endpoint, status.as_u16(), format!("Failed to read response: {}", e))
325 })?;
326
327 debug!(response = %response_text, "Raw Gemini image API response");
328
329 let api_response: GeminiResponse = serde_json::from_str(&response_text).map_err(|e| {
331 Error::api(
332 &endpoint,
333 status.as_u16(),
334 format!("Failed to parse response: {}. Raw: {}", e, &response_text[..response_text.len().min(1000)]),
335 )
336 })?;
337
338 let image = self.extract_image_from_response(&api_response)?;
340
341 info!("Received image from Gemini API");
342
343 self.handle_image_output(image, ¶ms).await
345 }
346
347 #[instrument(level = "info", name = "multimodal_synthesize_speech", skip(self, params))]
356 pub async fn synthesize_speech(&self, params: MultimodalTtsParams) -> Result<TtsResult, Error> {
357 params.validate().map_err(|errors| {
359 let messages: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
360 Error::validation(messages.join("; "))
361 })?;
362
363 let voice = params.get_voice();
364 info!(voice = %voice, model = %params.model, "Synthesizing speech with Gemini API");
365
366 let prompt = if let Some(ref style) = params.style {
368 format!(
369 "Say the following text in a {} tone: {}",
370 style, params.text
371 )
372 } else {
373 params.text.clone()
374 };
375
376 let request = GeminiTtsRequest {
378 contents: vec![GeminiContent {
379 role: "user".to_string(),
380 parts: vec![GeminiPart::Text { text: prompt }],
381 }],
382 generation_config: GeminiTtsGenerationConfig {
383 response_modalities: vec!["AUDIO".to_string()],
384 speech_config: GeminiSpeechConfig {
385 voice_config: GeminiVoiceConfig {
386 prebuilt_voice_config: GeminiPrebuiltVoiceConfig {
387 voice_name: voice.to_string(),
388 },
389 },
390 },
391 },
392 };
393
394 let endpoint = self.get_tts_endpoint(¶ms.model);
396 debug!(endpoint = %endpoint, "Calling Gemini API for TTS");
397
398 let builder = self.http
399 .post(&endpoint)
400 .header("Content-Type", "application/json")
401 .json(&request);
402 let builder = self.add_auth(builder).await?;
403
404 let response = builder
405 .send()
406 .await
407 .map_err(|e| Error::api(&endpoint, 0, format!("Request failed: {}", e)))?;
408
409 let status = response.status();
410 if !status.is_success() {
411 let body = response.text().await.unwrap_or_default();
412 return Err(Error::api(&endpoint, status.as_u16(), body));
413 }
414
415 let response_text = response.text().await.map_err(|e| {
417 Error::api(&endpoint, status.as_u16(), format!("Failed to read response: {}", e))
418 })?;
419
420 debug!(response = %response_text, "Raw Gemini TTS API response");
421
422 let api_response: GeminiResponse = serde_json::from_str(&response_text).map_err(|e| {
424 Error::api(
425 &endpoint,
426 status.as_u16(),
427 format!("Failed to parse response: {}. Raw: {}", e, &response_text[..response_text.len().min(1000)]),
428 )
429 })?;
430
431 let audio = self.extract_audio_from_response(&api_response)?;
433
434 info!("Received audio from Gemini API");
435
436 self.handle_audio_output(audio, ¶ms).await
438 }
439
440 pub fn list_voices(&self) -> Vec<VoiceInfo> {
442 AVAILABLE_VOICES
443 .iter()
444 .map(|&name| VoiceInfo {
445 name: name.to_string(),
446 description: format!("Gemini TTS voice: {}", name),
447 })
448 .collect()
449 }
450
451 pub fn list_language_codes(&self) -> Vec<LanguageCodeInfo> {
453 SUPPORTED_LANGUAGE_CODES
454 .iter()
455 .map(|&(code, name)| LanguageCodeInfo {
456 code: code.to_string(),
457 name: name.to_string(),
458 })
459 .collect()
460 }
461
462 fn extract_image_from_response(
464 &self,
465 response: &GeminiResponse,
466 ) -> Result<GeneratedImage, Error> {
467 for candidate in &response.candidates {
468 if let Some(ref content) = candidate.content {
469 for part in &content.parts {
470 if let GeminiResponsePart::InlineData { inline_data } = part {
471 return Ok(GeneratedImage {
472 data: inline_data.data.clone(),
473 mime_type: inline_data.mime_type.clone(),
474 });
475 }
476 }
477 }
478 }
479
480 Err(Error::api(
481 "gemini",
482 200,
483 "No image data found in response".to_string(),
484 ))
485 }
486
487 fn extract_audio_from_response(
489 &self,
490 response: &GeminiResponse,
491 ) -> Result<GeneratedAudio, Error> {
492 for candidate in &response.candidates {
493 if let Some(ref content) = candidate.content {
494 for part in &content.parts {
495 if let GeminiResponsePart::InlineData { inline_data } = part {
496 return Ok(GeneratedAudio {
497 data: inline_data.data.clone(),
498 mime_type: inline_data.mime_type.clone(),
499 });
500 }
501 }
502 }
503 }
504
505 Err(Error::api(
506 "gemini",
507 200,
508 "No audio data found in response".to_string(),
509 ))
510 }
511
512 async fn handle_image_output(
514 &self,
515 image: GeneratedImage,
516 params: &MultimodalImageParams,
517 ) -> Result<ImageGenerateResult, Error> {
518 if let Some(output_file) = ¶ms.output_file {
520 return self.save_image_to_file(image, output_file).await;
521 }
522
523 Ok(ImageGenerateResult::Base64(image))
525 }
526
527 async fn handle_audio_output(
529 &self,
530 audio: GeneratedAudio,
531 params: &MultimodalTtsParams,
532 ) -> Result<TtsResult, Error> {
533 if let Some(output_file) = ¶ms.output_file {
535 return self.save_audio_to_file(audio, output_file).await;
536 }
537
538 Ok(TtsResult::Base64(audio))
540 }
541
542 async fn save_image_to_file(
544 &self,
545 image: GeneratedImage,
546 output_file: &str,
547 ) -> Result<ImageGenerateResult, Error> {
548 let data = BASE64
550 .decode(&image.data)
551 .map_err(|e| Error::validation(format!("Invalid base64 data: {}", e)))?;
552
553 if let Some(parent) = Path::new(output_file).parent() {
555 if !parent.as_os_str().is_empty() {
556 tokio::fs::create_dir_all(parent).await?;
557 }
558 }
559
560 tokio::fs::write(output_file, &data).await?;
562
563 info!(path = %output_file, "Saved image to local file");
564 Ok(ImageGenerateResult::LocalFile(output_file.to_string()))
565 }
566
567 async fn save_audio_to_file(
569 &self,
570 audio: GeneratedAudio,
571 output_file: &str,
572 ) -> Result<TtsResult, Error> {
573 let data = BASE64
575 .decode(&audio.data)
576 .map_err(|e| Error::validation(format!("Invalid base64 data: {}", e)))?;
577
578 if let Some(parent) = Path::new(output_file).parent() {
580 if !parent.as_os_str().is_empty() {
581 tokio::fs::create_dir_all(parent).await?;
582 }
583 }
584
585 tokio::fs::write(output_file, &data).await?;
587
588 info!(path = %output_file, "Saved audio to local file");
589 Ok(TtsResult::LocalFile(output_file.to_string()))
590 }
591}
592
593
594#[derive(Debug, Serialize)]
600#[serde(rename_all = "camelCase")]
601pub struct GeminiImageRequest {
602 pub contents: Vec<GeminiContent>,
604 pub generation_config: GeminiGenerationConfig,
606}
607
608#[derive(Debug, Serialize)]
610#[serde(rename_all = "camelCase")]
611pub struct GeminiTtsRequest {
612 pub contents: Vec<GeminiContent>,
614 pub generation_config: GeminiTtsGenerationConfig,
616}
617
618#[derive(Debug, Serialize, Deserialize)]
620pub struct GeminiContent {
621 pub role: String,
623 pub parts: Vec<GeminiPart>,
625}
626
627#[derive(Debug, Serialize, Deserialize)]
629#[serde(untagged)]
630pub enum GeminiPart {
631 Text { text: String },
633}
634
635#[derive(Debug, Serialize)]
637#[serde(rename_all = "camelCase")]
638pub struct GeminiGenerationConfig {
639 pub response_modalities: Vec<String>,
641 #[serde(skip_serializing_if = "Option::is_none")]
643 pub image_config: Option<GeminiImageConfig>,
644 #[serde(skip_serializing_if = "Option::is_none")]
646 pub temperature: Option<f32>,
647 #[serde(skip_serializing_if = "Option::is_none")]
649 pub max_output_tokens: Option<u32>,
650}
651
652#[derive(Debug, Serialize)]
654#[serde(rename_all = "camelCase")]
655pub struct GeminiImageConfig {
656 pub aspect_ratio: String,
658}
659
660#[derive(Debug, Serialize)]
662#[serde(rename_all = "camelCase")]
663pub struct GeminiTtsGenerationConfig {
664 pub response_modalities: Vec<String>,
666 pub speech_config: GeminiSpeechConfig,
668}
669
670#[derive(Debug, Serialize)]
672#[serde(rename_all = "camelCase")]
673pub struct GeminiSpeechConfig {
674 pub voice_config: GeminiVoiceConfig,
676}
677
678#[derive(Debug, Serialize)]
680#[serde(rename_all = "camelCase")]
681pub struct GeminiVoiceConfig {
682 pub prebuilt_voice_config: GeminiPrebuiltVoiceConfig,
684}
685
686#[derive(Debug, Serialize)]
688#[serde(rename_all = "camelCase")]
689pub struct GeminiPrebuiltVoiceConfig {
690 pub voice_name: String,
692}
693
694#[derive(Debug, Deserialize)]
696#[serde(rename_all = "camelCase")]
697pub struct GeminiResponse {
698 #[serde(default)]
700 pub candidates: Vec<GeminiCandidate>,
701}
702
703#[derive(Debug, Deserialize)]
705#[serde(rename_all = "camelCase")]
706pub struct GeminiCandidate {
707 pub content: Option<GeminiResponseContent>,
709}
710
711#[derive(Debug, Deserialize)]
713pub struct GeminiResponseContent {
714 pub parts: Vec<GeminiResponsePart>,
716}
717
718#[derive(Debug, Deserialize)]
720#[serde(untagged)]
721pub enum GeminiResponsePart {
722 InlineData {
724 #[serde(rename = "inlineData")]
725 inline_data: GeminiInlineData,
726 },
727 Text { text: String },
729}
730
731#[derive(Debug, Deserialize)]
733#[serde(rename_all = "camelCase")]
734pub struct GeminiInlineData {
735 pub mime_type: String,
737 pub data: String,
739}
740
741#[derive(Debug, Clone)]
747pub struct GeneratedImage {
748 pub data: String,
750 pub mime_type: String,
752}
753
754#[derive(Debug, Clone)]
756pub struct GeneratedAudio {
757 pub data: String,
759 pub mime_type: String,
761}
762
763#[derive(Debug)]
765pub enum ImageGenerateResult {
766 Base64(GeneratedImage),
768 LocalFile(String),
770}
771
772#[derive(Debug)]
774pub enum TtsResult {
775 Base64(GeneratedAudio),
777 LocalFile(String),
779}
780
781#[derive(Debug, Clone, Serialize)]
783pub struct VoiceInfo {
784 pub name: String,
786 pub description: String,
788}
789
790#[derive(Debug, Clone, Serialize)]
792pub struct LanguageCodeInfo {
793 pub code: String,
795 pub name: String,
797}
798
799#[cfg(test)]
800mod tests {
801 use super::*;
802
803 #[test]
804 fn test_default_image_params() {
805 let params: MultimodalImageParams =
806 serde_json::from_str(r#"{"prompt": "A cat"}"#).unwrap();
807 assert_eq!(params.model, DEFAULT_IMAGE_MODEL);
808 assert!(params.output_file.is_none());
809 }
810
811 #[test]
812 fn test_valid_image_params() {
813 let params = MultimodalImageParams {
814 prompt: "A beautiful sunset".to_string(),
815 model: DEFAULT_IMAGE_MODEL.to_string(),
816 output_file: None,
817 };
818
819 assert!(params.validate().is_ok());
820 }
821
822 #[test]
823 fn test_empty_prompt_image() {
824 let params = MultimodalImageParams {
825 prompt: " ".to_string(),
826 model: DEFAULT_IMAGE_MODEL.to_string(),
827 output_file: None,
828 };
829
830 let result = params.validate();
831 assert!(result.is_err());
832 let errors = result.unwrap_err();
833 assert!(errors.iter().any(|e| e.field == "prompt"));
834 }
835
836 #[test]
837 fn test_default_tts_params() {
838 let params: MultimodalTtsParams =
839 serde_json::from_str(r#"{"text": "Hello world"}"#).unwrap();
840 assert_eq!(params.model, DEFAULT_TTS_MODEL);
841 assert!(params.voice.is_none());
842 assert!(params.style.is_none());
843 assert!(params.output_file.is_none());
844 }
845
846 #[test]
847 fn test_valid_tts_params() {
848 let params = MultimodalTtsParams {
849 text: "Hello world".to_string(),
850 voice: Some("Kore".to_string()),
851 style: Some("cheerful".to_string()),
852 model: DEFAULT_TTS_MODEL.to_string(),
853 output_file: None,
854 };
855
856 assert!(params.validate().is_ok());
857 }
858
859 #[test]
860 fn test_empty_text_tts() {
861 let params = MultimodalTtsParams {
862 text: " ".to_string(),
863 voice: None,
864 style: None,
865 model: DEFAULT_TTS_MODEL.to_string(),
866 output_file: None,
867 };
868
869 let result = params.validate();
870 assert!(result.is_err());
871 let errors = result.unwrap_err();
872 assert!(errors.iter().any(|e| e.field == "text"));
873 }
874
875 #[test]
876 fn test_invalid_voice() {
877 let params = MultimodalTtsParams {
878 text: "Hello".to_string(),
879 voice: Some("InvalidVoice".to_string()),
880 style: None,
881 model: DEFAULT_TTS_MODEL.to_string(),
882 output_file: None,
883 };
884
885 let result = params.validate();
886 assert!(result.is_err());
887 let errors = result.unwrap_err();
888 assert!(errors.iter().any(|e| e.field == "voice"));
889 }
890
891 #[test]
892 fn test_invalid_style() {
893 let params = MultimodalTtsParams {
894 text: "Hello".to_string(),
895 voice: None,
896 style: Some("invalid_style".to_string()),
897 model: DEFAULT_TTS_MODEL.to_string(),
898 output_file: None,
899 };
900
901 let result = params.validate();
902 assert!(result.is_err());
903 let errors = result.unwrap_err();
904 assert!(errors.iter().any(|e| e.field == "style"));
905 }
906
907 #[test]
908 fn test_get_voice_default() {
909 let params = MultimodalTtsParams {
910 text: "Hello".to_string(),
911 voice: None,
912 style: None,
913 model: DEFAULT_TTS_MODEL.to_string(),
914 output_file: None,
915 };
916
917 assert_eq!(params.get_voice(), DEFAULT_VOICE);
918 }
919
920 #[test]
921 fn test_get_voice_custom() {
922 let params = MultimodalTtsParams {
923 text: "Hello".to_string(),
924 voice: Some("Puck".to_string()),
925 style: None,
926 model: DEFAULT_TTS_MODEL.to_string(),
927 output_file: None,
928 };
929
930 assert_eq!(params.get_voice(), "Puck");
931 }
932
933 #[test]
934 fn test_all_valid_voices() {
935 for voice in AVAILABLE_VOICES {
936 let params = MultimodalTtsParams {
937 text: "Hello".to_string(),
938 voice: Some(voice.to_string()),
939 style: None,
940 model: DEFAULT_TTS_MODEL.to_string(),
941 output_file: None,
942 };
943 assert!(
944 params.validate().is_ok(),
945 "Voice {} should be valid",
946 voice
947 );
948 }
949 }
950
951 #[test]
952 fn test_all_valid_styles() {
953 for style in AVAILABLE_STYLES {
954 let params = MultimodalTtsParams {
955 text: "Hello".to_string(),
956 voice: None,
957 style: Some(style.to_string()),
958 model: DEFAULT_TTS_MODEL.to_string(),
959 output_file: None,
960 };
961 assert!(
962 params.validate().is_ok(),
963 "Style {} should be valid",
964 style
965 );
966 }
967 }
968
969 #[test]
970 fn test_serialization_roundtrip_image() {
971 let params = MultimodalImageParams {
972 prompt: "A cat".to_string(),
973 model: "custom-model".to_string(),
974 output_file: Some("/tmp/output.png".to_string()),
975 };
976
977 let json = serde_json::to_string(¶ms).unwrap();
978 let deserialized: MultimodalImageParams = serde_json::from_str(&json).unwrap();
979
980 assert_eq!(params.prompt, deserialized.prompt);
981 assert_eq!(params.model, deserialized.model);
982 assert_eq!(params.output_file, deserialized.output_file);
983 }
984
985 #[test]
986 fn test_serialization_roundtrip_tts() {
987 let params = MultimodalTtsParams {
988 text: "Hello world".to_string(),
989 voice: Some("Kore".to_string()),
990 style: Some("cheerful".to_string()),
991 model: "custom-model".to_string(),
992 output_file: Some("/tmp/output.wav".to_string()),
993 };
994
995 let json = serde_json::to_string(¶ms).unwrap();
996 let deserialized: MultimodalTtsParams = serde_json::from_str(&json).unwrap();
997
998 assert_eq!(params.text, deserialized.text);
999 assert_eq!(params.voice, deserialized.voice);
1000 assert_eq!(params.style, deserialized.style);
1001 assert_eq!(params.model, deserialized.model);
1002 assert_eq!(params.output_file, deserialized.output_file);
1003 }
1004}