use crate::{Document, Error, Extractor, Result};
use oar_ocr::prelude::*;
use std::path::{Path, PathBuf};
pub struct OnnxOcrExtractor {
ocr: OAROCR,
detection_model: PathBuf,
}
impl OnnxOcrExtractor {
pub fn with_models(
detection_model: &Path,
recognition_model: &Path,
dict: &Path,
) -> Result<Self> {
for (label, p) in [
("detection_model", detection_model),
("recognition_model", recognition_model),
("dict", dict),
] {
if !p.exists() {
return Err(Error::MissingDependency {
name: format!("oar-ocr {label}"),
details: format!(
"file not found: {} — download from \
https://github.com/GreatV/oar-ocr/releases",
p.display()
),
});
}
}
let ocr = OAROCRBuilder::new(
detection_model.to_path_buf(),
recognition_model.to_path_buf(),
dict.to_path_buf(),
)
.build()
.map_err(|e| {
Error::ParseError(format!(
"oar-ocr pipeline construction failed (libonnxruntime missing or \
model file rejected): {e}"
))
})?;
Ok(Self {
ocr,
detection_model: detection_model.to_path_buf(),
})
}
#[must_use]
pub fn detection_model_path(&self) -> &Path {
&self.detection_model
}
}
impl Extractor for OnnxOcrExtractor {
fn extensions(&self) -> &[&'static str] {
&["png", "jpg", "jpeg", "tiff", "tif", "bmp", "gif"]
}
fn name(&self) -> &'static str {
"ocr-onnx"
}
fn extract(&self, path: &Path) -> Result<Document> {
let img = image::open(path)
.map_err(|e| {
Error::ParseError(format!("could not decode image {}: {e}", path.display()))
})?
.to_rgb8();
let results = self.ocr.predict(vec![img]).map_err(|e| {
Error::ParseError(format!("oar-ocr predict failed on {}: {e}", path.display()))
})?;
let mut markdown = String::new();
if let Some(result) = results.first() {
for region in &result.text_regions {
let Some(arc_text) = region.text.as_ref() else {
continue;
};
let text = arc_text.trim();
if text.is_empty() {
continue;
}
if !markdown.is_empty() {
markdown.push('\n');
}
markdown.push_str(text);
}
}
Ok(Document {
markdown,
title: None,
metadata: std::collections::HashMap::new(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extensions_cover_common_image_formats() {
let ext: &[&str] = &["png", "jpg", "jpeg", "tiff", "tif", "bmp", "gif"];
for required in ["png", "jpg", "jpeg", "tiff", "bmp", "gif"] {
assert!(
ext.contains(&required),
"expected ocr-onnx to handle .{required}, got {ext:?}"
);
}
}
#[test]
fn missing_model_file_returns_typed_error() {
let result = OnnxOcrExtractor::with_models(
Path::new("/nonexistent-detection.onnx"),
Path::new("/nonexistent-recognition.onnx"),
Path::new("/nonexistent-dict.txt"),
);
assert!(matches!(result, Err(Error::MissingDependency { .. })));
}
#[test]
#[ignore = "requires libonnxruntime AND model files in tests/fixtures/onnx-models/"]
fn extracts_text_from_a_real_image() {
let extractor = OnnxOcrExtractor::with_models(
Path::new("tests/fixtures/onnx-models/pp-ocrv5_mobile_det.onnx"),
Path::new("tests/fixtures/onnx-models/en_pp-ocrv5_mobile_rec.onnx"),
Path::new("tests/fixtures/onnx-models/ppocrv5_en_dict.txt"),
)
.expect("model construction failed");
let doc = extractor
.extract(Path::new("tests/fixtures/hello.png"))
.expect("extraction failed");
assert!(
!doc.markdown.is_empty(),
"expected non-empty markdown from hello.png"
);
assert!(
doc.markdown.to_lowercase().contains("hello"),
"expected 'hello' in OCR output: {:?}",
doc.markdown
);
}
}