oar_ocr_core/models/detection/
rtdetr.rs1use crate::core::inference::OrtInfer;
7use crate::core::{OCRError, Tensor4D};
8use crate::processors::{
9 DetResizeForTest, ImageScaleInfo, LimitType, NormalizeImage, TensorLayout,
10};
11use image::{DynamicImage, RgbImage};
12use ndarray::Array2;
13
14type RTDetrPreprocessArtifacts = (Tensor4D, Vec<ImageScaleInfo>, Vec<[f32; 2]>, Vec<[f32; 2]>);
15type RTDetrPreprocessResult = Result<RTDetrPreprocessArtifacts, OCRError>;
16
17#[derive(Debug, Clone)]
19pub struct RTDetrPreprocessConfig {
20 pub image_shape: (u32, u32),
22 pub keep_ratio: bool,
24 pub limit_side_len: u32,
26 pub scale: f32,
28 pub mean: Vec<f32>,
30 pub std: Vec<f32>,
32}
33
34impl Default for RTDetrPreprocessConfig {
35 fn default() -> Self {
36 Self {
37 image_shape: (640, 640),
38 keep_ratio: false,
39 limit_side_len: 640,
40 scale: 1.0 / 255.0,
41 mean: vec![0.0, 0.0, 0.0],
43 std: vec![1.0, 1.0, 1.0],
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct RTDetrPostprocessConfig {
51 pub num_classes: usize,
53}
54
55#[derive(Debug, Clone)]
57pub struct RTDetrModelOutput {
58 pub predictions: Tensor4D,
61}
62
63#[derive(Debug)]
72pub struct RTDetrModel {
73 inference: OrtInfer,
74 resizer: DetResizeForTest,
75 normalizer: NormalizeImage,
76 _preprocess_config: RTDetrPreprocessConfig,
77}
78
79impl RTDetrModel {
80 pub fn new(
82 inference: OrtInfer,
83 preprocess_config: RTDetrPreprocessConfig,
84 ) -> Result<Self, OCRError> {
85 let resizer = DetResizeForTest::new(
87 None,
88 Some((
89 preprocess_config.image_shape.0,
90 preprocess_config.image_shape.1,
91 )),
92 Some(preprocess_config.keep_ratio),
93 Some(preprocess_config.limit_side_len),
94 Some(LimitType::Max),
95 None,
96 None,
97 );
98
99 let normalizer = NormalizeImage::with_color_order_from_rgb_stats(
102 Some(preprocess_config.scale),
103 preprocess_config.mean.clone(),
104 preprocess_config.std.clone(),
105 Some(TensorLayout::CHW),
106 crate::processors::types::ColorOrder::BGR,
107 )?;
108
109 Ok(Self {
110 inference,
111 resizer,
112 normalizer,
113 _preprocess_config: preprocess_config,
114 })
115 }
116
117 pub fn preprocess(&self, images: Vec<RgbImage>) -> RTDetrPreprocessResult {
125 let orig_shapes: Vec<[f32; 2]> = images
127 .iter()
128 .map(|img| [img.height() as f32, img.width() as f32])
129 .collect();
130
131 let dynamic_images: Vec<DynamicImage> =
133 images.into_iter().map(DynamicImage::ImageRgb8).collect();
134
135 let (resized_images, img_shapes) = self.resizer.apply(
137 dynamic_images,
138 None, None, None,
141 );
142
143 let resized_shapes: Vec<[f32; 2]> = resized_images
145 .iter()
146 .map(|img| [img.height() as f32, img.width() as f32])
147 .collect();
148
149 let batch_tensor = self.normalizer.normalize_batch_to(resized_images)?;
151
152 Ok((batch_tensor, img_shapes, orig_shapes, resized_shapes))
153 }
154
155 pub fn infer(
159 &self,
160 batch_tensor: &Tensor4D,
161 scale_factor: Array2<f32>,
162 im_shape: Array2<f32>,
163 ) -> Result<Tensor4D, OCRError> {
164 self.inference
165 .infer_4d_layout(batch_tensor, Some(scale_factor), Some(im_shape))
166 }
167
168 pub fn postprocess(
173 &self,
174 predictions: Tensor4D,
175 _config: &RTDetrPostprocessConfig,
176 ) -> Result<RTDetrModelOutput, OCRError> {
177 Ok(RTDetrModelOutput { predictions })
178 }
179
180 pub fn forward(
182 &self,
183 images: Vec<RgbImage>,
184 config: &RTDetrPostprocessConfig,
185 ) -> Result<(RTDetrModelOutput, Vec<ImageScaleInfo>), OCRError> {
186 let (batch_tensor, img_shapes, _orig_shapes, resized_shapes) = self.preprocess(images)?;
187
188 let batch_size = batch_tensor.shape()[0];
189
190 let scale_data: Vec<f32> = img_shapes
192 .iter()
193 .flat_map(|shape| [shape.ratio_h, shape.ratio_w])
194 .collect();
195 let scale_factor = Array2::from_shape_vec((batch_size, 2), scale_data).map_err(|e| {
196 OCRError::InvalidInput {
197 message: format!("Failed to create scale_factor array: {}", e),
198 }
199 })?;
200
201 let im_shape_data: Vec<f32> = resized_shapes
203 .iter()
204 .flat_map(|shape| [shape[0], shape[1]])
205 .collect();
206 let im_shape = Array2::from_shape_vec((batch_size, 2), im_shape_data).map_err(|e| {
207 OCRError::InvalidInput {
208 message: format!("Failed to create im_shape array: {}", e),
209 }
210 })?;
211
212 let predictions = self.infer(&batch_tensor, scale_factor, im_shape)?;
213 let output = self.postprocess(predictions, config)?;
214 Ok((output, img_shapes))
215 }
216}
217
218#[derive(Debug, Default)]
220pub struct RTDetrModelBuilder {
221 preprocess_config: Option<RTDetrPreprocessConfig>,
222}
223
224impl RTDetrModelBuilder {
225 pub fn new() -> Self {
227 Self::default()
228 }
229
230 pub fn preprocess_config(mut self, config: RTDetrPreprocessConfig) -> Self {
232 self.preprocess_config = Some(config);
233 self
234 }
235
236 pub fn image_shape(mut self, height: u32, width: u32) -> Self {
238 let mut config = self.preprocess_config.unwrap_or_default();
239 config.image_shape = (height, width);
240 config.limit_side_len = height.max(width);
241 self.preprocess_config = Some(config);
242 self
243 }
244
245 pub fn build(self, inference: OrtInfer) -> Result<RTDetrModel, OCRError> {
247 let preprocess_config = self.preprocess_config.unwrap_or_default();
248 RTDetrModel::new(inference, preprocess_config)
249 }
250}