use crate::core::OCRError;
use crate::core::inference::{OrtInfer, TensorInput};
use crate::processors::{CTCLabelDecode, OCRResize};
use image::RgbImage;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct CRNNModelOutput {
pub texts: Vec<String>,
pub scores: Vec<f32>,
pub char_positions: Vec<Vec<f32>>,
pub char_col_indices: Vec<Vec<usize>>,
pub sequence_lengths: Vec<usize>,
}
#[derive(Debug)]
pub struct CRNNModel {
inference: OrtInfer,
resizer: OCRResize,
decoder: CTCLabelDecode,
}
impl CRNNModel {
pub fn new(inference: OrtInfer, resizer: OCRResize, decoder: CTCLabelDecode) -> Self {
Self {
inference,
resizer,
decoder,
}
}
pub fn preprocess(&self, images: Vec<RgbImage>) -> Result<ndarray::Array4<f32>, OCRError> {
if images.is_empty() {
return Ok(ndarray::Array4::zeros((0, 0, 0, 0)));
}
let [_img_c, img_h, img_w] = self.resizer.rec_image_shape;
let base_ratio = img_w as f32 / img_h.max(1) as f32;
let max_wh_ratio = images
.iter()
.map(|img| img.width() as f32 / img.height().max(1) as f32)
.fold(base_ratio, |acc, r| acc.max(r));
let tensor_width = ((img_h as f32 * max_wh_ratio) as usize).min(self.resizer.max_img_w);
let batch_size = images.len();
let mut batch_tensor = ndarray::Array4::<f32>::zeros((batch_size, 3, img_h, tensor_width));
for (batch_idx, img) in images.iter().enumerate() {
let (orig_w, orig_h) = (img.width() as f32, img.height() as f32);
let ratio = orig_w / orig_h;
let resized_w = ((img_h as f32 * ratio).ceil() as usize).min(tensor_width);
let resized = image::imageops::resize(
img,
resized_w as u32,
img_h as u32,
image::imageops::FilterType::Triangle,
);
for y in 0..img_h {
for x in 0..resized_w {
let pixel = resized.get_pixel(x as u32, y as u32);
let b = (pixel[2] as f32 / 255.0 - 0.5) / 0.5;
let g = (pixel[1] as f32 / 255.0 - 0.5) / 0.5;
let r = (pixel[0] as f32 / 255.0 - 0.5) / 0.5;
batch_tensor[[batch_idx, 0, y, x]] = b;
batch_tensor[[batch_idx, 1, y, x]] = g;
batch_tensor[[batch_idx, 2, y, x]] = r;
}
}
}
Ok(batch_tensor)
}
pub fn infer(
&self,
batch_tensor: &ndarray::Array4<f32>,
) -> Result<ndarray::Array3<f32>, OCRError> {
let input_name = self.inference.input_name();
let inputs = vec![(input_name, TensorInput::Array4(batch_tensor))];
let outputs = self
.inference
.infer(&inputs)
.map_err(|e| OCRError::Inference {
model_name: "CRNN".to_string(),
context: format!(
"failed to run inference on batch with shape {:?}",
batch_tensor.shape()
),
source: Box::new(e),
})?;
let output = outputs
.into_iter()
.next()
.ok_or_else(|| OCRError::InvalidInput {
message: "CRNN: no output returned from inference".to_string(),
})?;
output
.1
.try_into_array3_f32()
.map_err(|e| OCRError::Inference {
model_name: "CRNN".to_string(),
context: "failed to convert output to 3D array".to_string(),
source: Box::new(e),
})
}
pub fn postprocess(
&self,
predictions: &ndarray::Array3<f32>,
return_positions: bool,
) -> CRNNModelOutput {
if return_positions {
let (texts, scores, char_positions, char_col_indices, sequence_lengths) =
self.decoder.apply_with_positions(predictions);
CRNNModelOutput {
texts,
scores,
char_positions,
char_col_indices,
sequence_lengths,
}
} else {
let (texts, scores) = self.decoder.apply(predictions);
CRNNModelOutput {
texts,
scores,
char_positions: Vec::new(),
char_col_indices: Vec::new(),
sequence_lengths: Vec::new(),
}
}
}
pub fn forward(
&self,
images: Vec<RgbImage>,
return_positions: bool,
) -> Result<CRNNModelOutput, OCRError> {
tracing::debug!("CRNN forward: {} images", images.len());
if !images.is_empty() {
tracing::debug!(
"First image size: {}x{}",
images[0].width(),
images[0].height()
);
}
let batch_tensor = self.preprocess(images)?;
tracing::debug!("CRNN preprocess output shape: {:?}", batch_tensor.shape());
let predictions = self.infer(&batch_tensor)?;
tracing::debug!("CRNN infer output shape: {:?}", predictions.shape());
let output = self.postprocess(&predictions, return_positions);
tracing::debug!(
"CRNN postprocess: {} texts, first 3: {:?}",
output.texts.len(),
&output.texts[..3.min(output.texts.len())]
);
Ok(output)
}
}
#[derive(Debug, Clone)]
pub struct CRNNPreprocessConfig {
pub model_input_shape: [usize; 3],
pub max_img_w: Option<usize>,
}
impl Default for CRNNPreprocessConfig {
fn default() -> Self {
Self {
model_input_shape: [3, 48, 320],
max_img_w: None,
}
}
}
pub struct CRNNModelBuilder {
preprocess_config: CRNNPreprocessConfig,
character_dict: Option<Vec<String>>,
ort_config: Option<crate::core::config::OrtSessionConfig>,
}
impl CRNNModelBuilder {
pub fn new() -> Self {
Self {
preprocess_config: CRNNPreprocessConfig::default(),
character_dict: None,
ort_config: None,
}
}
pub fn preprocess_config(mut self, config: CRNNPreprocessConfig) -> Self {
self.preprocess_config = config;
self
}
pub fn model_input_shape(mut self, shape: [usize; 3]) -> Self {
self.preprocess_config.model_input_shape = shape;
self
}
pub fn character_dict(mut self, character_dict: Vec<String>) -> Self {
self.character_dict = Some(character_dict);
self
}
pub fn max_img_w(mut self, max_img_w: usize) -> Self {
self.preprocess_config.max_img_w = Some(max_img_w);
self
}
pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
self.ort_config = Some(config);
self
}
pub fn build(self, model_path: &Path) -> Result<CRNNModel, OCRError> {
let inference = if self.ort_config.is_some() {
let common = crate::core::config::ModelInferenceConfig {
ort_session: self.ort_config,
..Default::default()
};
OrtInfer::from_config(&common, model_path, None)?
} else {
OrtInfer::new(model_path, None)?
};
let resizer = OCRResize::new(Some(self.preprocess_config.model_input_shape), None);
let decoder = if let Some(character_dict) = self.character_dict {
CTCLabelDecode::from_string_list(Some(&character_dict), true, false)
} else {
CTCLabelDecode::new(None, true)
};
Ok(CRNNModel::new(inference, resizer, decoder))
}
}
impl Default for CRNNModelBuilder {
fn default() -> Self {
Self::new()
}
}