kaccy_ai/
ocr.rs

1//! Optical Character Recognition (OCR) module
2//!
3//! This module provides OCR capabilities for extracting text from images
4//! and screenshots. It supports multiple OCR backends including cloud
5//! services and local processing.
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10
11use crate::error::{AiError, Result};
12use crate::llm::{ChatRequest, LlmClient};
13
14/// OCR provider trait for different backends
15#[async_trait]
16pub trait OcrProvider: Send + Sync {
17    /// Extract text from image bytes
18    async fn extract_text(&self, image_data: &[u8], format: ImageFormat) -> Result<OcrResult>;
19
20    /// Extract text from image URL
21    async fn extract_text_from_url(&self, url: &str) -> Result<OcrResult>;
22
23    /// Get provider name
24    fn name(&self) -> &str;
25}
26
27/// Supported image formats
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
29pub enum ImageFormat {
30    /// JPEG image
31    Jpeg,
32    /// PNG image
33    Png,
34    /// WebP image
35    WebP,
36    /// GIF image (first frame)
37    Gif,
38    /// BMP image
39    Bmp,
40    /// TIFF image
41    Tiff,
42    /// Unknown format (auto-detect)
43    Unknown,
44}
45
46impl ImageFormat {
47    /// Detect image format from magic bytes
48    #[must_use]
49    pub fn detect(data: &[u8]) -> Self {
50        if data.len() < 2 {
51            return ImageFormat::Unknown;
52        }
53
54        // Check magic bytes (ordered by minimum required length)
55        // BMP: 2 bytes
56        if data.starts_with(&[0x42, 0x4D]) {
57            return ImageFormat::Bmp;
58        }
59        // JPEG: 3 bytes
60        if data.len() >= 3 && data.starts_with(&[0xFF, 0xD8, 0xFF]) {
61            return ImageFormat::Jpeg;
62        }
63        // TIFF: 4 bytes
64        if data.len() >= 4
65            && (data.starts_with(&[0x49, 0x49, 0x2A, 0x00])
66                || data.starts_with(&[0x4D, 0x4D, 0x00, 0x2A]))
67        {
68            return ImageFormat::Tiff;
69        }
70        // GIF: 6 bytes
71        if data.len() >= 6 && (data.starts_with(b"GIF87a") || data.starts_with(b"GIF89a")) {
72            return ImageFormat::Gif;
73        }
74        // PNG: 8 bytes
75        if data.len() >= 8 && data.starts_with(&[0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]) {
76            return ImageFormat::Png;
77        }
78        // WebP: 12 bytes
79        if data.len() >= 12 && data.starts_with(b"RIFF") && &data[8..12] == b"WEBP" {
80            return ImageFormat::WebP;
81        }
82
83        ImageFormat::Unknown
84    }
85
86    /// Detect format from file extension
87    #[must_use]
88    pub fn from_extension(ext: &str) -> Self {
89        match ext.to_lowercase().as_str() {
90            "jpg" | "jpeg" => ImageFormat::Jpeg,
91            "png" => ImageFormat::Png,
92            "webp" => ImageFormat::WebP,
93            "gif" => ImageFormat::Gif,
94            "bmp" => ImageFormat::Bmp,
95            "tif" | "tiff" => ImageFormat::Tiff,
96            _ => ImageFormat::Unknown,
97        }
98    }
99
100    /// Get MIME type for format
101    #[must_use]
102    pub fn mime_type(&self) -> &'static str {
103        match self {
104            ImageFormat::Jpeg => "image/jpeg",
105            ImageFormat::Png => "image/png",
106            ImageFormat::WebP => "image/webp",
107            ImageFormat::Gif => "image/gif",
108            ImageFormat::Bmp => "image/bmp",
109            ImageFormat::Tiff => "image/tiff",
110            ImageFormat::Unknown => "application/octet-stream",
111        }
112    }
113}
114
115/// OCR extraction result
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct OcrResult {
118    /// Extracted text content
119    pub text: String,
120    /// Confidence score (0.0-1.0)
121    pub confidence: f64,
122    /// Language detected (ISO 639-1 code)
123    pub language: Option<String>,
124    /// Text blocks with positions
125    pub blocks: Vec<TextBlock>,
126    /// Processing time in milliseconds
127    pub processing_time_ms: u64,
128    /// Provider used for extraction
129    pub provider: String,
130}
131
132/// Text block with position information
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct TextBlock {
135    /// Text content
136    pub text: String,
137    /// Confidence for this block
138    pub confidence: f64,
139    /// Bounding box (x, y, width, height) as percentage of image
140    pub bounding_box: Option<BoundingBox>,
141    /// Block type (paragraph, line, word, etc.)
142    pub block_type: BlockType,
143}
144
145/// Bounding box coordinates (normalized 0.0-1.0)
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct BoundingBox {
148    /// X coordinate (left)
149    pub x: f64,
150    /// Y coordinate (top)
151    pub y: f64,
152    /// Width
153    pub width: f64,
154    /// Height
155    pub height: f64,
156}
157
158/// Type of text block
159#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
160pub enum BlockType {
161    /// Full page
162    Page,
163    /// Paragraph
164    Paragraph,
165    /// Single line
166    Line,
167    /// Single word
168    Word,
169    /// Symbol/character
170    Symbol,
171}
172
173/// LLM-based OCR provider using vision models
174pub struct LlmOcrProvider {
175    llm: LlmClient,
176    config: LlmOcrConfig,
177}
178
179/// Configuration for LLM OCR
180#[derive(Debug, Clone)]
181pub struct LlmOcrConfig {
182    /// Maximum image size in bytes
183    pub max_image_size: usize,
184    /// Whether to include detailed analysis
185    pub detailed_analysis: bool,
186    /// Target language for extraction (optional)
187    pub target_language: Option<String>,
188}
189
190impl Default for LlmOcrConfig {
191    fn default() -> Self {
192        Self {
193            max_image_size: 20 * 1024 * 1024, // 20MB
194            detailed_analysis: true,
195            target_language: None,
196        }
197    }
198}
199
200impl LlmOcrProvider {
201    /// Create a new LLM OCR provider
202    #[must_use]
203    pub fn new(llm: LlmClient) -> Self {
204        Self {
205            llm,
206            config: LlmOcrConfig::default(),
207        }
208    }
209
210    /// Create with custom configuration
211    #[must_use]
212    pub fn with_config(llm: LlmClient, config: LlmOcrConfig) -> Self {
213        Self { llm, config }
214    }
215
216    fn build_ocr_prompt(&self, detailed: bool) -> String {
217        if detailed {
218            r#"Extract all text visible in this image. Please provide:
219
2201. **Full Text**: All text content exactly as it appears
2212. **Structure**: Identify headings, paragraphs, lists, tables if present
2223. **Layout**: Note any important layout information
223
224Respond in JSON format:
225{
226    "text": "<all extracted text>",
227    "blocks": [
228        {"text": "<block text>", "type": "<paragraph|heading|list|table|other>"}
229    ],
230    "language": "<detected language code or null>",
231    "confidence": <0.0-1.0 confidence estimate>
232}"#
233            .to_string()
234        } else {
235            "Extract all text from this image. Return only the text content, preserving the layout as much as possible.".to_string()
236        }
237    }
238
239    fn parse_ocr_response(&self, response: &str, detailed: bool) -> Result<OcrResult> {
240        let start_time = std::time::Instant::now();
241
242        if detailed {
243            // Try to parse JSON response
244            if let Some(start) = response.find('{') {
245                if let Some(end) = response.rfind('}') {
246                    let json_str = &response[start..=end];
247                    if let Ok(parsed) = serde_json::from_str::<LlmOcrResponse>(json_str) {
248                        return Ok(OcrResult {
249                            text: parsed.text,
250                            confidence: parsed.confidence.unwrap_or(0.8),
251                            language: parsed.language,
252                            blocks: parsed
253                                .blocks
254                                .unwrap_or_default()
255                                .into_iter()
256                                .map(|b| TextBlock {
257                                    text: b.text,
258                                    confidence: 0.8,
259                                    bounding_box: None,
260                                    block_type: match b.block_type.as_str() {
261                                        "heading" => BlockType::Line,
262                                        "paragraph" => BlockType::Paragraph,
263                                        "list" | "table" => BlockType::Paragraph,
264                                        _ => BlockType::Paragraph,
265                                    },
266                                })
267                                .collect(),
268                            processing_time_ms: start_time.elapsed().as_millis() as u64,
269                            provider: "llm".to_string(),
270                        });
271                    }
272                }
273            }
274        }
275
276        // Fall back to plain text response
277        Ok(OcrResult {
278            text: response.trim().to_string(),
279            confidence: 0.7,
280            language: None,
281            blocks: vec![TextBlock {
282                text: response.trim().to_string(),
283                confidence: 0.7,
284                bounding_box: None,
285                block_type: BlockType::Page,
286            }],
287            processing_time_ms: start_time.elapsed().as_millis() as u64,
288            provider: "llm".to_string(),
289        })
290    }
291}
292
293#[derive(Debug, Deserialize)]
294struct LlmOcrResponse {
295    text: String,
296    blocks: Option<Vec<LlmOcrBlock>>,
297    language: Option<String>,
298    confidence: Option<f64>,
299}
300
301#[derive(Debug, Deserialize)]
302struct LlmOcrBlock {
303    text: String,
304    #[serde(rename = "type")]
305    block_type: String,
306}
307
308#[async_trait]
309impl OcrProvider for LlmOcrProvider {
310    async fn extract_text(&self, image_data: &[u8], format: ImageFormat) -> Result<OcrResult> {
311        // Check image size
312        if image_data.len() > self.config.max_image_size {
313            return Err(AiError::Validation(format!(
314                "Image too large: {} bytes (max {} bytes)",
315                image_data.len(),
316                self.config.max_image_size
317            )));
318        }
319
320        // Encode image as base64
321        let base64_image =
322            base64::Engine::encode(&base64::engine::general_purpose::STANDARD, image_data);
323
324        let prompt = self.build_ocr_prompt(self.config.detailed_analysis);
325
326        // Create vision request with image
327        let image_url = format!("data:{};base64,{}", format.mime_type(), base64_image);
328
329        let request = ChatRequest::with_vision(
330            "You are an expert OCR system. Extract text accurately from images.",
331            prompt,
332            image_url,
333        )
334        .max_tokens(4096)
335        .temperature(0.1);
336
337        let response = self.llm.chat(request).await?;
338
339        self.parse_ocr_response(&response.message.content, self.config.detailed_analysis)
340    }
341
342    async fn extract_text_from_url(&self, url: &str) -> Result<OcrResult> {
343        let prompt = self.build_ocr_prompt(self.config.detailed_analysis);
344
345        let request = ChatRequest::with_vision(
346            "You are an expert OCR system. Extract text accurately from images.",
347            prompt,
348            url.to_string(),
349        )
350        .max_tokens(4096)
351        .temperature(0.1);
352
353        let response = self.llm.chat(request).await?;
354
355        self.parse_ocr_response(&response.message.content, self.config.detailed_analysis)
356    }
357
358    fn name(&self) -> &'static str {
359        "llm-vision"
360    }
361}
362
363/// Simple image analyzer for basic image properties
364pub struct ImageAnalyzer;
365
366impl ImageAnalyzer {
367    /// Get image dimensions from bytes
368    pub fn get_dimensions(data: &[u8]) -> Result<(u32, u32)> {
369        use image::GenericImageView;
370
371        let img = image::load_from_memory(data)
372            .map_err(|e| AiError::Validation(format!("Failed to load image: {e}")))?;
373
374        Ok(img.dimensions())
375    }
376
377    /// Check if image is likely a screenshot (based on dimensions)
378    #[must_use]
379    pub fn is_likely_screenshot(width: u32, height: u32) -> bool {
380        // Common screen resolutions
381        let common_widths = [1920, 2560, 3840, 1366, 1440, 1536, 1280];
382        let common_heights = [1080, 1440, 2160, 768, 900, 864, 720, 800];
383
384        common_widths.contains(&width) || common_heights.contains(&height)
385    }
386
387    /// Analyze image content for text regions (simple heuristic)
388    pub fn estimate_text_regions(data: &[u8]) -> Result<TextRegionEstimate> {
389        use image::GenericImageView;
390
391        let img = image::load_from_memory(data)
392            .map_err(|e| AiError::Validation(format!("Failed to load image: {e}")))?;
393
394        let (width, height) = img.dimensions();
395        let grayscale = img.to_luma8();
396
397        // Calculate contrast and edges (simplified)
398        let mut high_contrast_pixels = 0u64;
399        let total_pixels = u64::from(width * height);
400
401        for y in 1..height - 1 {
402            for x in 1..width - 1 {
403                let center = i32::from(grayscale.get_pixel(x, y)[0]);
404                let right = i32::from(grayscale.get_pixel(x + 1, y)[0]);
405                let down = i32::from(grayscale.get_pixel(x, y + 1)[0]);
406
407                // Simple edge detection
408                if (center - right).abs() > 30 || (center - down).abs() > 30 {
409                    high_contrast_pixels += 1;
410                }
411            }
412        }
413
414        let edge_ratio = high_contrast_pixels as f64 / total_pixels as f64;
415
416        // Text-heavy images typically have moderate edge ratios
417        let likely_has_text = edge_ratio > 0.05 && edge_ratio < 0.4;
418
419        Ok(TextRegionEstimate {
420            width,
421            height,
422            edge_ratio,
423            likely_has_text,
424            estimated_text_coverage: if likely_has_text {
425                (edge_ratio * 2.0).min(1.0)
426            } else {
427                0.0
428            },
429        })
430    }
431}
432
433/// Estimate of text regions in an image
434#[derive(Debug, Clone, Serialize, Deserialize)]
435pub struct TextRegionEstimate {
436    /// Image width
437    pub width: u32,
438    /// Image height
439    pub height: u32,
440    /// Ratio of high-contrast edges
441    pub edge_ratio: f64,
442    /// Whether the image likely contains text
443    pub likely_has_text: bool,
444    /// Estimated text coverage (0.0-1.0)
445    pub estimated_text_coverage: f64,
446}
447
448/// Screenshot-specific OCR service
449pub struct ScreenshotOcr {
450    provider: Box<dyn OcrProvider>,
451}
452
453impl ScreenshotOcr {
454    /// Create with default LLM provider
455    #[must_use]
456    pub fn new(llm: LlmClient) -> Self {
457        Self {
458            provider: Box::new(LlmOcrProvider::new(llm)),
459        }
460    }
461
462    /// Create with custom provider
463    #[must_use]
464    pub fn with_provider(provider: Box<dyn OcrProvider>) -> Self {
465        Self { provider }
466    }
467
468    /// Process a screenshot and extract text
469    pub async fn process_screenshot(&self, data: &[u8]) -> Result<ScreenshotAnalysis> {
470        let format = ImageFormat::detect(data);
471
472        // Analyze image properties
473        let (width, height) = ImageAnalyzer::get_dimensions(data)?;
474        let is_screenshot = ImageAnalyzer::is_likely_screenshot(width, height);
475        let text_estimate = ImageAnalyzer::estimate_text_regions(data)?;
476
477        // Skip OCR if unlikely to have text
478        if !text_estimate.likely_has_text {
479            return Ok(ScreenshotAnalysis {
480                ocr_result: None,
481                is_screenshot,
482                dimensions: (width, height),
483                format,
484                has_text: false,
485                text_estimate,
486            });
487        }
488
489        // Perform OCR
490        let ocr_result = self.provider.extract_text(data, format).await?;
491        let has_text = !ocr_result.text.trim().is_empty();
492
493        Ok(ScreenshotAnalysis {
494            ocr_result: Some(ocr_result),
495            is_screenshot,
496            dimensions: (width, height),
497            format,
498            has_text,
499            text_estimate,
500        })
501    }
502
503    /// Process screenshot from file path
504    pub async fn process_file(&self, path: &Path) -> Result<ScreenshotAnalysis> {
505        let data = std::fs::read(path)
506            .map_err(|e| AiError::Validation(format!("Failed to read file: {e}")))?;
507        self.process_screenshot(&data).await
508    }
509
510    /// Process screenshot from URL
511    pub async fn process_url(&self, url: &str) -> Result<ScreenshotAnalysis> {
512        let ocr_result = self.provider.extract_text_from_url(url).await?;
513        let has_text = !ocr_result.text.trim().is_empty();
514
515        Ok(ScreenshotAnalysis {
516            ocr_result: Some(ocr_result),
517            is_screenshot: true, // Assume URL screenshots are actual screenshots
518            dimensions: (0, 0),  // Unknown without downloading
519            format: ImageFormat::Unknown,
520            has_text,
521            text_estimate: TextRegionEstimate {
522                width: 0,
523                height: 0,
524                edge_ratio: 0.0,
525                likely_has_text: true,
526                estimated_text_coverage: 0.0,
527            },
528        })
529    }
530}
531
532/// Screenshot analysis result
533#[derive(Debug, Clone, Serialize, Deserialize)]
534pub struct ScreenshotAnalysis {
535    /// OCR result (None if no text detected)
536    pub ocr_result: Option<OcrResult>,
537    /// Whether image appears to be a screenshot
538    pub is_screenshot: bool,
539    /// Image dimensions (width, height)
540    pub dimensions: (u32, u32),
541    /// Image format
542    pub format: ImageFormat,
543    /// Whether text was detected
544    pub has_text: bool,
545    /// Text region estimate
546    pub text_estimate: TextRegionEstimate,
547}
548
549impl ScreenshotAnalysis {
550    /// Get extracted text if available
551    #[must_use]
552    pub fn text(&self) -> Option<&str> {
553        self.ocr_result.as_ref().map(|r| r.text.as_str())
554    }
555
556    /// Get confidence score
557    #[must_use]
558    pub fn confidence(&self) -> f64 {
559        self.ocr_result.as_ref().map_or(0.0, |r| r.confidence)
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566
567    #[test]
568    fn test_image_format_detection() {
569        // PNG magic bytes
570        let png_data = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
571        assert_eq!(ImageFormat::detect(&png_data), ImageFormat::Png);
572
573        // JPEG magic bytes
574        let jpeg_data = [0xFF, 0xD8, 0xFF, 0xE0];
575        assert_eq!(ImageFormat::detect(&jpeg_data), ImageFormat::Jpeg);
576
577        // Unknown
578        let unknown_data = [0x00, 0x01, 0x02, 0x03];
579        assert_eq!(ImageFormat::detect(&unknown_data), ImageFormat::Unknown);
580    }
581
582    #[test]
583    fn test_format_from_extension() {
584        assert_eq!(ImageFormat::from_extension("png"), ImageFormat::Png);
585        assert_eq!(ImageFormat::from_extension("jpg"), ImageFormat::Jpeg);
586        assert_eq!(ImageFormat::from_extension("jpeg"), ImageFormat::Jpeg);
587        assert_eq!(ImageFormat::from_extension("webp"), ImageFormat::WebP);
588        assert_eq!(ImageFormat::from_extension("txt"), ImageFormat::Unknown);
589    }
590
591    #[test]
592    fn test_mime_type() {
593        assert_eq!(ImageFormat::Png.mime_type(), "image/png");
594        assert_eq!(ImageFormat::Jpeg.mime_type(), "image/jpeg");
595        assert_eq!(ImageFormat::WebP.mime_type(), "image/webp");
596    }
597
598    #[test]
599    fn test_is_likely_screenshot() {
600        assert!(ImageAnalyzer::is_likely_screenshot(1920, 1080));
601        assert!(ImageAnalyzer::is_likely_screenshot(2560, 1440));
602        assert!(!ImageAnalyzer::is_likely_screenshot(500, 500));
603    }
604}