use crate::core::OCRError;
use crate::core::inference::{OrtInfer, TensorInput};
use crate::domain::adapters::preprocessing::rgb_to_dynamic;
use crate::processors::{NormalizeImage, TensorLayout};
use crate::utils::topk::Topk;
use image::{RgbImage, imageops::FilterType};
#[derive(Debug, Clone)]
pub struct PPLCNetPreprocessConfig {
pub input_shape: (u32, u32),
pub resize_filter: FilterType,
pub resize_short: Option<u32>,
pub normalize_scale: f32,
pub normalize_mean: Vec<f32>,
pub normalize_std: Vec<f32>,
pub tensor_layout: TensorLayout,
}
impl Default for PPLCNetPreprocessConfig {
fn default() -> Self {
Self {
input_shape: (224, 224),
resize_filter: FilterType::Triangle,
resize_short: Some(256),
normalize_scale: 1.0 / 255.0,
normalize_mean: vec![0.485, 0.456, 0.406],
normalize_std: vec![0.229, 0.224, 0.225],
tensor_layout: TensorLayout::CHW,
}
}
}
#[derive(Debug, Clone)]
pub struct PPLCNetPostprocessConfig {
pub labels: Vec<String>,
pub topk: usize,
}
impl Default for PPLCNetPostprocessConfig {
fn default() -> Self {
Self {
labels: vec![],
topk: 1,
}
}
}
#[derive(Debug, Clone)]
pub struct PPLCNetModelOutput {
pub class_ids: Vec<Vec<usize>>,
pub scores: Vec<Vec<f32>>,
pub label_names: Option<Vec<Vec<String>>>,
}
#[derive(Debug)]
pub struct PPLCNetModel {
inference: OrtInfer,
normalizer: NormalizeImage,
topk_processor: Topk,
input_shape: (u32, u32),
resize_filter: FilterType,
resize_short: Option<u32>,
}
impl PPLCNetModel {
pub fn new(
inference: OrtInfer,
normalizer: NormalizeImage,
topk_processor: Topk,
input_shape: (u32, u32),
resize_filter: FilterType,
resize_short: Option<u32>,
) -> Self {
Self {
inference,
normalizer,
topk_processor,
input_shape,
resize_filter,
resize_short,
}
}
pub fn preprocess(&self, images: Vec<RgbImage>) -> Result<ndarray::Array4<f32>, OCRError> {
let (crop_h, crop_w) = self.input_shape;
let resized_rgb: Vec<RgbImage> = if let Some(resize_short) = self.resize_short {
images
.into_iter()
.filter_map(|img| {
let (w, h) = (img.width(), img.height());
if w == 0 || h == 0 {
return None;
}
let short = w.min(h) as f32;
let scale = (resize_short as f32) / short;
let new_w = ((w as f32) * scale).round().max(crop_w as f32) as u32;
let new_h = ((h as f32) * scale).round().max(crop_h as f32) as u32;
let resized = image::imageops::resize(&img, new_w, new_h, self.resize_filter);
let x1 = (new_w.saturating_sub(crop_w)) / 2;
let y1 = (new_h.saturating_sub(crop_h)) / 2;
let cropped =
image::imageops::crop_imm(&resized, x1, y1, crop_w, crop_h).to_image();
Some(cropped)
})
.collect()
} else {
images
.into_iter()
.filter_map(|img| {
let (w, h) = (img.width(), img.height());
if w == 0 || h == 0 {
return None;
}
Some(image::imageops::resize(
&img,
crop_w,
crop_h,
self.resize_filter,
))
})
.collect()
};
let dynamic_images = rgb_to_dynamic(resized_rgb);
self.normalizer.normalize_batch_to(dynamic_images)
}
pub fn infer(
&self,
batch_tensor: &ndarray::Array4<f32>,
) -> Result<ndarray::Array2<f32>, OCRError> {
let input_name = self.inference.input_name();
let inputs = vec![(input_name, TensorInput::Array4(batch_tensor))];
let outputs = self
.inference
.infer(&inputs)
.map_err(|e| OCRError::Inference {
model_name: "PP-LCNet".to_string(),
context: format!(
"failed to run inference on batch with shape {:?}",
batch_tensor.shape()
),
source: Box::new(e),
})?;
let output = outputs
.into_iter()
.next()
.ok_or_else(|| OCRError::InvalidInput {
message: "PP-LCNet: no output returned from inference".to_string(),
})?;
output
.1
.try_into_array2_f32()
.map_err(|e| OCRError::Inference {
model_name: "PP-LCNet".to_string(),
context: "failed to convert output to 2D array".to_string(),
source: Box::new(e),
})
}
pub fn postprocess(
&self,
predictions: &ndarray::Array2<f32>,
config: &PPLCNetPostprocessConfig,
) -> Result<PPLCNetModelOutput, OCRError> {
let predictions_vec: Vec<Vec<f32>> =
predictions.outer_iter().map(|row| row.to_vec()).collect();
let topk_result = self
.topk_processor
.process(&predictions_vec, config.topk)
.unwrap_or_else(|_| crate::utils::topk::TopkResult {
indexes: vec![],
scores: vec![],
class_names: None,
});
let class_ids = topk_result.indexes;
let scores = topk_result.scores;
let label_names = if !config.labels.is_empty() {
Some(
class_ids
.iter()
.map(|ids| {
ids.iter()
.map(|&id| {
config
.labels
.get(id)
.cloned()
.unwrap_or_else(|| format!("class_{}", id))
})
.collect()
})
.collect(),
)
} else {
topk_result.class_names
};
Ok(PPLCNetModelOutput {
class_ids,
scores,
label_names,
})
}
pub fn forward(
&self,
images: Vec<RgbImage>,
config: &PPLCNetPostprocessConfig,
) -> Result<PPLCNetModelOutput, OCRError> {
let batch_tensor = self.preprocess(images)?;
let predictions = self.infer(&batch_tensor)?;
self.postprocess(&predictions, config)
}
}
#[derive(Debug, Default)]
pub struct PPLCNetModelBuilder {
preprocess_config: PPLCNetPreprocessConfig,
ort_config: Option<crate::core::config::OrtSessionConfig>,
}
impl PPLCNetModelBuilder {
pub fn new() -> Self {
Self {
preprocess_config: PPLCNetPreprocessConfig::default(),
ort_config: None,
}
}
pub fn preprocess_config(mut self, config: PPLCNetPreprocessConfig) -> Self {
self.preprocess_config = config;
self
}
pub fn input_shape(mut self, shape: (u32, u32)) -> Self {
self.preprocess_config.input_shape = shape;
self
}
pub fn resize_filter(mut self, filter: FilterType) -> Self {
self.preprocess_config.resize_filter = filter;
self
}
pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
self.ort_config = Some(config);
self
}
pub fn build(self, model_path: &std::path::Path) -> Result<PPLCNetModel, OCRError> {
let inference = if self.ort_config.is_some() {
use crate::core::config::ModelInferenceConfig;
let common_config = ModelInferenceConfig {
ort_session: self.ort_config,
..Default::default()
};
OrtInfer::from_config(&common_config, model_path, None)?
} else {
OrtInfer::new(model_path, None)?
};
let mean = self.preprocess_config.normalize_mean.clone();
let std = self.preprocess_config.normalize_std.clone();
let normalizer = NormalizeImage::with_color_order(
Some(self.preprocess_config.normalize_scale),
Some(mean),
Some(std),
Some(self.preprocess_config.tensor_layout),
Some(crate::processors::types::ColorOrder::RGB),
)?;
let topk_processor = Topk::new(None);
Ok(PPLCNetModel::new(
inference,
normalizer,
topk_processor,
self.preprocess_config.input_shape,
self.preprocess_config.resize_filter,
self.preprocess_config.resize_short,
))
}
}