llm_toolkit/
multimodal.rs1use crate::prompt::{PromptPart, ToPrompt};
4use base64::{Engine, engine::general_purpose::STANDARD};
5use serde::{Deserialize, Serialize};
6use std::path::Path;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ImageData {
14    pub media_type: String,
16    pub data: Vec<u8>,
18}
19
20impl ImageData {
21    pub fn new(media_type: impl Into<String>, data: Vec<u8>) -> Self {
23        Self {
24            media_type: media_type.into(),
25            data,
26        }
27    }
28
29    pub fn from_file(path: impl AsRef<Path>) -> std::io::Result<Self> {
38        let path = path.as_ref();
39        let data = std::fs::read(path)?;
40
41        let media_type = match path.extension().and_then(|ext| ext.to_str()) {
42            Some("jpg") | Some("jpeg") => "image/jpeg",
43            Some("png") => "image/png",
44            Some("gif") => "image/gif",
45            Some("webp") => "image/webp",
46            Some("bmp") => "image/bmp",
47            Some("svg") => "image/svg+xml",
48            _ => "application/octet-stream",
49        }
50        .to_string();
51
52        Ok(Self { media_type, data })
53    }
54
55    pub fn from_base64(
66        base64_str: &str,
67        media_type: impl Into<String>,
68    ) -> Result<Self, base64::DecodeError> {
69        let data = STANDARD.decode(base64_str)?;
70        Ok(Self {
71            media_type: media_type.into(),
72            data,
73        })
74    }
75
76    pub fn to_base64(&self) -> String {
78        STANDARD.encode(&self.data)
79    }
80}
81
82impl ToPrompt for ImageData {
83    fn to_prompt_parts(&self) -> Vec<PromptPart> {
84        vec![PromptPart::Image {
85            media_type: self.media_type.clone(),
86            data: self.data.clone(),
87        }]
88    }
89}
90
91impl TryFrom<&str> for ImageData {
116    type Error = String;
117
118    fn try_from(data_url: &str) -> Result<Self, Self::Error> {
119        if !data_url.starts_with("data:") {
120            return Err("Not a data URL".to_string());
121        }
122
123        let content = data_url.strip_prefix("data:").ok_or("Invalid data URL")?;
124
125        let parts: Vec<&str> = content.splitn(2, ',').collect();
126        if parts.len() != 2 {
127            return Err("Invalid data URL format".to_string());
128        }
129
130        let media_type = parts[0]
131            .split(';')
132            .next()
133            .unwrap_or("application/octet-stream")
134            .to_string();
135
136        let is_base64 = parts[0].contains("base64");
137
138        let data = if is_base64 {
139            STANDARD
140                .decode(parts[1])
141                .map_err(|e| format!("Failed to decode base64: {}", e))?
142        } else {
143            parts[1].as_bytes().to_vec()
145        };
146
147        Ok(Self { media_type, data })
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn test_image_data_creation() {
157        let data = vec![0xFF, 0xD8, 0xFF, 0xE0]; let img = ImageData::new("image/jpeg", data.clone());
159
160        assert_eq!(img.media_type, "image/jpeg");
161        assert_eq!(img.data, data);
162    }
163
164    #[test]
165    fn test_image_data_to_prompt_parts() {
166        let data = vec![1, 2, 3, 4];
167        let img = ImageData::new("image/png", data.clone());
168        let parts = img.to_prompt_parts();
169
170        assert_eq!(parts.len(), 1);
171        match &parts[0] {
172            PromptPart::Image {
173                media_type,
174                data: img_data,
175            } => {
176                assert_eq!(media_type, "image/png");
177                assert_eq!(img_data, &data);
178            }
179            _ => panic!("Expected Image variant"),
180        }
181    }
182
183    #[test]
184    fn test_base64_conversion() {
185        let original_data = vec![72, 101, 108, 108, 111]; let img = ImageData::new("image/test", original_data.clone());
187
188        let base64 = img.to_base64();
189        let decoded = ImageData::from_base64(&base64, "image/test").unwrap();
190
191        assert_eq!(decoded.data, original_data);
192        assert_eq!(decoded.media_type, "image/test");
193    }
194
195    #[test]
196    fn test_data_url_parsing() {
197        let data_url = "data:image/png;base64,SGVsbG8="; let img = ImageData::try_from(data_url).unwrap();
199
200        assert_eq!(img.media_type, "image/png");
201        assert_eq!(img.data, b"Hello");
202    }
203}