use adk_rust_mcp_common::auth::AuthProvider;
use adk_rust_mcp_common::config::Config;
use adk_rust_mcp_common::error::Error;
use adk_rust_mcp_common::gcs::{GcsClient, GcsUri};
use adk_rust_mcp_common::models::{ImagenModel, ModelRegistry, IMAGEN_MODELS};
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::path::Path;
use tracing::{debug, info, instrument};
pub const VALID_ASPECT_RATIOS: &[&str] = &["1:1", "3:4", "4:3", "9:16", "16:9"];
pub const DEFAULT_MODEL: &str = "imagen-3.0-generate-002";
pub const MIN_NUMBER_OF_IMAGES: u8 = 1;
pub const MAX_NUMBER_OF_IMAGES: u8 = 4;
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
pub struct ImageGenerateParams {
pub prompt: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub negative_prompt: Option<String>,
#[serde(default = "default_model")]
pub model: String,
#[serde(default = "default_aspect_ratio")]
pub aspect_ratio: String,
#[serde(default = "default_number_of_images")]
pub number_of_images: u8,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_file: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_uri: Option<String>,
}
fn default_model() -> String {
DEFAULT_MODEL.to_string()
}
fn default_aspect_ratio() -> String {
"1:1".to_string()
}
fn default_number_of_images() -> u8 {
1
}
pub const VALID_UPSCALE_FACTORS: &[&str] = &["x2", "x4"];
pub const UPSCALE_MODEL: &str = "imagen-4.0-upscale-preview";
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)]
pub struct ImageUpscaleParams {
pub image: String,
#[serde(default = "default_upscale_factor")]
pub upscale_factor: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_file: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_uri: Option<String>,
}
fn default_upscale_factor() -> String {
"x2".to_string()
}
impl ImageUpscaleParams {
pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
let mut errors = Vec::new();
if self.image.trim().is_empty() {
errors.push(ValidationError {
field: "image".to_string(),
message: "Image cannot be empty".to_string(),
});
}
if !VALID_UPSCALE_FACTORS.contains(&self.upscale_factor.as_str()) {
errors.push(ValidationError {
field: "upscale_factor".to_string(),
message: format!(
"Invalid upscale factor '{}'. Valid options: {}",
self.upscale_factor,
VALID_UPSCALE_FACTORS.join(", ")
),
});
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
#[derive(Debug, Clone)]
pub struct ValidationError {
pub field: String,
pub message: String,
}
impl std::fmt::Display for ValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.field, self.message)
}
}
impl ImageGenerateParams {
pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
let mut errors = Vec::new();
let model = ModelRegistry::resolve_imagen(&self.model);
if model.is_none() {
errors.push(ValidationError {
field: "model".to_string(),
message: format!(
"Unknown model '{}'. Valid models: {}",
self.model,
IMAGEN_MODELS
.iter()
.map(|m| m.id)
.collect::<Vec<_>>()
.join(", ")
),
});
}
if let Some(model) = model {
if self.prompt.len() > model.max_prompt_length {
errors.push(ValidationError {
field: "prompt".to_string(),
message: format!(
"Prompt length {} exceeds maximum {} for model {}",
self.prompt.len(),
model.max_prompt_length,
model.id
),
});
}
if !model.supported_aspect_ratios.contains(&self.aspect_ratio.as_str()) {
errors.push(ValidationError {
field: "aspect_ratio".to_string(),
message: format!(
"Invalid aspect ratio '{}'. Valid options for {}: {}",
self.aspect_ratio,
model.id,
model.supported_aspect_ratios.join(", ")
),
});
}
} else {
if !VALID_ASPECT_RATIOS.contains(&self.aspect_ratio.as_str()) {
errors.push(ValidationError {
field: "aspect_ratio".to_string(),
message: format!(
"Invalid aspect ratio '{}'. Valid options: {}",
self.aspect_ratio,
VALID_ASPECT_RATIOS.join(", ")
),
});
}
}
if self.number_of_images < MIN_NUMBER_OF_IMAGES
|| self.number_of_images > MAX_NUMBER_OF_IMAGES
{
errors.push(ValidationError {
field: "number_of_images".to_string(),
message: format!(
"number_of_images must be between {} and {}, got {}",
MIN_NUMBER_OF_IMAGES, MAX_NUMBER_OF_IMAGES, self.number_of_images
),
});
}
if self.prompt.trim().is_empty() {
errors.push(ValidationError {
field: "prompt".to_string(),
message: "Prompt cannot be empty".to_string(),
});
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
pub fn get_model(&self) -> Option<&'static ImagenModel> {
ModelRegistry::resolve_imagen(&self.model)
}
}
pub struct ImageHandler {
pub config: Config,
pub gcs: GcsClient,
pub http: reqwest::Client,
pub auth: AuthProvider,
}
impl ImageHandler {
#[instrument(level = "debug", name = "image_handler_new", skip_all)]
pub async fn new(config: Config) -> Result<Self, Error> {
debug!("Initializing ImageHandler");
let auth = AuthProvider::new().await?;
let gcs = GcsClient::with_auth(AuthProvider::new().await?);
let http = reqwest::Client::new();
Ok(Self {
config,
gcs,
http,
auth,
})
}
#[cfg(test)]
pub fn with_deps(config: Config, gcs: GcsClient, http: reqwest::Client, auth: AuthProvider) -> Self {
Self {
config,
gcs,
http,
auth,
}
}
pub fn get_endpoint(&self, model: &str) -> String {
format!(
"https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
self.config.location,
self.config.project_id,
self.config.location,
model
)
}
#[instrument(level = "info", name = "generate_image", skip(self, params), fields(model = %params.model, aspect_ratio = %params.aspect_ratio))]
pub async fn generate_image(&self, params: ImageGenerateParams) -> Result<ImageGenerateResult, Error> {
params.validate().map_err(|errors| {
let messages: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
Error::validation(messages.join("; "))
})?;
let model = params.get_model().ok_or_else(|| {
Error::validation(format!("Unknown model: {}", params.model))
})?;
info!(model_id = model.id, "Generating image with Imagen API");
let request = ImagenRequest {
instances: vec![ImagenInstance {
prompt: params.prompt.clone(),
negative_prompt: params.negative_prompt.clone(),
}],
parameters: ImagenParameters {
sample_count: params.number_of_images,
aspect_ratio: params.aspect_ratio.clone(),
seed: params.seed,
},
};
let token = self.auth.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
let endpoint = self.get_endpoint(model.id);
debug!(endpoint = %endpoint, "Calling Imagen API");
let response = self.http
.post(&endpoint)
.header("Authorization", format!("Bearer {}", token))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| Error::api(&endpoint, 0, format!("Request failed: {}", e)))?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(Error::api(&endpoint, status.as_u16(), body));
}
let api_response: ImagenResponse = response.json().await.map_err(|e| {
Error::api(&endpoint, status.as_u16(), format!("Failed to parse response: {}", e))
})?;
let images: Vec<GeneratedImage> = api_response
.predictions
.into_iter()
.filter_map(|p| {
p.bytes_base64_encoded.map(|data| GeneratedImage {
data,
mime_type: p.mime_type.unwrap_or_else(|| "image/png".to_string()),
})
})
.collect();
if images.is_empty() {
return Err(Error::api(&endpoint, 200, "No images returned from API"));
}
info!(count = images.len(), "Received images from API");
self.handle_output(images, ¶ms).await
}
async fn handle_output(
&self,
images: Vec<GeneratedImage>,
params: &ImageGenerateParams,
) -> Result<ImageGenerateResult, Error> {
if let Some(output_uri) = ¶ms.output_uri {
return self.upload_to_storage(images, output_uri).await;
}
if let Some(output_file) = ¶ms.output_file {
return self.save_to_file(images, output_file).await;
}
Ok(ImageGenerateResult::Base64(images))
}
async fn upload_to_storage(
&self,
images: Vec<GeneratedImage>,
output_uri: &str,
) -> Result<ImageGenerateResult, Error> {
let mut uris = Vec::new();
for (i, image) in images.iter().enumerate() {
let data = BASE64.decode(&image.data).map_err(|e| {
Error::validation(format!("Invalid base64 data: {}", e))
})?;
let uri = if images.len() == 1 {
output_uri.to_string()
} else {
Self::add_index_suffix_to_uri(output_uri, i, "image", "png")
};
let gcs_uri = GcsUri::parse(&uri)?;
self.gcs.upload(&gcs_uri, &data, &image.mime_type).await?;
uris.push(uri);
}
info!(count = uris.len(), "Uploaded images to storage");
Ok(ImageGenerateResult::StorageUris(uris))
}
fn add_index_suffix_to_uri(uri: &str, index: usize, default_stem: &str, default_ext: &str) -> String {
if let Some(stripped) = uri.strip_prefix("gs://") {
if let Some(slash_pos) = stripped.find('/') {
let bucket = &stripped[..slash_pos];
let object_path = &stripped[slash_pos + 1..];
let (dir, filename) = if let Some(last_slash) = object_path.rfind('/') {
(&object_path[..last_slash], &object_path[last_slash + 1..])
} else {
("", object_path)
};
let (stem, ext) = if let Some(dot_pos) = filename.rfind('.') {
(&filename[..dot_pos], &filename[dot_pos + 1..])
} else {
(filename, default_ext)
};
let stem = if stem.is_empty() { default_stem } else { stem };
if dir.is_empty() {
format!("gs://{}/{}_{}.{}", bucket, stem, index, ext)
} else {
format!("gs://{}/{}/{}_{}.{}", bucket, dir, stem, index, ext)
}
} else {
format!("{}/{}_{}.{}", uri, default_stem, index, default_ext)
}
} else {
let path = Path::new(uri);
let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or(default_stem);
let ext = path.extension().and_then(|s| s.to_str()).unwrap_or(default_ext);
let parent = path.parent().and_then(|p| p.to_str()).unwrap_or("");
if parent.is_empty() {
format!("{}_{}.{}", stem, index, ext)
} else {
format!("{}/{}_{}.{}", parent, stem, index, ext)
}
}
}
async fn save_to_file(
&self,
images: Vec<GeneratedImage>,
output_file: &str,
) -> Result<ImageGenerateResult, Error> {
let mut paths = Vec::new();
for (i, image) in images.iter().enumerate() {
let data = BASE64.decode(&image.data).map_err(|e| {
Error::validation(format!("Invalid base64 data: {}", e))
})?;
let path = if images.len() == 1 {
output_file.to_string()
} else {
let p = Path::new(output_file);
let stem = p.file_stem().and_then(|s| s.to_str()).unwrap_or("image");
let ext = p.extension().and_then(|s| s.to_str()).unwrap_or("png");
let parent = p.parent().and_then(|p| p.to_str()).unwrap_or("");
if parent.is_empty() {
format!("{}_{}.{}", stem, i, ext)
} else {
format!("{}/{}_{}.{}", parent, stem, i, ext)
}
};
if let Some(parent) = Path::new(&path).parent() {
if !parent.as_os_str().is_empty() {
tokio::fs::create_dir_all(parent).await?;
}
}
tokio::fs::write(&path, &data).await?;
paths.push(path);
}
info!(count = paths.len(), "Saved images to local files");
Ok(ImageGenerateResult::LocalFiles(paths))
}
#[instrument(level = "info", name = "upscale_image", skip(self, params), fields(upscale_factor = %params.upscale_factor))]
pub async fn upscale_image(&self, params: ImageUpscaleParams) -> Result<ImageUpscaleResult, Error> {
params.validate().map_err(|errors| {
let messages: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
Error::validation(messages.join("; "))
})?;
info!(upscale_factor = %params.upscale_factor, "Upscaling image with Imagen Upscale API");
let image_data = self.resolve_image_input(¶ms.image).await?;
let request = UpscaleRequest {
instances: vec![UpscaleInstance {
image: UpscaleImageInput {
bytes_base64_encoded: image_data,
},
}],
parameters: UpscaleParameters {
upscale_factor: params.upscale_factor.clone(),
output_mime_type: "image/png".to_string(),
},
};
let token = self.auth.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
let endpoint = self.get_upscale_endpoint();
debug!(endpoint = %endpoint, "Calling Imagen Upscale API");
let response = self.http
.post(&endpoint)
.header("Authorization", format!("Bearer {}", token))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| Error::api(&endpoint, 0, format!("Request failed: {}", e)))?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(Error::api(&endpoint, status.as_u16(), body));
}
let api_response: UpscaleResponse = response.json().await.map_err(|e| {
Error::api(&endpoint, status.as_u16(), format!("Failed to parse response: {}", e))
})?;
let prediction = api_response.predictions.into_iter().next()
.ok_or_else(|| Error::api(&endpoint, 200, "No image returned from API"))?;
let image_data = prediction.bytes_base64_encoded
.ok_or_else(|| Error::api(&endpoint, 200, "No image data in response"))?;
let image = GeneratedImage {
data: image_data,
mime_type: prediction.mime_type.unwrap_or_else(|| "image/png".to_string()),
};
info!("Received upscaled image from API");
self.handle_upscale_output(image, ¶ms).await
}
pub fn get_upscale_endpoint(&self) -> String {
format!(
"https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
self.config.location,
self.config.project_id,
self.config.location,
UPSCALE_MODEL
)
}
async fn resolve_image_input(&self, image: &str) -> Result<String, Error> {
if image.starts_with("gs://") {
let uri = GcsUri::parse(image)?;
let data = self.gcs.download(&uri).await?;
return Ok(BASE64.encode(&data));
}
let looks_like_path = image.starts_with('/')
|| image.starts_with("./")
|| image.starts_with("../")
|| image.starts_with("~/")
|| (image.len() < 500 && image.contains('/'));
if looks_like_path {
let path = Path::new(image);
if !path.exists() {
return Err(Error::validation(format!("Image file not found: {}", image)));
}
let data = tokio::fs::read(path).await?;
return Ok(BASE64.encode(&data));
}
if image.len() > 100 {
if BASE64.decode(image).is_ok() {
return Ok(image.to_string());
}
}
let path = Path::new(image);
if path.exists() {
let data = tokio::fs::read(path).await?;
return Ok(BASE64.encode(&data));
}
if image.len() > 100 {
return Ok(image.to_string());
}
Err(Error::validation(format!(
"Image input is not a valid file path, GCS URI, or base64 data"
)))
}
async fn handle_upscale_output(
&self,
image: GeneratedImage,
params: &ImageUpscaleParams,
) -> Result<ImageUpscaleResult, Error> {
if let Some(output_uri) = ¶ms.output_uri {
let data = BASE64.decode(&image.data).map_err(|e| {
Error::validation(format!("Invalid base64 data: {}", e))
})?;
let gcs_uri = GcsUri::parse(output_uri)?;
self.gcs.upload(&gcs_uri, &data, &image.mime_type).await?;
info!(uri = %output_uri, "Uploaded upscaled image to storage");
return Ok(ImageUpscaleResult::StorageUri(output_uri.clone()));
}
if let Some(output_file) = ¶ms.output_file {
let data = BASE64.decode(&image.data).map_err(|e| {
Error::validation(format!("Invalid base64 data: {}", e))
})?;
if let Some(parent) = Path::new(output_file).parent() {
if !parent.as_os_str().is_empty() {
tokio::fs::create_dir_all(parent).await?;
}
}
tokio::fs::write(output_file, &data).await?;
info!(path = %output_file, "Saved upscaled image to local file");
return Ok(ImageUpscaleResult::LocalFile(output_file.clone()));
}
Ok(ImageUpscaleResult::Base64(image))
}
}
#[derive(Debug, Serialize)]
pub struct ImagenRequest {
pub instances: Vec<ImagenInstance>,
pub parameters: ImagenParameters,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ImagenInstance {
pub prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub negative_prompt: Option<String>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ImagenParameters {
pub sample_count: u8,
pub aspect_ratio: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
}
#[derive(Debug, Deserialize)]
pub struct ImagenResponse {
pub predictions: Vec<ImagenPrediction>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ImagenPrediction {
pub bytes_base64_encoded: Option<String>,
pub mime_type: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct UpscaleRequest {
pub instances: Vec<UpscaleInstance>,
pub parameters: UpscaleParameters,
}
#[derive(Debug, Serialize)]
pub struct UpscaleInstance {
pub image: UpscaleImageInput,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct UpscaleImageInput {
pub bytes_base64_encoded: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct UpscaleParameters {
pub upscale_factor: String,
pub output_mime_type: String,
}
#[derive(Debug, Deserialize)]
pub struct UpscaleResponse {
pub predictions: Vec<UpscalePrediction>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct UpscalePrediction {
pub bytes_base64_encoded: Option<String>,
pub mime_type: Option<String>,
}
#[derive(Debug, Clone)]
pub struct GeneratedImage {
pub data: String,
pub mime_type: String,
}
#[derive(Debug)]
pub enum ImageGenerateResult {
Base64(Vec<GeneratedImage>),
LocalFiles(Vec<String>),
StorageUris(Vec<String>),
}
#[derive(Debug)]
pub enum ImageUpscaleResult {
Base64(GeneratedImage),
LocalFile(String),
StorageUri(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_params() {
let params: ImageGenerateParams = serde_json::from_str(r#"{"prompt": "a cat"}"#).unwrap();
assert_eq!(params.model, DEFAULT_MODEL);
assert_eq!(params.aspect_ratio, "1:1");
assert_eq!(params.number_of_images, 1);
assert!(params.negative_prompt.is_none());
assert!(params.seed.is_none());
assert!(params.output_file.is_none());
assert!(params.output_uri.is_none());
}
#[test]
fn test_valid_params() {
let params = ImageGenerateParams {
prompt: "A beautiful sunset over mountains".to_string(),
negative_prompt: Some("blurry, low quality".to_string()),
model: "imagen-4".to_string(),
aspect_ratio: "16:9".to_string(),
number_of_images: 2,
seed: Some(42),
output_file: None,
output_uri: None,
};
assert!(params.validate().is_ok());
}
#[test]
fn test_invalid_number_of_images_zero() {
let params = ImageGenerateParams {
prompt: "A cat".to_string(),
negative_prompt: None,
model: DEFAULT_MODEL.to_string(),
aspect_ratio: "1:1".to_string(),
number_of_images: 0,
seed: None,
output_file: None,
output_uri: None,
};
let result = params.validate();
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(errors.iter().any(|e| e.field == "number_of_images"));
}
#[test]
fn test_invalid_number_of_images_too_high() {
let params = ImageGenerateParams {
prompt: "A cat".to_string(),
negative_prompt: None,
model: DEFAULT_MODEL.to_string(),
aspect_ratio: "1:1".to_string(),
number_of_images: 5,
seed: None,
output_file: None,
output_uri: None,
};
let result = params.validate();
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(errors.iter().any(|e| e.field == "number_of_images"));
}
#[test]
fn test_invalid_aspect_ratio() {
let params = ImageGenerateParams {
prompt: "A cat".to_string(),
negative_prompt: None,
model: DEFAULT_MODEL.to_string(),
aspect_ratio: "2:1".to_string(),
number_of_images: 1,
seed: None,
output_file: None,
output_uri: None,
};
let result = params.validate();
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(errors.iter().any(|e| e.field == "aspect_ratio"));
}
#[test]
fn test_invalid_model() {
let params = ImageGenerateParams {
prompt: "A cat".to_string(),
negative_prompt: None,
model: "unknown-model".to_string(),
aspect_ratio: "1:1".to_string(),
number_of_images: 1,
seed: None,
output_file: None,
output_uri: None,
};
let result = params.validate();
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(errors.iter().any(|e| e.field == "model"));
}
#[test]
fn test_empty_prompt() {
let params = ImageGenerateParams {
prompt: " ".to_string(),
negative_prompt: None,
model: DEFAULT_MODEL.to_string(),
aspect_ratio: "1:1".to_string(),
number_of_images: 1,
seed: None,
output_file: None,
output_uri: None,
};
let result = params.validate();
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(errors.iter().any(|e| e.field == "prompt"));
}
#[test]
fn test_prompt_too_long_imagen3() {
let long_prompt = "a".repeat(500); let params = ImageGenerateParams {
prompt: long_prompt,
negative_prompt: None,
model: "imagen-3".to_string(),
aspect_ratio: "1:1".to_string(),
number_of_images: 1,
seed: None,
output_file: None,
output_uri: None,
};
let result = params.validate();
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(errors.iter().any(|e| e.field == "prompt" && e.message.contains("exceeds")));
}
#[test]
fn test_prompt_ok_imagen4() {
let long_prompt = "a".repeat(500); let params = ImageGenerateParams {
prompt: long_prompt,
negative_prompt: None,
model: "imagen-4".to_string(),
aspect_ratio: "1:1".to_string(),
number_of_images: 1,
seed: None,
output_file: None,
output_uri: None,
};
assert!(params.validate().is_ok());
}
#[test]
fn test_all_valid_aspect_ratios() {
for ratio in VALID_ASPECT_RATIOS {
let params = ImageGenerateParams {
prompt: "A cat".to_string(),
negative_prompt: None,
model: DEFAULT_MODEL.to_string(),
aspect_ratio: ratio.to_string(),
number_of_images: 1,
seed: None,
output_file: None,
output_uri: None,
};
assert!(params.validate().is_ok(), "Aspect ratio {} should be valid", ratio);
}
}
#[test]
fn test_all_valid_number_of_images() {
for n in MIN_NUMBER_OF_IMAGES..=MAX_NUMBER_OF_IMAGES {
let params = ImageGenerateParams {
prompt: "A cat".to_string(),
negative_prompt: None,
model: DEFAULT_MODEL.to_string(),
aspect_ratio: "1:1".to_string(),
number_of_images: n,
seed: None,
output_file: None,
output_uri: None,
};
assert!(params.validate().is_ok(), "number_of_images {} should be valid", n);
}
}
#[test]
fn test_get_model() {
let params = ImageGenerateParams {
prompt: "A cat".to_string(),
negative_prompt: None,
model: "imagen-4".to_string(),
aspect_ratio: "1:1".to_string(),
number_of_images: 1,
seed: None,
output_file: None,
output_uri: None,
};
let model = params.get_model();
assert!(model.is_some());
assert_eq!(model.unwrap().id, "imagen-4.0-generate-preview-06-06");
}
#[test]
fn test_serialization_roundtrip() {
let params = ImageGenerateParams {
prompt: "A cat".to_string(),
negative_prompt: Some("blurry".to_string()),
model: "imagen-4".to_string(),
aspect_ratio: "16:9".to_string(),
number_of_images: 2,
seed: Some(42),
output_file: Some("/tmp/output.png".to_string()),
output_uri: None,
};
let json = serde_json::to_string(¶ms).unwrap();
let deserialized: ImageGenerateParams = serde_json::from_str(&json).unwrap();
assert_eq!(params.prompt, deserialized.prompt);
assert_eq!(params.negative_prompt, deserialized.negative_prompt);
assert_eq!(params.model, deserialized.model);
assert_eq!(params.aspect_ratio, deserialized.aspect_ratio);
assert_eq!(params.number_of_images, deserialized.number_of_images);
assert_eq!(params.seed, deserialized.seed);
assert_eq!(params.output_file, deserialized.output_file);
}
#[test]
fn test_add_index_suffix_to_gcs_uri_simple() {
let uri = "gs://bucket/output.png";
let result = ImageHandler::add_index_suffix_to_uri(uri, 0, "image", "png");
assert_eq!(result, "gs://bucket/output_0.png");
}
#[test]
fn test_add_index_suffix_to_gcs_uri_with_path() {
let uri = "gs://bucket/path/to/output.png";
let result = ImageHandler::add_index_suffix_to_uri(uri, 1, "image", "png");
assert_eq!(result, "gs://bucket/path/to/output_1.png");
}
#[test]
fn test_add_index_suffix_to_gcs_uri_no_extension() {
let uri = "gs://bucket/output";
let result = ImageHandler::add_index_suffix_to_uri(uri, 2, "image", "png");
assert_eq!(result, "gs://bucket/output_2.png");
}
#[test]
fn test_add_index_suffix_to_local_path() {
let path = "/tmp/output.png";
let result = ImageHandler::add_index_suffix_to_uri(path, 0, "image", "png");
assert_eq!(result, "/tmp/output_0.png");
}
#[test]
fn test_add_index_suffix_to_local_path_no_dir() {
let path = "output.png";
let result = ImageHandler::add_index_suffix_to_uri(path, 1, "image", "png");
assert_eq!(result, "output_1.png");
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
fn valid_number_of_images_strategy() -> impl Strategy<Value = u8> {
MIN_NUMBER_OF_IMAGES..=MAX_NUMBER_OF_IMAGES
}
fn invalid_number_of_images_strategy() -> impl Strategy<Value = u8> {
prop_oneof![
Just(0u8),
(MAX_NUMBER_OF_IMAGES + 1)..=u8::MAX,
]
}
fn valid_aspect_ratio_strategy() -> impl Strategy<Value = &'static str> {
prop_oneof![
Just("1:1"),
Just("3:4"),
Just("4:3"),
Just("9:16"),
Just("16:9"),
]
}
fn invalid_aspect_ratio_strategy() -> impl Strategy<Value = String> {
prop_oneof![
Just("2:1".to_string()),
Just("1:2".to_string()),
Just("5:4".to_string()),
Just("invalid".to_string()),
Just("".to_string()),
Just("16:10".to_string()),
Just("21:9".to_string()),
"[0-9]+:[0-9]+".prop_filter("Must not be a valid ratio", |s| {
!VALID_ASPECT_RATIOS.contains(&s.as_str())
}),
]
}
fn valid_prompt_strategy() -> impl Strategy<Value = String> {
"[a-zA-Z0-9 ]{1,100}".prop_map(|s| s.trim().to_string())
.prop_filter("Must not be empty", |s| !s.trim().is_empty())
}
proptest! {
#[test]
fn valid_number_of_images_passes_validation(
num in valid_number_of_images_strategy(),
prompt in valid_prompt_strategy(),
) {
let params = ImageGenerateParams {
prompt,
negative_prompt: None,
model: DEFAULT_MODEL.to_string(),
aspect_ratio: "1:1".to_string(),
number_of_images: num,
seed: None,
output_file: None,
output_uri: None,
};
let result = params.validate();
prop_assert!(
result.is_ok(),
"number_of_images {} should be valid, but got errors: {:?}",
num,
result.err()
);
}
#[test]
fn invalid_number_of_images_fails_validation(
num in invalid_number_of_images_strategy(),
prompt in valid_prompt_strategy(),
) {
let params = ImageGenerateParams {
prompt,
negative_prompt: None,
model: DEFAULT_MODEL.to_string(),
aspect_ratio: "1:1".to_string(),
number_of_images: num,
seed: None,
output_file: None,
output_uri: None,
};
let result = params.validate();
prop_assert!(
result.is_err(),
"number_of_images {} should be invalid",
num
);
let errors = result.unwrap_err();
prop_assert!(
errors.iter().any(|e| e.field == "number_of_images"),
"Should have a number_of_images validation error for value {}",
num
);
}
#[test]
fn valid_aspect_ratio_passes_validation(
ratio in valid_aspect_ratio_strategy(),
prompt in valid_prompt_strategy(),
) {
let params = ImageGenerateParams {
prompt,
negative_prompt: None,
model: DEFAULT_MODEL.to_string(),
aspect_ratio: ratio.to_string(),
number_of_images: 1,
seed: None,
output_file: None,
output_uri: None,
};
let result = params.validate();
prop_assert!(
result.is_ok(),
"aspect_ratio '{}' should be valid, but got errors: {:?}",
ratio,
result.err()
);
}
#[test]
fn invalid_aspect_ratio_fails_validation(
ratio in invalid_aspect_ratio_strategy(),
prompt in valid_prompt_strategy(),
) {
let params = ImageGenerateParams {
prompt,
negative_prompt: None,
model: DEFAULT_MODEL.to_string(),
aspect_ratio: ratio.clone(),
number_of_images: 1,
seed: None,
output_file: None,
output_uri: None,
};
let result = params.validate();
prop_assert!(
result.is_err(),
"aspect_ratio '{}' should be invalid",
ratio
);
let errors = result.unwrap_err();
prop_assert!(
errors.iter().any(|e| e.field == "aspect_ratio"),
"Should have an aspect_ratio validation error for value '{}'",
ratio
);
let aspect_error = errors.iter().find(|e| e.field == "aspect_ratio").unwrap();
prop_assert!(
aspect_error.message.contains("Valid options"),
"Error message should list valid options: {}",
aspect_error.message
);
}
#[test]
fn valid_params_combination_passes(
num in valid_number_of_images_strategy(),
ratio in valid_aspect_ratio_strategy(),
prompt in valid_prompt_strategy(),
) {
let params = ImageGenerateParams {
prompt,
negative_prompt: None,
model: DEFAULT_MODEL.to_string(),
aspect_ratio: ratio.to_string(),
number_of_images: num,
seed: None,
output_file: None,
output_uri: None,
};
let result = params.validate();
prop_assert!(
result.is_ok(),
"Valid params (num={}, ratio='{}') should pass, but got: {:?}",
num,
ratio,
result.err()
);
}
}
}
#[cfg(test)]
mod api_tests {
use super::*;
#[test]
fn test_imagen_request_serialization() {
let request = ImagenRequest {
instances: vec![ImagenInstance {
prompt: "A beautiful sunset".to_string(),
negative_prompt: Some("blurry".to_string()),
}],
parameters: ImagenParameters {
sample_count: 2,
aspect_ratio: "16:9".to_string(),
seed: Some(42),
},
};
let json = serde_json::to_value(&request).unwrap();
assert!(json["instances"].is_array());
assert_eq!(json["instances"][0]["prompt"], "A beautiful sunset");
assert_eq!(json["instances"][0]["negativePrompt"], "blurry");
assert_eq!(json["parameters"]["sampleCount"], 2);
assert_eq!(json["parameters"]["aspectRatio"], "16:9");
assert_eq!(json["parameters"]["seed"], 42);
}
#[test]
fn test_imagen_request_serialization_minimal() {
let request = ImagenRequest {
instances: vec![ImagenInstance {
prompt: "A cat".to_string(),
negative_prompt: None,
}],
parameters: ImagenParameters {
sample_count: 1,
aspect_ratio: "1:1".to_string(),
seed: None,
},
};
let json = serde_json::to_value(&request).unwrap();
assert!(json["instances"][0].get("negativePrompt").is_none());
assert!(json["parameters"].get("seed").is_none());
}
#[test]
fn test_imagen_response_deserialization() {
let json = r#"{
"predictions": [
{
"bytesBase64Encoded": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
"mimeType": "image/png"
}
]
}"#;
let response: ImagenResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.predictions.len(), 1);
assert!(response.predictions[0].bytes_base64_encoded.is_some());
assert_eq!(response.predictions[0].mime_type, Some("image/png".to_string()));
}
#[test]
fn test_imagen_response_multiple_predictions() {
let json = r#"{
"predictions": [
{
"bytesBase64Encoded": "base64data1",
"mimeType": "image/png"
},
{
"bytesBase64Encoded": "base64data2",
"mimeType": "image/png"
}
]
}"#;
let response: ImagenResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.predictions.len(), 2);
assert_eq!(response.predictions[0].bytes_base64_encoded, Some("base64data1".to_string()));
assert_eq!(response.predictions[1].bytes_base64_encoded, Some("base64data2".to_string()));
}
#[test]
fn test_imagen_response_empty_predictions() {
let json = r#"{"predictions": []}"#;
let response: ImagenResponse = serde_json::from_str(json).unwrap();
assert!(response.predictions.is_empty());
}
#[test]
fn test_imagen_response_no_image_data() {
let json = r#"{
"predictions": [
{
"mimeType": "image/png"
}
]
}"#;
let response: ImagenResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.predictions.len(), 1);
assert!(response.predictions[0].bytes_base64_encoded.is_none());
}
#[test]
fn test_get_endpoint() {
let config = Config {
project_id: "my-project".to_string(),
location: "us-central1".to_string(),
gcs_bucket: None,
port: 8080,
..Default::default()
};
let expected_url = format!(
"https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
config.location,
config.project_id,
config.location,
"imagen-4.0-generate-preview-05-20"
);
assert!(expected_url.contains("us-central1-aiplatform.googleapis.com"));
assert!(expected_url.contains("my-project"));
assert!(expected_url.contains("imagen-4.0-generate-preview-05-20"));
assert!(expected_url.ends_with(":predict"));
}
#[test]
fn test_generated_image() {
let image = GeneratedImage {
data: "base64encodeddata".to_string(),
mime_type: "image/png".to_string(),
};
assert_eq!(image.data, "base64encodeddata");
assert_eq!(image.mime_type, "image/png");
}
#[test]
fn test_image_generate_result_base64() {
let images = vec![
GeneratedImage {
data: "data1".to_string(),
mime_type: "image/png".to_string(),
},
GeneratedImage {
data: "data2".to_string(),
mime_type: "image/jpeg".to_string(),
},
];
let result = ImageGenerateResult::Base64(images);
match result {
ImageGenerateResult::Base64(imgs) => {
assert_eq!(imgs.len(), 2);
assert_eq!(imgs[0].data, "data1");
assert_eq!(imgs[1].mime_type, "image/jpeg");
}
_ => panic!("Expected Base64 variant"),
}
}
#[test]
fn test_image_generate_result_local_files() {
let paths = vec!["/tmp/image1.png".to_string(), "/tmp/image2.png".to_string()];
let result = ImageGenerateResult::LocalFiles(paths);
match result {
ImageGenerateResult::LocalFiles(p) => {
assert_eq!(p.len(), 2);
assert!(p[0].contains("image1"));
}
_ => panic!("Expected LocalFiles variant"),
}
}
#[test]
fn test_image_generate_result_storage_uris() {
let uris = vec![
"gs://bucket/image1.png".to_string(),
"gs://bucket/image2.png".to_string(),
];
let result = ImageGenerateResult::StorageUris(uris);
match result {
ImageGenerateResult::StorageUris(u) => {
assert_eq!(u.len(), 2);
assert!(u[0].starts_with("gs://"));
}
_ => panic!("Expected StorageUris variant"),
}
}
#[test]
fn test_validation_error_display() {
let error = ValidationError {
field: "prompt".to_string(),
message: "cannot be empty".to_string(),
};
let display = format!("{}", error);
assert_eq!(display, "prompt: cannot be empty");
}
#[test]
fn test_validation_multiple_errors() {
let params = ImageGenerateParams {
prompt: " ".to_string(), negative_prompt: None,
model: "unknown-model".to_string(), aspect_ratio: "invalid".to_string(), number_of_images: 10, seed: None,
output_file: None,
output_uri: None,
};
let result = params.validate();
assert!(result.is_err());
let errors = result.unwrap_err();
assert!(errors.len() >= 3, "Expected at least 3 validation errors, got {}", errors.len());
let fields: Vec<&str> = errors.iter().map(|e| e.field.as_str()).collect();
assert!(fields.contains(&"prompt"));
assert!(fields.contains(&"model"));
assert!(fields.contains(&"number_of_images"));
}
}