use {
image::{imageops::FilterType, load_from_memory, ImageError},
ort::{
session::{RunOptions, Session},
value::Value,
Error as OrtError,
},
std::{collections::HashMap, path::Path},
thiserror::Error,
};
#[derive(Debug, Error)]
pub enum DdddOcrError {
#[error("Image error: {0}")]
Image(#[from] ImageError),
#[error("ONNX Runtime error: {0}")]
Ort(#[from] OrtError),
}
const CHARSET_DATA: [&str; 8210] = include!("../charset.json");
pub struct DdddOcr {
session: Session,
}
impl DdddOcr {
pub fn new<P>(model_path: P) -> Result<Self, DdddOcrError>
where
P: AsRef<Path>,
{
let session = Session::builder()?.commit_from_file(model_path)?;
Ok(DdddOcr { session })
}
pub async fn classification(&mut self, img: &[u8]) -> Result<String, DdddOcrError> {
let img = load_from_memory(img)?;
let new_width = (img.width() as f32 * (64.0 / img.height() as f32)) as u32;
let resized = img.resize_exact(new_width, 64, FilterType::Lanczos3);
let gray_image = resized.to_luma8();
let height = gray_image.height() as usize;
let width = gray_image.width() as usize;
let mut img_data = Vec::with_capacity(height * width);
for pixel in gray_image.pixels() {
let normalized = (pixel[0] as f32 / 255.0 - 0.5) / 0.5;
img_data.push(normalized);
}
let shape = vec![1usize, 1, height, width];
let input_value = Value::from_array((shape, img_data))?;
let inputs = HashMap::from([("input1".to_string(), input_value)]);
let run_options = RunOptions::new()?;
let outputs = self.session.run_async(inputs, &run_options)?.await?;
let output = &outputs[0];
let (_, output_data) = output.try_extract_tensor::<i64>()?;
let mut result = String::new();
let mut last_item = 0i64;
for &item in output_data.iter() {
if item == last_item {
continue;
}
last_item = item;
if let Some(char_str) = CHARSET_DATA.get(item as usize) {
result.push_str(char_str);
}
}
Ok(result)
}
}