use std::path::Path;
use image::DynamicImage;
use super::config::OcrConfig;
use super::detector::TextDetector;
use super::error::OcrResult;
use super::preprocessor::crop_text_region;
use super::recognizer::TextRecognizer;
#[derive(Debug, Clone)]
pub struct OcrSpan {
pub text: String,
pub polygon: [[f32; 2]; 4],
pub confidence: f32,
pub char_confidences: Vec<f32>,
}
impl OcrSpan {
pub fn to_text_span(&self, sequence: usize, scale: f32) -> crate::layout::text_block::TextSpan {
use crate::geometry::Rect;
use crate::layout::text_block::{Color, FontWeight, TextSpan};
let min_x = self.polygon.iter().map(|p| p[0]).fold(f32::MAX, f32::min);
let max_x = self.polygon.iter().map(|p| p[0]).fold(f32::MIN, f32::max);
let min_y = self.polygon.iter().map(|p| p[1]).fold(f32::MAX, f32::min);
let max_y = self.polygon.iter().map(|p| p[1]).fold(f32::MIN, f32::max);
let bbox = Rect::new(min_x / scale, min_y / scale, max_x / scale, max_y / scale);
let height_pixels = max_y - min_y;
let font_size = self.estimate_font_size(height_pixels, scale);
TextSpan {
artifact_type: None,
text: self.text.clone(),
bbox,
font_name: "OCR".to_string(),
font_size,
font_weight: FontWeight::Normal,
is_italic: false,
is_monospace: false,
color: Color::black(),
mcid: None,
sequence,
split_boundary_before: false,
offset_semantic: false,
char_spacing: 0.0,
word_spacing: 0.0,
horizontal_scaling: 100.0,
primary_detected: false,
char_widths: Vec::new(),
heading_level: None,
rotation_degrees: 0.0,
}
}
fn estimate_font_size(&self, height_pixels: f32, scale: f32) -> f32 {
let height_points = height_pixels / scale;
(height_points * 0.75).clamp(6.0, 72.0) }
pub fn bounding_rect(&self) -> crate::geometry::Rect {
use crate::geometry::Rect;
let min_x = self.polygon.iter().map(|p| p[0]).fold(f32::MAX, f32::min);
let max_x = self.polygon.iter().map(|p| p[0]).fold(f32::MIN, f32::max);
let min_y = self.polygon.iter().map(|p| p[1]).fold(f32::MAX, f32::min);
let max_y = self.polygon.iter().map(|p| p[1]).fold(f32::MIN, f32::max);
Rect::new(min_x, min_y, max_x - min_x, max_y - min_y)
}
}
#[derive(Debug, Clone)]
pub struct OcrOutput {
pub spans: Vec<OcrSpan>,
pub total_confidence: f32,
}
impl OcrOutput {
pub fn text(&self) -> String {
self.spans
.iter()
.map(|s| s.text.as_str())
.collect::<Vec<_>>()
.join(" ")
}
fn reading_order_cmp(a: &[[f32; 2]; 4], b: &[[f32; 2]; 4]) -> std::cmp::Ordering {
const Y_BAND: f32 = 10.0;
let band = |y: f32| (y / Y_BAND).round() as i64;
band(a[0][1])
.cmp(&band(b[0][1]))
.then_with(|| crate::utils::safe_float_cmp(a[0][0], b[0][0]))
.then_with(|| crate::utils::safe_float_cmp(a[0][1], b[0][1]))
}
pub fn text_in_reading_order(&self) -> String {
let mut spans: Vec<_> = self.spans.iter().collect();
spans.sort_by(|a, b| Self::reading_order_cmp(&a.polygon, &b.polygon));
spans
.iter()
.map(|s| s.text.as_str())
.collect::<Vec<_>>()
.join(" ")
}
pub fn to_text_spans(&self, scale: f32) -> Vec<crate::layout::text_block::TextSpan> {
let mut spans_with_pos: Vec<_> = self.spans.iter().enumerate().collect();
spans_with_pos.sort_by(|(_, a), (_, b)| Self::reading_order_cmp(&a.polygon, &b.polygon));
spans_with_pos
.iter()
.enumerate()
.map(|(seq, (_, ocr_span))| ocr_span.to_text_span(seq, scale))
.collect()
}
}
pub struct OcrEngine {
detector: TextDetector,
recognizer: TextRecognizer,
config: OcrConfig,
}
impl OcrEngine {
pub fn new(
det_model_path: impl AsRef<Path>,
rec_model_path: impl AsRef<Path>,
dict_path: impl AsRef<Path>,
config: OcrConfig,
) -> OcrResult<Self> {
let detector = TextDetector::new(det_model_path, config.clone())?;
let recognizer = TextRecognizer::new(rec_model_path, dict_path, config.clone())?;
Ok(Self {
detector,
recognizer,
config,
})
}
pub fn from_bytes(
det_model_bytes: &[u8],
rec_model_bytes: &[u8],
dict_content: &str,
config: OcrConfig,
) -> OcrResult<Self> {
let detector = TextDetector::from_bytes(det_model_bytes, config.clone())?;
let recognizer = TextRecognizer::from_bytes(rec_model_bytes, dict_content, config.clone())?;
Ok(Self {
detector,
recognizer,
config,
})
}
pub fn ocr_image(&self, image: &DynamicImage) -> OcrResult<OcrOutput> {
let boxes = self.detector.detect(image)?;
if boxes.is_empty() {
return Ok(OcrOutput {
spans: Vec::new(),
total_confidence: 0.0,
});
}
let mut spans = Vec::new();
let mut total_confidence = 0.0;
for detected_box in &boxes {
let crop = crop_text_region(image, &detected_box.polygon)?;
let recognition = self.recognizer.recognize(&crop)?;
if recognition.confidence >= self.config.rec_threshold
&& !recognition.text.trim().is_empty()
{
total_confidence += recognition.confidence;
spans.push(OcrSpan {
text: recognition.text,
polygon: detected_box.polygon,
confidence: recognition.confidence,
char_confidences: recognition.char_confidences,
});
}
}
let avg_confidence = if spans.is_empty() {
0.0
} else {
total_confidence / spans.len() as f32
};
Ok(OcrOutput {
spans,
total_confidence: avg_confidence,
})
}
pub fn detector(&self) -> &TextDetector {
&self.detector
}
pub fn recognizer(&self) -> &TextRecognizer {
&self.recognizer
}
pub fn config(&self) -> &OcrConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ocr_output_text() {
let result = OcrOutput {
spans: vec![
OcrSpan {
text: "Hello".to_string(),
polygon: [[0.0, 0.0], [50.0, 0.0], [50.0, 20.0], [0.0, 20.0]],
confidence: 0.95,
char_confidences: vec![],
},
OcrSpan {
text: "World".to_string(),
polygon: [[60.0, 0.0], [110.0, 0.0], [110.0, 20.0], [60.0, 20.0]],
confidence: 0.92,
char_confidences: vec![],
},
],
total_confidence: 0.935,
};
assert_eq!(result.text(), "Hello World");
}
#[test]
fn test_reading_order_cmp_is_total_order() {
use std::cmp::Ordering;
let poly = |x: f32, y: f32| [[x, y], [x + 5.0, y], [x + 5.0, y + 2.0], [x, y + 2.0]];
let pts = [
poly(10.0, 0.0),
poly(5.0, 8.0),
poly(0.0, 16.0),
poly(0.0, 0.0),
poly(100.0, 1.0),
poly(50.0, 9.0),
poly(7.0, 23.0),
poly(7.0, 24.0),
poly(7.0, 25.0),
];
for a in &pts {
for b in &pts {
assert_eq!(
OcrOutput::reading_order_cmp(a, b),
OcrOutput::reading_order_cmp(b, a).reverse(),
"antisymmetry"
);
}
}
let le = |x, y| OcrOutput::reading_order_cmp(x, y) != Ordering::Greater;
for a in &pts {
for b in &pts {
for c in &pts {
if le(a, b) && le(b, c) {
assert!(le(a, c), "transitivity violated");
}
}
}
}
let mut v = pts.to_vec();
v.sort_by(|a, b| OcrOutput::reading_order_cmp(a, b));
assert_eq!(v.len(), pts.len());
}
#[test]
fn test_ocr_output_reading_order() {
let result = OcrOutput {
spans: vec![
OcrSpan {
text: "Line2".to_string(),
polygon: [[0.0, 50.0], [50.0, 50.0], [50.0, 70.0], [0.0, 70.0]],
confidence: 0.9,
char_confidences: vec![],
},
OcrSpan {
text: "Line1".to_string(),
polygon: [[0.0, 0.0], [50.0, 0.0], [50.0, 20.0], [0.0, 20.0]],
confidence: 0.9,
char_confidences: vec![],
},
],
total_confidence: 0.9,
};
assert_eq!(result.text_in_reading_order(), "Line1 Line2");
}
#[test]
fn test_ocr_span() {
let span = OcrSpan {
text: "Test".to_string(),
polygon: [[10.0, 20.0], [110.0, 20.0], [110.0, 60.0], [10.0, 60.0]],
confidence: 0.98,
char_confidences: vec![0.99, 0.97, 0.98, 0.99],
};
assert_eq!(span.text, "Test");
assert!(span.confidence > 0.9);
}
}