llm_toolkit/
multimodal.rs

1//! Multimodal support for prompts, including image data handling.
2
3use crate::prompt::{PromptPart, ToPrompt};
4use base64::{Engine, engine::general_purpose::STANDARD};
5use serde::{Deserialize, Serialize};
6use std::path::Path;
7
8/// Helper structure for handling image data in prompts.
9///
10/// This struct provides a convenient way to represent images
11/// that can be included in multimodal prompts.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ImageData {
14    /// The MIME media type (e.g., "image/jpeg", "image/png").
15    pub media_type: String,
16    /// The raw image data.
17    pub data: Vec<u8>,
18}
19
20impl ImageData {
21    /// Creates a new `ImageData` instance with the given media type and data.
22    pub fn new(media_type: impl Into<String>, data: Vec<u8>) -> Self {
23        Self {
24            media_type: media_type.into(),
25            data,
26        }
27    }
28
29    /// Creates an `ImageData` instance from a file path.
30    ///
31    /// The media type is inferred from the file extension.
32    ///
33    /// # Errors
34    ///
35    /// Returns an error if the file cannot be read or if the media type
36    /// cannot be determined from the file extension.
37    pub fn from_file(path: impl AsRef<Path>) -> std::io::Result<Self> {
38        let path = path.as_ref();
39        let data = std::fs::read(path)?;
40
41        let media_type = match path.extension().and_then(|ext| ext.to_str()) {
42            Some("jpg") | Some("jpeg") => "image/jpeg",
43            Some("png") => "image/png",
44            Some("gif") => "image/gif",
45            Some("webp") => "image/webp",
46            Some("bmp") => "image/bmp",
47            Some("svg") => "image/svg+xml",
48            _ => "application/octet-stream",
49        }
50        .to_string();
51
52        Ok(Self { media_type, data })
53    }
54
55    /// Creates an `ImageData` instance from a base64-encoded string.
56    ///
57    /// # Arguments
58    ///
59    /// * `base64_str` - The base64-encoded image data
60    /// * `media_type` - The MIME media type of the image
61    ///
62    /// # Errors
63    ///
64    /// Returns an error if the base64 string cannot be decoded.
65    pub fn from_base64(
66        base64_str: &str,
67        media_type: impl Into<String>,
68    ) -> Result<Self, base64::DecodeError> {
69        let data = STANDARD.decode(base64_str)?;
70        Ok(Self {
71            media_type: media_type.into(),
72            data,
73        })
74    }
75
76    /// Converts the image data to a base64-encoded string.
77    pub fn to_base64(&self) -> String {
78        STANDARD.encode(&self.data)
79    }
80}
81
82impl ToPrompt for ImageData {
83    fn to_prompt_parts(&self) -> Vec<PromptPart> {
84        vec![PromptPart::Image {
85            media_type: self.media_type.clone(),
86            data: self.data.clone(),
87        }]
88    }
89}
90
91// Optional: From implementations for common image library types
92// These would be behind feature flags in a real implementation
93// Commented out for now as the `image` feature is not defined
94
95// #[cfg(feature = "image")]
96// impl From<image::DynamicImage> for ImageData {
97//     fn from(img: image::DynamicImage) -> Self {
98//         use std::io::Cursor;
99//
100//         let mut buffer = Vec::new();
101//         let mut cursor = Cursor::new(&mut buffer);
102//
103//         // Default to PNG format
104//         img.write_to(&mut cursor, image::ImageFormat::Png)
105//             .expect("Failed to encode image");
106//
107//         Self {
108//             media_type: "image/png".to_string(),
109//             data: buffer,
110//         }
111//     }
112// }
113
114// From implementation for data URL strings (e.g., "data:image/png;base64,...")
115impl TryFrom<&str> for ImageData {
116    type Error = String;
117
118    fn try_from(data_url: &str) -> Result<Self, Self::Error> {
119        if !data_url.starts_with("data:") {
120            return Err("Not a data URL".to_string());
121        }
122
123        let content = data_url.strip_prefix("data:").ok_or("Invalid data URL")?;
124
125        let parts: Vec<&str> = content.splitn(2, ',').collect();
126        if parts.len() != 2 {
127            return Err("Invalid data URL format".to_string());
128        }
129
130        let media_type = parts[0]
131            .split(';')
132            .next()
133            .unwrap_or("application/octet-stream")
134            .to_string();
135
136        let is_base64 = parts[0].contains("base64");
137
138        let data = if is_base64 {
139            STANDARD
140                .decode(parts[1])
141                .map_err(|e| format!("Failed to decode base64: {}", e))?
142        } else {
143            // URL-encoded data
144            parts[1].as_bytes().to_vec()
145        };
146
147        Ok(Self { media_type, data })
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn test_image_data_creation() {
157        let data = vec![0xFF, 0xD8, 0xFF, 0xE0]; // JPEG magic bytes
158        let img = ImageData::new("image/jpeg", data.clone());
159
160        assert_eq!(img.media_type, "image/jpeg");
161        assert_eq!(img.data, data);
162    }
163
164    #[test]
165    fn test_image_data_to_prompt_parts() {
166        let data = vec![1, 2, 3, 4];
167        let img = ImageData::new("image/png", data.clone());
168        let parts = img.to_prompt_parts();
169
170        assert_eq!(parts.len(), 1);
171        match &parts[0] {
172            PromptPart::Image {
173                media_type,
174                data: img_data,
175            } => {
176                assert_eq!(media_type, "image/png");
177                assert_eq!(img_data, &data);
178            }
179            _ => panic!("Expected Image variant"),
180        }
181    }
182
183    #[test]
184    fn test_base64_conversion() {
185        let original_data = vec![72, 101, 108, 108, 111]; // "Hello" in ASCII
186        let img = ImageData::new("image/test", original_data.clone());
187
188        let base64 = img.to_base64();
189        let decoded = ImageData::from_base64(&base64, "image/test").unwrap();
190
191        assert_eq!(decoded.data, original_data);
192        assert_eq!(decoded.media_type, "image/test");
193    }
194
195    #[test]
196    fn test_data_url_parsing() {
197        let data_url = "data:image/png;base64,SGVsbG8="; // "Hello" in base64
198        let img = ImageData::try_from(data_url).unwrap();
199
200        assert_eq!(img.media_type, "image/png");
201        assert_eq!(img.data, b"Hello");
202    }
203}