use crate::core::OCRError;
use crate::core::inference::{OrtInfer, TensorInput};
use crate::processors::{
DetResizeForTest, ImageScaleInfo, LimitType, NormalizeImage, TensorLayout,
};
use image::{DynamicImage, RgbImage};
use ndarray::Array2;
type RTDetrPreprocessArtifacts = (
ndarray::Array4<f32>,
Vec<ImageScaleInfo>,
Vec<[f32; 2]>,
Vec<[f32; 2]>,
);
type RTDetrPreprocessResult = Result<RTDetrPreprocessArtifacts, OCRError>;
#[derive(Debug, Clone)]
pub struct RTDetrPreprocessConfig {
pub image_shape: (u32, u32),
pub keep_ratio: bool,
pub limit_side_len: u32,
pub scale: f32,
pub mean: Vec<f32>,
pub std: Vec<f32>,
}
impl Default for RTDetrPreprocessConfig {
fn default() -> Self {
Self {
image_shape: (640, 640),
keep_ratio: false,
limit_side_len: 640,
scale: 1.0 / 255.0,
mean: vec![0.0, 0.0, 0.0],
std: vec![1.0, 1.0, 1.0],
}
}
}
#[derive(Debug, Clone)]
pub struct RTDetrPostprocessConfig {
pub num_classes: usize,
}
#[derive(Debug, Clone)]
pub struct RTDetrModelOutput {
pub predictions: ndarray::Array4<f32>,
}
#[derive(Debug)]
pub struct RTDetrModel {
inference: OrtInfer,
resizer: DetResizeForTest,
normalizer: NormalizeImage,
_preprocess_config: RTDetrPreprocessConfig,
}
impl RTDetrModel {
pub fn new(
inference: OrtInfer,
preprocess_config: RTDetrPreprocessConfig,
) -> Result<Self, OCRError> {
let resizer = DetResizeForTest::new(
None,
Some((
preprocess_config.image_shape.0,
preprocess_config.image_shape.1,
)),
Some(preprocess_config.keep_ratio),
Some(preprocess_config.limit_side_len),
Some(LimitType::Max),
None,
None,
);
let normalizer = NormalizeImage::with_color_order_from_rgb_stats(
Some(preprocess_config.scale),
preprocess_config.mean.clone(),
preprocess_config.std.clone(),
Some(TensorLayout::CHW),
crate::processors::types::ColorOrder::BGR,
)?;
Ok(Self {
inference,
resizer,
normalizer,
_preprocess_config: preprocess_config,
})
}
pub fn preprocess(&self, images: Vec<RgbImage>) -> RTDetrPreprocessResult {
let orig_shapes: Vec<[f32; 2]> = images
.iter()
.map(|img| [img.height() as f32, img.width() as f32])
.collect();
let dynamic_images: Vec<DynamicImage> =
images.into_iter().map(DynamicImage::ImageRgb8).collect();
let (resized_images, img_shapes) = self.resizer.apply(
dynamic_images,
None, None, None,
);
let resized_shapes: Vec<[f32; 2]> = resized_images
.iter()
.map(|img| [img.height() as f32, img.width() as f32])
.collect();
let batch_tensor = self.normalizer.normalize_batch_to(resized_images)?;
Ok((batch_tensor, img_shapes, orig_shapes, resized_shapes))
}
pub fn infer(
&self,
batch_tensor: &ndarray::Array4<f32>,
scale_factor: &Array2<f32>,
im_shape: &Array2<f32>,
) -> Result<ndarray::Array4<f32>, OCRError> {
let inputs = vec![
("image", TensorInput::Array4(batch_tensor)),
("scale_factor", TensorInput::Array2(scale_factor)),
("im_shape", TensorInput::Array2(im_shape)),
];
let outputs = self
.inference
.infer(&inputs)
.map_err(|e| OCRError::Inference {
model_name: "RT-DETR".to_string(),
context: "failed to run inference".to_string(),
source: Box::new(e),
})?;
let output = outputs
.iter()
.find(|(name, _)| name == "fetch_name_0")
.or_else(|| outputs.first())
.ok_or_else(|| OCRError::InvalidInput {
message: "RT-DETR: no outputs found from model".to_string(),
})?;
let output_shape = output.1.shape();
match output_shape.len() {
2 => {
let output_array =
output
.1
.clone()
.try_into_array_f32()
.map_err(|e| OCRError::InvalidInput {
message: format!("Failed to extract output tensor: {}", e),
})?;
let num_boxes = output_shape[0] as usize;
let box_dim = output_shape[1] as usize;
let (data, _offset) = output_array.into_raw_vec_and_offset();
ndarray::Array::from_shape_vec((1, num_boxes, 1, box_dim), data).map_err(|e| {
OCRError::InvalidInput {
message: format!("Failed to reshape 2D output to 4D: {}", e),
}
})
}
4 => {
output
.1
.clone()
.try_into_array4_f32()
.map_err(|e| OCRError::InvalidInput {
message: format!("Failed to convert to 4D array: {}", e),
})
}
_ => Err(OCRError::InvalidInput {
message: format!(
"RT-DETR inference: expected 2D or 4D output, got {}D with shape {:?}",
output_shape.len(),
output_shape
),
}),
}
}
pub fn postprocess(
&self,
predictions: ndarray::Array4<f32>,
_config: &RTDetrPostprocessConfig,
) -> Result<RTDetrModelOutput, OCRError> {
Ok(RTDetrModelOutput { predictions })
}
pub fn forward(
&self,
images: Vec<RgbImage>,
config: &RTDetrPostprocessConfig,
) -> Result<(RTDetrModelOutput, Vec<ImageScaleInfo>), OCRError> {
let (batch_tensor, img_shapes, _orig_shapes, resized_shapes) = self.preprocess(images)?;
let batch_size = batch_tensor.shape()[0];
let scale_data: Vec<f32> = img_shapes
.iter()
.flat_map(|shape| [shape.ratio_h, shape.ratio_w])
.collect();
let scale_factor = Array2::from_shape_vec((batch_size, 2), scale_data).map_err(|e| {
OCRError::InvalidInput {
message: format!("Failed to create scale_factor array: {}", e),
}
})?;
let im_shape_data: Vec<f32> = resized_shapes
.iter()
.flat_map(|shape| [shape[0], shape[1]])
.collect();
let im_shape = Array2::from_shape_vec((batch_size, 2), im_shape_data).map_err(|e| {
OCRError::InvalidInput {
message: format!("Failed to create im_shape array: {}", e),
}
})?;
let predictions = self.infer(&batch_tensor, &scale_factor, &im_shape)?;
let output = self.postprocess(predictions, config)?;
Ok((output, img_shapes))
}
}
#[derive(Debug, Default)]
pub struct RTDetrModelBuilder {
preprocess_config: Option<RTDetrPreprocessConfig>,
}
impl RTDetrModelBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn preprocess_config(mut self, config: RTDetrPreprocessConfig) -> Self {
self.preprocess_config = Some(config);
self
}
pub fn image_shape(mut self, height: u32, width: u32) -> Self {
let mut config = self.preprocess_config.unwrap_or_default();
config.image_shape = (height, width);
config.limit_side_len = height.max(width);
self.preprocess_config = Some(config);
self
}
pub fn build(self, inference: OrtInfer) -> Result<RTDetrModel, OCRError> {
let preprocess_config = self.preprocess_config.unwrap_or_default();
RTDetrModel::new(inference, preprocess_config)
}
}