use serde::{Deserialize, Serialize};
use std::path::PathBuf;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum ImageDetail {
#[default]
Auto,
Low,
High,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum VisionContent {
ImageUrl {
url: String,
#[serde(default)]
detail: ImageDetail,
},
ImageBase64 {
data: String,
media_type: String,
#[serde(default)]
detail: ImageDetail,
},
ImageFile {
path: PathBuf,
#[serde(default)]
detail: ImageDetail,
},
}
impl VisionContent {
pub fn validate_format(&self) -> Result<(), VisionError> {
match self {
VisionContent::ImageUrl { url, .. } => {
let url_lower = url.to_lowercase();
if url_lower.ends_with(".png")
|| url_lower.ends_with(".jpg")
|| url_lower.ends_with(".jpeg")
|| url_lower.ends_with(".gif")
|| url_lower.ends_with(".webp")
{
Ok(())
} else {
Err(VisionError::UnsupportedFormat(
"URL must end with .png, .jpg, .jpeg, .gif, or .webp".to_string(),
))
}
}
VisionContent::ImageBase64 { media_type, .. } => {
if media_type == "image/png"
|| media_type == "image/jpeg"
|| media_type == "image/gif"
|| media_type == "image/webp"
{
Ok(())
} else {
Err(VisionError::UnsupportedFormat(format!(
"Unsupported media type: {}. Supported: image/png, image/jpeg, image/gif, image/webp",
media_type
)))
}
}
VisionContent::ImageFile { path, .. } => {
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
let ext_lower = ext.to_lowercase();
if ext_lower == "png"
|| ext_lower == "jpg"
|| ext_lower == "jpeg"
|| ext_lower == "gif"
|| ext_lower == "webp"
{
Ok(())
} else {
Err(VisionError::UnsupportedFormat(format!(
"Unsupported file extension: {}. Supported: png, jpg, jpeg, gif, webp",
ext
)))
}
} else {
Err(VisionError::UnsupportedFormat(
"File has no extension".to_string(),
))
}
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct VisionRequest {
pub text: String,
pub images: Vec<VisionContent>,
}
impl VisionRequest {
pub fn new(text: String, images: Vec<VisionContent>) -> Result<Self, VisionError> {
if text.is_empty() {
return Err(VisionError::InvalidRequest(
"Text prompt cannot be empty".to_string(),
));
}
if images.is_empty() {
return Err(VisionError::InvalidRequest(
"At least one image is required".to_string(),
));
}
for image in &images {
image.validate_format()?;
}
Ok(Self { text, images })
}
pub fn validate(&self) -> Result<(), VisionError> {
if self.text.is_empty() {
return Err(VisionError::InvalidRequest(
"Text prompt cannot be empty".to_string(),
));
}
if self.images.is_empty() {
return Err(VisionError::InvalidRequest(
"At least one image is required".to_string(),
));
}
for image in &self.images {
image.validate_format()?;
}
Ok(())
}
}
#[derive(Debug, thiserror::Error)]
pub enum VisionError {
#[error("Unsupported image format: {0}")]
UnsupportedFormat(String),
#[error("Image file too large: {size} bytes (max: {max})")]
FileTooLarge { size: usize, max: usize },
#[error("Invalid image data: {0}")]
InvalidImage(String),
#[error("Model does not support vision: {0}")]
ModelNotSupported(String),
#[error("Network error: {0}")]
NetworkError(String),
#[error("Encryption error: {0}")]
EncryptionError(String),
#[error("Invalid vision request: {0}")]
InvalidRequest(String),
#[error("Authentication error: {0}")]
AuthenticationError(String),
#[error("Rate limit exceeded: {0}")]
RateLimitExceeded(String),
#[error("Provider error: {0}")]
ProviderError(String),
#[error("Request timeout after {0} seconds")]
Timeout(u64),
#[error("Unsupported vision provider: {0}")]
UnsupportedProvider(String),
#[error("Maximum retry attempts exceeded: {0} attempts")]
MaxRetriesExceeded(u32),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_image_detail_enum() {
let auto = ImageDetail::Auto;
let low = ImageDetail::Low;
let high = ImageDetail::High;
assert_eq!(auto, ImageDetail::Auto);
assert_eq!(low, ImageDetail::Low);
assert_eq!(high, ImageDetail::High);
assert_eq!(ImageDetail::default(), ImageDetail::Auto);
let json = serde_json::to_string(&auto).unwrap();
assert_eq!(json, "\"auto\"");
let deserialized: ImageDetail = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, auto);
}
#[test]
fn test_vision_content_validation() {
let valid_png = VisionContent::ImageUrl {
url: "https://example.com/image.png".to_string(),
detail: ImageDetail::Auto,
};
assert!(valid_png.validate_format().is_ok());
let valid_jpg = VisionContent::ImageUrl {
url: "https://example.com/photo.jpg".to_string(),
detail: ImageDetail::Low,
};
assert!(valid_jpg.validate_format().is_ok());
let valid_gif = VisionContent::ImageUrl {
url: "https://example.com/animation.gif".to_string(),
detail: ImageDetail::High,
};
assert!(valid_gif.validate_format().is_ok());
let valid_webp = VisionContent::ImageUrl {
url: "https://example.com/image.webp".to_string(),
detail: ImageDetail::Auto,
};
assert!(valid_webp.validate_format().is_ok());
let invalid_url = VisionContent::ImageUrl {
url: "https://example.com/document.pdf".to_string(),
detail: ImageDetail::Auto,
};
assert!(invalid_url.validate_format().is_err());
}
#[test]
fn test_vision_content_base64_validation() {
let valid_base64 = VisionContent::ImageBase64 {
data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==".to_string(),
media_type: "image/png".to_string(),
detail: ImageDetail::Auto,
};
assert!(valid_base64.validate_format().is_ok());
let invalid_media = VisionContent::ImageBase64 {
data: "base64data".to_string(),
media_type: "image/bmp".to_string(),
detail: ImageDetail::Auto,
};
assert!(invalid_media.validate_format().is_err());
}
#[test]
fn test_vision_content_file_validation() {
let valid_png = VisionContent::ImageFile {
path: PathBuf::from("/path/to/image.png"),
detail: ImageDetail::Auto,
};
assert!(valid_png.validate_format().is_ok());
let valid_jpg = VisionContent::ImageFile {
path: PathBuf::from("/path/to/photo.jpeg"),
detail: ImageDetail::Auto,
};
assert!(valid_jpg.validate_format().is_ok());
let invalid_file = VisionContent::ImageFile {
path: PathBuf::from("/path/to/document.txt"),
detail: ImageDetail::Auto,
};
assert!(invalid_file.validate_format().is_err());
let no_ext = VisionContent::ImageFile {
path: PathBuf::from("/path/to/file"),
detail: ImageDetail::Auto,
};
assert!(no_ext.validate_format().is_err());
}
#[test]
fn test_vision_request_creation() {
let images = vec![VisionContent::ImageUrl {
url: "https://example.com/image.png".to_string(),
detail: ImageDetail::Auto,
}];
let request = VisionRequest::new("Describe this image".to_string(), images);
assert!(request.is_ok());
let request = request.unwrap();
assert_eq!(request.text, "Describe this image");
assert_eq!(request.images.len(), 1);
}
#[test]
fn test_vision_request_empty_text() {
let images = vec![VisionContent::ImageUrl {
url: "https://example.com/image.png".to_string(),
detail: ImageDetail::Auto,
}];
let request = VisionRequest::new("".to_string(), images);
assert!(request.is_err());
match request {
Err(VisionError::InvalidRequest(msg)) => {
assert!(msg.contains("Text prompt cannot be empty"));
}
_ => panic!("Expected InvalidRequest error"),
}
}
#[test]
fn test_vision_request_no_images() {
let request = VisionRequest::new("Describe this".to_string(), vec![]);
assert!(request.is_err());
match request {
Err(VisionError::InvalidRequest(msg)) => {
assert!(msg.contains("At least one image is required"));
}
_ => panic!("Expected InvalidRequest error"),
}
}
#[test]
fn test_vision_request_multiple_images() {
let images = vec![
VisionContent::ImageUrl {
url: "https://example.com/image1.png".to_string(),
detail: ImageDetail::Auto,
},
VisionContent::ImageUrl {
url: "https://example.com/image2.jpg".to_string(),
detail: ImageDetail::Low,
},
VisionContent::ImageFile {
path: PathBuf::from("/local/image.gif"),
detail: ImageDetail::High,
},
];
let request = VisionRequest::new("Compare these images".to_string(), images);
assert!(request.is_ok());
let request = request.unwrap();
assert_eq!(request.images.len(), 3);
}
#[test]
fn test_vision_request_invalid_format() {
let images = vec![VisionContent::ImageUrl {
url: "https://example.com/document.pdf".to_string(),
detail: ImageDetail::Auto,
}];
let request = VisionRequest::new("Describe this".to_string(), images);
assert!(request.is_err());
match request {
Err(VisionError::UnsupportedFormat(_)) => {
}
_ => panic!("Expected UnsupportedFormat error"),
}
}
#[test]
fn test_vision_error_variants() {
let auth_err = VisionError::AuthenticationError("Invalid API key".to_string());
assert!(auth_err.to_string().contains("Authentication error"));
assert!(auth_err.to_string().contains("Invalid API key"));
let rate_err = VisionError::RateLimitExceeded("Too many requests".to_string());
assert!(rate_err.to_string().contains("Rate limit exceeded"));
assert!(rate_err.to_string().contains("Too many requests"));
let provider_err = VisionError::ProviderError("Internal server error".to_string());
assert!(provider_err.to_string().contains("Provider error"));
assert!(provider_err.to_string().contains("Internal server error"));
let timeout_err = VisionError::Timeout(30);
assert!(timeout_err.to_string().contains("timeout"));
assert!(timeout_err.to_string().contains("30"));
let unsupported_err = VisionError::UnsupportedProvider("unknown-provider".to_string());
assert!(
unsupported_err
.to_string()
.contains("Unsupported vision provider")
);
assert!(unsupported_err.to_string().contains("unknown-provider"));
let max_retries_err = VisionError::MaxRetriesExceeded(3);
assert!(
max_retries_err
.to_string()
.contains("Maximum retry attempts exceeded")
);
assert!(max_retries_err.to_string().contains("3"));
}
#[test]
fn test_vision_error_existing_variants() {
let invalid_img = VisionError::InvalidImage("Corrupted data".to_string());
assert!(invalid_img.to_string().contains("Invalid image data"));
let unsupported_fmt = VisionError::UnsupportedFormat("BMP not supported".to_string());
assert!(
unsupported_fmt
.to_string()
.contains("Unsupported image format")
);
let network_err = VisionError::NetworkError("Connection failed".to_string());
assert!(network_err.to_string().contains("Network error"));
let large_file = VisionError::FileTooLarge {
size: 10_000_000,
max: 5_000_000,
};
assert!(large_file.to_string().contains("too large"));
assert!(large_file.to_string().contains("10000000"));
assert!(large_file.to_string().contains("5000000"));
}
}