use crate::prompt::{PromptPart, ToPrompt};
use base64::{Engine, engine::general_purpose::STANDARD};
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageData {
pub media_type: String,
pub data: Vec<u8>,
}
impl ImageData {
pub fn new(media_type: impl Into<String>, data: Vec<u8>) -> Self {
Self {
media_type: media_type.into(),
data,
}
}
pub fn from_file(path: impl AsRef<Path>) -> std::io::Result<Self> {
let path = path.as_ref();
let data = std::fs::read(path)?;
let media_type = match path.extension().and_then(|ext| ext.to_str()) {
Some("jpg") | Some("jpeg") => "image/jpeg",
Some("png") => "image/png",
Some("gif") => "image/gif",
Some("webp") => "image/webp",
Some("bmp") => "image/bmp",
Some("svg") => "image/svg+xml",
_ => "application/octet-stream",
}
.to_string();
Ok(Self { media_type, data })
}
pub fn from_base64(
base64_str: &str,
media_type: impl Into<String>,
) -> Result<Self, base64::DecodeError> {
let data = STANDARD.decode(base64_str)?;
Ok(Self {
media_type: media_type.into(),
data,
})
}
pub fn to_base64(&self) -> String {
STANDARD.encode(&self.data)
}
}
impl ToPrompt for ImageData {
fn to_prompt_parts(&self) -> Vec<PromptPart> {
vec![PromptPart::Image {
media_type: self.media_type.clone(),
data: self.data.clone(),
}]
}
}
impl TryFrom<&str> for ImageData {
type Error = String;
fn try_from(data_url: &str) -> Result<Self, Self::Error> {
if !data_url.starts_with("data:") {
return Err("Not a data URL".to_string());
}
let content = data_url.strip_prefix("data:").ok_or("Invalid data URL")?;
let parts: Vec<&str> = content.splitn(2, ',').collect();
if parts.len() != 2 {
return Err("Invalid data URL format".to_string());
}
let media_type = parts[0]
.split(';')
.next()
.unwrap_or("application/octet-stream")
.to_string();
let is_base64 = parts[0].contains("base64");
let data = if is_base64 {
STANDARD
.decode(parts[1])
.map_err(|e| format!("Failed to decode base64: {}", e))?
} else {
parts[1].as_bytes().to_vec()
};
Ok(Self { media_type, data })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_image_data_creation() {
let data = vec![0xFF, 0xD8, 0xFF, 0xE0]; let img = ImageData::new("image/jpeg", data.clone());
assert_eq!(img.media_type, "image/jpeg");
assert_eq!(img.data, data);
}
#[test]
fn test_image_data_to_prompt_parts() {
let data = vec![1, 2, 3, 4];
let img = ImageData::new("image/png", data.clone());
let parts = img.to_prompt_parts();
assert_eq!(parts.len(), 1);
match &parts[0] {
PromptPart::Image {
media_type,
data: img_data,
} => {
assert_eq!(media_type, "image/png");
assert_eq!(img_data, &data);
}
_ => panic!("Expected Image variant"),
}
}
#[test]
fn test_base64_conversion() {
let original_data = vec![72, 101, 108, 108, 111]; let img = ImageData::new("image/test", original_data.clone());
let base64 = img.to_base64();
let decoded = ImageData::from_base64(&base64, "image/test").unwrap();
assert_eq!(decoded.data, original_data);
assert_eq!(decoded.media_type, "image/test");
}
#[test]
fn test_data_url_parsing() {
let data_url = "data:image/png;base64,SGVsbG8="; let img = ImageData::try_from(data_url).unwrap();
assert_eq!(img.media_type, "image/png");
assert_eq!(img.data, b"Hello");
}
}