use super::error::{OcrError, OcrResult};
pub(crate) trait InferenceBackend: Send + Sync {
fn run(&self, input: &ndarray::Array4<f32>) -> OcrResult<ndarray::ArrayD<f32>>;
}
#[allow(unused_variables)]
pub(crate) fn build_backend(
model_bytes: &[u8],
num_threads: usize,
) -> OcrResult<Box<dyn InferenceBackend>> {
#[cfg(feature = "ocr")]
{
Ok(Box::new(OrtBackend::from_bytes(model_bytes, num_threads)?))
}
#[cfg(all(not(feature = "ocr"), feature = "ocr-tract"))]
{
Ok(Box::new(TractBackend::from_bytes(model_bytes)?))
}
#[cfg(all(not(feature = "ocr"), not(feature = "ocr-tract")))]
{
Err(OcrError::ModelLoadError(
"no OCR inference backend compiled in (enable `ocr` or `ocr-tract`)".to_string(),
))
}
}
#[cfg(feature = "ocr")]
pub(crate) struct OrtBackend {
session: std::sync::Mutex<ort::session::Session>,
}
#[cfg(feature = "ocr")]
impl OrtBackend {
pub(crate) fn from_bytes(model_bytes: &[u8], num_threads: usize) -> OcrResult<Self> {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
ort::session::Session::builder()
.map_err(|e| {
OcrError::ModelLoadError(format!("Failed to create session builder: {}", e))
})?
.with_intra_threads(num_threads)
.map_err(|e| OcrError::ModelLoadError(format!("Failed to set threads: {}", e)))?
.commit_from_memory(model_bytes)
.map_err(|e| OcrError::ModelLoadError(format!("Failed to load model: {}", e)))
}));
let session = match result {
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(e),
Err(panic_payload) => {
let detail = panic_payload
.downcast::<String>()
.map(|b| *b)
.unwrap_or_else(|p| {
p.downcast::<&'static str>()
.map(|b| (*b).to_string())
.unwrap_or_else(|_| "unknown panic".to_string())
});
return Err(OcrError::ModelLoadError(format!(
"ONNX Runtime initialisation panicked (typically: \
libonnxruntime dylib failed to load — install onnxruntime \
or set ORT_DYLIB_PATH). Detail: {}",
detail
)));
},
};
Ok(Self {
session: std::sync::Mutex::new(session),
})
}
}
#[cfg(feature = "ocr")]
impl InferenceBackend for OrtBackend {
fn run(&self, input: &ndarray::Array4<f32>) -> OcrResult<ndarray::ArrayD<f32>> {
use ort::value::TensorRef;
let mut session = self.session.lock().map_err(|e| {
OcrError::InferenceError(format!("Failed to acquire session lock: {}", e))
})?;
let input_tensor = TensorRef::from_array_view(input).map_err(|e| {
OcrError::InferenceError(format!("Failed to create input tensor: {}", e))
})?;
let outputs = session
.run(ort::inputs!["x" => input_tensor])
.map_err(|e| OcrError::InferenceError(format!("Inference failed: {}", e)))?;
let (_, output_tensor) = outputs
.iter()
.next()
.ok_or_else(|| OcrError::InferenceError("No output tensor found".to_string()))?;
let view = output_tensor
.try_extract_array::<f32>()
.map_err(|e| OcrError::InferenceError(format!("Failed to extract output: {}", e)))?;
Ok(view.to_owned())
}
}
#[cfg(feature = "ocr-tract")]
#[cfg_attr(feature = "ocr", allow(dead_code))]
pub(crate) struct TractBackend {
model: tract_onnx::prelude::InferenceModel,
plans: std::sync::Mutex<std::collections::HashMap<Vec<usize>, std::sync::Arc<TractPlan>>>,
}
#[cfg(feature = "ocr-tract")]
#[cfg_attr(feature = "ocr", allow(dead_code))]
type TractPlan = tract_onnx::prelude::TypedRunnableModel<tract_onnx::prelude::TypedModel>;
#[cfg(feature = "ocr-tract")]
#[cfg_attr(feature = "ocr", allow(dead_code))]
impl TractBackend {
pub(crate) fn from_bytes(model_bytes: &[u8]) -> OcrResult<Self> {
use tract_onnx::prelude::*;
let model = tract_onnx::onnx()
.model_for_read(&mut std::io::Cursor::new(model_bytes))
.map_err(|e| OcrError::ModelLoadError(format!("tract: parse ONNX: {}", e)))?;
Ok(Self {
model,
plans: std::sync::Mutex::new(std::collections::HashMap::new()),
})
}
fn plan_for(&self, shape: &[usize]) -> OcrResult<std::sync::Arc<TractPlan>> {
use tract_onnx::prelude::*;
let key = shape.to_vec();
let mut plans = self
.plans
.lock()
.map_err(|e| OcrError::InferenceError(format!("tract: plan lock: {}", e)))?;
if let Some(p) = plans.get(&key) {
return Ok(p.clone());
}
let runnable = self
.model
.clone()
.with_input_fact(0, f32::fact(shape).into())
.map_err(|e| OcrError::InferenceError(format!("tract: input fact: {}", e)))?
.into_optimized()
.map_err(|e| OcrError::InferenceError(format!("tract: optimize: {}", e)))?
.into_runnable()
.map_err(|e| OcrError::InferenceError(format!("tract: runnable: {}", e)))?;
let arc = std::sync::Arc::new(runnable);
plans.insert(key, arc.clone());
Ok(arc)
}
}
#[cfg(feature = "ocr-tract")]
impl InferenceBackend for TractBackend {
fn run(&self, input: &ndarray::Array4<f32>) -> OcrResult<ndarray::ArrayD<f32>> {
use tract_onnx::prelude::*;
let shape: Vec<usize> = input.shape().to_vec();
let plan = self.plan_for(&shape)?;
let data: Vec<f32> = input.iter().copied().collect();
let tensor = Tensor::from_shape(&shape, &data)
.map_err(|e| OcrError::InferenceError(format!("tract: input tensor: {}", e)))?;
let result = plan
.run(tvec!(tensor.into()))
.map_err(|e| OcrError::InferenceError(format!("tract: run: {}", e)))?;
let out = result
.into_iter()
.next()
.ok_or_else(|| OcrError::InferenceError("tract: no output tensor".to_string()))?;
let out_shape: Vec<usize> = out.shape().to_vec();
let out_data = out
.as_slice::<f32>()
.map_err(|e| OcrError::InferenceError(format!("tract: extract output: {}", e)))?;
ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&out_shape), out_data.to_vec())
.map_err(|e| OcrError::InferenceError(format!("tract: reshape output: {}", e)))
}
}
#[cfg(all(test, feature = "ocr", feature = "ocr-tract"))]
mod parity {
use super::*;
fn models_dir() -> std::path::PathBuf {
std::env::var_os("PDF_OXIDE_MODEL_DIR")
.map(std::path::PathBuf::from)
.unwrap_or_else(|| {
std::path::PathBuf::from(std::env::var("HOME").expect("HOME"))
.join(".cache/pdf_oxide/models")
})
}
fn deterministic_input(shape: [usize; 4]) -> ndarray::Array4<f32> {
let n: usize = shape.iter().product();
let v: Vec<f32> = (0..n).map(|i| (i as f32 * 0.013).sin() * 2.0).collect();
ndarray::Array4::from_shape_vec(shape, v).expect("input shape")
}
fn diff(a: &ndarray::ArrayD<f32>, b: &ndarray::ArrayD<f32>) -> (f32, f64) {
assert_eq!(
a.shape(),
b.shape(),
"ort/tract output SHAPES differ: {:?} vs {:?}",
a.shape(),
b.shape()
);
let mut max = 0f32;
let mut sum = 0f64;
for (x, y) in a.iter().zip(b.iter()) {
let d = (x - y).abs();
max = max.max(d);
sum += d as f64;
}
(max, sum / a.len().max(1) as f64)
}
#[test]
#[ignore = "needs PDF_OXIDE_MODEL_DIR models + ORT_DYLIB_PATH"]
fn ort_vs_tract_detector() {
let m = std::fs::read(models_dir().join("det.onnx")).expect("det.onnx");
let ort = OrtBackend::from_bytes(&m, 1).expect("ort det");
let tract = TractBackend::from_bytes(&m).expect("tract det");
let inp = deterministic_input([1, 3, 640, 640]);
let o = ort.run(&inp).expect("ort run");
let t = tract.run(&inp).expect("tract run");
let (mx, mean) = diff(&o, &t);
println!("DET shape={:?} max_abs_diff={mx:.6} mean_abs_diff={mean:.6}", o.shape());
assert!(mx < 1e-2, "detector ort/tract diverge: max_abs_diff={mx}");
}
#[test]
#[ignore = "needs PDF_OXIDE_MODEL_DIR models + ORT_DYLIB_PATH"]
fn ort_vs_tract_recognizer() {
let m = std::fs::read(models_dir().join("rec.onnx")).expect("rec.onnx");
let ort = OrtBackend::from_bytes(&m, 1).expect("ort rec");
let tract = TractBackend::from_bytes(&m).expect("tract rec");
let inp = deterministic_input([1, 3, 48, 320]);
let o = ort.run(&inp).expect("ort run");
let t = tract.run(&inp).expect("tract run");
let (mx, mean) = diff(&o, &t);
println!("REC shape={:?} max_abs_diff={mx:.6} mean_abs_diff={mean:.6}", o.shape());
assert!(mx < 1e-2, "recognizer ort/tract diverge: max_abs_diff={mx}");
}
}