1use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10
11use crate::error::{AiError, Result};
12use crate::llm::{ChatRequest, LlmClient};
13
14#[async_trait]
16pub trait OcrProvider: Send + Sync {
17 async fn extract_text(&self, image_data: &[u8], format: ImageFormat) -> Result<OcrResult>;
19
20 async fn extract_text_from_url(&self, url: &str) -> Result<OcrResult>;
22
23 fn name(&self) -> &str;
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
29pub enum ImageFormat {
30 Jpeg,
32 Png,
34 WebP,
36 Gif,
38 Bmp,
40 Tiff,
42 Unknown,
44}
45
46impl ImageFormat {
47 #[must_use]
49 pub fn detect(data: &[u8]) -> Self {
50 if data.len() < 2 {
51 return ImageFormat::Unknown;
52 }
53
54 if data.starts_with(&[0x42, 0x4D]) {
57 return ImageFormat::Bmp;
58 }
59 if data.len() >= 3 && data.starts_with(&[0xFF, 0xD8, 0xFF]) {
61 return ImageFormat::Jpeg;
62 }
63 if data.len() >= 4
65 && (data.starts_with(&[0x49, 0x49, 0x2A, 0x00])
66 || data.starts_with(&[0x4D, 0x4D, 0x00, 0x2A]))
67 {
68 return ImageFormat::Tiff;
69 }
70 if data.len() >= 6 && (data.starts_with(b"GIF87a") || data.starts_with(b"GIF89a")) {
72 return ImageFormat::Gif;
73 }
74 if data.len() >= 8 && data.starts_with(&[0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]) {
76 return ImageFormat::Png;
77 }
78 if data.len() >= 12 && data.starts_with(b"RIFF") && &data[8..12] == b"WEBP" {
80 return ImageFormat::WebP;
81 }
82
83 ImageFormat::Unknown
84 }
85
86 #[must_use]
88 pub fn from_extension(ext: &str) -> Self {
89 match ext.to_lowercase().as_str() {
90 "jpg" | "jpeg" => ImageFormat::Jpeg,
91 "png" => ImageFormat::Png,
92 "webp" => ImageFormat::WebP,
93 "gif" => ImageFormat::Gif,
94 "bmp" => ImageFormat::Bmp,
95 "tif" | "tiff" => ImageFormat::Tiff,
96 _ => ImageFormat::Unknown,
97 }
98 }
99
100 #[must_use]
102 pub fn mime_type(&self) -> &'static str {
103 match self {
104 ImageFormat::Jpeg => "image/jpeg",
105 ImageFormat::Png => "image/png",
106 ImageFormat::WebP => "image/webp",
107 ImageFormat::Gif => "image/gif",
108 ImageFormat::Bmp => "image/bmp",
109 ImageFormat::Tiff => "image/tiff",
110 ImageFormat::Unknown => "application/octet-stream",
111 }
112 }
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct OcrResult {
118 pub text: String,
120 pub confidence: f64,
122 pub language: Option<String>,
124 pub blocks: Vec<TextBlock>,
126 pub processing_time_ms: u64,
128 pub provider: String,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct TextBlock {
135 pub text: String,
137 pub confidence: f64,
139 pub bounding_box: Option<BoundingBox>,
141 pub block_type: BlockType,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct BoundingBox {
148 pub x: f64,
150 pub y: f64,
152 pub width: f64,
154 pub height: f64,
156}
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
160pub enum BlockType {
161 Page,
163 Paragraph,
165 Line,
167 Word,
169 Symbol,
171}
172
173pub struct LlmOcrProvider {
175 llm: LlmClient,
176 config: LlmOcrConfig,
177}
178
179#[derive(Debug, Clone)]
181pub struct LlmOcrConfig {
182 pub max_image_size: usize,
184 pub detailed_analysis: bool,
186 pub target_language: Option<String>,
188}
189
190impl Default for LlmOcrConfig {
191 fn default() -> Self {
192 Self {
193 max_image_size: 20 * 1024 * 1024, detailed_analysis: true,
195 target_language: None,
196 }
197 }
198}
199
200impl LlmOcrProvider {
201 #[must_use]
203 pub fn new(llm: LlmClient) -> Self {
204 Self {
205 llm,
206 config: LlmOcrConfig::default(),
207 }
208 }
209
210 #[must_use]
212 pub fn with_config(llm: LlmClient, config: LlmOcrConfig) -> Self {
213 Self { llm, config }
214 }
215
216 fn build_ocr_prompt(&self, detailed: bool) -> String {
217 if detailed {
218 r#"Extract all text visible in this image. Please provide:
219
2201. **Full Text**: All text content exactly as it appears
2212. **Structure**: Identify headings, paragraphs, lists, tables if present
2223. **Layout**: Note any important layout information
223
224Respond in JSON format:
225{
226 "text": "<all extracted text>",
227 "blocks": [
228 {"text": "<block text>", "type": "<paragraph|heading|list|table|other>"}
229 ],
230 "language": "<detected language code or null>",
231 "confidence": <0.0-1.0 confidence estimate>
232}"#
233 .to_string()
234 } else {
235 "Extract all text from this image. Return only the text content, preserving the layout as much as possible.".to_string()
236 }
237 }
238
239 fn parse_ocr_response(&self, response: &str, detailed: bool) -> Result<OcrResult> {
240 let start_time = std::time::Instant::now();
241
242 if detailed {
243 if let Some(start) = response.find('{') {
245 if let Some(end) = response.rfind('}') {
246 let json_str = &response[start..=end];
247 if let Ok(parsed) = serde_json::from_str::<LlmOcrResponse>(json_str) {
248 return Ok(OcrResult {
249 text: parsed.text,
250 confidence: parsed.confidence.unwrap_or(0.8),
251 language: parsed.language,
252 blocks: parsed
253 .blocks
254 .unwrap_or_default()
255 .into_iter()
256 .map(|b| TextBlock {
257 text: b.text,
258 confidence: 0.8,
259 bounding_box: None,
260 block_type: match b.block_type.as_str() {
261 "heading" => BlockType::Line,
262 "paragraph" => BlockType::Paragraph,
263 "list" | "table" => BlockType::Paragraph,
264 _ => BlockType::Paragraph,
265 },
266 })
267 .collect(),
268 processing_time_ms: start_time.elapsed().as_millis() as u64,
269 provider: "llm".to_string(),
270 });
271 }
272 }
273 }
274 }
275
276 Ok(OcrResult {
278 text: response.trim().to_string(),
279 confidence: 0.7,
280 language: None,
281 blocks: vec![TextBlock {
282 text: response.trim().to_string(),
283 confidence: 0.7,
284 bounding_box: None,
285 block_type: BlockType::Page,
286 }],
287 processing_time_ms: start_time.elapsed().as_millis() as u64,
288 provider: "llm".to_string(),
289 })
290 }
291}
292
293#[derive(Debug, Deserialize)]
294struct LlmOcrResponse {
295 text: String,
296 blocks: Option<Vec<LlmOcrBlock>>,
297 language: Option<String>,
298 confidence: Option<f64>,
299}
300
301#[derive(Debug, Deserialize)]
302struct LlmOcrBlock {
303 text: String,
304 #[serde(rename = "type")]
305 block_type: String,
306}
307
308#[async_trait]
309impl OcrProvider for LlmOcrProvider {
310 async fn extract_text(&self, image_data: &[u8], format: ImageFormat) -> Result<OcrResult> {
311 if image_data.len() > self.config.max_image_size {
313 return Err(AiError::Validation(format!(
314 "Image too large: {} bytes (max {} bytes)",
315 image_data.len(),
316 self.config.max_image_size
317 )));
318 }
319
320 let base64_image =
322 base64::Engine::encode(&base64::engine::general_purpose::STANDARD, image_data);
323
324 let prompt = self.build_ocr_prompt(self.config.detailed_analysis);
325
326 let image_url = format!("data:{};base64,{}", format.mime_type(), base64_image);
328
329 let request = ChatRequest::with_vision(
330 "You are an expert OCR system. Extract text accurately from images.",
331 prompt,
332 image_url,
333 )
334 .max_tokens(4096)
335 .temperature(0.1);
336
337 let response = self.llm.chat(request).await?;
338
339 self.parse_ocr_response(&response.message.content, self.config.detailed_analysis)
340 }
341
342 async fn extract_text_from_url(&self, url: &str) -> Result<OcrResult> {
343 let prompt = self.build_ocr_prompt(self.config.detailed_analysis);
344
345 let request = ChatRequest::with_vision(
346 "You are an expert OCR system. Extract text accurately from images.",
347 prompt,
348 url.to_string(),
349 )
350 .max_tokens(4096)
351 .temperature(0.1);
352
353 let response = self.llm.chat(request).await?;
354
355 self.parse_ocr_response(&response.message.content, self.config.detailed_analysis)
356 }
357
358 fn name(&self) -> &'static str {
359 "llm-vision"
360 }
361}
362
363pub struct ImageAnalyzer;
365
366impl ImageAnalyzer {
367 pub fn get_dimensions(data: &[u8]) -> Result<(u32, u32)> {
369 use image::GenericImageView;
370
371 let img = image::load_from_memory(data)
372 .map_err(|e| AiError::Validation(format!("Failed to load image: {e}")))?;
373
374 Ok(img.dimensions())
375 }
376
377 #[must_use]
379 pub fn is_likely_screenshot(width: u32, height: u32) -> bool {
380 let common_widths = [1920, 2560, 3840, 1366, 1440, 1536, 1280];
382 let common_heights = [1080, 1440, 2160, 768, 900, 864, 720, 800];
383
384 common_widths.contains(&width) || common_heights.contains(&height)
385 }
386
387 pub fn estimate_text_regions(data: &[u8]) -> Result<TextRegionEstimate> {
389 use image::GenericImageView;
390
391 let img = image::load_from_memory(data)
392 .map_err(|e| AiError::Validation(format!("Failed to load image: {e}")))?;
393
394 let (width, height) = img.dimensions();
395 let grayscale = img.to_luma8();
396
397 let mut high_contrast_pixels = 0u64;
399 let total_pixels = u64::from(width * height);
400
401 for y in 1..height - 1 {
402 for x in 1..width - 1 {
403 let center = i32::from(grayscale.get_pixel(x, y)[0]);
404 let right = i32::from(grayscale.get_pixel(x + 1, y)[0]);
405 let down = i32::from(grayscale.get_pixel(x, y + 1)[0]);
406
407 if (center - right).abs() > 30 || (center - down).abs() > 30 {
409 high_contrast_pixels += 1;
410 }
411 }
412 }
413
414 let edge_ratio = high_contrast_pixels as f64 / total_pixels as f64;
415
416 let likely_has_text = edge_ratio > 0.05 && edge_ratio < 0.4;
418
419 Ok(TextRegionEstimate {
420 width,
421 height,
422 edge_ratio,
423 likely_has_text,
424 estimated_text_coverage: if likely_has_text {
425 (edge_ratio * 2.0).min(1.0)
426 } else {
427 0.0
428 },
429 })
430 }
431}
432
433#[derive(Debug, Clone, Serialize, Deserialize)]
435pub struct TextRegionEstimate {
436 pub width: u32,
438 pub height: u32,
440 pub edge_ratio: f64,
442 pub likely_has_text: bool,
444 pub estimated_text_coverage: f64,
446}
447
448pub struct ScreenshotOcr {
450 provider: Box<dyn OcrProvider>,
451}
452
453impl ScreenshotOcr {
454 #[must_use]
456 pub fn new(llm: LlmClient) -> Self {
457 Self {
458 provider: Box::new(LlmOcrProvider::new(llm)),
459 }
460 }
461
462 #[must_use]
464 pub fn with_provider(provider: Box<dyn OcrProvider>) -> Self {
465 Self { provider }
466 }
467
468 pub async fn process_screenshot(&self, data: &[u8]) -> Result<ScreenshotAnalysis> {
470 let format = ImageFormat::detect(data);
471
472 let (width, height) = ImageAnalyzer::get_dimensions(data)?;
474 let is_screenshot = ImageAnalyzer::is_likely_screenshot(width, height);
475 let text_estimate = ImageAnalyzer::estimate_text_regions(data)?;
476
477 if !text_estimate.likely_has_text {
479 return Ok(ScreenshotAnalysis {
480 ocr_result: None,
481 is_screenshot,
482 dimensions: (width, height),
483 format,
484 has_text: false,
485 text_estimate,
486 });
487 }
488
489 let ocr_result = self.provider.extract_text(data, format).await?;
491 let has_text = !ocr_result.text.trim().is_empty();
492
493 Ok(ScreenshotAnalysis {
494 ocr_result: Some(ocr_result),
495 is_screenshot,
496 dimensions: (width, height),
497 format,
498 has_text,
499 text_estimate,
500 })
501 }
502
503 pub async fn process_file(&self, path: &Path) -> Result<ScreenshotAnalysis> {
505 let data = std::fs::read(path)
506 .map_err(|e| AiError::Validation(format!("Failed to read file: {e}")))?;
507 self.process_screenshot(&data).await
508 }
509
510 pub async fn process_url(&self, url: &str) -> Result<ScreenshotAnalysis> {
512 let ocr_result = self.provider.extract_text_from_url(url).await?;
513 let has_text = !ocr_result.text.trim().is_empty();
514
515 Ok(ScreenshotAnalysis {
516 ocr_result: Some(ocr_result),
517 is_screenshot: true, dimensions: (0, 0), format: ImageFormat::Unknown,
520 has_text,
521 text_estimate: TextRegionEstimate {
522 width: 0,
523 height: 0,
524 edge_ratio: 0.0,
525 likely_has_text: true,
526 estimated_text_coverage: 0.0,
527 },
528 })
529 }
530}
531
532#[derive(Debug, Clone, Serialize, Deserialize)]
534pub struct ScreenshotAnalysis {
535 pub ocr_result: Option<OcrResult>,
537 pub is_screenshot: bool,
539 pub dimensions: (u32, u32),
541 pub format: ImageFormat,
543 pub has_text: bool,
545 pub text_estimate: TextRegionEstimate,
547}
548
549impl ScreenshotAnalysis {
550 #[must_use]
552 pub fn text(&self) -> Option<&str> {
553 self.ocr_result.as_ref().map(|r| r.text.as_str())
554 }
555
556 #[must_use]
558 pub fn confidence(&self) -> f64 {
559 self.ocr_result.as_ref().map_or(0.0, |r| r.confidence)
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566
567 #[test]
568 fn test_image_format_detection() {
569 let png_data = [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A];
571 assert_eq!(ImageFormat::detect(&png_data), ImageFormat::Png);
572
573 let jpeg_data = [0xFF, 0xD8, 0xFF, 0xE0];
575 assert_eq!(ImageFormat::detect(&jpeg_data), ImageFormat::Jpeg);
576
577 let unknown_data = [0x00, 0x01, 0x02, 0x03];
579 assert_eq!(ImageFormat::detect(&unknown_data), ImageFormat::Unknown);
580 }
581
582 #[test]
583 fn test_format_from_extension() {
584 assert_eq!(ImageFormat::from_extension("png"), ImageFormat::Png);
585 assert_eq!(ImageFormat::from_extension("jpg"), ImageFormat::Jpeg);
586 assert_eq!(ImageFormat::from_extension("jpeg"), ImageFormat::Jpeg);
587 assert_eq!(ImageFormat::from_extension("webp"), ImageFormat::WebP);
588 assert_eq!(ImageFormat::from_extension("txt"), ImageFormat::Unknown);
589 }
590
591 #[test]
592 fn test_mime_type() {
593 assert_eq!(ImageFormat::Png.mime_type(), "image/png");
594 assert_eq!(ImageFormat::Jpeg.mime_type(), "image/jpeg");
595 assert_eq!(ImageFormat::WebP.mime_type(), "image/webp");
596 }
597
598 #[test]
599 fn test_is_likely_screenshot() {
600 assert!(ImageAnalyzer::is_likely_screenshot(1920, 1080));
601 assert!(ImageAnalyzer::is_likely_screenshot(2560, 1440));
602 assert!(!ImageAnalyzer::is_likely_screenshot(500, 500));
603 }
604}