llm_toolkit/
multimodal.rs1use crate::prompt::{PromptPart, ToPrompt};
4use base64::{Engine, engine::general_purpose::STANDARD};
5use serde::{Deserialize, Serialize};
6use std::path::Path;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ImageData {
14 pub media_type: String,
16 pub data: Vec<u8>,
18}
19
20impl ImageData {
21 pub fn new(media_type: impl Into<String>, data: Vec<u8>) -> Self {
23 Self {
24 media_type: media_type.into(),
25 data,
26 }
27 }
28
29 pub fn from_file(path: impl AsRef<Path>) -> std::io::Result<Self> {
38 let path = path.as_ref();
39 let data = std::fs::read(path)?;
40
41 let media_type = match path.extension().and_then(|ext| ext.to_str()) {
42 Some("jpg") | Some("jpeg") => "image/jpeg",
43 Some("png") => "image/png",
44 Some("gif") => "image/gif",
45 Some("webp") => "image/webp",
46 Some("bmp") => "image/bmp",
47 Some("svg") => "image/svg+xml",
48 _ => "application/octet-stream",
49 }
50 .to_string();
51
52 Ok(Self { media_type, data })
53 }
54
55 pub fn from_base64(
66 base64_str: &str,
67 media_type: impl Into<String>,
68 ) -> Result<Self, base64::DecodeError> {
69 let data = STANDARD.decode(base64_str)?;
70 Ok(Self {
71 media_type: media_type.into(),
72 data,
73 })
74 }
75
76 pub fn to_base64(&self) -> String {
78 STANDARD.encode(&self.data)
79 }
80}
81
82impl ToPrompt for ImageData {
83 fn to_prompt_parts(&self) -> Vec<PromptPart> {
84 vec![PromptPart::Image {
85 media_type: self.media_type.clone(),
86 data: self.data.clone(),
87 }]
88 }
89}
90
91impl TryFrom<&str> for ImageData {
116 type Error = String;
117
118 fn try_from(data_url: &str) -> Result<Self, Self::Error> {
119 if !data_url.starts_with("data:") {
120 return Err("Not a data URL".to_string());
121 }
122
123 let content = data_url.strip_prefix("data:").ok_or("Invalid data URL")?;
124
125 let parts: Vec<&str> = content.splitn(2, ',').collect();
126 if parts.len() != 2 {
127 return Err("Invalid data URL format".to_string());
128 }
129
130 let media_type = parts[0]
131 .split(';')
132 .next()
133 .unwrap_or("application/octet-stream")
134 .to_string();
135
136 let is_base64 = parts[0].contains("base64");
137
138 let data = if is_base64 {
139 STANDARD
140 .decode(parts[1])
141 .map_err(|e| format!("Failed to decode base64: {}", e))?
142 } else {
143 parts[1].as_bytes().to_vec()
145 };
146
147 Ok(Self { media_type, data })
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn test_image_data_creation() {
157 let data = vec![0xFF, 0xD8, 0xFF, 0xE0]; let img = ImageData::new("image/jpeg", data.clone());
159
160 assert_eq!(img.media_type, "image/jpeg");
161 assert_eq!(img.data, data);
162 }
163
164 #[test]
165 fn test_image_data_to_prompt_parts() {
166 let data = vec![1, 2, 3, 4];
167 let img = ImageData::new("image/png", data.clone());
168 let parts = img.to_prompt_parts();
169
170 assert_eq!(parts.len(), 1);
171 match &parts[0] {
172 PromptPart::Image {
173 media_type,
174 data: img_data,
175 } => {
176 assert_eq!(media_type, "image/png");
177 assert_eq!(img_data, &data);
178 }
179 _ => panic!("Expected Image variant"),
180 }
181 }
182
183 #[test]
184 fn test_base64_conversion() {
185 let original_data = vec![72, 101, 108, 108, 111]; let img = ImageData::new("image/test", original_data.clone());
187
188 let base64 = img.to_base64();
189 let decoded = ImageData::from_base64(&base64, "image/test").unwrap();
190
191 assert_eq!(decoded.data, original_data);
192 assert_eq!(decoded.media_type, "image/test");
193 }
194
195 #[test]
196 fn test_data_url_parsing() {
197 let data_url = ""; let img = ImageData::try_from(data_url).unwrap();
199
200 assert_eq!(img.media_type, "image/png");
201 assert_eq!(img.data, b"Hello");
202 }
203}