use std::path::PathBuf;
use std::sync::{LazyLock, Mutex};
use image::DynamicImage;
use oar_ocr::core::config::{OrtExecutionProvider, OrtGraphOptimizationLevel, OrtSessionConfig};
use oar_ocr::oarocr::{OAROCRBuilder, OAROCR};
use crate::element::Element;
#[derive(Debug, Clone)]
pub struct TextResult {
pub texts: Vec<Element>,
}
static OCR_INSTANCE: LazyLock<Mutex<Option<OAROCR>>> = LazyLock::new(|| Mutex::new(None));
const MODEL_FILES: &[&str] = &[
"ppocrv5_mobile_det.onnx",
"ppocrv5_mobile_rec.onnx",
"ppocrv5_dict.txt",
];
const OCR_MAX_DIM: u32 = 1200;
pub fn detect_text(img: &DynamicImage) -> TextResult {
let guard = match get_or_init_ocr() {
Ok(g) => g,
Err(e) => {
eprintln!("[OCR] {}", e);
return TextResult { texts: Vec::new() };
}
};
let ocr = match guard.as_ref() {
Some(o) => o,
None => {
eprintln!("[OCR] Engine not initialized");
return TextResult { texts: Vec::new() };
}
};
let (orig_w, orig_h) = (img.width(), img.height());
let (ocr_img, scale_x, scale_y) = if orig_w.max(orig_h) > OCR_MAX_DIM {
let ratio = OCR_MAX_DIM as f64 / orig_w.max(orig_h) as f64;
let new_w = (orig_w as f64 * ratio).round() as u32;
let new_h = (orig_h as f64 * ratio).round() as u32;
let resized = img.resize_exact(new_w, new_h, image::imageops::FilterType::Lanczos3);
println!(
" → OCR downsampled: {}x{} → {}x{} (scale={:.3})",
orig_w, orig_h, new_w, new_h, ratio
);
(
resized,
orig_w as f64 / new_w as f64,
orig_h as f64 / new_h as f64,
)
} else {
(img.clone(), 1.0, 1.0)
};
let results = match ocr.predict(vec![ocr_img.to_rgb8()]) {
Ok(r) => r,
Err(e) => {
eprintln!("[OCR] Prediction failed: {}", e);
return TextResult { texts: Vec::new() };
}
};
if results.is_empty() {
return TextResult { texts: Vec::new() };
}
let result = &results[0];
let mut elements = Vec::with_capacity(result.text_regions.len());
for region in &result.text_regions {
let text = region.text.as_deref().unwrap_or("");
let confidence = region.confidence.unwrap_or(0.0);
if text.len() < 1 {
continue;
}
if text
.chars()
.all(|c| c.is_ascii_punctuation() || c.is_whitespace())
{
continue;
}
let pass = if text.len() > 5 {
confidence >= 0.85
} else if text.len() == 1 {
confidence >= 0.97
} else {
confidence >= 0.9
};
if !pass {
continue;
}
let bbox = ®ion.bounding_box;
let x_min = (bbox.x_min() as f64 * scale_x) as i32;
let y_min = (bbox.y_min() as f64 * scale_y) as i32;
let x_max = (bbox.x_max() as f64 * scale_x) as i32;
let y_max = (bbox.y_max() as f64 * scale_y) as i32;
if (x_max - x_min) < 5 || (y_max - y_min) < 3 {
continue;
}
let mut element = Element::from_parts(0, x_min, y_min, x_max, y_max, "Text");
element.text_content = region.text.as_ref().map(|s| s.to_string());
elements.push(element);
}
println!("[OCR] Detected {} text elements", elements.len());
TextResult { texts: elements }
}
pub fn init_ocr() -> Result<(), String> {
let _guard = get_or_init_ocr()?;
println!("[OCR] Engine pre-initialized");
Ok(())
}
pub fn clean_ocr() {
if let Ok(mut guard) = OCR_INSTANCE.lock() {
*guard = None;
println!("[OCR] Engine cleaned up");
}
}
fn get_or_init_ocr() -> Result<std::sync::MutexGuard<'static, Option<OAROCR>>, String> {
let mut guard = OCR_INSTANCE
.lock()
.map_err(|e| format!("[OCR] Lock error: {}", e))?;
if guard.is_none() {
let models_dir = find_models_dir()?;
let det_path = models_dir.join("ppocrv5_mobile_det.onnx");
let rec_path = models_dir.join("ppocrv5_mobile_rec.onnx");
let dict_path = models_dir.join("ppocrv5_dict.txt");
for (name, p) in [
("detection model", &det_path),
("recognition model", &rec_path),
("dictionary", &dict_path),
] {
if !p.exists() {
return Err(format!(
"[OCR] {} not found at: {}\n\
Please download OCR models from:\n\
https://github.com/GreatV/oar-ocr/releases\n\
And place them in: {}",
name,
p.display(),
models_dir.display()
));
}
}
let mut ort_config = OrtSessionConfig::new()
.with_memory_pattern(true)
.with_intra_threads(4)
.with_inter_threads(2)
.with_optimization_level(OrtGraphOptimizationLevel::Level3)
.add_config_entry("session.intra_op.allow_spinning", "1")
.add_config_entry("session.inter_op.allow_spinning", "1");
ort_config = ort_config.with_execution_providers(build_provider_chain());
let ocr = OAROCRBuilder::new(
det_path.to_string_lossy().as_ref(),
rec_path.to_string_lossy().as_ref(),
dict_path.to_string_lossy().as_ref(),
)
.ort_session(ort_config)
.region_batch_size(16)
.build()
.map_err(|e| format!("[OCR] Failed to initialize: {}", e))?;
println!("[OCR] Engine initialized (oar-ocr + PaddleOCR v5)");
*guard = Some(ocr);
}
Ok(guard)
}
fn build_provider_chain() -> Vec<OrtExecutionProvider> {
let mut providers = Vec::new();
#[cfg(target_os = "windows")]
providers.push(OrtExecutionProvider::DirectML { device_id: Some(0) });
#[cfg(target_os = "macos")]
providers.push(OrtExecutionProvider::CoreML {
ane_only: Some(false),
subgraphs: Some(true),
});
providers.push(OrtExecutionProvider::CPU);
providers
}
fn find_models_dir() -> Result<PathBuf, String> {
if let Ok(dir) = std::env::var("QUASIVISION_MODELS_DIR") {
let p = PathBuf::from(&dir);
if p.join("ppocrv5_mobile_det.onnx").exists() {
return Ok(p);
}
}
if let Ok(exe_path) = std::env::current_exe() {
if let Some(exe_dir) = exe_path.parent() {
let p = exe_dir.join("ocr-models");
if p.join("ppocrv5_mobile_det.onnx").exists() {
return Ok(p);
}
}
}
if let Ok(cwd) = std::env::current_dir() {
let p = cwd.join("ocr-models");
if p.join("ppocrv5_mobile_det.onnx").exists() {
return Ok(p);
}
}
if let Some(manifest_dir) = option_env!("CARGO_MANIFEST_DIR") {
let dev_path = PathBuf::from(manifest_dir)
.join("resources")
.join("ocr-models");
if dev_path.join("ppocrv5_mobile_det.onnx").exists() {
return Ok(dev_path);
}
}
Err(format!(
"OCR models directory not found.\n\
Searched locations:\n\
- $QUASIVISION_MODELS_DIR\n\
- (exe dir)/ocr-models/\n\
- ./ocr-models/\n\
- <project>/resources/ocr-models/\n\n\
Setup:\n\
1. Download models from: https://github.com/GreatV/oar-ocr/releases\n\
2. Required files:\n\
{}\n\
3. Place them in: ./ocr-models/\n\
Or set env: QUASIVISION_MODELS_DIR=/path/to/models",
MODEL_FILES.join("\n "),
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_text_empty() {
let img = DynamicImage::new_rgb8(100, 100);
let result = detect_text(&img);
assert!(result.texts.is_empty());
}
}