use std::path::Path;
use image::DynamicImage;
use super::config::OcrConfig;
use super::detector::TextDetector;
use super::error::OcrResult;
use super::preprocessor::crop_text_region;
use super::recognizer::TextRecognizer;
#[derive(Debug, Clone)]
pub struct OcrSpan {
pub text: String,
pub polygon: [[f32; 2]; 4],
pub confidence: f32,
pub char_confidences: Vec<f32>,
}
impl OcrSpan {
pub fn to_text_span(&self, sequence: usize, scale: f32) -> crate::layout::text_block::TextSpan {
use crate::geometry::Rect;
use crate::layout::text_block::{Color, FontWeight, TextSpan};
let min_x = self.polygon.iter().map(|p| p[0]).fold(f32::MAX, f32::min);
let max_x = self.polygon.iter().map(|p| p[0]).fold(f32::MIN, f32::max);
let min_y = self.polygon.iter().map(|p| p[1]).fold(f32::MAX, f32::min);
let max_y = self.polygon.iter().map(|p| p[1]).fold(f32::MIN, f32::max);
let bbox = Rect::new(min_x / scale, min_y / scale, max_x / scale, max_y / scale);
let height_pixels = max_y - min_y;
let font_size = self.estimate_font_size(height_pixels, scale);
TextSpan {
artifact_type: None,
text: self.text.clone(),
bbox,
font_name: "OCR".to_string(),
font_size,
font_weight: FontWeight::Normal,
is_italic: false,
is_monospace: false,
color: Color::black(),
mcid: None,
sequence,
split_boundary_before: false,
offset_semantic: false,
char_spacing: 0.0,
word_spacing: 0.0,
horizontal_scaling: 100.0,
primary_detected: false,
char_widths: Vec::new(),
}
}
fn estimate_font_size(&self, height_pixels: f32, scale: f32) -> f32 {
let height_points = height_pixels / scale;
(height_points * 0.75).clamp(6.0, 72.0) }
pub fn bounding_rect(&self) -> crate::geometry::Rect {
use crate::geometry::Rect;
let min_x = self.polygon.iter().map(|p| p[0]).fold(f32::MAX, f32::min);
let max_x = self.polygon.iter().map(|p| p[0]).fold(f32::MIN, f32::max);
let min_y = self.polygon.iter().map(|p| p[1]).fold(f32::MAX, f32::min);
let max_y = self.polygon.iter().map(|p| p[1]).fold(f32::MIN, f32::max);
Rect::new(min_x, min_y, max_x - min_x, max_y - min_y)
}
}
#[derive(Debug, Clone)]
pub struct OcrOutput {
pub spans: Vec<OcrSpan>,
pub total_confidence: f32,
}
impl OcrOutput {
pub fn text(&self) -> String {
self.spans
.iter()
.map(|s| s.text.as_str())
.collect::<Vec<_>>()
.join(" ")
}
pub fn text_in_reading_order(&self) -> String {
let mut spans: Vec<_> = self.spans.iter().collect();
spans.sort_by(|a, b| {
let y_a = a.polygon[0][1];
let y_b = b.polygon[0][1];
if (y_a - y_b).abs() < 10.0 {
let x_a = a.polygon[0][0];
let x_b = b.polygon[0][0];
crate::utils::safe_float_cmp(x_a, x_b)
} else {
crate::utils::safe_float_cmp(y_a, y_b)
}
});
spans
.iter()
.map(|s| s.text.as_str())
.collect::<Vec<_>>()
.join(" ")
}
pub fn to_text_spans(&self, scale: f32) -> Vec<crate::layout::text_block::TextSpan> {
let mut spans_with_pos: Vec<_> = self.spans.iter().enumerate().collect();
spans_with_pos.sort_by(|(_, a), (_, b)| {
let y_a = a.polygon[0][1];
let y_b = b.polygon[0][1];
if (y_a - y_b).abs() < 10.0 {
let x_a = a.polygon[0][0];
let x_b = b.polygon[0][0];
crate::utils::safe_float_cmp(x_a, x_b)
} else {
crate::utils::safe_float_cmp(y_a, y_b)
}
});
spans_with_pos
.iter()
.enumerate()
.map(|(seq, (_, ocr_span))| ocr_span.to_text_span(seq, scale))
.collect()
}
}
pub struct OcrEngine {
detector: TextDetector,
recognizer: TextRecognizer,
config: OcrConfig,
}
impl OcrEngine {
pub fn new(
det_model_path: impl AsRef<Path>,
rec_model_path: impl AsRef<Path>,
dict_path: impl AsRef<Path>,
config: OcrConfig,
) -> OcrResult<Self> {
let detector = TextDetector::new(det_model_path, config.clone())?;
let recognizer = TextRecognizer::new(rec_model_path, dict_path, config.clone())?;
Ok(Self {
detector,
recognizer,
config,
})
}
pub fn from_bytes(
det_model_bytes: &[u8],
rec_model_bytes: &[u8],
dict_content: &str,
config: OcrConfig,
) -> OcrResult<Self> {
let detector = TextDetector::from_bytes(det_model_bytes, config.clone())?;
let recognizer = TextRecognizer::from_bytes(rec_model_bytes, dict_content, config.clone())?;
Ok(Self {
detector,
recognizer,
config,
})
}
pub fn ocr_image(&self, image: &DynamicImage) -> OcrResult<OcrOutput> {
let boxes = self.detector.detect(image)?;
if boxes.is_empty() {
return Ok(OcrOutput {
spans: Vec::new(),
total_confidence: 0.0,
});
}
let mut spans = Vec::new();
let mut total_confidence = 0.0;
for detected_box in &boxes {
let crop = crop_text_region(image, &detected_box.polygon)?;
let recognition = self.recognizer.recognize(&crop)?;
if recognition.confidence >= self.config.rec_threshold
&& !recognition.text.trim().is_empty()
{
total_confidence += recognition.confidence;
spans.push(OcrSpan {
text: recognition.text,
polygon: detected_box.polygon,
confidence: recognition.confidence,
char_confidences: recognition.char_confidences,
});
}
}
let avg_confidence = if spans.is_empty() {
0.0
} else {
total_confidence / spans.len() as f32
};
Ok(OcrOutput {
spans,
total_confidence: avg_confidence,
})
}
pub fn detector(&self) -> &TextDetector {
&self.detector
}
pub fn recognizer(&self) -> &TextRecognizer {
&self.recognizer
}
pub fn config(&self) -> &OcrConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ocr_output_text() {
let result = OcrOutput {
spans: vec![
OcrSpan {
text: "Hello".to_string(),
polygon: [[0.0, 0.0], [50.0, 0.0], [50.0, 20.0], [0.0, 20.0]],
confidence: 0.95,
char_confidences: vec![],
},
OcrSpan {
text: "World".to_string(),
polygon: [[60.0, 0.0], [110.0, 0.0], [110.0, 20.0], [60.0, 20.0]],
confidence: 0.92,
char_confidences: vec![],
},
],
total_confidence: 0.935,
};
assert_eq!(result.text(), "Hello World");
}
#[test]
fn test_ocr_output_reading_order() {
let result = OcrOutput {
spans: vec![
OcrSpan {
text: "Line2".to_string(),
polygon: [[0.0, 50.0], [50.0, 50.0], [50.0, 70.0], [0.0, 70.0]],
confidence: 0.9,
char_confidences: vec![],
},
OcrSpan {
text: "Line1".to_string(),
polygon: [[0.0, 0.0], [50.0, 0.0], [50.0, 20.0], [0.0, 20.0]],
confidence: 0.9,
char_confidences: vec![],
},
],
total_confidence: 0.9,
};
assert_eq!(result.text_in_reading_order(), "Line1 Line2");
}
#[test]
fn test_ocr_span() {
let span = OcrSpan {
text: "Test".to_string(),
polygon: [[10.0, 20.0], [110.0, 20.0], [110.0, 60.0], [10.0, 60.0]],
confidence: 0.98,
char_confidences: vec![0.99, 0.97, 0.98, 0.99],
};
assert_eq!(span.text, "Test");
assert!(span.confidence > 0.9);
}
}