Skip to main content

adk_rust_mcp_multimodal/
handler.rs

1//! Multimodal generation handler for the MCP Multimodal server.
2//!
3//! This module provides the `MultimodalHandler` struct and parameter types for
4//! image generation and text-to-speech using Google's Gemini API.
5
6use adk_rust_mcp_common::auth::AuthProvider;
7use adk_rust_mcp_common::config::Config;
8use adk_rust_mcp_common::error::Error;
9use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12use std::path::Path;
13use tracing::{debug, info, instrument};
14
15/// Default model for multimodal image generation.
16pub const DEFAULT_IMAGE_MODEL: &str = "gemini-2.5-flash-image";
17
18/// Default model for multimodal TTS.
19pub const DEFAULT_TTS_MODEL: &str = "gemini-2.5-flash-preview-tts";
20
21/// Default voice for multimodal TTS.
22pub const DEFAULT_VOICE: &str = "Kore";
23
24/// Available Gemini TTS voices.
25pub const AVAILABLE_VOICES: &[&str] = &[
26    "Zephyr", "Puck", "Charon", "Kore", "Fenrir", "Leda", "Orus", "Aoede",
27    "Callirrhoe", "Autonoe", "Enceladus", "Iapetus", "Umbriel", "Algieba",
28    "Despina", "Erinome", "Algenib", "Rasalgethi", "Laomedeia", "Achernar",
29    "Alnilam", "Schedar", "Gacrux", "Pulcherrima", "Achird", "Zubenelgenubi",
30    "Vindemiatrix", "Sadachbia", "Sadaltager", "Sulafat",
31];
32
33/// Available TTS styles.
34pub const AVAILABLE_STYLES: &[&str] = &[
35    "neutral", "cheerful", "sad", "angry", "fearful", "surprised", "calm",
36];
37
38/// Supported language codes for Gemini TTS (auto-detected, BCP-47).
39pub const SUPPORTED_LANGUAGE_CODES: &[(&str, &str)] = &[
40    ("en", "English"), ("ar", "Arabic"), ("bn", "Bangla"), ("nl", "Dutch"),
41    ("fr", "French"), ("de", "German"), ("hi", "Hindi"), ("id", "Indonesian"),
42    ("it", "Italian"), ("ja", "Japanese"), ("ko", "Korean"), ("mr", "Marathi"),
43    ("pl", "Polish"), ("pt", "Portuguese"), ("ro", "Romanian"), ("ru", "Russian"),
44    ("es", "Spanish"), ("ta", "Tamil"), ("te", "Telugu"), ("th", "Thai"),
45    ("tr", "Turkish"), ("uk", "Ukrainian"), ("vi", "Vietnamese"), ("fil", "Filipino"),
46    ("fi", "Finnish"), ("el", "Greek"), ("gu", "Gujarati"), ("he", "Hebrew"),
47    ("hu", "Hungarian"), ("sv", "Swedish"), ("zh", "Chinese (Mandarin)"),
48    ("cs", "Czech"), ("da", "Danish"), ("nb", "Norwegian"),
49];
50
51/// Multimodal image generation parameters.
52///
53/// These parameters control image generation via the Gemini API.
54#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
55pub struct MultimodalImageParams {
56    /// Text prompt describing the image to generate.
57    pub prompt: String,
58
59    /// Model to use for generation.
60    #[serde(default = "default_image_model")]
61    pub model: String,
62
63    /// Output file path for saving the image locally.
64    /// If not specified, returns base64-encoded data.
65    #[serde(default, skip_serializing_if = "Option::is_none")]
66    pub output_file: Option<String>,
67}
68
69fn default_image_model() -> String {
70    DEFAULT_IMAGE_MODEL.to_string()
71}
72
73/// Multimodal TTS parameters.
74///
75/// These parameters control text-to-speech synthesis via the Gemini API.
76#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
77pub struct MultimodalTtsParams {
78    /// Text to synthesize into speech.
79    pub text: String,
80
81    /// Voice name to use.
82    #[serde(default, skip_serializing_if = "Option::is_none")]
83    pub voice: Option<String>,
84
85    /// Style/tone for the speech (e.g., "cheerful", "calm", "neutral").
86    #[serde(default, skip_serializing_if = "Option::is_none")]
87    pub style: Option<String>,
88
89    /// Model to use for TTS.
90    #[serde(default = "default_tts_model")]
91    pub model: String,
92
93    /// Output file path for saving the audio locally.
94    /// If not specified, returns base64-encoded data.
95    #[serde(default, skip_serializing_if = "Option::is_none")]
96    pub output_file: Option<String>,
97}
98
99fn default_tts_model() -> String {
100    DEFAULT_TTS_MODEL.to_string()
101}
102
103/// Validation error details.
104#[derive(Debug, Clone)]
105pub struct ValidationError {
106    /// The field that failed validation.
107    pub field: String,
108    /// Description of the validation failure.
109    pub message: String,
110}
111
112impl std::fmt::Display for ValidationError {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        write!(f, "{}: {}", self.field, self.message)
115    }
116}
117
118impl MultimodalImageParams {
119    /// Validate the parameters.
120    pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
121        let mut errors = Vec::new();
122
123        // Validate prompt is not empty
124        if self.prompt.trim().is_empty() {
125            errors.push(ValidationError {
126                field: "prompt".to_string(),
127                message: "Prompt cannot be empty".to_string(),
128            });
129        }
130
131        if errors.is_empty() {
132            Ok(())
133        } else {
134            Err(errors)
135        }
136    }
137}
138
139impl MultimodalTtsParams {
140    /// Validate the parameters.
141    pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
142        let mut errors = Vec::new();
143
144        // Validate text is not empty
145        if self.text.trim().is_empty() {
146            errors.push(ValidationError {
147                field: "text".to_string(),
148                message: "Text cannot be empty".to_string(),
149            });
150        }
151
152        // Validate voice if provided
153        if let Some(ref voice) = self.voice {
154            if !AVAILABLE_VOICES.contains(&voice.as_str()) {
155                errors.push(ValidationError {
156                    field: "voice".to_string(),
157                    message: format!(
158                        "Invalid voice '{}'. Available voices: {}",
159                        voice,
160                        AVAILABLE_VOICES.join(", ")
161                    ),
162                });
163            }
164        }
165
166        // Validate style if provided
167        if let Some(ref style) = self.style {
168            if !AVAILABLE_STYLES.contains(&style.as_str()) {
169                errors.push(ValidationError {
170                    field: "style".to_string(),
171                    message: format!(
172                        "Invalid style '{}'. Available styles: {}",
173                        style,
174                        AVAILABLE_STYLES.join(", ")
175                    ),
176                });
177            }
178        }
179
180        if errors.is_empty() {
181            Ok(())
182        } else {
183            Err(errors)
184        }
185    }
186
187    /// Get the voice name to use, defaulting if not specified.
188    pub fn get_voice(&self) -> &str {
189        self.voice.as_deref().unwrap_or(DEFAULT_VOICE)
190    }
191}
192
193/// Multimodal generation handler.
194///
195/// Handles image generation and TTS requests using the Gemini API.
196pub struct MultimodalHandler {
197    /// Application configuration.
198    pub config: Config,
199    /// HTTP client for API requests.
200    pub http: reqwest::Client,
201    /// Authentication provider.
202    pub auth: AuthProvider,
203}
204
205impl MultimodalHandler {
206    /// Create a new MultimodalHandler with the given configuration.
207    ///
208    /// # Errors
209    /// Returns an error if auth provider initialization fails.
210    #[instrument(level = "debug", name = "multimodal_handler_new", skip_all)]
211    pub async fn new(config: Config) -> Result<Self, Error> {
212        debug!("Initializing MultimodalHandler");
213
214        let auth = AuthProvider::new().await?;
215        let http = reqwest::Client::new();
216
217        Ok(Self { config, http, auth })
218    }
219
220    /// Create a new MultimodalHandler with provided dependencies (for testing).
221    #[cfg(test)]
222    pub fn with_deps(config: Config, http: reqwest::Client, auth: AuthProvider) -> Self {
223        Self { config, http, auth }
224    }
225
226    /// Get the Gemini API endpoint for image generation.
227    pub fn get_image_endpoint(&self, model: &str) -> String {
228        if self.config.is_gemini() {
229            format!("{}/models/{}:generateContent", self.config.gemini_base_url(), model)
230        } else {
231            format!(
232                "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:generateContent",
233                self.config.location, self.config.project_id, self.config.location, model
234            )
235        }
236    }
237
238    /// Get the Gemini API endpoint for TTS.
239    pub fn get_tts_endpoint(&self, model: &str) -> String {
240        if self.config.is_gemini() {
241            format!("{}/models/{}:generateContent", self.config.gemini_base_url(), model)
242        } else {
243            format!(
244                "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:generateContent",
245                self.config.location, self.config.project_id, self.config.location, model
246            )
247        }
248    }
249
250    /// Add auth headers based on provider.
251    async fn add_auth(&self, builder: reqwest::RequestBuilder) -> Result<reqwest::RequestBuilder, Error> {
252        if self.config.is_gemini() {
253            let key = self.config.gemini_api_key.as_deref().unwrap_or_default();
254            Ok(builder.header("x-goog-api-key", key))
255        } else {
256            let token = self.auth.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
257            Ok(builder.header("Authorization", format!("Bearer {}", token)))
258        }
259    }
260
261
262    /// Generate an image from a text prompt using Gemini.
263    ///
264    /// # Arguments
265    /// * `params` - Image generation parameters
266    ///
267    /// # Returns
268    /// * `Ok(ImageGenerateResult)` - Generated image with data or path
269    /// * `Err(Error)` - If validation fails, API call fails, or output handling fails
270    #[instrument(level = "info", name = "multimodal_generate_image", skip(self, params))]
271    pub async fn generate_image(
272        &self,
273        params: MultimodalImageParams,
274    ) -> Result<ImageGenerateResult, Error> {
275        // Validate parameters
276        params.validate().map_err(|errors| {
277            let messages: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
278            Error::validation(messages.join("; "))
279        })?;
280
281        info!(model = %params.model, "Generating image with Gemini API");
282
283        // Build the API request
284        let request = GeminiImageRequest {
285            contents: vec![GeminiContent {
286                role: "user".to_string(),
287                parts: vec![GeminiPart::Text {
288                    text: format!("Generate an image of: {}", params.prompt),
289                }],
290            }],
291            generation_config: GeminiGenerationConfig {
292                response_modalities: vec!["TEXT".to_string(), "IMAGE".to_string()],
293                image_config: Some(GeminiImageConfig {
294                    aspect_ratio: "1:1".to_string(),
295                }),
296                temperature: None,
297                max_output_tokens: None,
298            },
299        };
300
301        // Make API request
302        let endpoint = self.get_image_endpoint(&params.model);
303        debug!(endpoint = %endpoint, "Calling Gemini API for image generation");
304
305        let builder = self.http
306            .post(&endpoint)
307            .header("Content-Type", "application/json")
308            .json(&request);
309        let builder = self.add_auth(builder).await?;
310
311        let response = builder
312            .send()
313            .await
314            .map_err(|e| Error::api(&endpoint, 0, format!("Request failed: {}", e)))?;
315
316        let status = response.status();
317        if !status.is_success() {
318            let body = response.text().await.unwrap_or_default();
319            return Err(Error::api(&endpoint, status.as_u16(), body));
320        }
321
322        // Get raw response text for debugging
323        let response_text = response.text().await.map_err(|e| {
324            Error::api(&endpoint, status.as_u16(), format!("Failed to read response: {}", e))
325        })?;
326        
327        debug!(response = %response_text, "Raw Gemini image API response");
328
329        // Parse response
330        let api_response: GeminiResponse = serde_json::from_str(&response_text).map_err(|e| {
331            Error::api(
332                &endpoint,
333                status.as_u16(),
334                format!("Failed to parse response: {}. Raw: {}", e, &response_text[..response_text.len().min(1000)]),
335            )
336        })?;
337
338        // Extract image from response
339        let image = self.extract_image_from_response(&api_response)?;
340
341        info!("Received image from Gemini API");
342
343        // Handle output based on params
344        self.handle_image_output(image, &params).await
345    }
346
347    /// Synthesize speech from text using Gemini.
348    ///
349    /// # Arguments
350    /// * `params` - TTS parameters
351    ///
352    /// # Returns
353    /// * `Ok(TtsResult)` - Generated audio with data or path
354    /// * `Err(Error)` - If validation fails, API call fails, or output handling fails
355    #[instrument(level = "info", name = "multimodal_synthesize_speech", skip(self, params))]
356    pub async fn synthesize_speech(&self, params: MultimodalTtsParams) -> Result<TtsResult, Error> {
357        // Validate parameters
358        params.validate().map_err(|errors| {
359            let messages: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
360            Error::validation(messages.join("; "))
361        })?;
362
363        let voice = params.get_voice();
364        info!(voice = %voice, model = %params.model, "Synthesizing speech with Gemini API");
365
366        // Build the prompt with style if provided
367        let prompt = if let Some(ref style) = params.style {
368            format!(
369                "Say the following text in a {} tone: {}",
370                style, params.text
371            )
372        } else {
373            params.text.clone()
374        };
375
376        // Build the API request
377        let request = GeminiTtsRequest {
378            contents: vec![GeminiContent {
379                role: "user".to_string(),
380                parts: vec![GeminiPart::Text { text: prompt }],
381            }],
382            generation_config: GeminiTtsGenerationConfig {
383                response_modalities: vec!["AUDIO".to_string()],
384                speech_config: GeminiSpeechConfig {
385                    voice_config: GeminiVoiceConfig {
386                        prebuilt_voice_config: GeminiPrebuiltVoiceConfig {
387                            voice_name: voice.to_string(),
388                        },
389                    },
390                },
391            },
392        };
393
394        // Make API request
395        let endpoint = self.get_tts_endpoint(&params.model);
396        debug!(endpoint = %endpoint, "Calling Gemini API for TTS");
397
398        let builder = self.http
399            .post(&endpoint)
400            .header("Content-Type", "application/json")
401            .json(&request);
402        let builder = self.add_auth(builder).await?;
403
404        let response = builder
405            .send()
406            .await
407            .map_err(|e| Error::api(&endpoint, 0, format!("Request failed: {}", e)))?;
408
409        let status = response.status();
410        if !status.is_success() {
411            let body = response.text().await.unwrap_or_default();
412            return Err(Error::api(&endpoint, status.as_u16(), body));
413        }
414
415        // Get raw response text for debugging
416        let response_text = response.text().await.map_err(|e| {
417            Error::api(&endpoint, status.as_u16(), format!("Failed to read response: {}", e))
418        })?;
419        
420        debug!(response = %response_text, "Raw Gemini TTS API response");
421
422        // Parse response
423        let api_response: GeminiResponse = serde_json::from_str(&response_text).map_err(|e| {
424            Error::api(
425                &endpoint,
426                status.as_u16(),
427                format!("Failed to parse response: {}. Raw: {}", e, &response_text[..response_text.len().min(1000)]),
428            )
429        })?;
430
431        // Extract audio from response
432        let audio = self.extract_audio_from_response(&api_response)?;
433
434        info!("Received audio from Gemini API");
435
436        // Handle output based on params
437        self.handle_audio_output(audio, &params).await
438    }
439
440    /// List available voices.
441    pub fn list_voices(&self) -> Vec<VoiceInfo> {
442        AVAILABLE_VOICES
443            .iter()
444            .map(|&name| VoiceInfo {
445                name: name.to_string(),
446                description: format!("Gemini TTS voice: {}", name),
447            })
448            .collect()
449    }
450
451    /// List supported language codes.
452    pub fn list_language_codes(&self) -> Vec<LanguageCodeInfo> {
453        SUPPORTED_LANGUAGE_CODES
454            .iter()
455            .map(|&(code, name)| LanguageCodeInfo {
456                code: code.to_string(),
457                name: name.to_string(),
458            })
459            .collect()
460    }
461
462    /// Extract image data from Gemini response.
463    fn extract_image_from_response(
464        &self,
465        response: &GeminiResponse,
466    ) -> Result<GeneratedImage, Error> {
467        for candidate in &response.candidates {
468            if let Some(ref content) = candidate.content {
469                for part in &content.parts {
470                    if let GeminiResponsePart::InlineData { inline_data } = part {
471                        return Ok(GeneratedImage {
472                            data: inline_data.data.clone(),
473                            mime_type: inline_data.mime_type.clone(),
474                        });
475                    }
476                }
477            }
478        }
479
480        Err(Error::api(
481            "gemini",
482            200,
483            "No image data found in response".to_string(),
484        ))
485    }
486
487    /// Extract audio data from Gemini response.
488    fn extract_audio_from_response(
489        &self,
490        response: &GeminiResponse,
491    ) -> Result<GeneratedAudio, Error> {
492        for candidate in &response.candidates {
493            if let Some(ref content) = candidate.content {
494                for part in &content.parts {
495                    if let GeminiResponsePart::InlineData { inline_data } = part {
496                        return Ok(GeneratedAudio {
497                            data: inline_data.data.clone(),
498                            mime_type: inline_data.mime_type.clone(),
499                        });
500                    }
501                }
502            }
503        }
504
505        Err(Error::api(
506            "gemini",
507            200,
508            "No audio data found in response".to_string(),
509        ))
510    }
511
512    /// Handle output of generated image based on params.
513    async fn handle_image_output(
514        &self,
515        image: GeneratedImage,
516        params: &MultimodalImageParams,
517    ) -> Result<ImageGenerateResult, Error> {
518        // If output_file is specified, save to local file
519        if let Some(output_file) = &params.output_file {
520            return self.save_image_to_file(image, output_file).await;
521        }
522
523        // Otherwise, return base64-encoded data
524        Ok(ImageGenerateResult::Base64(image))
525    }
526
527    /// Handle output of generated audio based on params.
528    async fn handle_audio_output(
529        &self,
530        audio: GeneratedAudio,
531        params: &MultimodalTtsParams,
532    ) -> Result<TtsResult, Error> {
533        // If output_file is specified, save to local file
534        if let Some(output_file) = &params.output_file {
535            return self.save_audio_to_file(audio, output_file).await;
536        }
537
538        // Otherwise, return base64-encoded data
539        Ok(TtsResult::Base64(audio))
540    }
541
542    /// Save image to local file.
543    async fn save_image_to_file(
544        &self,
545        image: GeneratedImage,
546        output_file: &str,
547    ) -> Result<ImageGenerateResult, Error> {
548        // Decode base64 data
549        let data = BASE64
550            .decode(&image.data)
551            .map_err(|e| Error::validation(format!("Invalid base64 data: {}", e)))?;
552
553        // Ensure parent directory exists
554        if let Some(parent) = Path::new(output_file).parent() {
555            if !parent.as_os_str().is_empty() {
556                tokio::fs::create_dir_all(parent).await?;
557            }
558        }
559
560        // Write to file
561        tokio::fs::write(output_file, &data).await?;
562
563        info!(path = %output_file, "Saved image to local file");
564        Ok(ImageGenerateResult::LocalFile(output_file.to_string()))
565    }
566
567    /// Save audio to local file.
568    async fn save_audio_to_file(
569        &self,
570        audio: GeneratedAudio,
571        output_file: &str,
572    ) -> Result<TtsResult, Error> {
573        // Decode base64 data
574        let data = BASE64
575            .decode(&audio.data)
576            .map_err(|e| Error::validation(format!("Invalid base64 data: {}", e)))?;
577
578        // Ensure parent directory exists
579        if let Some(parent) = Path::new(output_file).parent() {
580            if !parent.as_os_str().is_empty() {
581                tokio::fs::create_dir_all(parent).await?;
582            }
583        }
584
585        // Write to file
586        tokio::fs::write(output_file, &data).await?;
587
588        info!(path = %output_file, "Saved audio to local file");
589        Ok(TtsResult::LocalFile(output_file.to_string()))
590    }
591}
592
593
594// =============================================================================
595// API Request/Response Types
596// =============================================================================
597
598/// Gemini API request for image generation.
599#[derive(Debug, Serialize)]
600#[serde(rename_all = "camelCase")]
601pub struct GeminiImageRequest {
602    /// Content parts
603    pub contents: Vec<GeminiContent>,
604    /// Generation configuration
605    pub generation_config: GeminiGenerationConfig,
606}
607
608/// Gemini API request for TTS.
609#[derive(Debug, Serialize)]
610#[serde(rename_all = "camelCase")]
611pub struct GeminiTtsRequest {
612    /// Content parts
613    pub contents: Vec<GeminiContent>,
614    /// Generation configuration
615    pub generation_config: GeminiTtsGenerationConfig,
616}
617
618/// Gemini content structure.
619#[derive(Debug, Serialize, Deserialize)]
620pub struct GeminiContent {
621    /// Role (user or model)
622    pub role: String,
623    /// Content parts
624    pub parts: Vec<GeminiPart>,
625}
626
627/// Gemini content part (request).
628#[derive(Debug, Serialize, Deserialize)]
629#[serde(untagged)]
630pub enum GeminiPart {
631    /// Text content
632    Text { text: String },
633}
634
635/// Gemini generation config for image generation.
636#[derive(Debug, Serialize)]
637#[serde(rename_all = "camelCase")]
638pub struct GeminiGenerationConfig {
639    /// Response modalities (TEXT, IMAGE, AUDIO)
640    pub response_modalities: Vec<String>,
641    /// Image configuration
642    #[serde(skip_serializing_if = "Option::is_none")]
643    pub image_config: Option<GeminiImageConfig>,
644    /// Temperature for generation
645    #[serde(skip_serializing_if = "Option::is_none")]
646    pub temperature: Option<f32>,
647    /// Max output tokens
648    #[serde(skip_serializing_if = "Option::is_none")]
649    pub max_output_tokens: Option<u32>,
650}
651
652/// Gemini image configuration.
653#[derive(Debug, Serialize)]
654#[serde(rename_all = "camelCase")]
655pub struct GeminiImageConfig {
656    /// Aspect ratio for generated images
657    pub aspect_ratio: String,
658}
659
660/// Gemini generation config for TTS.
661#[derive(Debug, Serialize)]
662#[serde(rename_all = "camelCase")]
663pub struct GeminiTtsGenerationConfig {
664    /// Response modalities (AUDIO)
665    pub response_modalities: Vec<String>,
666    /// Speech configuration
667    pub speech_config: GeminiSpeechConfig,
668}
669
670/// Gemini speech configuration.
671#[derive(Debug, Serialize)]
672#[serde(rename_all = "camelCase")]
673pub struct GeminiSpeechConfig {
674    /// Voice configuration
675    pub voice_config: GeminiVoiceConfig,
676}
677
678/// Gemini voice configuration.
679#[derive(Debug, Serialize)]
680#[serde(rename_all = "camelCase")]
681pub struct GeminiVoiceConfig {
682    /// Prebuilt voice configuration
683    pub prebuilt_voice_config: GeminiPrebuiltVoiceConfig,
684}
685
686/// Gemini prebuilt voice configuration.
687#[derive(Debug, Serialize)]
688#[serde(rename_all = "camelCase")]
689pub struct GeminiPrebuiltVoiceConfig {
690    /// Voice name
691    pub voice_name: String,
692}
693
694/// Gemini API response.
695#[derive(Debug, Deserialize)]
696#[serde(rename_all = "camelCase")]
697pub struct GeminiResponse {
698    /// Response candidates
699    #[serde(default)]
700    pub candidates: Vec<GeminiCandidate>,
701}
702
703/// Gemini response candidate.
704#[derive(Debug, Deserialize)]
705#[serde(rename_all = "camelCase")]
706pub struct GeminiCandidate {
707    /// Content
708    pub content: Option<GeminiResponseContent>,
709}
710
711/// Gemini response content.
712#[derive(Debug, Deserialize)]
713pub struct GeminiResponseContent {
714    /// Content parts
715    pub parts: Vec<GeminiResponsePart>,
716}
717
718/// Gemini response part.
719#[derive(Debug, Deserialize)]
720#[serde(untagged)]
721pub enum GeminiResponsePart {
722    /// Inline data (image or audio)
723    InlineData {
724        #[serde(rename = "inlineData")]
725        inline_data: GeminiInlineData,
726    },
727    /// Text content
728    Text { text: String },
729}
730
731/// Gemini inline data (base64 encoded).
732#[derive(Debug, Deserialize)]
733#[serde(rename_all = "camelCase")]
734pub struct GeminiInlineData {
735    /// MIME type
736    pub mime_type: String,
737    /// Base64-encoded data
738    pub data: String,
739}
740
741// =============================================================================
742// Result Types
743// =============================================================================
744
745/// Generated image data.
746#[derive(Debug, Clone)]
747pub struct GeneratedImage {
748    /// Base64-encoded image data
749    pub data: String,
750    /// MIME type of the image
751    pub mime_type: String,
752}
753
754/// Generated audio data.
755#[derive(Debug, Clone)]
756pub struct GeneratedAudio {
757    /// Base64-encoded audio data
758    pub data: String,
759    /// MIME type of the audio
760    pub mime_type: String,
761}
762
763/// Result of image generation.
764#[derive(Debug)]
765pub enum ImageGenerateResult {
766    /// Base64-encoded image data (when no output specified)
767    Base64(GeneratedImage),
768    /// Local file path (when output_file specified)
769    LocalFile(String),
770}
771
772/// Result of TTS synthesis.
773#[derive(Debug)]
774pub enum TtsResult {
775    /// Base64-encoded audio data (when no output specified)
776    Base64(GeneratedAudio),
777    /// Local file path (when output_file specified)
778    LocalFile(String),
779}
780
781/// Voice information.
782#[derive(Debug, Clone, Serialize)]
783pub struct VoiceInfo {
784    /// Voice name
785    pub name: String,
786    /// Voice description
787    pub description: String,
788}
789
790/// Language code information.
791#[derive(Debug, Clone, Serialize)]
792pub struct LanguageCodeInfo {
793    /// Language code (e.g., "en-US")
794    pub code: String,
795    /// Language name (e.g., "English (US)")
796    pub name: String,
797}
798
799#[cfg(test)]
800mod tests {
801    use super::*;
802
803    #[test]
804    fn test_default_image_params() {
805        let params: MultimodalImageParams =
806            serde_json::from_str(r#"{"prompt": "A cat"}"#).unwrap();
807        assert_eq!(params.model, DEFAULT_IMAGE_MODEL);
808        assert!(params.output_file.is_none());
809    }
810
811    #[test]
812    fn test_valid_image_params() {
813        let params = MultimodalImageParams {
814            prompt: "A beautiful sunset".to_string(),
815            model: DEFAULT_IMAGE_MODEL.to_string(),
816            output_file: None,
817        };
818
819        assert!(params.validate().is_ok());
820    }
821
822    #[test]
823    fn test_empty_prompt_image() {
824        let params = MultimodalImageParams {
825            prompt: "   ".to_string(),
826            model: DEFAULT_IMAGE_MODEL.to_string(),
827            output_file: None,
828        };
829
830        let result = params.validate();
831        assert!(result.is_err());
832        let errors = result.unwrap_err();
833        assert!(errors.iter().any(|e| e.field == "prompt"));
834    }
835
836    #[test]
837    fn test_default_tts_params() {
838        let params: MultimodalTtsParams =
839            serde_json::from_str(r#"{"text": "Hello world"}"#).unwrap();
840        assert_eq!(params.model, DEFAULT_TTS_MODEL);
841        assert!(params.voice.is_none());
842        assert!(params.style.is_none());
843        assert!(params.output_file.is_none());
844    }
845
846    #[test]
847    fn test_valid_tts_params() {
848        let params = MultimodalTtsParams {
849            text: "Hello world".to_string(),
850            voice: Some("Kore".to_string()),
851            style: Some("cheerful".to_string()),
852            model: DEFAULT_TTS_MODEL.to_string(),
853            output_file: None,
854        };
855
856        assert!(params.validate().is_ok());
857    }
858
859    #[test]
860    fn test_empty_text_tts() {
861        let params = MultimodalTtsParams {
862            text: "   ".to_string(),
863            voice: None,
864            style: None,
865            model: DEFAULT_TTS_MODEL.to_string(),
866            output_file: None,
867        };
868
869        let result = params.validate();
870        assert!(result.is_err());
871        let errors = result.unwrap_err();
872        assert!(errors.iter().any(|e| e.field == "text"));
873    }
874
875    #[test]
876    fn test_invalid_voice() {
877        let params = MultimodalTtsParams {
878            text: "Hello".to_string(),
879            voice: Some("InvalidVoice".to_string()),
880            style: None,
881            model: DEFAULT_TTS_MODEL.to_string(),
882            output_file: None,
883        };
884
885        let result = params.validate();
886        assert!(result.is_err());
887        let errors = result.unwrap_err();
888        assert!(errors.iter().any(|e| e.field == "voice"));
889    }
890
891    #[test]
892    fn test_invalid_style() {
893        let params = MultimodalTtsParams {
894            text: "Hello".to_string(),
895            voice: None,
896            style: Some("invalid_style".to_string()),
897            model: DEFAULT_TTS_MODEL.to_string(),
898            output_file: None,
899        };
900
901        let result = params.validate();
902        assert!(result.is_err());
903        let errors = result.unwrap_err();
904        assert!(errors.iter().any(|e| e.field == "style"));
905    }
906
907    #[test]
908    fn test_get_voice_default() {
909        let params = MultimodalTtsParams {
910            text: "Hello".to_string(),
911            voice: None,
912            style: None,
913            model: DEFAULT_TTS_MODEL.to_string(),
914            output_file: None,
915        };
916
917        assert_eq!(params.get_voice(), DEFAULT_VOICE);
918    }
919
920    #[test]
921    fn test_get_voice_custom() {
922        let params = MultimodalTtsParams {
923            text: "Hello".to_string(),
924            voice: Some("Puck".to_string()),
925            style: None,
926            model: DEFAULT_TTS_MODEL.to_string(),
927            output_file: None,
928        };
929
930        assert_eq!(params.get_voice(), "Puck");
931    }
932
933    #[test]
934    fn test_all_valid_voices() {
935        for voice in AVAILABLE_VOICES {
936            let params = MultimodalTtsParams {
937                text: "Hello".to_string(),
938                voice: Some(voice.to_string()),
939                style: None,
940                model: DEFAULT_TTS_MODEL.to_string(),
941                output_file: None,
942            };
943            assert!(
944                params.validate().is_ok(),
945                "Voice {} should be valid",
946                voice
947            );
948        }
949    }
950
951    #[test]
952    fn test_all_valid_styles() {
953        for style in AVAILABLE_STYLES {
954            let params = MultimodalTtsParams {
955                text: "Hello".to_string(),
956                voice: None,
957                style: Some(style.to_string()),
958                model: DEFAULT_TTS_MODEL.to_string(),
959                output_file: None,
960            };
961            assert!(
962                params.validate().is_ok(),
963                "Style {} should be valid",
964                style
965            );
966        }
967    }
968
969    #[test]
970    fn test_serialization_roundtrip_image() {
971        let params = MultimodalImageParams {
972            prompt: "A cat".to_string(),
973            model: "custom-model".to_string(),
974            output_file: Some("/tmp/output.png".to_string()),
975        };
976
977        let json = serde_json::to_string(&params).unwrap();
978        let deserialized: MultimodalImageParams = serde_json::from_str(&json).unwrap();
979
980        assert_eq!(params.prompt, deserialized.prompt);
981        assert_eq!(params.model, deserialized.model);
982        assert_eq!(params.output_file, deserialized.output_file);
983    }
984
985    #[test]
986    fn test_serialization_roundtrip_tts() {
987        let params = MultimodalTtsParams {
988            text: "Hello world".to_string(),
989            voice: Some("Kore".to_string()),
990            style: Some("cheerful".to_string()),
991            model: "custom-model".to_string(),
992            output_file: Some("/tmp/output.wav".to_string()),
993        };
994
995        let json = serde_json::to_string(&params).unwrap();
996        let deserialized: MultimodalTtsParams = serde_json::from_str(&json).unwrap();
997
998        assert_eq!(params.text, deserialized.text);
999        assert_eq!(params.voice, deserialized.voice);
1000        assert_eq!(params.style, deserialized.style);
1001        assert_eq!(params.model, deserialized.model);
1002        assert_eq!(params.output_file, deserialized.output_file);
1003    }
1004}