Skip to main content

adk_rust_mcp_image/
handler.rs

1//! Image generation handler for the MCP Image server.
2//!
3//! This module provides the `ImageHandler` struct and parameter types for
4//! text-to-image generation using Google's Vertex AI Imagen API.
5
6use adk_rust_mcp_common::auth::AuthProvider;
7use adk_rust_mcp_common::config::Config;
8use adk_rust_mcp_common::error::Error;
9use adk_rust_mcp_common::gcs::{GcsClient, GcsUri};
10use adk_rust_mcp_common::models::{ImagenModel, ModelRegistry, IMAGEN_MODELS};
11use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
12use schemars::JsonSchema;
13use serde::{Deserialize, Serialize};
14use std::path::Path;
15use tracing::{debug, info, instrument};
16
17/// Valid aspect ratios for image generation.
18pub const VALID_ASPECT_RATIOS: &[&str] = &["1:1", "3:4", "4:3", "9:16", "16:9"];
19
20/// Default model for image generation.
21pub const DEFAULT_MODEL: &str = "imagen-3.0-generate-002";
22
23/// Minimum number of images that can be generated.
24pub const MIN_NUMBER_OF_IMAGES: u8 = 1;
25
26/// Maximum number of images that can be generated.
27pub const MAX_NUMBER_OF_IMAGES: u8 = 4;
28
29/// Text-to-image generation parameters.
30///
31/// These parameters control the image generation process via the Vertex AI Imagen API.
32#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
33pub struct ImageGenerateParams {
34    /// Text prompt describing the image to generate.
35    /// Maximum length depends on the model (480 chars for Imagen 3, 2000 for Imagen 4).
36    pub prompt: String,
37
38    /// Negative prompt - what to avoid in the generated image.
39    #[serde(default, skip_serializing_if = "Option::is_none")]
40    pub negative_prompt: Option<String>,
41
42    /// Model to use for generation.
43    /// Defaults to "imagen-4.0-generate-preview-05-20".
44    #[serde(default = "default_model")]
45    pub model: String,
46
47    /// Aspect ratio for the generated image.
48    /// Valid values: "1:1", "3:4", "4:3", "9:16", "16:9".
49    #[serde(default = "default_aspect_ratio")]
50    pub aspect_ratio: String,
51
52    /// Number of images to generate (1-4).
53    #[serde(default = "default_number_of_images")]
54    pub number_of_images: u8,
55
56    /// Random seed for reproducible generation.
57    #[serde(default, skip_serializing_if = "Option::is_none")]
58    pub seed: Option<i64>,
59
60    /// Output file path for saving the image locally.
61    /// If not specified and output_uri is not specified, returns base64-encoded data.
62    #[serde(default, skip_serializing_if = "Option::is_none")]
63    pub output_file: Option<String>,
64
65    /// Output storage URI (e.g., gs://bucket/path).
66    /// If specified, uploads the image to the storage backend.
67    #[serde(default, skip_serializing_if = "Option::is_none")]
68    pub output_uri: Option<String>,
69}
70
71fn default_model() -> String {
72    DEFAULT_MODEL.to_string()
73}
74
75fn default_aspect_ratio() -> String {
76    "1:1".to_string()
77}
78
79fn default_number_of_images() -> u8 {
80    1
81}
82
83/// Valid upscale factors.
84pub const VALID_UPSCALE_FACTORS: &[&str] = &["x2", "x4"];
85
86/// Default upscale model.
87pub const UPSCALE_MODEL: &str = "imagen-4.0-upscale-preview";
88
89/// Image upscaling parameters.
90///
91/// These parameters control the image upscaling process via the Vertex AI Imagen Upscale API.
92#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
93pub struct ImageUpscaleParams {
94    /// Source image to upscale.
95    /// Can be base64 data, local file path, or GCS URI.
96    pub image: String,
97
98    /// Upscale factor: "x2" or "x4".
99    #[serde(default = "default_upscale_factor")]
100    pub upscale_factor: String,
101
102    /// Output file path for saving the upscaled image locally.
103    /// If not specified and output_uri is not specified, returns base64-encoded data.
104    #[serde(default, skip_serializing_if = "Option::is_none")]
105    pub output_file: Option<String>,
106
107    /// Output storage URI (e.g., gs://bucket/path).
108    /// If specified, uploads the upscaled image to the storage backend.
109    #[serde(default, skip_serializing_if = "Option::is_none")]
110    pub output_uri: Option<String>,
111}
112
113fn default_upscale_factor() -> String {
114    "x2".to_string()
115}
116
117impl ImageUpscaleParams {
118    /// Validate the upscale parameters.
119    pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
120        let mut errors = Vec::new();
121
122        // Validate image is not empty
123        if self.image.trim().is_empty() {
124            errors.push(ValidationError {
125                field: "image".to_string(),
126                message: "Image cannot be empty".to_string(),
127            });
128        }
129
130        // Validate upscale factor
131        if !VALID_UPSCALE_FACTORS.contains(&self.upscale_factor.as_str()) {
132            errors.push(ValidationError {
133                field: "upscale_factor".to_string(),
134                message: format!(
135                    "Invalid upscale factor '{}'. Valid options: {}",
136                    self.upscale_factor,
137                    VALID_UPSCALE_FACTORS.join(", ")
138                ),
139            });
140        }
141
142        if errors.is_empty() {
143            Ok(())
144        } else {
145            Err(errors)
146        }
147    }
148}
149
150/// Validation error details for image generation parameters.
151#[derive(Debug, Clone)]
152pub struct ValidationError {
153    /// The field that failed validation.
154    pub field: String,
155    /// Description of the validation failure.
156    pub message: String,
157}
158
159impl std::fmt::Display for ValidationError {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        write!(f, "{}: {}", self.field, self.message)
162    }
163}
164
165impl ImageGenerateParams {
166    /// Validate the parameters against the model constraints.
167    ///
168    /// # Returns
169    /// - `Ok(())` if all parameters are valid
170    /// - `Err(Vec<ValidationError>)` with all validation errors
171    pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
172        let mut errors = Vec::new();
173
174        // Resolve the model to get constraints
175        let model = ModelRegistry::resolve_imagen(&self.model);
176
177        // Validate model exists
178        if model.is_none() {
179            errors.push(ValidationError {
180                field: "model".to_string(),
181                message: format!(
182                    "Unknown model '{}'. Valid models: {}",
183                    self.model,
184                    IMAGEN_MODELS
185                        .iter()
186                        .map(|m| m.id)
187                        .collect::<Vec<_>>()
188                        .join(", ")
189                ),
190            });
191        }
192
193        // Validate prompt length (if model is known)
194        if let Some(model) = model {
195            if self.prompt.len() > model.max_prompt_length {
196                errors.push(ValidationError {
197                    field: "prompt".to_string(),
198                    message: format!(
199                        "Prompt length {} exceeds maximum {} for model {}",
200                        self.prompt.len(),
201                        model.max_prompt_length,
202                        model.id
203                    ),
204                });
205            }
206
207            // Validate aspect ratio against model's supported ratios
208            if !model.supported_aspect_ratios.contains(&self.aspect_ratio.as_str()) {
209                errors.push(ValidationError {
210                    field: "aspect_ratio".to_string(),
211                    message: format!(
212                        "Invalid aspect ratio '{}'. Valid options for {}: {}",
213                        self.aspect_ratio,
214                        model.id,
215                        model.supported_aspect_ratios.join(", ")
216                    ),
217                });
218            }
219        } else {
220            // If model is unknown, validate against common aspect ratios
221            if !VALID_ASPECT_RATIOS.contains(&self.aspect_ratio.as_str()) {
222                errors.push(ValidationError {
223                    field: "aspect_ratio".to_string(),
224                    message: format!(
225                        "Invalid aspect ratio '{}'. Valid options: {}",
226                        self.aspect_ratio,
227                        VALID_ASPECT_RATIOS.join(", ")
228                    ),
229                });
230            }
231        }
232
233        // Validate number_of_images range
234        if self.number_of_images < MIN_NUMBER_OF_IMAGES
235            || self.number_of_images > MAX_NUMBER_OF_IMAGES
236        {
237            errors.push(ValidationError {
238                field: "number_of_images".to_string(),
239                message: format!(
240                    "number_of_images must be between {} and {}, got {}",
241                    MIN_NUMBER_OF_IMAGES, MAX_NUMBER_OF_IMAGES, self.number_of_images
242                ),
243            });
244        }
245
246        // Validate prompt is not empty
247        if self.prompt.trim().is_empty() {
248            errors.push(ValidationError {
249                field: "prompt".to_string(),
250                message: "Prompt cannot be empty".to_string(),
251            });
252        }
253
254        if errors.is_empty() {
255            Ok(())
256        } else {
257            Err(errors)
258        }
259    }
260
261    /// Get the resolved model definition.
262    pub fn get_model(&self) -> Option<&'static ImagenModel> {
263        ModelRegistry::resolve_imagen(&self.model)
264    }
265}
266
267/// Image generation handler.
268///
269/// Handles image generation requests using the Vertex AI Imagen API.
270pub struct ImageHandler {
271    /// Application configuration.
272    pub config: Config,
273    /// GCS client for storage operations.
274    pub gcs: GcsClient,
275    /// HTTP client for API requests.
276    pub http: reqwest::Client,
277    /// Authentication provider.
278    pub auth: AuthProvider,
279}
280
281impl ImageHandler {
282    /// Create a new ImageHandler with the given configuration.
283    ///
284    /// # Errors
285    /// Returns an error if GCS client or auth provider initialization fails.
286    #[instrument(level = "debug", name = "image_handler_new", skip_all)]
287    pub async fn new(config: Config) -> Result<Self, Error> {
288        debug!("Initializing ImageHandler");
289
290        let auth = AuthProvider::new().await?;
291        let gcs = GcsClient::with_auth(AuthProvider::new().await?);
292        let http = reqwest::Client::new();
293
294        Ok(Self {
295            config,
296            gcs,
297            http,
298            auth,
299        })
300    }
301
302    /// Create a new ImageHandler with provided dependencies (for testing).
303    #[cfg(test)]
304    pub fn with_deps(config: Config, gcs: GcsClient, http: reqwest::Client, auth: AuthProvider) -> Self {
305        Self {
306            config,
307            gcs,
308            http,
309            auth,
310        }
311    }
312
313    /// Get the Vertex AI Imagen API endpoint for the given model.
314    pub fn get_endpoint(&self, model: &str) -> String {
315        format!(
316            "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
317            self.config.location,
318            self.config.project_id,
319            self.config.location,
320            model
321        )
322    }
323
324    /// Generate images from a text prompt.
325    ///
326    /// # Arguments
327    /// * `params` - Image generation parameters
328    ///
329    /// # Returns
330    /// * `Ok(ImageGenerateResult)` - Generated images with their data or paths
331    /// * `Err(Error)` - If validation fails, API call fails, or output handling fails
332    #[instrument(level = "info", name = "generate_image", skip(self, params), fields(model = %params.model, aspect_ratio = %params.aspect_ratio))]
333    pub async fn generate_image(&self, params: ImageGenerateParams) -> Result<ImageGenerateResult, Error> {
334        // Validate parameters
335        params.validate().map_err(|errors| {
336            let messages: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
337            Error::validation(messages.join("; "))
338        })?;
339
340        // Resolve the model to get the canonical ID
341        let model = params.get_model().ok_or_else(|| {
342            Error::validation(format!("Unknown model: {}", params.model))
343        })?;
344
345        info!(model_id = model.id, "Generating image with Imagen API");
346
347        // Build the API request
348        let request = ImagenRequest {
349            instances: vec![ImagenInstance {
350                prompt: params.prompt.clone(),
351                negative_prompt: params.negative_prompt.clone(),
352            }],
353            parameters: ImagenParameters {
354                sample_count: params.number_of_images,
355                aspect_ratio: params.aspect_ratio.clone(),
356                seed: params.seed,
357            },
358        };
359
360        // Get auth token
361        let token = self.auth.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
362
363        // Make API request
364        let endpoint = self.get_endpoint(model.id);
365        debug!(endpoint = %endpoint, "Calling Imagen API");
366
367        let response = self.http
368            .post(&endpoint)
369            .header("Authorization", format!("Bearer {}", token))
370            .header("Content-Type", "application/json")
371            .json(&request)
372            .send()
373            .await
374            .map_err(|e| Error::api(&endpoint, 0, format!("Request failed: {}", e)))?;
375
376        let status = response.status();
377        if !status.is_success() {
378            let body = response.text().await.unwrap_or_default();
379            return Err(Error::api(&endpoint, status.as_u16(), body));
380        }
381
382        // Parse response
383        let api_response: ImagenResponse = response.json().await.map_err(|e| {
384            Error::api(&endpoint, status.as_u16(), format!("Failed to parse response: {}", e))
385        })?;
386
387        // Extract images from response
388        let images: Vec<GeneratedImage> = api_response
389            .predictions
390            .into_iter()
391            .filter_map(|p| {
392                p.bytes_base64_encoded.map(|data| GeneratedImage {
393                    data,
394                    mime_type: p.mime_type.unwrap_or_else(|| "image/png".to_string()),
395                })
396            })
397            .collect();
398
399        if images.is_empty() {
400            return Err(Error::api(&endpoint, 200, "No images returned from API"));
401        }
402
403        info!(count = images.len(), "Received images from API");
404
405        // Handle output based on params
406        self.handle_output(images, &params).await
407    }
408
409    /// Handle output of generated images based on params.
410    async fn handle_output(
411        &self,
412        images: Vec<GeneratedImage>,
413        params: &ImageGenerateParams,
414    ) -> Result<ImageGenerateResult, Error> {
415        // If output_uri is specified, upload to storage
416        if let Some(output_uri) = &params.output_uri {
417            return self.upload_to_storage(images, output_uri).await;
418        }
419
420        // If output_file is specified, save to local file
421        if let Some(output_file) = &params.output_file {
422            return self.save_to_file(images, output_file).await;
423        }
424
425        // Otherwise, return base64-encoded data
426        Ok(ImageGenerateResult::Base64(images))
427    }
428
429    /// Upload images to cloud storage.
430    async fn upload_to_storage(
431        &self,
432        images: Vec<GeneratedImage>,
433        output_uri: &str,
434    ) -> Result<ImageGenerateResult, Error> {
435        let mut uris = Vec::new();
436
437        for (i, image) in images.iter().enumerate() {
438            // Decode base64 data
439            let data = BASE64.decode(&image.data).map_err(|e| {
440                Error::validation(format!("Invalid base64 data: {}", e))
441            })?;
442
443            // Determine the URI for this image
444            let uri = if images.len() == 1 {
445                output_uri.to_string()
446            } else {
447                // Add index suffix for multiple images
448                // Handle GCS URIs properly - don't use Path which treats gs:// as filesystem path
449                Self::add_index_suffix_to_uri(output_uri, i, "image", "png")
450            };
451
452            // Parse GCS URI and upload
453            let gcs_uri = GcsUri::parse(&uri)?;
454            self.gcs.upload(&gcs_uri, &data, &image.mime_type).await?;
455            uris.push(uri);
456        }
457
458        info!(count = uris.len(), "Uploaded images to storage");
459        Ok(ImageGenerateResult::StorageUris(uris))
460    }
461
462    /// Add an index suffix to a URI or path for multi-output scenarios.
463    /// Handles both GCS URIs (gs://bucket/path) and local paths correctly.
464    fn add_index_suffix_to_uri(uri: &str, index: usize, default_stem: &str, default_ext: &str) -> String {
465        // For GCS URIs, extract the path portion after gs://bucket/
466        if let Some(stripped) = uri.strip_prefix("gs://") {
467            if let Some(slash_pos) = stripped.find('/') {
468                let bucket = &stripped[..slash_pos];
469                let object_path = &stripped[slash_pos + 1..];
470                
471                // Find the last component (filename)
472                let (dir, filename) = if let Some(last_slash) = object_path.rfind('/') {
473                    (&object_path[..last_slash], &object_path[last_slash + 1..])
474                } else {
475                    ("", object_path)
476                };
477                
478                // Split filename into stem and extension
479                let (stem, ext) = if let Some(dot_pos) = filename.rfind('.') {
480                    (&filename[..dot_pos], &filename[dot_pos + 1..])
481                } else {
482                    (filename, default_ext)
483                };
484                
485                let stem = if stem.is_empty() { default_stem } else { stem };
486                
487                if dir.is_empty() {
488                    format!("gs://{}/{}_{}.{}", bucket, stem, index, ext)
489                } else {
490                    format!("gs://{}/{}/{}_{}.{}", bucket, dir, stem, index, ext)
491                }
492            } else {
493                // Malformed GCS URI (no path after bucket), just append index
494                format!("{}/{}_{}.{}", uri, default_stem, index, default_ext)
495            }
496        } else {
497            // Local filesystem path - use Path
498            let path = Path::new(uri);
499            let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or(default_stem);
500            let ext = path.extension().and_then(|s| s.to_str()).unwrap_or(default_ext);
501            let parent = path.parent().and_then(|p| p.to_str()).unwrap_or("");
502            if parent.is_empty() {
503                format!("{}_{}.{}", stem, index, ext)
504            } else {
505                format!("{}/{}_{}.{}", parent, stem, index, ext)
506            }
507        }
508    }
509
510    /// Save images to local files.
511    async fn save_to_file(
512        &self,
513        images: Vec<GeneratedImage>,
514        output_file: &str,
515    ) -> Result<ImageGenerateResult, Error> {
516        let mut paths = Vec::new();
517
518        for (i, image) in images.iter().enumerate() {
519            // Decode base64 data
520            let data = BASE64.decode(&image.data).map_err(|e| {
521                Error::validation(format!("Invalid base64 data: {}", e))
522            })?;
523
524            // Determine the path for this image
525            let path = if images.len() == 1 {
526                output_file.to_string()
527            } else {
528                // Add index suffix for multiple images
529                let p = Path::new(output_file);
530                let stem = p.file_stem().and_then(|s| s.to_str()).unwrap_or("image");
531                let ext = p.extension().and_then(|s| s.to_str()).unwrap_or("png");
532                let parent = p.parent().and_then(|p| p.to_str()).unwrap_or("");
533                if parent.is_empty() {
534                    format!("{}_{}.{}", stem, i, ext)
535                } else {
536                    format!("{}/{}_{}.{}", parent, stem, i, ext)
537                }
538            };
539
540            // Ensure parent directory exists
541            if let Some(parent) = Path::new(&path).parent() {
542                if !parent.as_os_str().is_empty() {
543                    tokio::fs::create_dir_all(parent).await?;
544                }
545            }
546
547            // Write to file
548            tokio::fs::write(&path, &data).await?;
549            paths.push(path);
550        }
551
552        info!(count = paths.len(), "Saved images to local files");
553        Ok(ImageGenerateResult::LocalFiles(paths))
554    }
555
556    /// Upscale an image using the Imagen Upscale API.
557    ///
558    /// # Arguments
559    /// * `params` - Image upscale parameters
560    ///
561    /// # Returns
562    /// * `Ok(ImageUpscaleResult)` - Upscaled image with data or path
563    /// * `Err(Error)` - If validation fails, API call fails, or output handling fails
564    #[instrument(level = "info", name = "upscale_image", skip(self, params), fields(upscale_factor = %params.upscale_factor))]
565    pub async fn upscale_image(&self, params: ImageUpscaleParams) -> Result<ImageUpscaleResult, Error> {
566        // Validate parameters
567        params.validate().map_err(|errors| {
568            let messages: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
569            Error::validation(messages.join("; "))
570        })?;
571
572        info!(upscale_factor = %params.upscale_factor, "Upscaling image with Imagen Upscale API");
573
574        // Resolve the image input
575        let image_data = self.resolve_image_input(&params.image).await?;
576
577        // Build the API request
578        let request = UpscaleRequest {
579            instances: vec![UpscaleInstance {
580                image: UpscaleImageInput {
581                    bytes_base64_encoded: image_data,
582                },
583            }],
584            parameters: UpscaleParameters {
585                upscale_factor: params.upscale_factor.clone(),
586                output_mime_type: "image/png".to_string(),
587            },
588        };
589
590        // Get auth token
591        let token = self.auth.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
592
593        // Make API request
594        let endpoint = self.get_upscale_endpoint();
595        debug!(endpoint = %endpoint, "Calling Imagen Upscale API");
596
597        let response = self.http
598            .post(&endpoint)
599            .header("Authorization", format!("Bearer {}", token))
600            .header("Content-Type", "application/json")
601            .json(&request)
602            .send()
603            .await
604            .map_err(|e| Error::api(&endpoint, 0, format!("Request failed: {}", e)))?;
605
606        let status = response.status();
607        if !status.is_success() {
608            let body = response.text().await.unwrap_or_default();
609            return Err(Error::api(&endpoint, status.as_u16(), body));
610        }
611
612        // Parse response
613        let api_response: UpscaleResponse = response.json().await.map_err(|e| {
614            Error::api(&endpoint, status.as_u16(), format!("Failed to parse response: {}", e))
615        })?;
616
617        // Extract upscaled image from response
618        let prediction = api_response.predictions.into_iter().next()
619            .ok_or_else(|| Error::api(&endpoint, 200, "No image returned from API"))?;
620
621        let image_data = prediction.bytes_base64_encoded
622            .ok_or_else(|| Error::api(&endpoint, 200, "No image data in response"))?;
623
624        let image = GeneratedImage {
625            data: image_data,
626            mime_type: prediction.mime_type.unwrap_or_else(|| "image/png".to_string()),
627        };
628
629        info!("Received upscaled image from API");
630
631        // Handle output based on params
632        self.handle_upscale_output(image, &params).await
633    }
634
635    /// Get the Vertex AI Imagen Upscale API endpoint.
636    pub fn get_upscale_endpoint(&self) -> String {
637        format!(
638            "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
639            self.config.location,
640            self.config.project_id,
641            self.config.location,
642            UPSCALE_MODEL
643        )
644    }
645
646    /// Resolve image input to base64 data.
647    async fn resolve_image_input(&self, image: &str) -> Result<String, Error> {
648        // Check if it's a GCS URI first (explicit protocol)
649        if image.starts_with("gs://") {
650            let uri = GcsUri::parse(image)?;
651            let data = self.gcs.download(&uri).await?;
652            return Ok(BASE64.encode(&data));
653        }
654
655        // Check if it looks like a file path
656        let looks_like_path = image.starts_with('/')
657            || image.starts_with("./")
658            || image.starts_with("../")
659            || image.starts_with("~/")
660            || (image.len() < 500 && image.contains('/'));
661
662        if looks_like_path {
663            let path = Path::new(image);
664            if !path.exists() {
665                return Err(Error::validation(format!("Image file not found: {}", image)));
666            }
667            let data = tokio::fs::read(path).await?;
668            return Ok(BASE64.encode(&data));
669        }
670
671        // Try to validate as base64
672        if image.len() > 100 {
673            if BASE64.decode(image).is_ok() {
674                return Ok(image.to_string());
675            }
676        }
677
678        // Last resort: try as file path
679        let path = Path::new(image);
680        if path.exists() {
681            let data = tokio::fs::read(path).await?;
682            return Ok(BASE64.encode(&data));
683        }
684
685        // If nothing worked and it's long, assume it's base64
686        if image.len() > 100 {
687            return Ok(image.to_string());
688        }
689
690        Err(Error::validation(format!(
691            "Image input is not a valid file path, GCS URI, or base64 data"
692        )))
693    }
694
695    /// Handle output of upscaled image based on params.
696    async fn handle_upscale_output(
697        &self,
698        image: GeneratedImage,
699        params: &ImageUpscaleParams,
700    ) -> Result<ImageUpscaleResult, Error> {
701        // If output_uri is specified, upload to storage
702        if let Some(output_uri) = &params.output_uri {
703            let data = BASE64.decode(&image.data).map_err(|e| {
704                Error::validation(format!("Invalid base64 data: {}", e))
705            })?;
706            let gcs_uri = GcsUri::parse(output_uri)?;
707            self.gcs.upload(&gcs_uri, &data, &image.mime_type).await?;
708            info!(uri = %output_uri, "Uploaded upscaled image to storage");
709            return Ok(ImageUpscaleResult::StorageUri(output_uri.clone()));
710        }
711
712        // If output_file is specified, save to local file
713        if let Some(output_file) = &params.output_file {
714            let data = BASE64.decode(&image.data).map_err(|e| {
715                Error::validation(format!("Invalid base64 data: {}", e))
716            })?;
717
718            // Ensure parent directory exists
719            if let Some(parent) = Path::new(output_file).parent() {
720                if !parent.as_os_str().is_empty() {
721                    tokio::fs::create_dir_all(parent).await?;
722                }
723            }
724
725            tokio::fs::write(output_file, &data).await?;
726            info!(path = %output_file, "Saved upscaled image to local file");
727            return Ok(ImageUpscaleResult::LocalFile(output_file.clone()));
728        }
729
730        // Otherwise, return base64-encoded data
731        Ok(ImageUpscaleResult::Base64(image))
732    }
733}
734
735// =============================================================================
736// API Request/Response Types
737// =============================================================================
738
739/// Vertex AI Imagen API request.
740#[derive(Debug, Serialize)]
741pub struct ImagenRequest {
742    /// Input instances (prompts)
743    pub instances: Vec<ImagenInstance>,
744    /// Generation parameters
745    pub parameters: ImagenParameters,
746}
747
748/// Imagen API instance (prompt).
749#[derive(Debug, Serialize)]
750#[serde(rename_all = "camelCase")]
751pub struct ImagenInstance {
752    /// Text prompt describing the image
753    pub prompt: String,
754    /// Negative prompt - what to avoid
755    #[serde(skip_serializing_if = "Option::is_none")]
756    pub negative_prompt: Option<String>,
757}
758
759/// Imagen API parameters.
760#[derive(Debug, Serialize)]
761#[serde(rename_all = "camelCase")]
762pub struct ImagenParameters {
763    /// Number of images to generate
764    pub sample_count: u8,
765    /// Aspect ratio
766    pub aspect_ratio: String,
767    /// Random seed for reproducibility
768    #[serde(skip_serializing_if = "Option::is_none")]
769    pub seed: Option<i64>,
770}
771
772/// Vertex AI Imagen API response.
773#[derive(Debug, Deserialize)]
774pub struct ImagenResponse {
775    /// Generated image predictions
776    pub predictions: Vec<ImagenPrediction>,
777}
778
779/// Imagen API prediction (generated image).
780#[derive(Debug, Deserialize)]
781#[serde(rename_all = "camelCase")]
782pub struct ImagenPrediction {
783    /// Base64-encoded image data
784    pub bytes_base64_encoded: Option<String>,
785    /// MIME type of the image
786    pub mime_type: Option<String>,
787}
788
789// =============================================================================
790// Upscale API Request/Response Types
791// =============================================================================
792
793/// Vertex AI Imagen Upscale API request.
794#[derive(Debug, Serialize)]
795pub struct UpscaleRequest {
796    /// Input instances (images to upscale)
797    pub instances: Vec<UpscaleInstance>,
798    /// Upscale parameters
799    pub parameters: UpscaleParameters,
800}
801
802/// Upscale API instance.
803#[derive(Debug, Serialize)]
804pub struct UpscaleInstance {
805    /// Source image to upscale
806    pub image: UpscaleImageInput,
807}
808
809/// Upscale image input.
810#[derive(Debug, Serialize)]
811#[serde(rename_all = "camelCase")]
812pub struct UpscaleImageInput {
813    /// Base64-encoded image data
814    pub bytes_base64_encoded: String,
815}
816
817/// Upscale API parameters.
818#[derive(Debug, Serialize)]
819#[serde(rename_all = "camelCase")]
820pub struct UpscaleParameters {
821    /// Upscale factor: "x2" or "x4"
822    pub upscale_factor: String,
823    /// Output MIME type
824    pub output_mime_type: String,
825}
826
827/// Vertex AI Imagen Upscale API response.
828#[derive(Debug, Deserialize)]
829pub struct UpscaleResponse {
830    /// Upscaled image predictions
831    pub predictions: Vec<UpscalePrediction>,
832}
833
834/// Upscale API prediction (upscaled image).
835#[derive(Debug, Deserialize)]
836#[serde(rename_all = "camelCase")]
837pub struct UpscalePrediction {
838    /// Base64-encoded image data
839    pub bytes_base64_encoded: Option<String>,
840    /// MIME type of the image
841    pub mime_type: Option<String>,
842}
843
844// =============================================================================
845// Result Types
846// =============================================================================
847
848/// Generated image data.
849#[derive(Debug, Clone)]
850pub struct GeneratedImage {
851    /// Base64-encoded image data
852    pub data: String,
853    /// MIME type of the image
854    pub mime_type: String,
855}
856
857/// Result of image generation.
858#[derive(Debug)]
859pub enum ImageGenerateResult {
860    /// Base64-encoded image data (when no output specified)
861    Base64(Vec<GeneratedImage>),
862    /// Local file paths (when output_file specified)
863    LocalFiles(Vec<String>),
864    /// Storage URIs (when output_uri specified)
865    StorageUris(Vec<String>),
866}
867
868/// Result of image upscaling.
869#[derive(Debug)]
870pub enum ImageUpscaleResult {
871    /// Base64-encoded image data (when no output specified)
872    Base64(GeneratedImage),
873    /// Local file path (when output_file specified)
874    LocalFile(String),
875    /// Storage URI (when output_uri specified)
876    StorageUri(String),
877}
878
879#[cfg(test)]
880mod tests {
881    use super::*;
882
883    #[test]
884    fn test_default_params() {
885        let params: ImageGenerateParams = serde_json::from_str(r#"{"prompt": "a cat"}"#).unwrap();
886        assert_eq!(params.model, DEFAULT_MODEL);
887        assert_eq!(params.aspect_ratio, "1:1");
888        assert_eq!(params.number_of_images, 1);
889        assert!(params.negative_prompt.is_none());
890        assert!(params.seed.is_none());
891        assert!(params.output_file.is_none());
892        assert!(params.output_uri.is_none());
893    }
894
895    #[test]
896    fn test_valid_params() {
897        let params = ImageGenerateParams {
898            prompt: "A beautiful sunset over mountains".to_string(),
899            negative_prompt: Some("blurry, low quality".to_string()),
900            model: "imagen-4".to_string(),
901            aspect_ratio: "16:9".to_string(),
902            number_of_images: 2,
903            seed: Some(42),
904            output_file: None,
905            output_uri: None,
906        };
907
908        assert!(params.validate().is_ok());
909    }
910
911    #[test]
912    fn test_invalid_number_of_images_zero() {
913        let params = ImageGenerateParams {
914            prompt: "A cat".to_string(),
915            negative_prompt: None,
916            model: DEFAULT_MODEL.to_string(),
917            aspect_ratio: "1:1".to_string(),
918            number_of_images: 0,
919            seed: None,
920            output_file: None,
921            output_uri: None,
922        };
923
924        let result = params.validate();
925        assert!(result.is_err());
926        let errors = result.unwrap_err();
927        assert!(errors.iter().any(|e| e.field == "number_of_images"));
928    }
929
930    #[test]
931    fn test_invalid_number_of_images_too_high() {
932        let params = ImageGenerateParams {
933            prompt: "A cat".to_string(),
934            negative_prompt: None,
935            model: DEFAULT_MODEL.to_string(),
936            aspect_ratio: "1:1".to_string(),
937            number_of_images: 5,
938            seed: None,
939            output_file: None,
940            output_uri: None,
941        };
942
943        let result = params.validate();
944        assert!(result.is_err());
945        let errors = result.unwrap_err();
946        assert!(errors.iter().any(|e| e.field == "number_of_images"));
947    }
948
949    #[test]
950    fn test_invalid_aspect_ratio() {
951        let params = ImageGenerateParams {
952            prompt: "A cat".to_string(),
953            negative_prompt: None,
954            model: DEFAULT_MODEL.to_string(),
955            aspect_ratio: "2:1".to_string(),
956            number_of_images: 1,
957            seed: None,
958            output_file: None,
959            output_uri: None,
960        };
961
962        let result = params.validate();
963        assert!(result.is_err());
964        let errors = result.unwrap_err();
965        assert!(errors.iter().any(|e| e.field == "aspect_ratio"));
966    }
967
968    #[test]
969    fn test_invalid_model() {
970        let params = ImageGenerateParams {
971            prompt: "A cat".to_string(),
972            negative_prompt: None,
973            model: "unknown-model".to_string(),
974            aspect_ratio: "1:1".to_string(),
975            number_of_images: 1,
976            seed: None,
977            output_file: None,
978            output_uri: None,
979        };
980
981        let result = params.validate();
982        assert!(result.is_err());
983        let errors = result.unwrap_err();
984        assert!(errors.iter().any(|e| e.field == "model"));
985    }
986
987    #[test]
988    fn test_empty_prompt() {
989        let params = ImageGenerateParams {
990            prompt: "   ".to_string(),
991            negative_prompt: None,
992            model: DEFAULT_MODEL.to_string(),
993            aspect_ratio: "1:1".to_string(),
994            number_of_images: 1,
995            seed: None,
996            output_file: None,
997            output_uri: None,
998        };
999
1000        let result = params.validate();
1001        assert!(result.is_err());
1002        let errors = result.unwrap_err();
1003        assert!(errors.iter().any(|e| e.field == "prompt"));
1004    }
1005
1006    #[test]
1007    fn test_prompt_too_long_imagen3() {
1008        let long_prompt = "a".repeat(500); // Exceeds 480 char limit for Imagen 3
1009        let params = ImageGenerateParams {
1010            prompt: long_prompt,
1011            negative_prompt: None,
1012            model: "imagen-3".to_string(),
1013            aspect_ratio: "1:1".to_string(),
1014            number_of_images: 1,
1015            seed: None,
1016            output_file: None,
1017            output_uri: None,
1018        };
1019
1020        let result = params.validate();
1021        assert!(result.is_err());
1022        let errors = result.unwrap_err();
1023        assert!(errors.iter().any(|e| e.field == "prompt" && e.message.contains("exceeds")));
1024    }
1025
1026    #[test]
1027    fn test_prompt_ok_imagen4() {
1028        let long_prompt = "a".repeat(500); // Within 2000 char limit for Imagen 4
1029        let params = ImageGenerateParams {
1030            prompt: long_prompt,
1031            negative_prompt: None,
1032            model: "imagen-4".to_string(),
1033            aspect_ratio: "1:1".to_string(),
1034            number_of_images: 1,
1035            seed: None,
1036            output_file: None,
1037            output_uri: None,
1038        };
1039
1040        assert!(params.validate().is_ok());
1041    }
1042
1043    #[test]
1044    fn test_all_valid_aspect_ratios() {
1045        for ratio in VALID_ASPECT_RATIOS {
1046            let params = ImageGenerateParams {
1047                prompt: "A cat".to_string(),
1048                negative_prompt: None,
1049                model: DEFAULT_MODEL.to_string(),
1050                aspect_ratio: ratio.to_string(),
1051                number_of_images: 1,
1052                seed: None,
1053                output_file: None,
1054                output_uri: None,
1055            };
1056            assert!(params.validate().is_ok(), "Aspect ratio {} should be valid", ratio);
1057        }
1058    }
1059
1060    #[test]
1061    fn test_all_valid_number_of_images() {
1062        for n in MIN_NUMBER_OF_IMAGES..=MAX_NUMBER_OF_IMAGES {
1063            let params = ImageGenerateParams {
1064                prompt: "A cat".to_string(),
1065                negative_prompt: None,
1066                model: DEFAULT_MODEL.to_string(),
1067                aspect_ratio: "1:1".to_string(),
1068                number_of_images: n,
1069                seed: None,
1070                output_file: None,
1071                output_uri: None,
1072            };
1073            assert!(params.validate().is_ok(), "number_of_images {} should be valid", n);
1074        }
1075    }
1076
1077    #[test]
1078    fn test_get_model() {
1079        let params = ImageGenerateParams {
1080            prompt: "A cat".to_string(),
1081            negative_prompt: None,
1082            model: "imagen-4".to_string(),
1083            aspect_ratio: "1:1".to_string(),
1084            number_of_images: 1,
1085            seed: None,
1086            output_file: None,
1087            output_uri: None,
1088        };
1089
1090        let model = params.get_model();
1091        assert!(model.is_some());
1092        assert_eq!(model.unwrap().id, "imagen-4.0-generate-preview-06-06");
1093    }
1094
1095    #[test]
1096    fn test_serialization_roundtrip() {
1097        let params = ImageGenerateParams {
1098            prompt: "A cat".to_string(),
1099            negative_prompt: Some("blurry".to_string()),
1100            model: "imagen-4".to_string(),
1101            aspect_ratio: "16:9".to_string(),
1102            number_of_images: 2,
1103            seed: Some(42),
1104            output_file: Some("/tmp/output.png".to_string()),
1105            output_uri: None,
1106        };
1107
1108        let json = serde_json::to_string(&params).unwrap();
1109        let deserialized: ImageGenerateParams = serde_json::from_str(&json).unwrap();
1110
1111        assert_eq!(params.prompt, deserialized.prompt);
1112        assert_eq!(params.negative_prompt, deserialized.negative_prompt);
1113        assert_eq!(params.model, deserialized.model);
1114        assert_eq!(params.aspect_ratio, deserialized.aspect_ratio);
1115        assert_eq!(params.number_of_images, deserialized.number_of_images);
1116        assert_eq!(params.seed, deserialized.seed);
1117        assert_eq!(params.output_file, deserialized.output_file);
1118    }
1119
1120    // Tests for GCS URI handling (P1 fix)
1121    #[test]
1122    fn test_add_index_suffix_to_gcs_uri_simple() {
1123        let uri = "gs://bucket/output.png";
1124        let result = ImageHandler::add_index_suffix_to_uri(uri, 0, "image", "png");
1125        assert_eq!(result, "gs://bucket/output_0.png");
1126    }
1127
1128    #[test]
1129    fn test_add_index_suffix_to_gcs_uri_with_path() {
1130        let uri = "gs://bucket/path/to/output.png";
1131        let result = ImageHandler::add_index_suffix_to_uri(uri, 1, "image", "png");
1132        assert_eq!(result, "gs://bucket/path/to/output_1.png");
1133    }
1134
1135    #[test]
1136    fn test_add_index_suffix_to_gcs_uri_no_extension() {
1137        let uri = "gs://bucket/output";
1138        let result = ImageHandler::add_index_suffix_to_uri(uri, 2, "image", "png");
1139        assert_eq!(result, "gs://bucket/output_2.png");
1140    }
1141
1142    #[test]
1143    fn test_add_index_suffix_to_local_path() {
1144        let path = "/tmp/output.png";
1145        let result = ImageHandler::add_index_suffix_to_uri(path, 0, "image", "png");
1146        assert_eq!(result, "/tmp/output_0.png");
1147    }
1148
1149    #[test]
1150    fn test_add_index_suffix_to_local_path_no_dir() {
1151        let path = "output.png";
1152        let result = ImageHandler::add_index_suffix_to_uri(path, 1, "image", "png");
1153        assert_eq!(result, "output_1.png");
1154    }
1155}
1156
1157
1158#[cfg(test)]
1159mod property_tests {
1160    use super::*;
1161    use proptest::prelude::*;
1162
1163    // Feature: rust-mcp-genmedia, Property 8: Numeric Parameter Range Validation (number_of_images)
1164    // **Validates: Requirements 4.5, 4.6**
1165    //
1166    // For any numeric parameter with defined bounds (number_of_images 1-4),
1167    // values outside the valid range SHALL be rejected with a validation error.
1168
1169    /// Strategy to generate valid number_of_images values (1-4)
1170    fn valid_number_of_images_strategy() -> impl Strategy<Value = u8> {
1171        MIN_NUMBER_OF_IMAGES..=MAX_NUMBER_OF_IMAGES
1172    }
1173
1174    /// Strategy to generate invalid number_of_images values (0 or > 4)
1175    fn invalid_number_of_images_strategy() -> impl Strategy<Value = u8> {
1176        prop_oneof![
1177            Just(0u8),
1178            (MAX_NUMBER_OF_IMAGES + 1)..=u8::MAX,
1179        ]
1180    }
1181
1182    /// Strategy to generate valid aspect ratios
1183    fn valid_aspect_ratio_strategy() -> impl Strategy<Value = &'static str> {
1184        prop_oneof![
1185            Just("1:1"),
1186            Just("3:4"),
1187            Just("4:3"),
1188            Just("9:16"),
1189            Just("16:9"),
1190        ]
1191    }
1192
1193    /// Strategy to generate invalid aspect ratios
1194    fn invalid_aspect_ratio_strategy() -> impl Strategy<Value = String> {
1195        prop_oneof![
1196            Just("2:1".to_string()),
1197            Just("1:2".to_string()),
1198            Just("5:4".to_string()),
1199            Just("invalid".to_string()),
1200            Just("".to_string()),
1201            Just("16:10".to_string()),
1202            Just("21:9".to_string()),
1203            // Generate random invalid ratios
1204            "[0-9]+:[0-9]+".prop_filter("Must not be a valid ratio", |s| {
1205                !VALID_ASPECT_RATIOS.contains(&s.as_str())
1206            }),
1207        ]
1208    }
1209
1210    /// Strategy to generate valid prompts (non-empty, within length limits)
1211    fn valid_prompt_strategy() -> impl Strategy<Value = String> {
1212        "[a-zA-Z0-9 ]{1,100}".prop_map(|s| s.trim().to_string())
1213            .prop_filter("Must not be empty", |s| !s.trim().is_empty())
1214    }
1215
1216    proptest! {
1217        /// Property 8: Valid number_of_images values (1-4) should pass validation
1218        #[test]
1219        fn valid_number_of_images_passes_validation(
1220            num in valid_number_of_images_strategy(),
1221            prompt in valid_prompt_strategy(),
1222        ) {
1223            let params = ImageGenerateParams {
1224                prompt,
1225                negative_prompt: None,
1226                model: DEFAULT_MODEL.to_string(),
1227                aspect_ratio: "1:1".to_string(),
1228                number_of_images: num,
1229                seed: None,
1230                output_file: None,
1231                output_uri: None,
1232            };
1233
1234            let result = params.validate();
1235            prop_assert!(
1236                result.is_ok(),
1237                "number_of_images {} should be valid, but got errors: {:?}",
1238                num,
1239                result.err()
1240            );
1241        }
1242
1243        /// Property 8: Invalid number_of_images values (0 or > 4) should fail validation
1244        #[test]
1245        fn invalid_number_of_images_fails_validation(
1246            num in invalid_number_of_images_strategy(),
1247            prompt in valid_prompt_strategy(),
1248        ) {
1249            let params = ImageGenerateParams {
1250                prompt,
1251                negative_prompt: None,
1252                model: DEFAULT_MODEL.to_string(),
1253                aspect_ratio: "1:1".to_string(),
1254                number_of_images: num,
1255                seed: None,
1256                output_file: None,
1257                output_uri: None,
1258            };
1259
1260            let result = params.validate();
1261            prop_assert!(
1262                result.is_err(),
1263                "number_of_images {} should be invalid",
1264                num
1265            );
1266
1267            let errors = result.unwrap_err();
1268            prop_assert!(
1269                errors.iter().any(|e| e.field == "number_of_images"),
1270                "Should have a number_of_images validation error for value {}",
1271                num
1272            );
1273        }
1274
1275        // Feature: rust-mcp-genmedia, Property 10: Aspect Ratio Validation
1276        // **Validates: Requirements 4.5, 4.6**
1277        //
1278        // For any aspect_ratio parameter value, it SHALL be one of the model's
1279        // supported aspect ratios. Invalid aspect ratios SHALL be rejected with
1280        // a validation error listing valid options.
1281
1282        /// Property 10: Valid aspect ratios should pass validation
1283        #[test]
1284        fn valid_aspect_ratio_passes_validation(
1285            ratio in valid_aspect_ratio_strategy(),
1286            prompt in valid_prompt_strategy(),
1287        ) {
1288            let params = ImageGenerateParams {
1289                prompt,
1290                negative_prompt: None,
1291                model: DEFAULT_MODEL.to_string(),
1292                aspect_ratio: ratio.to_string(),
1293                number_of_images: 1,
1294                seed: None,
1295                output_file: None,
1296                output_uri: None,
1297            };
1298
1299            let result = params.validate();
1300            prop_assert!(
1301                result.is_ok(),
1302                "aspect_ratio '{}' should be valid, but got errors: {:?}",
1303                ratio,
1304                result.err()
1305            );
1306        }
1307
1308        /// Property 10: Invalid aspect ratios should fail validation with descriptive error
1309        #[test]
1310        fn invalid_aspect_ratio_fails_validation(
1311            ratio in invalid_aspect_ratio_strategy(),
1312            prompt in valid_prompt_strategy(),
1313        ) {
1314            let params = ImageGenerateParams {
1315                prompt,
1316                negative_prompt: None,
1317                model: DEFAULT_MODEL.to_string(),
1318                aspect_ratio: ratio.clone(),
1319                number_of_images: 1,
1320                seed: None,
1321                output_file: None,
1322                output_uri: None,
1323            };
1324
1325            let result = params.validate();
1326            prop_assert!(
1327                result.is_err(),
1328                "aspect_ratio '{}' should be invalid",
1329                ratio
1330            );
1331
1332            let errors = result.unwrap_err();
1333            prop_assert!(
1334                errors.iter().any(|e| e.field == "aspect_ratio"),
1335                "Should have an aspect_ratio validation error for value '{}'",
1336                ratio
1337            );
1338
1339            // Verify the error message lists valid options
1340            let aspect_error = errors.iter().find(|e| e.field == "aspect_ratio").unwrap();
1341            prop_assert!(
1342                aspect_error.message.contains("Valid options"),
1343                "Error message should list valid options: {}",
1344                aspect_error.message
1345            );
1346        }
1347
1348        /// Property: Combination of valid parameters should always pass validation
1349        #[test]
1350        fn valid_params_combination_passes(
1351            num in valid_number_of_images_strategy(),
1352            ratio in valid_aspect_ratio_strategy(),
1353            prompt in valid_prompt_strategy(),
1354        ) {
1355            let params = ImageGenerateParams {
1356                prompt,
1357                negative_prompt: None,
1358                model: DEFAULT_MODEL.to_string(),
1359                aspect_ratio: ratio.to_string(),
1360                number_of_images: num,
1361                seed: None,
1362                output_file: None,
1363                output_uri: None,
1364            };
1365
1366            let result = params.validate();
1367            prop_assert!(
1368                result.is_ok(),
1369                "Valid params (num={}, ratio='{}') should pass, but got: {:?}",
1370                num,
1371                ratio,
1372                result.err()
1373            );
1374        }
1375    }
1376}
1377
1378/// Unit tests for API interactions and error handling.
1379/// These tests verify the handler's behavior with mocked API responses.
1380#[cfg(test)]
1381mod api_tests {
1382    use super::*;
1383
1384    /// Test that ImagenRequest serializes correctly for the API.
1385    #[test]
1386    fn test_imagen_request_serialization() {
1387        let request = ImagenRequest {
1388            instances: vec![ImagenInstance {
1389                prompt: "A beautiful sunset".to_string(),
1390                negative_prompt: Some("blurry".to_string()),
1391            }],
1392            parameters: ImagenParameters {
1393                sample_count: 2,
1394                aspect_ratio: "16:9".to_string(),
1395                seed: Some(42),
1396            },
1397        };
1398
1399        let json = serde_json::to_value(&request).unwrap();
1400        
1401        // Verify structure
1402        assert!(json["instances"].is_array());
1403        assert_eq!(json["instances"][0]["prompt"], "A beautiful sunset");
1404        assert_eq!(json["instances"][0]["negativePrompt"], "blurry");
1405        assert_eq!(json["parameters"]["sampleCount"], 2);
1406        assert_eq!(json["parameters"]["aspectRatio"], "16:9");
1407        assert_eq!(json["parameters"]["seed"], 42);
1408    }
1409
1410    /// Test that ImagenRequest serializes without optional fields when not provided.
1411    #[test]
1412    fn test_imagen_request_serialization_minimal() {
1413        let request = ImagenRequest {
1414            instances: vec![ImagenInstance {
1415                prompt: "A cat".to_string(),
1416                negative_prompt: None,
1417            }],
1418            parameters: ImagenParameters {
1419                sample_count: 1,
1420                aspect_ratio: "1:1".to_string(),
1421                seed: None,
1422            },
1423        };
1424
1425        let json = serde_json::to_value(&request).unwrap();
1426        
1427        // Verify optional fields are not present
1428        assert!(json["instances"][0].get("negativePrompt").is_none());
1429        assert!(json["parameters"].get("seed").is_none());
1430    }
1431
1432    /// Test that ImagenResponse deserializes correctly.
1433    #[test]
1434    fn test_imagen_response_deserialization() {
1435        let json = r#"{
1436            "predictions": [
1437                {
1438                    "bytesBase64Encoded": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
1439                    "mimeType": "image/png"
1440                }
1441            ]
1442        }"#;
1443
1444        let response: ImagenResponse = serde_json::from_str(json).unwrap();
1445        
1446        assert_eq!(response.predictions.len(), 1);
1447        assert!(response.predictions[0].bytes_base64_encoded.is_some());
1448        assert_eq!(response.predictions[0].mime_type, Some("image/png".to_string()));
1449    }
1450
1451    /// Test that ImagenResponse handles multiple predictions.
1452    #[test]
1453    fn test_imagen_response_multiple_predictions() {
1454        let json = r#"{
1455            "predictions": [
1456                {
1457                    "bytesBase64Encoded": "base64data1",
1458                    "mimeType": "image/png"
1459                },
1460                {
1461                    "bytesBase64Encoded": "base64data2",
1462                    "mimeType": "image/png"
1463                }
1464            ]
1465        }"#;
1466
1467        let response: ImagenResponse = serde_json::from_str(json).unwrap();
1468        
1469        assert_eq!(response.predictions.len(), 2);
1470        assert_eq!(response.predictions[0].bytes_base64_encoded, Some("base64data1".to_string()));
1471        assert_eq!(response.predictions[1].bytes_base64_encoded, Some("base64data2".to_string()));
1472    }
1473
1474    /// Test that ImagenResponse handles empty predictions gracefully.
1475    #[test]
1476    fn test_imagen_response_empty_predictions() {
1477        let json = r#"{"predictions": []}"#;
1478
1479        let response: ImagenResponse = serde_json::from_str(json).unwrap();
1480        
1481        assert!(response.predictions.is_empty());
1482    }
1483
1484    /// Test that ImagenResponse handles predictions without image data.
1485    #[test]
1486    fn test_imagen_response_no_image_data() {
1487        let json = r#"{
1488            "predictions": [
1489                {
1490                    "mimeType": "image/png"
1491                }
1492            ]
1493        }"#;
1494
1495        let response: ImagenResponse = serde_json::from_str(json).unwrap();
1496        
1497        assert_eq!(response.predictions.len(), 1);
1498        assert!(response.predictions[0].bytes_base64_encoded.is_none());
1499    }
1500
1501    /// Test endpoint URL construction.
1502    #[test]
1503    fn test_get_endpoint() {
1504        let config = Config {
1505            project_id: "my-project".to_string(),
1506            location: "us-central1".to_string(),
1507            gcs_bucket: None,
1508            port: 8080,
1509        ..Default::default()
1510        };
1511
1512        // Create a minimal handler for testing endpoint construction
1513        // We can't create a full handler without auth, but we can test the URL format
1514        let expected_url = format!(
1515            "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
1516            config.location,
1517            config.project_id,
1518            config.location,
1519            "imagen-4.0-generate-preview-05-20"
1520        );
1521
1522        assert!(expected_url.contains("us-central1-aiplatform.googleapis.com"));
1523        assert!(expected_url.contains("my-project"));
1524        assert!(expected_url.contains("imagen-4.0-generate-preview-05-20"));
1525        assert!(expected_url.ends_with(":predict"));
1526    }
1527
1528    /// Test GeneratedImage structure.
1529    #[test]
1530    fn test_generated_image() {
1531        let image = GeneratedImage {
1532            data: "base64encodeddata".to_string(),
1533            mime_type: "image/png".to_string(),
1534        };
1535
1536        assert_eq!(image.data, "base64encodeddata");
1537        assert_eq!(image.mime_type, "image/png");
1538    }
1539
1540    /// Test ImageGenerateResult variants.
1541    #[test]
1542    fn test_image_generate_result_base64() {
1543        let images = vec![
1544            GeneratedImage {
1545                data: "data1".to_string(),
1546                mime_type: "image/png".to_string(),
1547            },
1548            GeneratedImage {
1549                data: "data2".to_string(),
1550                mime_type: "image/jpeg".to_string(),
1551            },
1552        ];
1553
1554        let result = ImageGenerateResult::Base64(images);
1555        
1556        match result {
1557            ImageGenerateResult::Base64(imgs) => {
1558                assert_eq!(imgs.len(), 2);
1559                assert_eq!(imgs[0].data, "data1");
1560                assert_eq!(imgs[1].mime_type, "image/jpeg");
1561            }
1562            _ => panic!("Expected Base64 variant"),
1563        }
1564    }
1565
1566    /// Test ImageGenerateResult LocalFiles variant.
1567    #[test]
1568    fn test_image_generate_result_local_files() {
1569        let paths = vec!["/tmp/image1.png".to_string(), "/tmp/image2.png".to_string()];
1570        let result = ImageGenerateResult::LocalFiles(paths);
1571        
1572        match result {
1573            ImageGenerateResult::LocalFiles(p) => {
1574                assert_eq!(p.len(), 2);
1575                assert!(p[0].contains("image1"));
1576            }
1577            _ => panic!("Expected LocalFiles variant"),
1578        }
1579    }
1580
1581    /// Test ImageGenerateResult StorageUris variant.
1582    #[test]
1583    fn test_image_generate_result_storage_uris() {
1584        let uris = vec![
1585            "gs://bucket/image1.png".to_string(),
1586            "gs://bucket/image2.png".to_string(),
1587        ];
1588        let result = ImageGenerateResult::StorageUris(uris);
1589        
1590        match result {
1591            ImageGenerateResult::StorageUris(u) => {
1592                assert_eq!(u.len(), 2);
1593                assert!(u[0].starts_with("gs://"));
1594            }
1595            _ => panic!("Expected StorageUris variant"),
1596        }
1597    }
1598
1599    /// Test validation error formatting.
1600    #[test]
1601    fn test_validation_error_display() {
1602        let error = ValidationError {
1603            field: "prompt".to_string(),
1604            message: "cannot be empty".to_string(),
1605        };
1606
1607        let display = format!("{}", error);
1608        assert_eq!(display, "prompt: cannot be empty");
1609    }
1610
1611    /// Test that validation collects multiple errors.
1612    #[test]
1613    fn test_validation_multiple_errors() {
1614        let params = ImageGenerateParams {
1615            prompt: "   ".to_string(), // Empty prompt
1616            negative_prompt: None,
1617            model: "unknown-model".to_string(), // Invalid model
1618            aspect_ratio: "invalid".to_string(), // Invalid aspect ratio
1619            number_of_images: 10, // Out of range
1620            seed: None,
1621            output_file: None,
1622            output_uri: None,
1623        };
1624
1625        let result = params.validate();
1626        assert!(result.is_err());
1627        
1628        let errors = result.unwrap_err();
1629        // Should have errors for: prompt, model, aspect_ratio, number_of_images
1630        assert!(errors.len() >= 3, "Expected at least 3 validation errors, got {}", errors.len());
1631        
1632        let fields: Vec<&str> = errors.iter().map(|e| e.field.as_str()).collect();
1633        assert!(fields.contains(&"prompt"));
1634        assert!(fields.contains(&"model"));
1635        assert!(fields.contains(&"number_of_images"));
1636    }
1637}