use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Content {
Text(String),
Parts(Vec<ContentPart>),
}
impl Default for Content {
fn default() -> Self {
Content::Text(String::new())
}
}
impl From<String> for Content {
fn from(s: String) -> Self {
Content::Text(s)
}
}
impl From<&str> for Content {
fn from(s: &str) -> Self {
Content::Text(s.to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
impl Content {
pub fn text(s: impl Into<String>) -> Self {
Content::Text(s.into())
}
pub fn multipart(text: String, images: Vec<ImageData>) -> Self {
let mut parts = vec![ContentPart::Text { text }];
for img in images {
parts.push(ContentPart::ImageUrl {
image_url: ImageUrl {
url: img.to_data_url(),
detail: Some("auto".to_string()),
},
});
}
Content::Parts(parts)
}
pub fn has_images(&self) -> bool {
match self {
Content::Text(_) => false,
Content::Parts(parts) => parts
.iter()
.any(|p| matches!(p, ContentPart::ImageUrl { .. })),
}
}
pub fn text_content(&self) -> String {
match self {
Content::Text(s) => s.clone(),
Content::Parts(parts) => {
let mut result = String::new();
let mut first = true;
for part in parts {
if let ContentPart::Text { text } = part {
if !first {
result.push('\n');
}
result.push_str(text);
first = false;
}
}
result
}
}
}
}
#[derive(Debug, Clone)]
pub struct ImageData {
pub bytes: Vec<u8>,
pub mime_type: String,
}
impl ImageData {
pub fn new(bytes: Vec<u8>) -> Self {
let mime_type = detect_mime_type(&bytes);
Self { bytes, mime_type }
}
pub fn to_data_url(&self) -> String {
use base64::Engine;
let encoded = base64::engine::general_purpose::STANDARD.encode(&self.bytes);
format!("data:{};base64,{}", self.mime_type, encoded)
}
}
fn detect_mime_type(bytes: &[u8]) -> String {
if bytes.len() < 4 {
return "application/octet-stream".to_string();
}
if bytes[0..4] == [0x89, 0x50, 0x4E, 0x47] {
return "image/png".to_string();
}
if bytes[0..3] == [0xFF, 0xD8, 0xFF] {
return "image/jpeg".to_string();
}
if bytes[0..4] == [0x47, 0x49, 0x46, 0x38] {
return "image/gif".to_string();
}
if bytes.len() >= 12 && &bytes[0..4] == b"RIFF" && &bytes[8..12] == b"WEBP" {
return "image/webp".to_string();
}
"application/octet-stream".to_string() }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_mime_type_png() {
let png_bytes = vec![0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
assert_eq!(detect_mime_type(&png_bytes), "image/png");
}
#[test]
fn test_detect_mime_type_jpeg() {
let jpeg_bytes = vec![0xFF, 0xD8, 0xFF, 0xE0];
assert_eq!(detect_mime_type(&jpeg_bytes), "image/jpeg");
}
#[test]
fn test_image_data_url() {
let img = ImageData::new(vec![0x89, 0x50, 0x4E, 0x47]);
let url = img.to_data_url();
assert!(url.starts_with("data:image/png;base64,"));
}
#[test]
fn test_content_serialization_text() {
let content = Content::text("Hello");
let json = serde_json::to_string(&content).unwrap();
assert_eq!(json, r#""Hello""#);
}
#[test]
fn test_content_serialization_multipart() {
let content = Content::multipart(
"What's this?".to_string(),
vec![ImageData::new(vec![0x89, 0x50, 0x4E, 0x47])],
);
let json = serde_json::to_string(&content).unwrap();
assert!(json.contains(r#""type":"text""#));
assert!(json.contains(r#""type":"image_url""#));
}
}