1use image::{imageops::FilterType, DynamicImage, GenericImageView};
2use ndarray::{s, Array3, Array4, Axis};
3use tract_onnx::prelude::Tensor;
4
5#[derive(Debug, Clone, Copy)]
7pub struct DetPreProcessorConfig {
8 pub limit_side_len: u32,
9}
10
11impl Default for DetPreProcessorConfig {
12 fn default() -> Self {
13 Self {
14 limit_side_len: 960,
15 }
16 }
17}
18
19#[derive(Debug)]
21pub enum DetPreProcessorError {
22 EmptyImage,
24}
25
26impl std::fmt::Display for DetPreProcessorError {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 DetPreProcessorError::EmptyImage => {
30 write!(f, "input image dimensions must be positive")
31 }
32 }
33 }
34}
35
36impl std::error::Error for DetPreProcessorError {}
37
38#[derive(Debug, Clone)]
40pub struct PreprocessedDetInput {
41 pub tensor: Tensor,
42 pub resized_dims: (u32, u32),
43 pub scale_ratio: f64,
44}
45
46#[derive(Debug, Clone)]
48pub struct DetPreProcessor {
49 config: DetPreProcessorConfig,
50}
51
52impl DetPreProcessor {
53 pub fn new(config: DetPreProcessorConfig) -> Self {
54 Self { config }
55 }
56
57 pub fn process(
58 &self,
59 image: &DynamicImage,
60 ) -> Result<PreprocessedDetInput, DetPreProcessorError> {
61 let (orig_w, orig_h) = image.dimensions();
62 if orig_w == 0 || orig_h == 0 {
63 return Err(DetPreProcessorError::EmptyImage);
64 }
65
66 let (resized_w, resized_h, scale_ratio) =
67 compute_resized_dims(orig_w, orig_h, self.config.limit_side_len);
68
69 let resized = if resized_w == orig_w && resized_h == orig_h {
70 image.clone()
71 } else {
72 image.resize_exact(resized_w, resized_h, FilterType::Lanczos3)
73 };
74
75 let rgb_image = resized.to_rgb8();
76 let padded_w = round_up_to_multiple(resized_w, 32);
77 let padded_h = round_up_to_multiple(resized_h, 32);
78
79 let mut array_hwc = Array3::<f32>::zeros((padded_h as usize, padded_w as usize, 3));
80
81 for y in 0..resized_h as usize {
82 for x in 0..resized_w as usize {
83 let pixel = rgb_image.get_pixel(x as u32, y as u32);
84 for c in 0..3 {
85 array_hwc[[y, x, c]] = pixel[c] as f32 / 255.0;
86 }
87 }
88 }
89
90 let array_chw = array_hwc.permuted_axes([2, 0, 1]);
91 let array_nchw = array_chw.insert_axis(Axis(0));
92 let tensor: Tensor = array_nchw.into_dyn().into();
93
94 Ok(PreprocessedDetInput {
95 tensor,
96 resized_dims: (padded_w, padded_h),
97 scale_ratio,
98 })
99 }
100}
101
102fn compute_resized_dims(orig_w: u32, orig_h: u32, limit_side_len: u32) -> (u32, u32, f64) {
103 if limit_side_len == 0 {
104 return (orig_w, orig_h, 1.0);
105 }
106
107 let limit = limit_side_len as f64;
108 let max_side = (orig_w.max(orig_h)) as f64;
109 if max_side <= limit {
110 return (orig_w, orig_h, 1.0);
111 }
112
113 let scale_ratio = limit / max_side;
114 let resized_w = ((orig_w as f64 * scale_ratio).round().max(1.0)) as u32;
115 let resized_h = ((orig_h as f64 * scale_ratio).round().max(1.0)) as u32;
116
117 (resized_w, resized_h, scale_ratio)
118}
119
120fn round_up_to_multiple(value: u32, multiple: u32) -> u32 {
121 if multiple == 0 {
122 return value;
123 }
124
125 let remainder = value % multiple;
126 if remainder == 0 {
127 value
128 } else {
129 value + multiple - remainder
130 }
131}
132
133#[derive(Debug, Clone, Copy)]
135pub struct RecTextRegion {
136 pub x: u32,
137 pub y: u32,
138 pub width: u32,
139 pub height: u32,
140}
141
142#[derive(Debug, Clone)]
144pub struct RecPreProcessorConfig {
145 pub target_height: u32,
146 pub max_width: u32,
147 pub mean: [f32; 3],
148 pub std: [f32; 3],
149 pub pad_value: [f32; 3],
150}
151
152impl Default for RecPreProcessorConfig {
153 fn default() -> Self {
154 Self {
155 target_height: 48,
156 max_width: 320,
157 mean: [0.5, 0.5, 0.5],
158 std: [0.5, 0.5, 0.5],
159 pad_value: [0.0, 0.0, 0.0],
160 }
161 }
162}
163
164#[derive(Debug)]
166pub enum RecPreProcessorError {
167 EmptyRegions,
169 EmptyImage,
171 InvalidConfiguration,
173 ZeroArea { index: usize },
175 RegionOutOfBounds {
177 index: usize,
178 image_dims: (u32, u32),
179 region: RecTextRegion,
180 },
181}
182
183impl std::fmt::Display for RecPreProcessorError {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 match self {
186 RecPreProcessorError::EmptyRegions => {
187 write!(f, "at least one text region is required for recognition")
188 }
189 RecPreProcessorError::EmptyImage => {
190 write!(f, "input image dimensions must be positive")
191 }
192 RecPreProcessorError::InvalidConfiguration => {
193 write!(f, "recognition preprocessor configuration is invalid")
194 }
195 RecPreProcessorError::ZeroArea { index } => {
196 write!(f, "text region at index {} has zero area", index)
197 }
198 RecPreProcessorError::RegionOutOfBounds {
199 index,
200 image_dims,
201 region,
202 } => write!(
203 f,
204 "text region at index {} (x={}, y={}, w={}, h={}) exceeds image bounds {:?}",
205 index, region.x, region.y, region.width, region.height, image_dims
206 ),
207 }
208 }
209}
210
211impl std::error::Error for RecPreProcessorError {}
212
213#[derive(Debug, Clone)]
215pub struct PreprocessedRecBatch {
216 pub tensor: Tensor,
217 pub valid_widths: Vec<u32>,
218 pub max_width: u32,
219}
220
221impl PreprocessedRecBatch {
222 pub fn valid_width_ratios(&self) -> Vec<f32> {
223 if self.max_width == 0 {
224 return vec![0.0; self.valid_widths.len()];
225 }
226 self.valid_widths
227 .iter()
228 .map(|width| *width as f32 / self.max_width as f32)
229 .collect()
230 }
231}
232
233#[derive(Debug, Clone)]
235pub struct RecPreProcessor {
236 config: RecPreProcessorConfig,
237}
238
239impl RecPreProcessor {
240 pub fn new(config: RecPreProcessorConfig) -> Self {
241 Self { config }
242 }
243
244 pub fn process(
245 &self,
246 image: &DynamicImage,
247 regions: &[RecTextRegion],
248 ) -> Result<PreprocessedRecBatch, RecPreProcessorError> {
249 if regions.is_empty() {
250 return Err(RecPreProcessorError::EmptyRegions);
251 }
252
253 if self.config.target_height == 0 || self.config.max_width == 0 {
254 return Err(RecPreProcessorError::InvalidConfiguration);
255 }
256
257 let (img_w, img_h) = image.dimensions();
258 if img_w == 0 || img_h == 0 {
259 return Err(RecPreProcessorError::EmptyImage);
260 }
261
262 let target_height = self.config.target_height;
263 let max_width = self.config.max_width;
264 let batch_size = regions.len();
265
266 let mut batch =
267 Array4::<f32>::zeros((batch_size, 3, target_height as usize, max_width as usize));
268
269 for sample in 0..batch_size {
270 for channel in 0..3 {
271 let pad = normalize_value(
272 self.config.pad_value[channel],
273 self.config.mean[channel],
274 self.config.std[channel],
275 );
276 batch.slice_mut(s![sample, channel, .., ..]).fill(pad);
277 }
278 }
279
280 let mut valid_widths = Vec::with_capacity(batch_size);
281
282 for (index, region) in regions.iter().copied().enumerate() {
283 if region.width == 0 || region.height == 0 {
284 return Err(RecPreProcessorError::ZeroArea { index });
285 }
286
287 if region.x >= img_w
288 || region.y >= img_h
289 || region.x + region.width > img_w
290 || region.y + region.height > img_h
291 {
292 return Err(RecPreProcessorError::RegionOutOfBounds {
293 index,
294 image_dims: (img_w, img_h),
295 region,
296 });
297 }
298
299 let cropped = image.crop_imm(region.x, region.y, region.width, region.height);
300 let aspect_ratio = region.width as f32 / region.height as f32;
301 let mut target_width = (aspect_ratio * target_height as f32)
302 .round()
303 .clamp(1.0, max_width as f32) as u32;
304 if target_width == 0 {
305 target_width = 1;
306 }
307
308 let resized = cropped.resize_exact(target_width, target_height, FilterType::Lanczos3);
309 let rgb_image = resized.to_rgb8();
310
311 for y in 0..target_height as usize {
312 for x in 0..target_width as usize {
313 let pixel = rgb_image.get_pixel(x as u32, y as u32);
314 for channel in 0..3 {
315 let value = pixel[channel] as f32 / 255.0;
316 let normalized = normalize_value(
317 value,
318 self.config.mean[channel],
319 self.config.std[channel],
320 );
321 batch[[index, channel, y, x]] = normalized;
322 }
323 }
324 }
325
326 valid_widths.push(target_width);
327 }
328
329 let tensor: Tensor = batch.into_dyn().into();
330 Ok(PreprocessedRecBatch {
331 tensor,
332 valid_widths,
333 max_width,
334 })
335 }
336}
337
338fn normalize_value(value: f32, mean: f32, std: f32) -> f32 {
339 if std == 0.0 {
340 0.0
341 } else {
342 (value - mean) / std
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use image::{ImageBuffer, Rgb};
350
351 fn solid_image(width: u32, height: u32, value: u8) -> DynamicImage {
352 let pixel = Rgb([value, value.saturating_sub(1), value.saturating_add(1)]);
353 let buffer = ImageBuffer::from_pixel(width, height, pixel);
354 DynamicImage::ImageRgb8(buffer)
355 }
356
357 fn gradient_image(width: u32, height: u32) -> DynamicImage {
358 let mut buffer = ImageBuffer::new(width, height);
359 for (x, y, pixel) in buffer.enumerate_pixels_mut() {
360 let base = ((x + y) % 256) as u8;
361 let green = base.saturating_add(32);
362 let blue = base.saturating_add(64);
363 *pixel = Rgb([base, green, blue]);
364 }
365 DynamicImage::ImageRgb8(buffer)
366 }
367
368 #[test]
369 fn resize_long_side_to_limit() {
370 let image = solid_image(1920, 1080, 128);
371 let preprocessor = DetPreProcessor::new(DetPreProcessorConfig::default());
372
373 let result = preprocessor.process(&image).unwrap();
374
375 assert_eq!(result.resized_dims, (960, 544));
376 assert!((result.scale_ratio - 0.5).abs() < f64::EPSILON);
377 }
378
379 #[test]
380 fn keep_original_size_when_within_limit() {
381 let image = solid_image(800, 600, 64);
382 let preprocessor = DetPreProcessor::new(DetPreProcessorConfig::default());
383
384 let result = preprocessor.process(&image).unwrap();
385
386 assert_eq!(result.resized_dims, (800, 608));
387 assert!((result.scale_ratio - 1.0).abs() < f64::EPSILON);
388 }
389
390 #[test]
391 fn tensor_shape_and_normalization() {
392 let image = solid_image(320, 320, 255);
393 let preprocessor = DetPreProcessor::new(DetPreProcessorConfig {
394 limit_side_len: 320,
395 });
396
397 let result = preprocessor.process(&image).unwrap();
398 assert_eq!(result.tensor.shape(), &[1, 3, 320, 320]);
399
400 let array = result.tensor.to_array_view::<f32>().unwrap();
401 let min = array.iter().cloned().fold(f32::INFINITY, f32::min);
402 let max = array.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
403 assert!(min >= 0.0);
404 assert!(max <= 1.0);
405 assert!((max - 1.0).abs() < 1e-6);
406 }
407
408 #[test]
409 fn detection_tensor_dims_are_padded_to_multiple_of_32() {
410 let image = solid_image(123, 77, 200);
411 let preprocessor = DetPreProcessor::new(DetPreProcessorConfig::default());
412
413 let result = preprocessor.process(&image).unwrap();
414
415 assert_eq!(result.resized_dims, (128, 96));
416 assert_eq!(result.tensor.shape(), &[1, 3, 96, 128]);
417 assert!((result.scale_ratio - 1.0).abs() < f64::EPSILON);
418 }
419
420 #[test]
421 fn recognition_single_region_preprocessing() {
422 let image = gradient_image(200, 100);
423 let config = RecPreProcessorConfig::default();
424 let regions = vec![RecTextRegion {
425 x: 20,
426 y: 10,
427 width: 80,
428 height: 40,
429 }];
430
431 let preprocessor = RecPreProcessor::new(config.clone());
432 let batch = preprocessor.process(&image, ®ions).unwrap();
433
434 let expected_shape = [
435 1,
436 3,
437 config.target_height as usize,
438 config.max_width as usize,
439 ];
440 assert_eq!(batch.tensor.shape(), &expected_shape);
441 assert_eq!(batch.valid_widths, vec![96]);
442
443 let tensor = batch.tensor.to_array_view::<f32>().unwrap();
444 let pad = normalize_value(config.pad_value[0], config.mean[0], config.std[0]);
445 assert!(
446 (tensor[[0, 0, 0, (config.max_width - 1) as usize]] - pad).abs() < 1e-6,
447 "padded area should remain at pad value"
448 );
449 assert!(
450 (tensor[[0, 0, 0, 0]] - pad).abs() > 1e-3,
451 "cropped content should differ from pad value"
452 );
453
454 let ratios = batch.valid_width_ratios();
455 assert_eq!(ratios.len(), 1);
456 assert!((ratios[0] - 96.0 / config.max_width as f32).abs() < f32::EPSILON);
457 }
458
459 #[test]
460 fn recognition_multiple_regions_padding() {
461 let image = gradient_image(320, 160);
462 let config = RecPreProcessorConfig::default();
463 let regions = vec![
464 RecTextRegion {
465 x: 0,
466 y: 0,
467 width: 120,
468 height: 60,
469 },
470 RecTextRegion {
471 x: 150,
472 y: 40,
473 width: 40,
474 height: 80,
475 },
476 ];
477
478 let preprocessor = RecPreProcessor::new(config.clone());
479 let batch = preprocessor.process(&image, ®ions).unwrap();
480
481 assert_eq!(batch.valid_widths, vec![96, 24]);
482
483 let tensor = batch.tensor.to_array_view::<f32>().unwrap();
484 let pad = normalize_value(config.pad_value[0], config.mean[0], config.std[0]);
485
486 assert!((tensor[[0, 0, 10, (config.max_width - 1) as usize]] - pad).abs() < 1e-6);
488 assert!((tensor[[1, 1, 20, (config.max_width - 1) as usize]] - pad).abs() < 1e-6);
490 }
491
492 #[test]
493 fn recognition_region_out_of_bounds_is_error() {
494 let image = gradient_image(100, 50);
495 let config = RecPreProcessorConfig::default();
496 let regions = vec![RecTextRegion {
497 x: 80,
498 y: 10,
499 width: 30,
500 height: 20,
501 }];
502
503 let preprocessor = RecPreProcessor::new(config);
504 let error = preprocessor.process(&image, ®ions).unwrap_err();
505 assert!(matches!(
506 error,
507 RecPreProcessorError::RegionOutOfBounds { index: 0, .. }
508 ));
509 }
510
511 #[test]
512 fn recognition_zero_area_region_is_error() {
513 let image = gradient_image(100, 50);
514 let config = RecPreProcessorConfig::default();
515 let regions = vec![RecTextRegion {
516 x: 10,
517 y: 10,
518 width: 0,
519 height: 20,
520 }];
521
522 let preprocessor = RecPreProcessor::new(config);
523 let error = preprocessor.process(&image, ®ions).unwrap_err();
524 assert!(matches!(error, RecPreProcessorError::ZeroArea { index: 0 }));
525 }
526}