use std::path::{Path, PathBuf};
use crate::ocr::{OcrBackend, OcrError, OcrOptions};
use crate::pixmap::Pixmap;
use crate::text::{Rect, TextLayer, TextZone, TextZoneKind};
pub struct OnnxBackend {
model_path: PathBuf,
model: tract_onnx::prelude::SimplePlan<
tract_onnx::prelude::TypedFact,
Box<dyn tract_onnx::prelude::TypedOp>,
tract_onnx::prelude::Graph<
tract_onnx::prelude::TypedFact,
Box<dyn tract_onnx::prelude::TypedOp>,
>,
>,
vocab: Vec<char>,
}
impl OnnxBackend {
pub fn load(model_path: impl AsRef<Path>, vocab_path: Option<&Path>) -> Result<Self, OcrError> {
use tract_onnx::prelude::*;
let model_path = model_path.as_ref().to_path_buf();
let model = tract_onnx::onnx()
.model_for_path(&model_path)
.map_err(|e| OcrError::InitFailed(format!("failed to load ONNX model: {e}")))?
.into_optimized()
.map_err(|e| OcrError::InitFailed(format!("failed to optimize model: {e}")))?
.into_runnable()
.map_err(|e| OcrError::InitFailed(format!("failed to make model runnable: {e}")))?;
let vocab = if let Some(vp) = vocab_path {
std::fs::read_to_string(vp)?.chars().collect()
} else {
(' '..='~').collect()
};
Ok(Self {
model_path,
model,
vocab,
})
}
fn preprocess(&self, pixmap: &Pixmap) -> tract_onnx::prelude::Tensor {
use tract_onnx::prelude::*;
let gray = pixmap.to_gray8();
let w = gray.width as usize;
let h = gray.height as usize;
let data: Vec<f32> = gray.data.iter().map(|&v| v as f32 / 255.0).collect();
tract_ndarray::Array4::from_shape_vec((1, 1, h, w), data)
.expect("shape mismatch")
.into_tensor()
}
fn ctc_decode(&self, output: &[f32], seq_len: usize) -> String {
let vocab_size = self.vocab.len();
let mut result = String::new();
let mut prev_idx = None;
for t in 0..seq_len {
let offset = t * (vocab_size + 1); if offset + vocab_size >= output.len() {
break;
}
let mut best_idx = 0;
let mut best_val = f32::NEG_INFINITY;
for i in 0..=vocab_size {
let val = output[offset + i];
if val > best_val {
best_val = val;
best_idx = i;
}
}
if best_idx > 0 && Some(best_idx) != prev_idx {
if let Some(&ch) = self.vocab.get(best_idx - 1) {
result.push(ch);
}
}
prev_idx = Some(best_idx);
}
result
}
}
impl OcrBackend for OnnxBackend {
fn recognize(&self, pixmap: &Pixmap, _options: &OcrOptions) -> Result<TextLayer, OcrError> {
use tract_onnx::prelude::*;
let input = self.preprocess(pixmap);
let result = self
.model
.run(tvec![input.into()])
.map_err(|e| OcrError::RecognitionFailed(format!("model inference failed: {e}")))?;
let output = result[0]
.to_array_view::<f32>()
.map_err(|e| OcrError::RecognitionFailed(format!("output tensor error: {e}")))?;
let shape = output.shape();
let seq_len = if shape.len() >= 2 { shape[1] } else { shape[0] };
let text = self.ctc_decode(output.as_slice().unwrap_or(&[]), seq_len);
let zones = vec![TextZone {
kind: TextZoneKind::Page,
rect: Rect {
x: 0,
y: 0,
width: pixmap.width,
height: pixmap.height,
},
text: text.clone(),
children: vec![TextZone {
kind: TextZoneKind::Line,
rect: Rect {
x: 0,
y: 0,
width: pixmap.width,
height: pixmap.height,
},
text: text.clone(),
children: Vec::new(),
}],
}];
Ok(TextLayer { text, zones })
}
}