Skip to main content

oar_ocr_core/models/detection/
rtdetr.rs

1//! RT-DETR Layout Detection Model
2//!
3//! This module provides a pure implementation of the RT-DETR model for layout detection.
4//! The model is independent of any specific task and can be reused in different contexts.
5
6use crate::core::inference::OrtInfer;
7use crate::core::{OCRError, Tensor4D};
8use crate::processors::{
9    DetResizeForTest, ImageScaleInfo, LimitType, NormalizeImage, TensorLayout,
10};
11use image::{DynamicImage, RgbImage};
12use ndarray::Array2;
13
14type RTDetrPreprocessArtifacts = (Tensor4D, Vec<ImageScaleInfo>, Vec<[f32; 2]>, Vec<[f32; 2]>);
15type RTDetrPreprocessResult = Result<RTDetrPreprocessArtifacts, OCRError>;
16
17/// Preprocessing configuration for RT-DETR model.
18#[derive(Debug, Clone)]
19pub struct RTDetrPreprocessConfig {
20    /// Target image shape (height, width)
21    pub image_shape: (u32, u32),
22    /// Whether to keep aspect ratio when resizing
23    pub keep_ratio: bool,
24    /// Limit side length
25    pub limit_side_len: u32,
26    /// Normalization scale factor
27    pub scale: f32,
28    /// Normalization mean values (RGB)
29    pub mean: Vec<f32>,
30    /// Normalization std values (RGB)
31    pub std: Vec<f32>,
32}
33
34impl Default for RTDetrPreprocessConfig {
35    fn default() -> Self {
36        Self {
37            image_shape: (640, 640),
38            keep_ratio: false,
39            limit_side_len: 640,
40            scale: 1.0 / 255.0,
41            // Paddle's RT-DETR exports expect inputs scaled to [0,1] without mean/std shift
42            mean: vec![0.0, 0.0, 0.0],
43            std: vec![1.0, 1.0, 1.0],
44        }
45    }
46}
47
48/// Postprocessing configuration for RT-DETR model.
49#[derive(Debug, Clone)]
50pub struct RTDetrPostprocessConfig {
51    /// Number of classes
52    pub num_classes: usize,
53}
54
55/// Output from RT-DETR model.
56#[derive(Debug, Clone)]
57pub struct RTDetrModelOutput {
58    /// Detection predictions tensor [batch_size, num_detections, 6]
59    /// Each detection: [x1, y1, x2, y2, score, class_id]
60    pub predictions: Tensor4D,
61}
62
63/// RT-DETR layout detection model.
64///
65/// This is a pure model implementation that handles:
66/// - Preprocessing: Image resizing and normalization
67/// - Inference: Running the ONNX model
68/// - Postprocessing: Returning raw predictions
69///
70/// The model is independent of any specific task or adapter.
71#[derive(Debug)]
72pub struct RTDetrModel {
73    inference: OrtInfer,
74    resizer: DetResizeForTest,
75    normalizer: NormalizeImage,
76    _preprocess_config: RTDetrPreprocessConfig,
77}
78
79impl RTDetrModel {
80    /// Creates a new RT-DETR model.
81    pub fn new(
82        inference: OrtInfer,
83        preprocess_config: RTDetrPreprocessConfig,
84    ) -> Result<Self, OCRError> {
85        // Create resizer
86        let resizer = DetResizeForTest::new(
87            None,
88            Some((
89                preprocess_config.image_shape.0,
90                preprocess_config.image_shape.1,
91            )),
92            Some(preprocess_config.keep_ratio),
93            Some(preprocess_config.limit_side_len),
94            Some(LimitType::Max),
95            None,
96            None,
97        );
98
99        // Create normalizer.
100        // Paddle models expect BGR input; treat config mean/std as RGB and reorder.
101        let normalizer = NormalizeImage::with_color_order_from_rgb_stats(
102            Some(preprocess_config.scale),
103            preprocess_config.mean.clone(),
104            preprocess_config.std.clone(),
105            Some(TensorLayout::CHW),
106            crate::processors::types::ColorOrder::BGR,
107        )?;
108
109        Ok(Self {
110            inference,
111            resizer,
112            normalizer,
113            _preprocess_config: preprocess_config,
114        })
115    }
116
117    /// Preprocesses images for RT-DETR model.
118    ///
119    /// Returns:
120    /// - Batch tensor ready for inference
121    /// - Image shapes after resizing [h, w, ratio_h, ratio_w]
122    /// - Original shapes [h, w]
123    /// - Resized shapes [h, w]
124    pub fn preprocess(&self, images: Vec<RgbImage>) -> RTDetrPreprocessResult {
125        // Store original dimensions
126        let orig_shapes: Vec<[f32; 2]> = images
127            .iter()
128            .map(|img| [img.height() as f32, img.width() as f32])
129            .collect();
130
131        // Convert to DynamicImage
132        let dynamic_images: Vec<DynamicImage> =
133            images.into_iter().map(DynamicImage::ImageRgb8).collect();
134
135        // Resize images
136        let (resized_images, img_shapes) = self.resizer.apply(
137            dynamic_images,
138            None, // Use configured limit_side_length
139            None, // Use configured limit_type
140            None,
141        );
142
143        // Get resized dimensions
144        let resized_shapes: Vec<[f32; 2]> = resized_images
145            .iter()
146            .map(|img| [img.height() as f32, img.width() as f32])
147            .collect();
148
149        // Normalize and convert to tensor
150        let batch_tensor = self.normalizer.normalize_batch_to(resized_images)?;
151
152        Ok((batch_tensor, img_shapes, orig_shapes, resized_shapes))
153    }
154
155    /// Runs inference on the preprocessed batch tensor.
156    ///
157    /// RT-DETR requires both `scale_factor` and `im_shape` inputs.
158    pub fn infer(
159        &self,
160        batch_tensor: &Tensor4D,
161        scale_factor: Array2<f32>,
162        im_shape: Array2<f32>,
163    ) -> Result<Tensor4D, OCRError> {
164        self.inference
165            .infer_4d_layout(batch_tensor, Some(scale_factor), Some(im_shape))
166    }
167
168    /// Postprocesses model predictions.
169    ///
170    /// For RT-DETR, we just return the raw predictions.
171    /// The adapter layer will handle converting these to task-specific outputs.
172    pub fn postprocess(
173        &self,
174        predictions: Tensor4D,
175        _config: &RTDetrPostprocessConfig,
176    ) -> Result<RTDetrModelOutput, OCRError> {
177        Ok(RTDetrModelOutput { predictions })
178    }
179
180    /// Runs the complete forward pass: preprocess -> infer -> postprocess.
181    pub fn forward(
182        &self,
183        images: Vec<RgbImage>,
184        config: &RTDetrPostprocessConfig,
185    ) -> Result<(RTDetrModelOutput, Vec<ImageScaleInfo>), OCRError> {
186        let (batch_tensor, img_shapes, _orig_shapes, resized_shapes) = self.preprocess(images)?;
187
188        let batch_size = batch_tensor.shape()[0];
189
190        // Build scale_factor array [ratio_h, ratio_w]
191        let scale_data: Vec<f32> = img_shapes
192            .iter()
193            .flat_map(|shape| [shape.ratio_h, shape.ratio_w])
194            .collect();
195        let scale_factor = Array2::from_shape_vec((batch_size, 2), scale_data).map_err(|e| {
196            OCRError::InvalidInput {
197                message: format!("Failed to create scale_factor array: {}", e),
198            }
199        })?;
200
201        // Build im_shape array using resized dimensions
202        let im_shape_data: Vec<f32> = resized_shapes
203            .iter()
204            .flat_map(|shape| [shape[0], shape[1]])
205            .collect();
206        let im_shape = Array2::from_shape_vec((batch_size, 2), im_shape_data).map_err(|e| {
207            OCRError::InvalidInput {
208                message: format!("Failed to create im_shape array: {}", e),
209            }
210        })?;
211
212        let predictions = self.infer(&batch_tensor, scale_factor, im_shape)?;
213        let output = self.postprocess(predictions, config)?;
214        Ok((output, img_shapes))
215    }
216}
217
218/// Builder for RT-DETR model.
219#[derive(Debug, Default)]
220pub struct RTDetrModelBuilder {
221    preprocess_config: Option<RTDetrPreprocessConfig>,
222}
223
224impl RTDetrModelBuilder {
225    /// Creates a new builder.
226    pub fn new() -> Self {
227        Self::default()
228    }
229
230    /// Sets the preprocessing configuration.
231    pub fn preprocess_config(mut self, config: RTDetrPreprocessConfig) -> Self {
232        self.preprocess_config = Some(config);
233        self
234    }
235
236    /// Sets the image shape.
237    pub fn image_shape(mut self, height: u32, width: u32) -> Self {
238        let mut config = self.preprocess_config.unwrap_or_default();
239        config.image_shape = (height, width);
240        config.limit_side_len = height.max(width);
241        self.preprocess_config = Some(config);
242        self
243    }
244
245    /// Builds the RT-DETR model.
246    pub fn build(self, inference: OrtInfer) -> Result<RTDetrModel, OCRError> {
247        let preprocess_config = self.preprocess_config.unwrap_or_default();
248        RTDetrModel::new(inference, preprocess_config)
249    }
250}