ricecoder_images/
provider_integration.rs

1//! Integration with ricecoder-providers for image analysis.
2//!
3//! This module handles:
4//! - Extending ChatRequest with image data
5//! - Image serialization for different providers
6//! - Provider-specific image format handling
7//! - Token counting for images
8//! - Audit logging of image analysis requests
9
10use crate::error::ImageResult;
11use ricecoder_providers::models::ChatRequest;
12use serde::{Deserialize, Serialize};
13
14/// Image data for inclusion in chat requests.
15///
16/// This struct wraps image data in a format that can be serialized
17/// and sent to AI providers. Different providers may require different
18/// formats (base64, URL, etc.).
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ImageData {
21    /// Image format (png, jpg, gif, webp)
22    pub format: String,
23    /// Image data encoded as base64
24    pub data: String,
25    /// Image dimensions (width, height)
26    pub dimensions: (u32, u32),
27    /// Image size in bytes
28    pub size_bytes: u64,
29}
30
31impl ImageData {
32    /// Create image data from raw bytes.
33    pub fn from_bytes(
34        format: &str,
35        data: &[u8],
36        width: u32,
37        height: u32,
38    ) -> Self {
39        let base64_data = base64_encode(data);
40        Self {
41            format: format.to_string(),
42            data: base64_data,
43            dimensions: (width, height),
44            size_bytes: data.len() as u64,
45        }
46    }
47
48    /// Get the MIME type for this image format.
49    pub fn mime_type(&self) -> &str {
50        match self.format.as_str() {
51            "png" => "image/png",
52            "jpg" | "jpeg" => "image/jpeg",
53            "gif" => "image/gif",
54            "webp" => "image/webp",
55            _ => "application/octet-stream",
56        }
57    }
58
59    /// Get the data URL for this image (for providers that support it).
60    pub fn data_url(&self) -> String {
61        format!("data:{};base64,{}", self.mime_type(), self.data)
62    }
63}
64
65/// Extended chat request with image support.
66///
67/// This wraps the standard ChatRequest and adds image data that can be
68/// serialized for different providers.
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct ChatRequestWithImages {
71    /// The base chat request
72    pub request: ChatRequest,
73    /// Images to include in the request
74    pub images: Vec<ImageData>,
75    /// Provider-specific image handling
76    pub provider_format: ProviderImageFormat,
77}
78
79/// How to format images for a specific provider.
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
81pub enum ProviderImageFormat {
82    /// OpenAI format (base64 in message content)
83    OpenAi,
84    /// Anthropic format (base64 in message content)
85    Anthropic,
86    /// Google format (base64 in message content)
87    Google,
88    /// Ollama format (base64 in message content)
89    Ollama,
90    /// Generic format (base64 in message content)
91    Generic,
92}
93
94impl ProviderImageFormat {
95    /// Get the format for a provider name.
96    pub fn for_provider(provider_name: &str) -> Self {
97        match provider_name.to_lowercase().as_str() {
98            "openai" => ProviderImageFormat::OpenAi,
99            "anthropic" => ProviderImageFormat::Anthropic,
100            "google" => ProviderImageFormat::Google,
101            "ollama" => ProviderImageFormat::Ollama,
102            _ => ProviderImageFormat::Generic,
103        }
104    }
105}
106
107impl ChatRequestWithImages {
108    /// Create a new chat request with images.
109    pub fn new(request: ChatRequest, provider_name: &str) -> Self {
110        Self {
111            request,
112            images: Vec::new(),
113            provider_format: ProviderImageFormat::for_provider(provider_name),
114        }
115    }
116
117    /// Add an image to the request.
118    pub fn add_image(&mut self, image: ImageData) {
119        self.images.push(image);
120    }
121
122    /// Add multiple images to the request.
123    pub fn add_images(&mut self, images: Vec<ImageData>) {
124        self.images.extend(images);
125    }
126
127    /// Serialize the request for the target provider.
128    ///
129    /// This converts the images into the format expected by the provider
130    /// and updates the message content accordingly.
131    pub fn serialize_for_provider(&self) -> ImageResult<ChatRequest> {
132        let mut request = self.request.clone();
133
134        // If there are no images, return the request as-is
135        if self.images.is_empty() {
136            return Ok(request);
137        }
138
139        // Update the last user message to include image references
140        if let Some(last_message) = request.messages.iter_mut().rev().find(|m| m.role == "user") {
141            let image_content = self.format_images_for_provider();
142            last_message.content = format!("{}\n\n{}", last_message.content, image_content);
143        }
144
145        Ok(request)
146    }
147
148    /// Format images according to the provider's requirements.
149    fn format_images_for_provider(&self) -> String {
150        match self.provider_format {
151            ProviderImageFormat::OpenAi => self.format_for_openai(),
152            ProviderImageFormat::Anthropic => self.format_for_anthropic(),
153            ProviderImageFormat::Google => self.format_for_google(),
154            ProviderImageFormat::Ollama => self.format_for_ollama(),
155            ProviderImageFormat::Generic => self.format_generic(),
156        }
157    }
158
159    /// Format images for OpenAI (base64 with MIME type).
160    fn format_for_openai(&self) -> String {
161        let mut content = String::new();
162        for (i, image) in self.images.iter().enumerate() {
163            content.push_str(&format!(
164                "[Image {}]\nFormat: {}\nDimensions: {}x{}\nSize: {} bytes\nData: data:{}base64,{}...\n",
165                i + 1,
166                image.format,
167                image.dimensions.0,
168                image.dimensions.1,
169                image.size_bytes,
170                image.mime_type(),
171                &image.data[..std::cmp::min(50, image.data.len())]
172            ));
173        }
174        content
175    }
176
177    /// Format images for Anthropic (base64 with metadata).
178    fn format_for_anthropic(&self) -> String {
179        let mut content = String::new();
180        for (i, image) in self.images.iter().enumerate() {
181            content.push_str(&format!(
182                "[Image {}]\nType: {}\nDimensions: {}x{}\nSize: {} bytes\n",
183                i + 1,
184                image.mime_type(),
185                image.dimensions.0,
186                image.dimensions.1,
187                image.size_bytes
188            ));
189        }
190        content
191    }
192
193    /// Format images for Google (base64 with metadata).
194    fn format_for_google(&self) -> String {
195        let mut content = String::new();
196        for (i, image) in self.images.iter().enumerate() {
197            content.push_str(&format!(
198                "[Image {}]\nMIME Type: {}\nResolution: {}x{}\nSize: {} bytes\n",
199                i + 1,
200                image.mime_type(),
201                image.dimensions.0,
202                image.dimensions.1,
203                image.size_bytes
204            ));
205        }
206        content
207    }
208
209    /// Format images for Ollama (base64 with metadata).
210    fn format_for_ollama(&self) -> String {
211        let mut content = String::new();
212        for (i, image) in self.images.iter().enumerate() {
213            content.push_str(&format!(
214                "[Image {}]\nFormat: {}\nSize: {}x{}\nBytes: {}\n",
215                i + 1,
216                image.format,
217                image.dimensions.0,
218                image.dimensions.1,
219                image.size_bytes
220            ));
221        }
222        content
223    }
224
225    /// Format images in generic format (base64 with minimal metadata).
226    fn format_generic(&self) -> String {
227        let mut content = String::new();
228        for (i, image) in self.images.iter().enumerate() {
229            content.push_str(&format!(
230                "[Image {}] {} ({}x{}, {} bytes)\n",
231                i + 1,
232                image.format.to_uppercase(),
233                image.dimensions.0,
234                image.dimensions.1,
235                image.size_bytes
236            ));
237        }
238        content
239    }
240
241    /// Get the number of images in this request.
242    pub fn image_count(&self) -> usize {
243        self.images.len()
244    }
245
246    /// Check if this request has any images.
247    pub fn has_images(&self) -> bool {
248        !self.images.is_empty()
249    }
250}
251
252/// Audit log entry for image analysis requests.
253#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct ImageAuditLogEntry {
255    /// Timestamp of the request
256    pub timestamp: chrono::DateTime<chrono::Utc>,
257    /// Provider name
258    pub provider: String,
259    /// Model used
260    pub model: String,
261    /// Number of images analyzed
262    pub image_count: usize,
263    /// Total image size in bytes
264    pub total_image_size: u64,
265    /// Image hashes (for deduplication tracking)
266    pub image_hashes: Vec<String>,
267    /// Request status (success, failure, timeout)
268    pub status: String,
269    /// Error message if failed
270    pub error: Option<String>,
271    /// Tokens used
272    pub tokens_used: Option<u32>,
273}
274
275impl ImageAuditLogEntry {
276    /// Create a new audit log entry for a successful analysis.
277    pub fn success(
278        provider: String,
279        model: String,
280        image_count: usize,
281        total_image_size: u64,
282        image_hashes: Vec<String>,
283        tokens_used: u32,
284    ) -> Self {
285        Self {
286            timestamp: chrono::Utc::now(),
287            provider,
288            model,
289            image_count,
290            total_image_size,
291            image_hashes,
292            status: "success".to_string(),
293            error: None,
294            tokens_used: Some(tokens_used),
295        }
296    }
297
298    /// Create a new audit log entry for a failed analysis.
299    pub fn failure(
300        provider: String,
301        model: String,
302        image_count: usize,
303        total_image_size: u64,
304        image_hashes: Vec<String>,
305        error: String,
306    ) -> Self {
307        Self {
308            timestamp: chrono::Utc::now(),
309            provider,
310            model,
311            image_count,
312            total_image_size,
313            image_hashes,
314            status: "failure".to_string(),
315            error: Some(error),
316            tokens_used: None,
317        }
318    }
319
320    /// Create a new audit log entry for a timeout.
321    pub fn timeout(
322        provider: String,
323        model: String,
324        image_count: usize,
325        total_image_size: u64,
326        image_hashes: Vec<String>,
327    ) -> Self {
328        Self {
329            timestamp: chrono::Utc::now(),
330            provider,
331            model,
332            image_count,
333            total_image_size,
334            image_hashes,
335            status: "timeout".to_string(),
336            error: Some("Analysis timeout".to_string()),
337            tokens_used: None,
338        }
339    }
340}
341
342/// Encode binary data as base64 string.
343fn base64_encode(data: &[u8]) -> String {
344    const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
345    let mut result = String::new();
346
347    for chunk in data.chunks(3) {
348        let b1 = chunk[0];
349        let b2 = chunk.get(1).copied().unwrap_or(0);
350        let b3 = chunk.get(2).copied().unwrap_or(0);
351
352        let n = ((b1 as u32) << 16) | ((b2 as u32) << 8) | (b3 as u32);
353
354        result.push(BASE64_CHARS[((n >> 18) & 63) as usize] as char);
355        result.push(BASE64_CHARS[((n >> 12) & 63) as usize] as char);
356
357        if chunk.len() > 1 {
358            result.push(BASE64_CHARS[((n >> 6) & 63) as usize] as char);
359        } else {
360            result.push('=');
361        }
362
363        if chunk.len() > 2 {
364            result.push(BASE64_CHARS[(n & 63) as usize] as char);
365        } else {
366            result.push('=');
367        }
368    }
369
370    result
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use ricecoder_providers::models::Message;
377
378    #[test]
379    fn test_image_data_creation() {
380        let data = vec![1, 2, 3, 4, 5];
381        let image = ImageData::from_bytes("png", &data, 800, 600);
382
383        assert_eq!(image.format, "png");
384        assert_eq!(image.dimensions, (800, 600));
385        assert_eq!(image.size_bytes, 5);
386        assert!(!image.data.is_empty());
387    }
388
389    #[test]
390    fn test_image_data_mime_type() {
391        let data = vec![1, 2, 3];
392
393        let png = ImageData::from_bytes("png", &data, 100, 100);
394        assert_eq!(png.mime_type(), "image/png");
395
396        let jpg = ImageData::from_bytes("jpg", &data, 100, 100);
397        assert_eq!(jpg.mime_type(), "image/jpeg");
398
399        let gif = ImageData::from_bytes("gif", &data, 100, 100);
400        assert_eq!(gif.mime_type(), "image/gif");
401
402        let webp = ImageData::from_bytes("webp", &data, 100, 100);
403        assert_eq!(webp.mime_type(), "image/webp");
404    }
405
406    #[test]
407    fn test_image_data_url() {
408        let data = vec![1, 2, 3];
409        let image = ImageData::from_bytes("png", &data, 100, 100);
410        let url = image.data_url();
411
412        assert!(url.starts_with("data:image/png;base64,"));
413    }
414
415    #[test]
416    fn test_provider_image_format() {
417        assert_eq!(
418            ProviderImageFormat::for_provider("openai"),
419            ProviderImageFormat::OpenAi
420        );
421        assert_eq!(
422            ProviderImageFormat::for_provider("anthropic"),
423            ProviderImageFormat::Anthropic
424        );
425        assert_eq!(
426            ProviderImageFormat::for_provider("google"),
427            ProviderImageFormat::Google
428        );
429        assert_eq!(
430            ProviderImageFormat::for_provider("ollama"),
431            ProviderImageFormat::Ollama
432        );
433        assert_eq!(
434            ProviderImageFormat::for_provider("unknown"),
435            ProviderImageFormat::Generic
436        );
437    }
438
439    #[test]
440    fn test_chat_request_with_images_creation() {
441        let request = ChatRequest {
442            model: "gpt-4".to_string(),
443            messages: vec![Message {
444                role: "user".to_string(),
445                content: "Analyze this image".to_string(),
446            }],
447            temperature: Some(0.7),
448            max_tokens: Some(1000),
449            stream: false,
450        };
451
452        let chat_with_images = ChatRequestWithImages::new(request, "openai");
453        assert_eq!(chat_with_images.image_count(), 0);
454        assert!(!chat_with_images.has_images());
455    }
456
457    #[test]
458    fn test_chat_request_add_image() {
459        let request = ChatRequest {
460            model: "gpt-4".to_string(),
461            messages: vec![Message {
462                role: "user".to_string(),
463                content: "Analyze this image".to_string(),
464            }],
465            temperature: Some(0.7),
466            max_tokens: Some(1000),
467            stream: false,
468        };
469
470        let mut chat_with_images = ChatRequestWithImages::new(request, "openai");
471        let image = ImageData::from_bytes("png", &[1, 2, 3], 800, 600);
472        chat_with_images.add_image(image);
473
474        assert_eq!(chat_with_images.image_count(), 1);
475        assert!(chat_with_images.has_images());
476    }
477
478    #[test]
479    fn test_chat_request_serialize_for_provider() {
480        let request = ChatRequest {
481            model: "gpt-4".to_string(),
482            messages: vec![Message {
483                role: "user".to_string(),
484                content: "Analyze this image".to_string(),
485            }],
486            temperature: Some(0.7),
487            max_tokens: Some(1000),
488            stream: false,
489        };
490
491        let mut chat_with_images = ChatRequestWithImages::new(request, "openai");
492        let image = ImageData::from_bytes("png", &[1, 2, 3], 800, 600);
493        chat_with_images.add_image(image);
494
495        let serialized = chat_with_images.serialize_for_provider().unwrap();
496        assert!(!serialized.messages.is_empty());
497        assert!(serialized.messages[0].content.contains("Image"));
498    }
499
500    #[test]
501    fn test_audit_log_entry_success() {
502        let entry = ImageAuditLogEntry::success(
503            "openai".to_string(),
504            "gpt-4".to_string(),
505            1,
506            1024,
507            vec!["hash1".to_string()],
508            100,
509        );
510
511        assert_eq!(entry.status, "success");
512        assert_eq!(entry.tokens_used, Some(100));
513        assert!(entry.error.is_none());
514    }
515
516    #[test]
517    fn test_audit_log_entry_failure() {
518        let entry = ImageAuditLogEntry::failure(
519            "openai".to_string(),
520            "gpt-4".to_string(),
521            1,
522            1024,
523            vec!["hash1".to_string()],
524            "Provider error".to_string(),
525        );
526
527        assert_eq!(entry.status, "failure");
528        assert!(entry.tokens_used.is_none());
529        assert!(entry.error.is_some());
530    }
531
532    #[test]
533    fn test_audit_log_entry_timeout() {
534        let entry = ImageAuditLogEntry::timeout(
535            "openai".to_string(),
536            "gpt-4".to_string(),
537            1,
538            1024,
539            vec!["hash1".to_string()],
540        );
541
542        assert_eq!(entry.status, "timeout");
543        assert!(entry.tokens_used.is_none());
544        assert!(entry.error.is_some());
545    }
546
547    #[test]
548    fn test_base64_encode() {
549        let data = b"Hello";
550        let encoded = base64_encode(data);
551        assert!(!encoded.is_empty());
552
553        let empty = base64_encode(&[]);
554        assert_eq!(empty, "");
555
556        let single = base64_encode(&[65]); // 'A'
557        assert!(!single.is_empty());
558    }
559}