use crate::core::OCRError;
use crate::core::inference::{OrtInfer, TensorInput};
use crate::core::validation::{
validate_division, validate_image_dimensions, validate_non_empty, validate_positive,
validate_same_length,
};
use crate::processors::types::ColorOrder;
use crate::processors::{
DetResizeForTest, ImageScaleInfo, LimitType, NormalizeImage, TensorLayout,
};
use image::imageops::FilterType;
use image::{DynamicImage, RgbImage};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScaleAwareDetectorInferenceMode {
ScaleFactorOnly,
ScaleFactorAndImageShape,
}
#[derive(Debug, Clone)]
pub struct ScaleAwareDetectorPreprocessConfig {
pub image_shape: (u32, u32),
pub keep_ratio: bool,
pub limit_side_len: u32,
pub scale: f32,
pub mean: Vec<f32>,
pub std: Vec<f32>,
pub resize_filter: FilterType,
pub color_order: ColorOrder,
}
impl ScaleAwareDetectorPreprocessConfig {
pub fn picodet() -> Self {
Self {
image_shape: (800, 608),
keep_ratio: false,
limit_side_len: 800,
scale: 1.0 / 255.0,
mean: vec![0.485, 0.456, 0.406],
std: vec![0.229, 0.224, 0.225],
resize_filter: FilterType::Lanczos3,
color_order: ColorOrder::BGR,
}
}
pub fn pp_doclayout() -> Self {
Self {
image_shape: (800, 800),
keep_ratio: false,
limit_side_len: 800,
scale: 1.0 / 255.0,
mean: vec![0.0, 0.0, 0.0],
std: vec![1.0, 1.0, 1.0],
resize_filter: FilterType::CatmullRom,
color_order: ColorOrder::RGB,
}
}
pub fn validate(&self) -> Result<(), OCRError> {
validate_image_dimensions(
self.image_shape.0,
self.image_shape.1,
"ScaleAwareDetectorPreprocessConfig",
)?;
validate_positive(self.limit_side_len, "limit_side_len")?;
validate_positive(self.scale, "scale")?;
validate_same_length(&self.mean, &self.std, "mean", "std")?;
validate_non_empty(&self.mean, "mean")?;
validate_non_empty(&self.std, "std")?;
for (i, &std_val) in self.std.iter().enumerate() {
validate_positive(std_val, &format!("std[{}]", i))?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ScaleAwareDetectorPostprocessConfig {
pub num_classes: usize,
}
#[derive(Debug, Clone)]
pub struct ScaleAwareDetectorModelOutput {
pub predictions: ndarray::Array4<f32>,
}
type ScaleAwareDetectorPreprocessArtifacts = (
ndarray::Array4<f32>,
Vec<ImageScaleInfo>,
Vec<[f32; 2]>,
Vec<[f32; 2]>,
);
type ScaleAwareDetectorPreprocessResult = Result<ScaleAwareDetectorPreprocessArtifacts, OCRError>;
#[derive(Debug)]
pub struct ScaleAwareDetectorModel {
inference: OrtInfer,
resizer: DetResizeForTest,
normalizer: NormalizeImage,
inference_mode: ScaleAwareDetectorInferenceMode,
_preprocess_config: ScaleAwareDetectorPreprocessConfig,
}
impl ScaleAwareDetectorModel {
pub fn new(
inference: OrtInfer,
preprocess_config: ScaleAwareDetectorPreprocessConfig,
inference_mode: ScaleAwareDetectorInferenceMode,
) -> Result<Self, OCRError> {
preprocess_config.validate()?;
let resizer = DetResizeForTest::new(
None,
Some((
preprocess_config.image_shape.0,
preprocess_config.image_shape.1,
)),
Some(preprocess_config.keep_ratio),
Some(preprocess_config.limit_side_len),
Some(LimitType::Max),
None,
None,
)
.with_filter(preprocess_config.resize_filter);
let normalizer = NormalizeImage::with_color_order_from_rgb_stats(
Some(preprocess_config.scale),
preprocess_config.mean.clone(),
preprocess_config.std.clone(),
Some(TensorLayout::CHW),
preprocess_config.color_order,
)?;
Ok(Self {
inference,
resizer,
normalizer,
inference_mode,
_preprocess_config: preprocess_config,
})
}
pub fn preprocess(&self, images: Vec<RgbImage>) -> ScaleAwareDetectorPreprocessResult {
validate_non_empty(&images, "images")?;
let orig_shapes: Vec<[f32; 2]> = images
.iter()
.map(|img| [img.height() as f32, img.width() as f32])
.collect();
let dynamic_images: Vec<DynamicImage> =
images.into_iter().map(DynamicImage::ImageRgb8).collect();
let (resized_images, img_shapes) = self.resizer.apply(
dynamic_images,
None, None, None,
);
let resized_shapes: Vec<[f32; 2]> = resized_images
.iter()
.map(|img| [img.height() as f32, img.width() as f32])
.collect();
let batch_tensor = self.normalizer.normalize_batch_to(resized_images)?;
Ok((batch_tensor, img_shapes, orig_shapes, resized_shapes))
}
pub fn infer(
&self,
batch_tensor: &ndarray::Array4<f32>,
orig_shapes: &[[f32; 2]],
resized_shapes: &[[f32; 2]],
) -> Result<ndarray::Array4<f32>, OCRError> {
let batch_size = batch_tensor.shape()[0];
if batch_size == 0 {
return Err(OCRError::InvalidInput {
message: "Batch size cannot be zero".to_string(),
});
}
validate_same_length(orig_shapes, resized_shapes, "orig_shapes", "resized_shapes")?;
if orig_shapes.len() != batch_size {
return Err(OCRError::InvalidInput {
message: format!(
"Shape arrays length ({}) does not match batch size ({})",
orig_shapes.len(),
batch_size
),
});
}
let mut scale_factors = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let orig_h = orig_shapes[i][0];
let orig_w = orig_shapes[i][1];
let resized_h = resized_shapes[i][0];
let resized_w = resized_shapes[i][1];
validate_division(
resized_h,
orig_h,
&format!("scale_y calculation for image {}", i),
)?;
validate_division(
resized_w,
orig_w,
&format!("scale_x calculation for image {}", i),
)?;
let scale_y = resized_h / orig_h;
let scale_x = resized_w / orig_w;
scale_factors.push([scale_y, scale_x]);
}
let scale_factor = ndarray::Array2::from_shape_vec(
(batch_size, 2),
scale_factors.into_iter().flatten().collect(),
)
.map_err(|e| OCRError::InvalidInput {
message: format!("Failed to create scale_factor array: {}", e),
})?;
match self.inference_mode {
ScaleAwareDetectorInferenceMode::ScaleFactorOnly => {
let inputs = vec![
("image", TensorInput::Array4(batch_tensor)),
("scale_factor", TensorInput::Array2(&scale_factor)),
];
self.run_inference_with_inputs(&inputs)
}
ScaleAwareDetectorInferenceMode::ScaleFactorAndImageShape => {
let image_shape = ndarray::Array2::from_shape_vec(
(batch_size, 2),
resized_shapes
.iter()
.flat_map(|s| s.iter().copied())
.collect(),
)
.map_err(|e| OCRError::InvalidInput {
message: format!("Failed to create image_shape array: {}", e),
})?;
let inputs = vec![
("image", TensorInput::Array4(batch_tensor)),
("scale_factor", TensorInput::Array2(&scale_factor)),
("im_shape", TensorInput::Array2(&image_shape)),
];
self.run_inference_with_inputs(&inputs)
}
}
}
fn run_inference_with_inputs(
&self,
inputs: &[(&str, TensorInput)],
) -> Result<ndarray::Array4<f32>, OCRError> {
let outputs = self
.inference
.infer(inputs)
.map_err(|e| OCRError::Inference {
model_name: "Scale-Aware Detector".to_string(),
context: "failed to run inference".to_string(),
source: Box::new(e),
})?;
let output = outputs.first().ok_or_else(|| OCRError::InvalidInput {
message: "No output tensors available from model".to_string(),
})?;
let output_shape = output.1.shape();
match output_shape.len() {
2 => {
let num_boxes = output_shape[0] as usize;
let box_dim = output_shape[1] as usize;
if ![6, 7, 8].contains(&box_dim) {
return Err(OCRError::InvalidInput {
message: format!(
"Expected box dimension 6, 7, or 8, got {} with shape {:?}",
box_dim, output_shape
),
});
}
let output_array =
output
.1
.clone()
.try_into_array_f32()
.map_err(|e| OCRError::InvalidInput {
message: format!("Failed to extract output tensor: {}", e),
})?;
let (data, _offset) = output_array.into_raw_vec_and_offset();
ndarray::Array::from_shape_vec((1, num_boxes, 1, box_dim), data).map_err(|e| {
OCRError::InvalidInput {
message: format!("Failed to reshape 2D output to 4D: {}", e),
}
})
}
4 => {
output
.1
.clone()
.try_into_array4_f32()
.map_err(|e| OCRError::InvalidInput {
message: format!("Failed to convert to 4D array: {}", e),
})
}
_ => Err(OCRError::InvalidInput {
message: format!(
"Layout inference: expected 2D or 4D output tensor, got {}D with shape {:?}",
output_shape.len(),
output_shape
),
}),
}
}
pub fn postprocess(
&self,
predictions: ndarray::Array4<f32>,
_config: &ScaleAwareDetectorPostprocessConfig,
) -> Result<ScaleAwareDetectorModelOutput, OCRError> {
Ok(ScaleAwareDetectorModelOutput { predictions })
}
pub fn forward(
&self,
images: Vec<RgbImage>,
config: &ScaleAwareDetectorPostprocessConfig,
) -> Result<(ScaleAwareDetectorModelOutput, Vec<ImageScaleInfo>), OCRError> {
let (batch_tensor, img_shapes, orig_shapes, resized_shapes) = self.preprocess(images)?;
let predictions = self.infer(&batch_tensor, &orig_shapes, &resized_shapes)?;
let output = self.postprocess(predictions, config)?;
Ok((output, img_shapes))
}
}
#[derive(Debug, Default)]
pub struct ScaleAwareDetectorModelBuilder {
preprocess_config: Option<ScaleAwareDetectorPreprocessConfig>,
inference_mode: Option<ScaleAwareDetectorInferenceMode>,
}
impl ScaleAwareDetectorModelBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn picodet() -> Self {
Self {
preprocess_config: Some(ScaleAwareDetectorPreprocessConfig::picodet()),
inference_mode: Some(ScaleAwareDetectorInferenceMode::ScaleFactorOnly),
}
}
pub fn pp_doclayout() -> Self {
Self {
preprocess_config: Some(ScaleAwareDetectorPreprocessConfig::pp_doclayout()),
inference_mode: Some(ScaleAwareDetectorInferenceMode::ScaleFactorAndImageShape),
}
}
pub fn preprocess_config(mut self, config: ScaleAwareDetectorPreprocessConfig) -> Self {
self.preprocess_config = Some(config);
self
}
pub fn inference_mode(mut self, mode: ScaleAwareDetectorInferenceMode) -> Self {
self.inference_mode = Some(mode);
self
}
pub fn image_shape(mut self, height: u32, width: u32) -> Self {
let mut config = self
.preprocess_config
.unwrap_or_else(ScaleAwareDetectorPreprocessConfig::picodet);
config.image_shape = (height, width);
config.limit_side_len = height.max(width);
self.preprocess_config = Some(config);
self
}
pub fn build(self, inference: OrtInfer) -> Result<ScaleAwareDetectorModel, OCRError> {
let preprocess_config = self
.preprocess_config
.unwrap_or_else(ScaleAwareDetectorPreprocessConfig::picodet);
let inference_mode = self
.inference_mode
.unwrap_or(ScaleAwareDetectorInferenceMode::ScaleFactorOnly);
ScaleAwareDetectorModel::new(inference, preprocess_config, inference_mode)
}
}