use crate::core::OCRError;
use crate::core::inference::{OrtInfer, TensorInput};
use crate::processors::{NormalizeImage, TensorLayout, UVDocPostProcess};
use image::{DynamicImage, RgbImage, imageops::FilterType};
type PreprocessResult = Result<(ndarray::Array4<f32>, Vec<(u32, u32)>), OCRError>;
#[derive(Debug, Clone)]
pub struct UVDocPreprocessConfig {
pub rec_image_shape: [usize; 3],
}
impl Default for UVDocPreprocessConfig {
fn default() -> Self {
Self {
rec_image_shape: [3, 512, 512],
}
}
}
#[derive(Debug, Clone)]
pub struct UVDocModelOutput {
pub images: Vec<RgbImage>,
}
#[derive(Debug)]
pub struct UVDocModel {
inference: OrtInfer,
normalizer: NormalizeImage,
postprocessor: UVDocPostProcess,
rec_image_shape: [usize; 3],
}
impl UVDocModel {
pub fn new(
inference: OrtInfer,
normalizer: NormalizeImage,
postprocessor: UVDocPostProcess,
rec_image_shape: [usize; 3],
) -> Self {
Self {
inference,
normalizer,
postprocessor,
rec_image_shape,
}
}
pub fn preprocess(&self, images: Vec<RgbImage>) -> PreprocessResult {
let mut original_sizes = Vec::with_capacity(images.len());
let mut processed_images = Vec::with_capacity(images.len());
let target_height = self.rec_image_shape[1] as u32;
let target_width = self.rec_image_shape[2] as u32;
let should_resize = target_height > 0 && target_width > 0;
for img in images {
let original_size = (img.width(), img.height());
original_sizes.push(original_size);
if should_resize && (img.width() != target_width || img.height() != target_height) {
let resized = DynamicImage::ImageRgb8(img).resize_exact(
target_width,
target_height,
FilterType::Triangle,
);
processed_images.push(resized);
} else {
processed_images.push(DynamicImage::ImageRgb8(img));
}
}
let batch_tensor = self.normalizer.normalize_batch_to(processed_images)?;
Ok((batch_tensor, original_sizes))
}
pub fn infer(
&self,
batch_tensor: &ndarray::Array4<f32>,
) -> Result<ndarray::Array4<f32>, OCRError> {
let input_name = self.inference.input_name();
let inputs = vec![(input_name, TensorInput::Array4(batch_tensor))];
let outputs = self
.inference
.infer(&inputs)
.map_err(|e| OCRError::Inference {
model_name: "UVDoc".to_string(),
context: format!(
"failed to run inference on batch with shape {:?}",
batch_tensor.shape()
),
source: Box::new(e),
})?;
let output = outputs
.into_iter()
.next()
.ok_or_else(|| OCRError::InvalidInput {
message: "UVDoc: no output returned from inference".to_string(),
})?;
output
.1
.try_into_array4_f32()
.map_err(|e| OCRError::Inference {
model_name: "UVDoc".to_string(),
context: "failed to convert output to 4D array".to_string(),
source: Box::new(e),
})
}
pub fn postprocess(
&self,
predictions: &ndarray::Array4<f32>,
original_sizes: &[(u32, u32)],
) -> Result<Vec<RgbImage>, OCRError> {
let mut images =
self.postprocessor
.apply_batch(predictions)
.map_err(|e| OCRError::ConfigError {
message: format!("Failed to postprocess rectification output: {}", e),
})?;
if images.len() != original_sizes.len() {
return Err(OCRError::InvalidInput {
message: format!(
"Mismatched rectification batch sizes: predictions={}, originals={}",
images.len(),
original_sizes.len()
),
});
}
for (img, &(orig_w, orig_h)) in images.iter_mut().zip(original_sizes) {
if orig_w == 0 || orig_h == 0 {
continue;
}
if img.width() != orig_w || img.height() != orig_h {
let resized = DynamicImage::ImageRgb8(std::mem::take(img)).resize_exact(
orig_w,
orig_h,
FilterType::Triangle,
);
*img = resized.into_rgb8();
}
}
Ok(images)
}
pub fn forward(&self, images: Vec<RgbImage>) -> Result<UVDocModelOutput, OCRError> {
let (batch_tensor, original_sizes) = self.preprocess(images)?;
let predictions = self.infer(&batch_tensor)?;
let rectified_images = self.postprocess(&predictions, &original_sizes)?;
Ok(UVDocModelOutput {
images: rectified_images,
})
}
}
#[derive(Debug, Default)]
pub struct UVDocModelBuilder {
preprocess_config: UVDocPreprocessConfig,
ort_config: Option<crate::core::config::OrtSessionConfig>,
}
impl UVDocModelBuilder {
pub fn new() -> Self {
Self {
preprocess_config: UVDocPreprocessConfig::default(),
ort_config: None,
}
}
pub fn preprocess_config(mut self, config: UVDocPreprocessConfig) -> Self {
self.preprocess_config = config;
self
}
pub fn rec_image_shape(mut self, shape: [usize; 3]) -> Self {
self.preprocess_config.rec_image_shape = shape;
self
}
pub fn with_ort_config(mut self, config: crate::core::config::OrtSessionConfig) -> Self {
self.ort_config = Some(config);
self
}
pub fn build(self, model_path: &std::path::Path) -> Result<UVDocModel, OCRError> {
let inference = if self.ort_config.is_some() {
use crate::core::config::ModelInferenceConfig;
let common_config = ModelInferenceConfig {
ort_session: self.ort_config,
..Default::default()
};
OrtInfer::from_config(&common_config, model_path, Some("image"))?
} else {
OrtInfer::new(model_path, Some("image"))?
};
let normalizer = NormalizeImage::with_color_order(
Some(1.0 / 255.0),
Some(vec![0.0, 0.0, 0.0]),
Some(vec![1.0, 1.0, 1.0]),
Some(TensorLayout::CHW),
Some(crate::processors::types::ColorOrder::BGR),
)?;
let postprocessor = UVDocPostProcess::new(255.0);
Ok(UVDocModel::new(
inference,
normalizer,
postprocessor,
self.preprocess_config.rec_image_shape,
))
}
}