pure_onnx_ocr/
preprocessing.rs

1use image::{imageops::FilterType, DynamicImage, GenericImageView};
2use ndarray::{s, Array3, Array4, Axis};
3use tract_onnx::prelude::Tensor;
4
5/// Configuration parameters for `DetPreProcessor`.
6#[derive(Debug, Clone, Copy)]
7pub struct DetPreProcessorConfig {
8    pub limit_side_len: u32,
9}
10
11impl Default for DetPreProcessorConfig {
12    fn default() -> Self {
13        Self {
14            limit_side_len: 960,
15        }
16    }
17}
18
19/// Error returned when detection preprocessing fails.
20#[derive(Debug)]
21pub enum DetPreProcessorError {
22    /// The provided image has zero width or height.
23    EmptyImage,
24}
25
26impl std::fmt::Display for DetPreProcessorError {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            DetPreProcessorError::EmptyImage => {
30                write!(f, "input image dimensions must be positive")
31            }
32        }
33    }
34}
35
36impl std::error::Error for DetPreProcessorError {}
37
38/// Result of detection preprocessing.
39#[derive(Debug, Clone)]
40pub struct PreprocessedDetInput {
41    pub tensor: Tensor,
42    pub resized_dims: (u32, u32),
43    pub scale_ratio: f64,
44}
45
46/// DBNet detection preprocessor.
47#[derive(Debug, Clone)]
48pub struct DetPreProcessor {
49    config: DetPreProcessorConfig,
50}
51
52impl DetPreProcessor {
53    pub fn new(config: DetPreProcessorConfig) -> Self {
54        Self { config }
55    }
56
57    pub fn process(
58        &self,
59        image: &DynamicImage,
60    ) -> Result<PreprocessedDetInput, DetPreProcessorError> {
61        let (orig_w, orig_h) = image.dimensions();
62        if orig_w == 0 || orig_h == 0 {
63            return Err(DetPreProcessorError::EmptyImage);
64        }
65
66        let (resized_w, resized_h, scale_ratio) =
67            compute_resized_dims(orig_w, orig_h, self.config.limit_side_len);
68
69        let resized = if resized_w == orig_w && resized_h == orig_h {
70            image.clone()
71        } else {
72            image.resize_exact(resized_w, resized_h, FilterType::Lanczos3)
73        };
74
75        let rgb_image = resized.to_rgb8();
76        let padded_w = round_up_to_multiple(resized_w, 32);
77        let padded_h = round_up_to_multiple(resized_h, 32);
78
79        let mut array_hwc = Array3::<f32>::zeros((padded_h as usize, padded_w as usize, 3));
80
81        for y in 0..resized_h as usize {
82            for x in 0..resized_w as usize {
83                let pixel = rgb_image.get_pixel(x as u32, y as u32);
84                for c in 0..3 {
85                    array_hwc[[y, x, c]] = pixel[c] as f32 / 255.0;
86                }
87            }
88        }
89
90        let array_chw = array_hwc.permuted_axes([2, 0, 1]);
91        let array_nchw = array_chw.insert_axis(Axis(0));
92        let tensor: Tensor = array_nchw.into_dyn().into();
93
94        Ok(PreprocessedDetInput {
95            tensor,
96            resized_dims: (padded_w, padded_h),
97            scale_ratio,
98        })
99    }
100}
101
102fn compute_resized_dims(orig_w: u32, orig_h: u32, limit_side_len: u32) -> (u32, u32, f64) {
103    if limit_side_len == 0 {
104        return (orig_w, orig_h, 1.0);
105    }
106
107    let limit = limit_side_len as f64;
108    let max_side = (orig_w.max(orig_h)) as f64;
109    if max_side <= limit {
110        return (orig_w, orig_h, 1.0);
111    }
112
113    let scale_ratio = limit / max_side;
114    let resized_w = ((orig_w as f64 * scale_ratio).round().max(1.0)) as u32;
115    let resized_h = ((orig_h as f64 * scale_ratio).round().max(1.0)) as u32;
116
117    (resized_w, resized_h, scale_ratio)
118}
119
120fn round_up_to_multiple(value: u32, multiple: u32) -> u32 {
121    if multiple == 0 {
122        return value;
123    }
124
125    let remainder = value % multiple;
126    if remainder == 0 {
127        value
128    } else {
129        value + multiple - remainder
130    }
131}
132
133/// Rectangle specifying the area to crop for recognition preprocessing.
134#[derive(Debug, Clone, Copy)]
135pub struct RecTextRegion {
136    pub x: u32,
137    pub y: u32,
138    pub width: u32,
139    pub height: u32,
140}
141
142/// Configuration parameters for recognition preprocessing.
143#[derive(Debug, Clone)]
144pub struct RecPreProcessorConfig {
145    pub target_height: u32,
146    pub max_width: u32,
147    pub mean: [f32; 3],
148    pub std: [f32; 3],
149    pub pad_value: [f32; 3],
150}
151
152impl Default for RecPreProcessorConfig {
153    fn default() -> Self {
154        Self {
155            target_height: 48,
156            max_width: 320,
157            mean: [0.5, 0.5, 0.5],
158            std: [0.5, 0.5, 0.5],
159            pad_value: [0.0, 0.0, 0.0],
160        }
161    }
162}
163
164/// Errors that can be produced by recognition preprocessing.
165#[derive(Debug)]
166pub enum RecPreProcessorError {
167    /// The provided batch of regions is empty.
168    EmptyRegions,
169    /// The input image has zero width or height.
170    EmptyImage,
171    /// The configuration contains an invalid parameter (e.g. zero height/width).
172    InvalidConfiguration,
173    /// A region had zero width or height.
174    ZeroArea { index: usize },
175    /// A region extended beyond the bounds of the image.
176    RegionOutOfBounds {
177        index: usize,
178        image_dims: (u32, u32),
179        region: RecTextRegion,
180    },
181}
182
183impl std::fmt::Display for RecPreProcessorError {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        match self {
186            RecPreProcessorError::EmptyRegions => {
187                write!(f, "at least one text region is required for recognition")
188            }
189            RecPreProcessorError::EmptyImage => {
190                write!(f, "input image dimensions must be positive")
191            }
192            RecPreProcessorError::InvalidConfiguration => {
193                write!(f, "recognition preprocessor configuration is invalid")
194            }
195            RecPreProcessorError::ZeroArea { index } => {
196                write!(f, "text region at index {} has zero area", index)
197            }
198            RecPreProcessorError::RegionOutOfBounds {
199                index,
200                image_dims,
201                region,
202            } => write!(
203                f,
204                "text region at index {} (x={}, y={}, w={}, h={}) exceeds image bounds {:?}",
205                index, region.x, region.y, region.width, region.height, image_dims
206            ),
207        }
208    }
209}
210
211impl std::error::Error for RecPreProcessorError {}
212
213/// Result of recognition preprocessing.
214#[derive(Debug, Clone)]
215pub struct PreprocessedRecBatch {
216    pub tensor: Tensor,
217    pub valid_widths: Vec<u32>,
218    pub max_width: u32,
219}
220
221impl PreprocessedRecBatch {
222    pub fn valid_width_ratios(&self) -> Vec<f32> {
223        if self.max_width == 0 {
224            return vec![0.0; self.valid_widths.len()];
225        }
226        self.valid_widths
227            .iter()
228            .map(|width| *width as f32 / self.max_width as f32)
229            .collect()
230    }
231}
232
233/// SVTR recognition preprocessor.
234#[derive(Debug, Clone)]
235pub struct RecPreProcessor {
236    config: RecPreProcessorConfig,
237}
238
239impl RecPreProcessor {
240    pub fn new(config: RecPreProcessorConfig) -> Self {
241        Self { config }
242    }
243
244    pub fn process(
245        &self,
246        image: &DynamicImage,
247        regions: &[RecTextRegion],
248    ) -> Result<PreprocessedRecBatch, RecPreProcessorError> {
249        if regions.is_empty() {
250            return Err(RecPreProcessorError::EmptyRegions);
251        }
252
253        if self.config.target_height == 0 || self.config.max_width == 0 {
254            return Err(RecPreProcessorError::InvalidConfiguration);
255        }
256
257        let (img_w, img_h) = image.dimensions();
258        if img_w == 0 || img_h == 0 {
259            return Err(RecPreProcessorError::EmptyImage);
260        }
261
262        let target_height = self.config.target_height;
263        let max_width = self.config.max_width;
264        let batch_size = regions.len();
265
266        let mut batch =
267            Array4::<f32>::zeros((batch_size, 3, target_height as usize, max_width as usize));
268
269        for sample in 0..batch_size {
270            for channel in 0..3 {
271                let pad = normalize_value(
272                    self.config.pad_value[channel],
273                    self.config.mean[channel],
274                    self.config.std[channel],
275                );
276                batch.slice_mut(s![sample, channel, .., ..]).fill(pad);
277            }
278        }
279
280        let mut valid_widths = Vec::with_capacity(batch_size);
281
282        for (index, region) in regions.iter().copied().enumerate() {
283            if region.width == 0 || region.height == 0 {
284                return Err(RecPreProcessorError::ZeroArea { index });
285            }
286
287            if region.x >= img_w
288                || region.y >= img_h
289                || region.x + region.width > img_w
290                || region.y + region.height > img_h
291            {
292                return Err(RecPreProcessorError::RegionOutOfBounds {
293                    index,
294                    image_dims: (img_w, img_h),
295                    region,
296                });
297            }
298
299            let cropped = image.crop_imm(region.x, region.y, region.width, region.height);
300            let aspect_ratio = region.width as f32 / region.height as f32;
301            let mut target_width = (aspect_ratio * target_height as f32)
302                .round()
303                .clamp(1.0, max_width as f32) as u32;
304            if target_width == 0 {
305                target_width = 1;
306            }
307
308            let resized = cropped.resize_exact(target_width, target_height, FilterType::Lanczos3);
309            let rgb_image = resized.to_rgb8();
310
311            for y in 0..target_height as usize {
312                for x in 0..target_width as usize {
313                    let pixel = rgb_image.get_pixel(x as u32, y as u32);
314                    for channel in 0..3 {
315                        let value = pixel[channel] as f32 / 255.0;
316                        let normalized = normalize_value(
317                            value,
318                            self.config.mean[channel],
319                            self.config.std[channel],
320                        );
321                        batch[[index, channel, y, x]] = normalized;
322                    }
323                }
324            }
325
326            valid_widths.push(target_width);
327        }
328
329        let tensor: Tensor = batch.into_dyn().into();
330        Ok(PreprocessedRecBatch {
331            tensor,
332            valid_widths,
333            max_width,
334        })
335    }
336}
337
338fn normalize_value(value: f32, mean: f32, std: f32) -> f32 {
339    if std == 0.0 {
340        0.0
341    } else {
342        (value - mean) / std
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use image::{ImageBuffer, Rgb};
350
351    fn solid_image(width: u32, height: u32, value: u8) -> DynamicImage {
352        let pixel = Rgb([value, value.saturating_sub(1), value.saturating_add(1)]);
353        let buffer = ImageBuffer::from_pixel(width, height, pixel);
354        DynamicImage::ImageRgb8(buffer)
355    }
356
357    fn gradient_image(width: u32, height: u32) -> DynamicImage {
358        let mut buffer = ImageBuffer::new(width, height);
359        for (x, y, pixel) in buffer.enumerate_pixels_mut() {
360            let base = ((x + y) % 256) as u8;
361            let green = base.saturating_add(32);
362            let blue = base.saturating_add(64);
363            *pixel = Rgb([base, green, blue]);
364        }
365        DynamicImage::ImageRgb8(buffer)
366    }
367
368    #[test]
369    fn resize_long_side_to_limit() {
370        let image = solid_image(1920, 1080, 128);
371        let preprocessor = DetPreProcessor::new(DetPreProcessorConfig::default());
372
373        let result = preprocessor.process(&image).unwrap();
374
375        assert_eq!(result.resized_dims, (960, 544));
376        assert!((result.scale_ratio - 0.5).abs() < f64::EPSILON);
377    }
378
379    #[test]
380    fn keep_original_size_when_within_limit() {
381        let image = solid_image(800, 600, 64);
382        let preprocessor = DetPreProcessor::new(DetPreProcessorConfig::default());
383
384        let result = preprocessor.process(&image).unwrap();
385
386        assert_eq!(result.resized_dims, (800, 608));
387        assert!((result.scale_ratio - 1.0).abs() < f64::EPSILON);
388    }
389
390    #[test]
391    fn tensor_shape_and_normalization() {
392        let image = solid_image(320, 320, 255);
393        let preprocessor = DetPreProcessor::new(DetPreProcessorConfig {
394            limit_side_len: 320,
395        });
396
397        let result = preprocessor.process(&image).unwrap();
398        assert_eq!(result.tensor.shape(), &[1, 3, 320, 320]);
399
400        let array = result.tensor.to_array_view::<f32>().unwrap();
401        let min = array.iter().cloned().fold(f32::INFINITY, f32::min);
402        let max = array.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
403        assert!(min >= 0.0);
404        assert!(max <= 1.0);
405        assert!((max - 1.0).abs() < 1e-6);
406    }
407
408    #[test]
409    fn detection_tensor_dims_are_padded_to_multiple_of_32() {
410        let image = solid_image(123, 77, 200);
411        let preprocessor = DetPreProcessor::new(DetPreProcessorConfig::default());
412
413        let result = preprocessor.process(&image).unwrap();
414
415        assert_eq!(result.resized_dims, (128, 96));
416        assert_eq!(result.tensor.shape(), &[1, 3, 96, 128]);
417        assert!((result.scale_ratio - 1.0).abs() < f64::EPSILON);
418    }
419
420    #[test]
421    fn recognition_single_region_preprocessing() {
422        let image = gradient_image(200, 100);
423        let config = RecPreProcessorConfig::default();
424        let regions = vec![RecTextRegion {
425            x: 20,
426            y: 10,
427            width: 80,
428            height: 40,
429        }];
430
431        let preprocessor = RecPreProcessor::new(config.clone());
432        let batch = preprocessor.process(&image, &regions).unwrap();
433
434        let expected_shape = [
435            1,
436            3,
437            config.target_height as usize,
438            config.max_width as usize,
439        ];
440        assert_eq!(batch.tensor.shape(), &expected_shape);
441        assert_eq!(batch.valid_widths, vec![96]);
442
443        let tensor = batch.tensor.to_array_view::<f32>().unwrap();
444        let pad = normalize_value(config.pad_value[0], config.mean[0], config.std[0]);
445        assert!(
446            (tensor[[0, 0, 0, (config.max_width - 1) as usize]] - pad).abs() < 1e-6,
447            "padded area should remain at pad value"
448        );
449        assert!(
450            (tensor[[0, 0, 0, 0]] - pad).abs() > 1e-3,
451            "cropped content should differ from pad value"
452        );
453
454        let ratios = batch.valid_width_ratios();
455        assert_eq!(ratios.len(), 1);
456        assert!((ratios[0] - 96.0 / config.max_width as f32).abs() < f32::EPSILON);
457    }
458
459    #[test]
460    fn recognition_multiple_regions_padding() {
461        let image = gradient_image(320, 160);
462        let config = RecPreProcessorConfig::default();
463        let regions = vec![
464            RecTextRegion {
465                x: 0,
466                y: 0,
467                width: 120,
468                height: 60,
469            },
470            RecTextRegion {
471                x: 150,
472                y: 40,
473                width: 40,
474                height: 80,
475            },
476        ];
477
478        let preprocessor = RecPreProcessor::new(config.clone());
479        let batch = preprocessor.process(&image, &regions).unwrap();
480
481        assert_eq!(batch.valid_widths, vec![96, 24]);
482
483        let tensor = batch.tensor.to_array_view::<f32>().unwrap();
484        let pad = normalize_value(config.pad_value[0], config.mean[0], config.std[0]);
485
486        // Ensure padding column for first sample is untouched.
487        assert!((tensor[[0, 0, 10, (config.max_width - 1) as usize]] - pad).abs() < 1e-6);
488        // Ensure padding column for second sample is untouched.
489        assert!((tensor[[1, 1, 20, (config.max_width - 1) as usize]] - pad).abs() < 1e-6);
490    }
491
492    #[test]
493    fn recognition_region_out_of_bounds_is_error() {
494        let image = gradient_image(100, 50);
495        let config = RecPreProcessorConfig::default();
496        let regions = vec![RecTextRegion {
497            x: 80,
498            y: 10,
499            width: 30,
500            height: 20,
501        }];
502
503        let preprocessor = RecPreProcessor::new(config);
504        let error = preprocessor.process(&image, &regions).unwrap_err();
505        assert!(matches!(
506            error,
507            RecPreProcessorError::RegionOutOfBounds { index: 0, .. }
508        ));
509    }
510
511    #[test]
512    fn recognition_zero_area_region_is_error() {
513        let image = gradient_image(100, 50);
514        let config = RecPreProcessorConfig::default();
515        let regions = vec![RecTextRegion {
516            x: 10,
517            y: 10,
518            width: 0,
519            height: 20,
520        }];
521
522        let preprocessor = RecPreProcessor::new(config);
523        let error = preprocessor.process(&image, &regions).unwrap_err();
524        assert!(matches!(error, RecPreProcessorError::ZeroArea { index: 0 }));
525    }
526}