1use adk_rust_mcp_common::auth::AuthProvider;
7use adk_rust_mcp_common::config::Config;
8use adk_rust_mcp_common::error::Error;
9use adk_rust_mcp_common::gcs::{GcsClient, GcsUri};
10use adk_rust_mcp_common::models::{ImagenModel, ModelRegistry, IMAGEN_MODELS};
11use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
12use schemars::JsonSchema;
13use serde::{Deserialize, Serialize};
14use std::path::Path;
15use tracing::{debug, info, instrument};
16
17pub const VALID_ASPECT_RATIOS: &[&str] = &["1:1", "3:4", "4:3", "9:16", "16:9"];
19
20pub const DEFAULT_MODEL: &str = "imagen-3.0-generate-002";
22
23pub const MIN_NUMBER_OF_IMAGES: u8 = 1;
25
26pub const MAX_NUMBER_OF_IMAGES: u8 = 4;
28
29#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
33pub struct ImageGenerateParams {
34 pub prompt: String,
37
38 #[serde(default, skip_serializing_if = "Option::is_none")]
40 pub negative_prompt: Option<String>,
41
42 #[serde(default = "default_model")]
45 pub model: String,
46
47 #[serde(default = "default_aspect_ratio")]
50 pub aspect_ratio: String,
51
52 #[serde(default = "default_number_of_images")]
54 pub number_of_images: u8,
55
56 #[serde(default, skip_serializing_if = "Option::is_none")]
58 pub seed: Option<i64>,
59
60 #[serde(default, skip_serializing_if = "Option::is_none")]
63 pub output_file: Option<String>,
64
65 #[serde(default, skip_serializing_if = "Option::is_none")]
68 pub output_uri: Option<String>,
69}
70
71fn default_model() -> String {
72 DEFAULT_MODEL.to_string()
73}
74
75fn default_aspect_ratio() -> String {
76 "1:1".to_string()
77}
78
79fn default_number_of_images() -> u8 {
80 1
81}
82
83pub const VALID_UPSCALE_FACTORS: &[&str] = &["x2", "x4"];
85
86pub const UPSCALE_MODEL: &str = "imagen-4.0-upscale-preview";
88
89#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
93pub struct ImageUpscaleParams {
94 pub image: String,
97
98 #[serde(default = "default_upscale_factor")]
100 pub upscale_factor: String,
101
102 #[serde(default, skip_serializing_if = "Option::is_none")]
105 pub output_file: Option<String>,
106
107 #[serde(default, skip_serializing_if = "Option::is_none")]
110 pub output_uri: Option<String>,
111}
112
113fn default_upscale_factor() -> String {
114 "x2".to_string()
115}
116
117impl ImageUpscaleParams {
118 pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
120 let mut errors = Vec::new();
121
122 if self.image.trim().is_empty() {
124 errors.push(ValidationError {
125 field: "image".to_string(),
126 message: "Image cannot be empty".to_string(),
127 });
128 }
129
130 if !VALID_UPSCALE_FACTORS.contains(&self.upscale_factor.as_str()) {
132 errors.push(ValidationError {
133 field: "upscale_factor".to_string(),
134 message: format!(
135 "Invalid upscale factor '{}'. Valid options: {}",
136 self.upscale_factor,
137 VALID_UPSCALE_FACTORS.join(", ")
138 ),
139 });
140 }
141
142 if errors.is_empty() {
143 Ok(())
144 } else {
145 Err(errors)
146 }
147 }
148}
149
150#[derive(Debug, Clone)]
152pub struct ValidationError {
153 pub field: String,
155 pub message: String,
157}
158
159impl std::fmt::Display for ValidationError {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 write!(f, "{}: {}", self.field, self.message)
162 }
163}
164
165impl ImageGenerateParams {
166 pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
172 let mut errors = Vec::new();
173
174 let model = ModelRegistry::resolve_imagen(&self.model);
176
177 if model.is_none() {
179 errors.push(ValidationError {
180 field: "model".to_string(),
181 message: format!(
182 "Unknown model '{}'. Valid models: {}",
183 self.model,
184 IMAGEN_MODELS
185 .iter()
186 .map(|m| m.id)
187 .collect::<Vec<_>>()
188 .join(", ")
189 ),
190 });
191 }
192
193 if let Some(model) = model {
195 if self.prompt.len() > model.max_prompt_length {
196 errors.push(ValidationError {
197 field: "prompt".to_string(),
198 message: format!(
199 "Prompt length {} exceeds maximum {} for model {}",
200 self.prompt.len(),
201 model.max_prompt_length,
202 model.id
203 ),
204 });
205 }
206
207 if !model.supported_aspect_ratios.contains(&self.aspect_ratio.as_str()) {
209 errors.push(ValidationError {
210 field: "aspect_ratio".to_string(),
211 message: format!(
212 "Invalid aspect ratio '{}'. Valid options for {}: {}",
213 self.aspect_ratio,
214 model.id,
215 model.supported_aspect_ratios.join(", ")
216 ),
217 });
218 }
219 } else {
220 if !VALID_ASPECT_RATIOS.contains(&self.aspect_ratio.as_str()) {
222 errors.push(ValidationError {
223 field: "aspect_ratio".to_string(),
224 message: format!(
225 "Invalid aspect ratio '{}'. Valid options: {}",
226 self.aspect_ratio,
227 VALID_ASPECT_RATIOS.join(", ")
228 ),
229 });
230 }
231 }
232
233 if self.number_of_images < MIN_NUMBER_OF_IMAGES
235 || self.number_of_images > MAX_NUMBER_OF_IMAGES
236 {
237 errors.push(ValidationError {
238 field: "number_of_images".to_string(),
239 message: format!(
240 "number_of_images must be between {} and {}, got {}",
241 MIN_NUMBER_OF_IMAGES, MAX_NUMBER_OF_IMAGES, self.number_of_images
242 ),
243 });
244 }
245
246 if self.prompt.trim().is_empty() {
248 errors.push(ValidationError {
249 field: "prompt".to_string(),
250 message: "Prompt cannot be empty".to_string(),
251 });
252 }
253
254 if errors.is_empty() {
255 Ok(())
256 } else {
257 Err(errors)
258 }
259 }
260
261 pub fn get_model(&self) -> Option<&'static ImagenModel> {
263 ModelRegistry::resolve_imagen(&self.model)
264 }
265}
266
267pub struct ImageHandler {
271 pub config: Config,
273 pub gcs: GcsClient,
275 pub http: reqwest::Client,
277 pub auth: AuthProvider,
279}
280
281impl ImageHandler {
282 #[instrument(level = "debug", name = "image_handler_new", skip_all)]
287 pub async fn new(config: Config) -> Result<Self, Error> {
288 debug!("Initializing ImageHandler");
289
290 let auth = AuthProvider::new().await?;
291 let gcs = GcsClient::with_auth(AuthProvider::new().await?);
292 let http = reqwest::Client::new();
293
294 Ok(Self {
295 config,
296 gcs,
297 http,
298 auth,
299 })
300 }
301
302 #[cfg(test)]
304 pub fn with_deps(config: Config, gcs: GcsClient, http: reqwest::Client, auth: AuthProvider) -> Self {
305 Self {
306 config,
307 gcs,
308 http,
309 auth,
310 }
311 }
312
313 pub fn get_endpoint(&self, model: &str) -> String {
315 format!(
316 "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
317 self.config.location,
318 self.config.project_id,
319 self.config.location,
320 model
321 )
322 }
323
324 #[instrument(level = "info", name = "generate_image", skip(self, params), fields(model = %params.model, aspect_ratio = %params.aspect_ratio))]
333 pub async fn generate_image(&self, params: ImageGenerateParams) -> Result<ImageGenerateResult, Error> {
334 params.validate().map_err(|errors| {
336 let messages: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
337 Error::validation(messages.join("; "))
338 })?;
339
340 let model = params.get_model().ok_or_else(|| {
342 Error::validation(format!("Unknown model: {}", params.model))
343 })?;
344
345 info!(model_id = model.id, "Generating image with Imagen API");
346
347 let request = ImagenRequest {
349 instances: vec![ImagenInstance {
350 prompt: params.prompt.clone(),
351 negative_prompt: params.negative_prompt.clone(),
352 }],
353 parameters: ImagenParameters {
354 sample_count: params.number_of_images,
355 aspect_ratio: params.aspect_ratio.clone(),
356 seed: params.seed,
357 },
358 };
359
360 let token = self.auth.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
362
363 let endpoint = self.get_endpoint(model.id);
365 debug!(endpoint = %endpoint, "Calling Imagen API");
366
367 let response = self.http
368 .post(&endpoint)
369 .header("Authorization", format!("Bearer {}", token))
370 .header("Content-Type", "application/json")
371 .json(&request)
372 .send()
373 .await
374 .map_err(|e| Error::api(&endpoint, 0, format!("Request failed: {}", e)))?;
375
376 let status = response.status();
377 if !status.is_success() {
378 let body = response.text().await.unwrap_or_default();
379 return Err(Error::api(&endpoint, status.as_u16(), body));
380 }
381
382 let api_response: ImagenResponse = response.json().await.map_err(|e| {
384 Error::api(&endpoint, status.as_u16(), format!("Failed to parse response: {}", e))
385 })?;
386
387 let images: Vec<GeneratedImage> = api_response
389 .predictions
390 .into_iter()
391 .filter_map(|p| {
392 p.bytes_base64_encoded.map(|data| GeneratedImage {
393 data,
394 mime_type: p.mime_type.unwrap_or_else(|| "image/png".to_string()),
395 })
396 })
397 .collect();
398
399 if images.is_empty() {
400 return Err(Error::api(&endpoint, 200, "No images returned from API"));
401 }
402
403 info!(count = images.len(), "Received images from API");
404
405 self.handle_output(images, ¶ms).await
407 }
408
409 async fn handle_output(
411 &self,
412 images: Vec<GeneratedImage>,
413 params: &ImageGenerateParams,
414 ) -> Result<ImageGenerateResult, Error> {
415 if let Some(output_uri) = ¶ms.output_uri {
417 return self.upload_to_storage(images, output_uri).await;
418 }
419
420 if let Some(output_file) = ¶ms.output_file {
422 return self.save_to_file(images, output_file).await;
423 }
424
425 Ok(ImageGenerateResult::Base64(images))
427 }
428
429 async fn upload_to_storage(
431 &self,
432 images: Vec<GeneratedImage>,
433 output_uri: &str,
434 ) -> Result<ImageGenerateResult, Error> {
435 let mut uris = Vec::new();
436
437 for (i, image) in images.iter().enumerate() {
438 let data = BASE64.decode(&image.data).map_err(|e| {
440 Error::validation(format!("Invalid base64 data: {}", e))
441 })?;
442
443 let uri = if images.len() == 1 {
445 output_uri.to_string()
446 } else {
447 Self::add_index_suffix_to_uri(output_uri, i, "image", "png")
450 };
451
452 let gcs_uri = GcsUri::parse(&uri)?;
454 self.gcs.upload(&gcs_uri, &data, &image.mime_type).await?;
455 uris.push(uri);
456 }
457
458 info!(count = uris.len(), "Uploaded images to storage");
459 Ok(ImageGenerateResult::StorageUris(uris))
460 }
461
462 fn add_index_suffix_to_uri(uri: &str, index: usize, default_stem: &str, default_ext: &str) -> String {
465 if let Some(stripped) = uri.strip_prefix("gs://") {
467 if let Some(slash_pos) = stripped.find('/') {
468 let bucket = &stripped[..slash_pos];
469 let object_path = &stripped[slash_pos + 1..];
470
471 let (dir, filename) = if let Some(last_slash) = object_path.rfind('/') {
473 (&object_path[..last_slash], &object_path[last_slash + 1..])
474 } else {
475 ("", object_path)
476 };
477
478 let (stem, ext) = if let Some(dot_pos) = filename.rfind('.') {
480 (&filename[..dot_pos], &filename[dot_pos + 1..])
481 } else {
482 (filename, default_ext)
483 };
484
485 let stem = if stem.is_empty() { default_stem } else { stem };
486
487 if dir.is_empty() {
488 format!("gs://{}/{}_{}.{}", bucket, stem, index, ext)
489 } else {
490 format!("gs://{}/{}/{}_{}.{}", bucket, dir, stem, index, ext)
491 }
492 } else {
493 format!("{}/{}_{}.{}", uri, default_stem, index, default_ext)
495 }
496 } else {
497 let path = Path::new(uri);
499 let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or(default_stem);
500 let ext = path.extension().and_then(|s| s.to_str()).unwrap_or(default_ext);
501 let parent = path.parent().and_then(|p| p.to_str()).unwrap_or("");
502 if parent.is_empty() {
503 format!("{}_{}.{}", stem, index, ext)
504 } else {
505 format!("{}/{}_{}.{}", parent, stem, index, ext)
506 }
507 }
508 }
509
510 async fn save_to_file(
512 &self,
513 images: Vec<GeneratedImage>,
514 output_file: &str,
515 ) -> Result<ImageGenerateResult, Error> {
516 let mut paths = Vec::new();
517
518 for (i, image) in images.iter().enumerate() {
519 let data = BASE64.decode(&image.data).map_err(|e| {
521 Error::validation(format!("Invalid base64 data: {}", e))
522 })?;
523
524 let path = if images.len() == 1 {
526 output_file.to_string()
527 } else {
528 let p = Path::new(output_file);
530 let stem = p.file_stem().and_then(|s| s.to_str()).unwrap_or("image");
531 let ext = p.extension().and_then(|s| s.to_str()).unwrap_or("png");
532 let parent = p.parent().and_then(|p| p.to_str()).unwrap_or("");
533 if parent.is_empty() {
534 format!("{}_{}.{}", stem, i, ext)
535 } else {
536 format!("{}/{}_{}.{}", parent, stem, i, ext)
537 }
538 };
539
540 if let Some(parent) = Path::new(&path).parent() {
542 if !parent.as_os_str().is_empty() {
543 tokio::fs::create_dir_all(parent).await?;
544 }
545 }
546
547 tokio::fs::write(&path, &data).await?;
549 paths.push(path);
550 }
551
552 info!(count = paths.len(), "Saved images to local files");
553 Ok(ImageGenerateResult::LocalFiles(paths))
554 }
555
556 #[instrument(level = "info", name = "upscale_image", skip(self, params), fields(upscale_factor = %params.upscale_factor))]
565 pub async fn upscale_image(&self, params: ImageUpscaleParams) -> Result<ImageUpscaleResult, Error> {
566 params.validate().map_err(|errors| {
568 let messages: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
569 Error::validation(messages.join("; "))
570 })?;
571
572 info!(upscale_factor = %params.upscale_factor, "Upscaling image with Imagen Upscale API");
573
574 let image_data = self.resolve_image_input(¶ms.image).await?;
576
577 let request = UpscaleRequest {
579 instances: vec![UpscaleInstance {
580 image: UpscaleImageInput {
581 bytes_base64_encoded: image_data,
582 },
583 }],
584 parameters: UpscaleParameters {
585 upscale_factor: params.upscale_factor.clone(),
586 output_mime_type: "image/png".to_string(),
587 },
588 };
589
590 let token = self.auth.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
592
593 let endpoint = self.get_upscale_endpoint();
595 debug!(endpoint = %endpoint, "Calling Imagen Upscale API");
596
597 let response = self.http
598 .post(&endpoint)
599 .header("Authorization", format!("Bearer {}", token))
600 .header("Content-Type", "application/json")
601 .json(&request)
602 .send()
603 .await
604 .map_err(|e| Error::api(&endpoint, 0, format!("Request failed: {}", e)))?;
605
606 let status = response.status();
607 if !status.is_success() {
608 let body = response.text().await.unwrap_or_default();
609 return Err(Error::api(&endpoint, status.as_u16(), body));
610 }
611
612 let api_response: UpscaleResponse = response.json().await.map_err(|e| {
614 Error::api(&endpoint, status.as_u16(), format!("Failed to parse response: {}", e))
615 })?;
616
617 let prediction = api_response.predictions.into_iter().next()
619 .ok_or_else(|| Error::api(&endpoint, 200, "No image returned from API"))?;
620
621 let image_data = prediction.bytes_base64_encoded
622 .ok_or_else(|| Error::api(&endpoint, 200, "No image data in response"))?;
623
624 let image = GeneratedImage {
625 data: image_data,
626 mime_type: prediction.mime_type.unwrap_or_else(|| "image/png".to_string()),
627 };
628
629 info!("Received upscaled image from API");
630
631 self.handle_upscale_output(image, ¶ms).await
633 }
634
635 pub fn get_upscale_endpoint(&self) -> String {
637 format!(
638 "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
639 self.config.location,
640 self.config.project_id,
641 self.config.location,
642 UPSCALE_MODEL
643 )
644 }
645
646 async fn resolve_image_input(&self, image: &str) -> Result<String, Error> {
648 if image.starts_with("gs://") {
650 let uri = GcsUri::parse(image)?;
651 let data = self.gcs.download(&uri).await?;
652 return Ok(BASE64.encode(&data));
653 }
654
655 let looks_like_path = image.starts_with('/')
657 || image.starts_with("./")
658 || image.starts_with("../")
659 || image.starts_with("~/")
660 || (image.len() < 500 && image.contains('/'));
661
662 if looks_like_path {
663 let path = Path::new(image);
664 if !path.exists() {
665 return Err(Error::validation(format!("Image file not found: {}", image)));
666 }
667 let data = tokio::fs::read(path).await?;
668 return Ok(BASE64.encode(&data));
669 }
670
671 if image.len() > 100 {
673 if BASE64.decode(image).is_ok() {
674 return Ok(image.to_string());
675 }
676 }
677
678 let path = Path::new(image);
680 if path.exists() {
681 let data = tokio::fs::read(path).await?;
682 return Ok(BASE64.encode(&data));
683 }
684
685 if image.len() > 100 {
687 return Ok(image.to_string());
688 }
689
690 Err(Error::validation(format!(
691 "Image input is not a valid file path, GCS URI, or base64 data"
692 )))
693 }
694
695 async fn handle_upscale_output(
697 &self,
698 image: GeneratedImage,
699 params: &ImageUpscaleParams,
700 ) -> Result<ImageUpscaleResult, Error> {
701 if let Some(output_uri) = ¶ms.output_uri {
703 let data = BASE64.decode(&image.data).map_err(|e| {
704 Error::validation(format!("Invalid base64 data: {}", e))
705 })?;
706 let gcs_uri = GcsUri::parse(output_uri)?;
707 self.gcs.upload(&gcs_uri, &data, &image.mime_type).await?;
708 info!(uri = %output_uri, "Uploaded upscaled image to storage");
709 return Ok(ImageUpscaleResult::StorageUri(output_uri.clone()));
710 }
711
712 if let Some(output_file) = ¶ms.output_file {
714 let data = BASE64.decode(&image.data).map_err(|e| {
715 Error::validation(format!("Invalid base64 data: {}", e))
716 })?;
717
718 if let Some(parent) = Path::new(output_file).parent() {
720 if !parent.as_os_str().is_empty() {
721 tokio::fs::create_dir_all(parent).await?;
722 }
723 }
724
725 tokio::fs::write(output_file, &data).await?;
726 info!(path = %output_file, "Saved upscaled image to local file");
727 return Ok(ImageUpscaleResult::LocalFile(output_file.clone()));
728 }
729
730 Ok(ImageUpscaleResult::Base64(image))
732 }
733}
734
735#[derive(Debug, Serialize)]
741pub struct ImagenRequest {
742 pub instances: Vec<ImagenInstance>,
744 pub parameters: ImagenParameters,
746}
747
748#[derive(Debug, Serialize)]
750#[serde(rename_all = "camelCase")]
751pub struct ImagenInstance {
752 pub prompt: String,
754 #[serde(skip_serializing_if = "Option::is_none")]
756 pub negative_prompt: Option<String>,
757}
758
759#[derive(Debug, Serialize)]
761#[serde(rename_all = "camelCase")]
762pub struct ImagenParameters {
763 pub sample_count: u8,
765 pub aspect_ratio: String,
767 #[serde(skip_serializing_if = "Option::is_none")]
769 pub seed: Option<i64>,
770}
771
772#[derive(Debug, Deserialize)]
774pub struct ImagenResponse {
775 pub predictions: Vec<ImagenPrediction>,
777}
778
779#[derive(Debug, Deserialize)]
781#[serde(rename_all = "camelCase")]
782pub struct ImagenPrediction {
783 pub bytes_base64_encoded: Option<String>,
785 pub mime_type: Option<String>,
787}
788
789#[derive(Debug, Serialize)]
795pub struct UpscaleRequest {
796 pub instances: Vec<UpscaleInstance>,
798 pub parameters: UpscaleParameters,
800}
801
802#[derive(Debug, Serialize)]
804pub struct UpscaleInstance {
805 pub image: UpscaleImageInput,
807}
808
809#[derive(Debug, Serialize)]
811#[serde(rename_all = "camelCase")]
812pub struct UpscaleImageInput {
813 pub bytes_base64_encoded: String,
815}
816
817#[derive(Debug, Serialize)]
819#[serde(rename_all = "camelCase")]
820pub struct UpscaleParameters {
821 pub upscale_factor: String,
823 pub output_mime_type: String,
825}
826
827#[derive(Debug, Deserialize)]
829pub struct UpscaleResponse {
830 pub predictions: Vec<UpscalePrediction>,
832}
833
834#[derive(Debug, Deserialize)]
836#[serde(rename_all = "camelCase")]
837pub struct UpscalePrediction {
838 pub bytes_base64_encoded: Option<String>,
840 pub mime_type: Option<String>,
842}
843
844#[derive(Debug, Clone)]
850pub struct GeneratedImage {
851 pub data: String,
853 pub mime_type: String,
855}
856
857#[derive(Debug)]
859pub enum ImageGenerateResult {
860 Base64(Vec<GeneratedImage>),
862 LocalFiles(Vec<String>),
864 StorageUris(Vec<String>),
866}
867
868#[derive(Debug)]
870pub enum ImageUpscaleResult {
871 Base64(GeneratedImage),
873 LocalFile(String),
875 StorageUri(String),
877}
878
879#[cfg(test)]
880mod tests {
881 use super::*;
882
883 #[test]
884 fn test_default_params() {
885 let params: ImageGenerateParams = serde_json::from_str(r#"{"prompt": "a cat"}"#).unwrap();
886 assert_eq!(params.model, DEFAULT_MODEL);
887 assert_eq!(params.aspect_ratio, "1:1");
888 assert_eq!(params.number_of_images, 1);
889 assert!(params.negative_prompt.is_none());
890 assert!(params.seed.is_none());
891 assert!(params.output_file.is_none());
892 assert!(params.output_uri.is_none());
893 }
894
895 #[test]
896 fn test_valid_params() {
897 let params = ImageGenerateParams {
898 prompt: "A beautiful sunset over mountains".to_string(),
899 negative_prompt: Some("blurry, low quality".to_string()),
900 model: "imagen-4".to_string(),
901 aspect_ratio: "16:9".to_string(),
902 number_of_images: 2,
903 seed: Some(42),
904 output_file: None,
905 output_uri: None,
906 };
907
908 assert!(params.validate().is_ok());
909 }
910
911 #[test]
912 fn test_invalid_number_of_images_zero() {
913 let params = ImageGenerateParams {
914 prompt: "A cat".to_string(),
915 negative_prompt: None,
916 model: DEFAULT_MODEL.to_string(),
917 aspect_ratio: "1:1".to_string(),
918 number_of_images: 0,
919 seed: None,
920 output_file: None,
921 output_uri: None,
922 };
923
924 let result = params.validate();
925 assert!(result.is_err());
926 let errors = result.unwrap_err();
927 assert!(errors.iter().any(|e| e.field == "number_of_images"));
928 }
929
930 #[test]
931 fn test_invalid_number_of_images_too_high() {
932 let params = ImageGenerateParams {
933 prompt: "A cat".to_string(),
934 negative_prompt: None,
935 model: DEFAULT_MODEL.to_string(),
936 aspect_ratio: "1:1".to_string(),
937 number_of_images: 5,
938 seed: None,
939 output_file: None,
940 output_uri: None,
941 };
942
943 let result = params.validate();
944 assert!(result.is_err());
945 let errors = result.unwrap_err();
946 assert!(errors.iter().any(|e| e.field == "number_of_images"));
947 }
948
949 #[test]
950 fn test_invalid_aspect_ratio() {
951 let params = ImageGenerateParams {
952 prompt: "A cat".to_string(),
953 negative_prompt: None,
954 model: DEFAULT_MODEL.to_string(),
955 aspect_ratio: "2:1".to_string(),
956 number_of_images: 1,
957 seed: None,
958 output_file: None,
959 output_uri: None,
960 };
961
962 let result = params.validate();
963 assert!(result.is_err());
964 let errors = result.unwrap_err();
965 assert!(errors.iter().any(|e| e.field == "aspect_ratio"));
966 }
967
968 #[test]
969 fn test_invalid_model() {
970 let params = ImageGenerateParams {
971 prompt: "A cat".to_string(),
972 negative_prompt: None,
973 model: "unknown-model".to_string(),
974 aspect_ratio: "1:1".to_string(),
975 number_of_images: 1,
976 seed: None,
977 output_file: None,
978 output_uri: None,
979 };
980
981 let result = params.validate();
982 assert!(result.is_err());
983 let errors = result.unwrap_err();
984 assert!(errors.iter().any(|e| e.field == "model"));
985 }
986
987 #[test]
988 fn test_empty_prompt() {
989 let params = ImageGenerateParams {
990 prompt: " ".to_string(),
991 negative_prompt: None,
992 model: DEFAULT_MODEL.to_string(),
993 aspect_ratio: "1:1".to_string(),
994 number_of_images: 1,
995 seed: None,
996 output_file: None,
997 output_uri: None,
998 };
999
1000 let result = params.validate();
1001 assert!(result.is_err());
1002 let errors = result.unwrap_err();
1003 assert!(errors.iter().any(|e| e.field == "prompt"));
1004 }
1005
1006 #[test]
1007 fn test_prompt_too_long_imagen3() {
1008 let long_prompt = "a".repeat(500); let params = ImageGenerateParams {
1010 prompt: long_prompt,
1011 negative_prompt: None,
1012 model: "imagen-3".to_string(),
1013 aspect_ratio: "1:1".to_string(),
1014 number_of_images: 1,
1015 seed: None,
1016 output_file: None,
1017 output_uri: None,
1018 };
1019
1020 let result = params.validate();
1021 assert!(result.is_err());
1022 let errors = result.unwrap_err();
1023 assert!(errors.iter().any(|e| e.field == "prompt" && e.message.contains("exceeds")));
1024 }
1025
1026 #[test]
1027 fn test_prompt_ok_imagen4() {
1028 let long_prompt = "a".repeat(500); let params = ImageGenerateParams {
1030 prompt: long_prompt,
1031 negative_prompt: None,
1032 model: "imagen-4".to_string(),
1033 aspect_ratio: "1:1".to_string(),
1034 number_of_images: 1,
1035 seed: None,
1036 output_file: None,
1037 output_uri: None,
1038 };
1039
1040 assert!(params.validate().is_ok());
1041 }
1042
1043 #[test]
1044 fn test_all_valid_aspect_ratios() {
1045 for ratio in VALID_ASPECT_RATIOS {
1046 let params = ImageGenerateParams {
1047 prompt: "A cat".to_string(),
1048 negative_prompt: None,
1049 model: DEFAULT_MODEL.to_string(),
1050 aspect_ratio: ratio.to_string(),
1051 number_of_images: 1,
1052 seed: None,
1053 output_file: None,
1054 output_uri: None,
1055 };
1056 assert!(params.validate().is_ok(), "Aspect ratio {} should be valid", ratio);
1057 }
1058 }
1059
1060 #[test]
1061 fn test_all_valid_number_of_images() {
1062 for n in MIN_NUMBER_OF_IMAGES..=MAX_NUMBER_OF_IMAGES {
1063 let params = ImageGenerateParams {
1064 prompt: "A cat".to_string(),
1065 negative_prompt: None,
1066 model: DEFAULT_MODEL.to_string(),
1067 aspect_ratio: "1:1".to_string(),
1068 number_of_images: n,
1069 seed: None,
1070 output_file: None,
1071 output_uri: None,
1072 };
1073 assert!(params.validate().is_ok(), "number_of_images {} should be valid", n);
1074 }
1075 }
1076
1077 #[test]
1078 fn test_get_model() {
1079 let params = ImageGenerateParams {
1080 prompt: "A cat".to_string(),
1081 negative_prompt: None,
1082 model: "imagen-4".to_string(),
1083 aspect_ratio: "1:1".to_string(),
1084 number_of_images: 1,
1085 seed: None,
1086 output_file: None,
1087 output_uri: None,
1088 };
1089
1090 let model = params.get_model();
1091 assert!(model.is_some());
1092 assert_eq!(model.unwrap().id, "imagen-4.0-generate-preview-06-06");
1093 }
1094
1095 #[test]
1096 fn test_serialization_roundtrip() {
1097 let params = ImageGenerateParams {
1098 prompt: "A cat".to_string(),
1099 negative_prompt: Some("blurry".to_string()),
1100 model: "imagen-4".to_string(),
1101 aspect_ratio: "16:9".to_string(),
1102 number_of_images: 2,
1103 seed: Some(42),
1104 output_file: Some("/tmp/output.png".to_string()),
1105 output_uri: None,
1106 };
1107
1108 let json = serde_json::to_string(¶ms).unwrap();
1109 let deserialized: ImageGenerateParams = serde_json::from_str(&json).unwrap();
1110
1111 assert_eq!(params.prompt, deserialized.prompt);
1112 assert_eq!(params.negative_prompt, deserialized.negative_prompt);
1113 assert_eq!(params.model, deserialized.model);
1114 assert_eq!(params.aspect_ratio, deserialized.aspect_ratio);
1115 assert_eq!(params.number_of_images, deserialized.number_of_images);
1116 assert_eq!(params.seed, deserialized.seed);
1117 assert_eq!(params.output_file, deserialized.output_file);
1118 }
1119
1120 #[test]
1122 fn test_add_index_suffix_to_gcs_uri_simple() {
1123 let uri = "gs://bucket/output.png";
1124 let result = ImageHandler::add_index_suffix_to_uri(uri, 0, "image", "png");
1125 assert_eq!(result, "gs://bucket/output_0.png");
1126 }
1127
1128 #[test]
1129 fn test_add_index_suffix_to_gcs_uri_with_path() {
1130 let uri = "gs://bucket/path/to/output.png";
1131 let result = ImageHandler::add_index_suffix_to_uri(uri, 1, "image", "png");
1132 assert_eq!(result, "gs://bucket/path/to/output_1.png");
1133 }
1134
1135 #[test]
1136 fn test_add_index_suffix_to_gcs_uri_no_extension() {
1137 let uri = "gs://bucket/output";
1138 let result = ImageHandler::add_index_suffix_to_uri(uri, 2, "image", "png");
1139 assert_eq!(result, "gs://bucket/output_2.png");
1140 }
1141
1142 #[test]
1143 fn test_add_index_suffix_to_local_path() {
1144 let path = "/tmp/output.png";
1145 let result = ImageHandler::add_index_suffix_to_uri(path, 0, "image", "png");
1146 assert_eq!(result, "/tmp/output_0.png");
1147 }
1148
1149 #[test]
1150 fn test_add_index_suffix_to_local_path_no_dir() {
1151 let path = "output.png";
1152 let result = ImageHandler::add_index_suffix_to_uri(path, 1, "image", "png");
1153 assert_eq!(result, "output_1.png");
1154 }
1155}
1156
1157
1158#[cfg(test)]
1159mod property_tests {
1160 use super::*;
1161 use proptest::prelude::*;
1162
1163 fn valid_number_of_images_strategy() -> impl Strategy<Value = u8> {
1171 MIN_NUMBER_OF_IMAGES..=MAX_NUMBER_OF_IMAGES
1172 }
1173
1174 fn invalid_number_of_images_strategy() -> impl Strategy<Value = u8> {
1176 prop_oneof![
1177 Just(0u8),
1178 (MAX_NUMBER_OF_IMAGES + 1)..=u8::MAX,
1179 ]
1180 }
1181
1182 fn valid_aspect_ratio_strategy() -> impl Strategy<Value = &'static str> {
1184 prop_oneof![
1185 Just("1:1"),
1186 Just("3:4"),
1187 Just("4:3"),
1188 Just("9:16"),
1189 Just("16:9"),
1190 ]
1191 }
1192
1193 fn invalid_aspect_ratio_strategy() -> impl Strategy<Value = String> {
1195 prop_oneof![
1196 Just("2:1".to_string()),
1197 Just("1:2".to_string()),
1198 Just("5:4".to_string()),
1199 Just("invalid".to_string()),
1200 Just("".to_string()),
1201 Just("16:10".to_string()),
1202 Just("21:9".to_string()),
1203 "[0-9]+:[0-9]+".prop_filter("Must not be a valid ratio", |s| {
1205 !VALID_ASPECT_RATIOS.contains(&s.as_str())
1206 }),
1207 ]
1208 }
1209
1210 fn valid_prompt_strategy() -> impl Strategy<Value = String> {
1212 "[a-zA-Z0-9 ]{1,100}".prop_map(|s| s.trim().to_string())
1213 .prop_filter("Must not be empty", |s| !s.trim().is_empty())
1214 }
1215
1216 proptest! {
1217 #[test]
1219 fn valid_number_of_images_passes_validation(
1220 num in valid_number_of_images_strategy(),
1221 prompt in valid_prompt_strategy(),
1222 ) {
1223 let params = ImageGenerateParams {
1224 prompt,
1225 negative_prompt: None,
1226 model: DEFAULT_MODEL.to_string(),
1227 aspect_ratio: "1:1".to_string(),
1228 number_of_images: num,
1229 seed: None,
1230 output_file: None,
1231 output_uri: None,
1232 };
1233
1234 let result = params.validate();
1235 prop_assert!(
1236 result.is_ok(),
1237 "number_of_images {} should be valid, but got errors: {:?}",
1238 num,
1239 result.err()
1240 );
1241 }
1242
1243 #[test]
1245 fn invalid_number_of_images_fails_validation(
1246 num in invalid_number_of_images_strategy(),
1247 prompt in valid_prompt_strategy(),
1248 ) {
1249 let params = ImageGenerateParams {
1250 prompt,
1251 negative_prompt: None,
1252 model: DEFAULT_MODEL.to_string(),
1253 aspect_ratio: "1:1".to_string(),
1254 number_of_images: num,
1255 seed: None,
1256 output_file: None,
1257 output_uri: None,
1258 };
1259
1260 let result = params.validate();
1261 prop_assert!(
1262 result.is_err(),
1263 "number_of_images {} should be invalid",
1264 num
1265 );
1266
1267 let errors = result.unwrap_err();
1268 prop_assert!(
1269 errors.iter().any(|e| e.field == "number_of_images"),
1270 "Should have a number_of_images validation error for value {}",
1271 num
1272 );
1273 }
1274
1275 #[test]
1284 fn valid_aspect_ratio_passes_validation(
1285 ratio in valid_aspect_ratio_strategy(),
1286 prompt in valid_prompt_strategy(),
1287 ) {
1288 let params = ImageGenerateParams {
1289 prompt,
1290 negative_prompt: None,
1291 model: DEFAULT_MODEL.to_string(),
1292 aspect_ratio: ratio.to_string(),
1293 number_of_images: 1,
1294 seed: None,
1295 output_file: None,
1296 output_uri: None,
1297 };
1298
1299 let result = params.validate();
1300 prop_assert!(
1301 result.is_ok(),
1302 "aspect_ratio '{}' should be valid, but got errors: {:?}",
1303 ratio,
1304 result.err()
1305 );
1306 }
1307
1308 #[test]
1310 fn invalid_aspect_ratio_fails_validation(
1311 ratio in invalid_aspect_ratio_strategy(),
1312 prompt in valid_prompt_strategy(),
1313 ) {
1314 let params = ImageGenerateParams {
1315 prompt,
1316 negative_prompt: None,
1317 model: DEFAULT_MODEL.to_string(),
1318 aspect_ratio: ratio.clone(),
1319 number_of_images: 1,
1320 seed: None,
1321 output_file: None,
1322 output_uri: None,
1323 };
1324
1325 let result = params.validate();
1326 prop_assert!(
1327 result.is_err(),
1328 "aspect_ratio '{}' should be invalid",
1329 ratio
1330 );
1331
1332 let errors = result.unwrap_err();
1333 prop_assert!(
1334 errors.iter().any(|e| e.field == "aspect_ratio"),
1335 "Should have an aspect_ratio validation error for value '{}'",
1336 ratio
1337 );
1338
1339 let aspect_error = errors.iter().find(|e| e.field == "aspect_ratio").unwrap();
1341 prop_assert!(
1342 aspect_error.message.contains("Valid options"),
1343 "Error message should list valid options: {}",
1344 aspect_error.message
1345 );
1346 }
1347
1348 #[test]
1350 fn valid_params_combination_passes(
1351 num in valid_number_of_images_strategy(),
1352 ratio in valid_aspect_ratio_strategy(),
1353 prompt in valid_prompt_strategy(),
1354 ) {
1355 let params = ImageGenerateParams {
1356 prompt,
1357 negative_prompt: None,
1358 model: DEFAULT_MODEL.to_string(),
1359 aspect_ratio: ratio.to_string(),
1360 number_of_images: num,
1361 seed: None,
1362 output_file: None,
1363 output_uri: None,
1364 };
1365
1366 let result = params.validate();
1367 prop_assert!(
1368 result.is_ok(),
1369 "Valid params (num={}, ratio='{}') should pass, but got: {:?}",
1370 num,
1371 ratio,
1372 result.err()
1373 );
1374 }
1375 }
1376}
1377
1378#[cfg(test)]
1381mod api_tests {
1382 use super::*;
1383
1384 #[test]
1386 fn test_imagen_request_serialization() {
1387 let request = ImagenRequest {
1388 instances: vec![ImagenInstance {
1389 prompt: "A beautiful sunset".to_string(),
1390 negative_prompt: Some("blurry".to_string()),
1391 }],
1392 parameters: ImagenParameters {
1393 sample_count: 2,
1394 aspect_ratio: "16:9".to_string(),
1395 seed: Some(42),
1396 },
1397 };
1398
1399 let json = serde_json::to_value(&request).unwrap();
1400
1401 assert!(json["instances"].is_array());
1403 assert_eq!(json["instances"][0]["prompt"], "A beautiful sunset");
1404 assert_eq!(json["instances"][0]["negativePrompt"], "blurry");
1405 assert_eq!(json["parameters"]["sampleCount"], 2);
1406 assert_eq!(json["parameters"]["aspectRatio"], "16:9");
1407 assert_eq!(json["parameters"]["seed"], 42);
1408 }
1409
1410 #[test]
1412 fn test_imagen_request_serialization_minimal() {
1413 let request = ImagenRequest {
1414 instances: vec![ImagenInstance {
1415 prompt: "A cat".to_string(),
1416 negative_prompt: None,
1417 }],
1418 parameters: ImagenParameters {
1419 sample_count: 1,
1420 aspect_ratio: "1:1".to_string(),
1421 seed: None,
1422 },
1423 };
1424
1425 let json = serde_json::to_value(&request).unwrap();
1426
1427 assert!(json["instances"][0].get("negativePrompt").is_none());
1429 assert!(json["parameters"].get("seed").is_none());
1430 }
1431
1432 #[test]
1434 fn test_imagen_response_deserialization() {
1435 let json = r#"{
1436 "predictions": [
1437 {
1438 "bytesBase64Encoded": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
1439 "mimeType": "image/png"
1440 }
1441 ]
1442 }"#;
1443
1444 let response: ImagenResponse = serde_json::from_str(json).unwrap();
1445
1446 assert_eq!(response.predictions.len(), 1);
1447 assert!(response.predictions[0].bytes_base64_encoded.is_some());
1448 assert_eq!(response.predictions[0].mime_type, Some("image/png".to_string()));
1449 }
1450
1451 #[test]
1453 fn test_imagen_response_multiple_predictions() {
1454 let json = r#"{
1455 "predictions": [
1456 {
1457 "bytesBase64Encoded": "base64data1",
1458 "mimeType": "image/png"
1459 },
1460 {
1461 "bytesBase64Encoded": "base64data2",
1462 "mimeType": "image/png"
1463 }
1464 ]
1465 }"#;
1466
1467 let response: ImagenResponse = serde_json::from_str(json).unwrap();
1468
1469 assert_eq!(response.predictions.len(), 2);
1470 assert_eq!(response.predictions[0].bytes_base64_encoded, Some("base64data1".to_string()));
1471 assert_eq!(response.predictions[1].bytes_base64_encoded, Some("base64data2".to_string()));
1472 }
1473
1474 #[test]
1476 fn test_imagen_response_empty_predictions() {
1477 let json = r#"{"predictions": []}"#;
1478
1479 let response: ImagenResponse = serde_json::from_str(json).unwrap();
1480
1481 assert!(response.predictions.is_empty());
1482 }
1483
1484 #[test]
1486 fn test_imagen_response_no_image_data() {
1487 let json = r#"{
1488 "predictions": [
1489 {
1490 "mimeType": "image/png"
1491 }
1492 ]
1493 }"#;
1494
1495 let response: ImagenResponse = serde_json::from_str(json).unwrap();
1496
1497 assert_eq!(response.predictions.len(), 1);
1498 assert!(response.predictions[0].bytes_base64_encoded.is_none());
1499 }
1500
1501 #[test]
1503 fn test_get_endpoint() {
1504 let config = Config {
1505 project_id: "my-project".to_string(),
1506 location: "us-central1".to_string(),
1507 gcs_bucket: None,
1508 port: 8080,
1509 ..Default::default()
1510 };
1511
1512 let expected_url = format!(
1515 "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
1516 config.location,
1517 config.project_id,
1518 config.location,
1519 "imagen-4.0-generate-preview-05-20"
1520 );
1521
1522 assert!(expected_url.contains("us-central1-aiplatform.googleapis.com"));
1523 assert!(expected_url.contains("my-project"));
1524 assert!(expected_url.contains("imagen-4.0-generate-preview-05-20"));
1525 assert!(expected_url.ends_with(":predict"));
1526 }
1527
1528 #[test]
1530 fn test_generated_image() {
1531 let image = GeneratedImage {
1532 data: "base64encodeddata".to_string(),
1533 mime_type: "image/png".to_string(),
1534 };
1535
1536 assert_eq!(image.data, "base64encodeddata");
1537 assert_eq!(image.mime_type, "image/png");
1538 }
1539
1540 #[test]
1542 fn test_image_generate_result_base64() {
1543 let images = vec![
1544 GeneratedImage {
1545 data: "data1".to_string(),
1546 mime_type: "image/png".to_string(),
1547 },
1548 GeneratedImage {
1549 data: "data2".to_string(),
1550 mime_type: "image/jpeg".to_string(),
1551 },
1552 ];
1553
1554 let result = ImageGenerateResult::Base64(images);
1555
1556 match result {
1557 ImageGenerateResult::Base64(imgs) => {
1558 assert_eq!(imgs.len(), 2);
1559 assert_eq!(imgs[0].data, "data1");
1560 assert_eq!(imgs[1].mime_type, "image/jpeg");
1561 }
1562 _ => panic!("Expected Base64 variant"),
1563 }
1564 }
1565
1566 #[test]
1568 fn test_image_generate_result_local_files() {
1569 let paths = vec!["/tmp/image1.png".to_string(), "/tmp/image2.png".to_string()];
1570 let result = ImageGenerateResult::LocalFiles(paths);
1571
1572 match result {
1573 ImageGenerateResult::LocalFiles(p) => {
1574 assert_eq!(p.len(), 2);
1575 assert!(p[0].contains("image1"));
1576 }
1577 _ => panic!("Expected LocalFiles variant"),
1578 }
1579 }
1580
1581 #[test]
1583 fn test_image_generate_result_storage_uris() {
1584 let uris = vec![
1585 "gs://bucket/image1.png".to_string(),
1586 "gs://bucket/image2.png".to_string(),
1587 ];
1588 let result = ImageGenerateResult::StorageUris(uris);
1589
1590 match result {
1591 ImageGenerateResult::StorageUris(u) => {
1592 assert_eq!(u.len(), 2);
1593 assert!(u[0].starts_with("gs://"));
1594 }
1595 _ => panic!("Expected StorageUris variant"),
1596 }
1597 }
1598
1599 #[test]
1601 fn test_validation_error_display() {
1602 let error = ValidationError {
1603 field: "prompt".to_string(),
1604 message: "cannot be empty".to_string(),
1605 };
1606
1607 let display = format!("{}", error);
1608 assert_eq!(display, "prompt: cannot be empty");
1609 }
1610
1611 #[test]
1613 fn test_validation_multiple_errors() {
1614 let params = ImageGenerateParams {
1615 prompt: " ".to_string(), negative_prompt: None,
1617 model: "unknown-model".to_string(), aspect_ratio: "invalid".to_string(), number_of_images: 10, seed: None,
1621 output_file: None,
1622 output_uri: None,
1623 };
1624
1625 let result = params.validate();
1626 assert!(result.is_err());
1627
1628 let errors = result.unwrap_err();
1629 assert!(errors.len() >= 3, "Expected at least 3 validation errors, got {}", errors.len());
1631
1632 let fields: Vec<&str> = errors.iter().map(|e| e.field.as_str()).collect();
1633 assert!(fields.contains(&"prompt"));
1634 assert!(fields.contains(&"model"));
1635 assert!(fields.contains(&"number_of_images"));
1636 }
1637}