ocr_rs/
preprocess.rs

1//! Image Preprocessing Utilities
2//!
3//! Provides various image preprocessing functions required for OCR
4
5use image::{DynamicImage, GenericImageView, RgbImage};
6use ndarray::{Array4, ArrayBase, Dim, OwnedRepr};
7
8/// Image normalization parameters
9#[derive(Debug, Clone)]
10pub struct NormalizeParams {
11    /// RGB channel means
12    pub mean: [f32; 3],
13    /// RGB channel standard deviations
14    pub std: [f32; 3],
15}
16
17impl Default for NormalizeParams {
18    fn default() -> Self {
19        // ImageNet normalization parameters
20        Self {
21            mean: [0.485, 0.456, 0.406],
22            std: [0.229, 0.224, 0.225],
23        }
24    }
25}
26
27impl NormalizeParams {
28    /// Normalization parameters for PaddleOCR detection model
29    pub fn paddle_det() -> Self {
30        Self {
31            mean: [0.485, 0.456, 0.406],
32            std: [0.229, 0.224, 0.225],
33        }
34    }
35
36    /// Normalization parameters for PaddleOCR recognition model
37    pub fn paddle_rec() -> Self {
38        Self {
39            mean: [0.5, 0.5, 0.5],
40            std: [0.5, 0.5, 0.5],
41        }
42    }
43}
44
45/// Calculate size to pad to (multiple of 32)
46#[inline]
47pub fn get_padded_size(size: u32) -> u32 {
48    ((size + 31) / 32) * 32
49}
50
51/// Scale image to specified maximum side length
52///
53/// Maintains aspect ratio, scales longest side to max_side_len
54pub fn resize_to_max_side(img: &DynamicImage, max_side_len: u32) -> DynamicImage {
55    let (w, h) = img.dimensions();
56    let max_dim = w.max(h);
57
58    if max_dim <= max_side_len {
59        return img.clone();
60    }
61
62    let scale = max_side_len as f64 / max_dim as f64;
63    let new_w = (w as f64 * scale).round() as u32;
64    let new_h = (h as f64 * scale).round() as u32;
65
66    fast_resize(img, new_w, new_h)
67}
68
69/// Scale image to specified height (for recognition model)
70///
71/// Scales maintaining aspect ratio
72pub fn resize_to_height(img: &DynamicImage, target_height: u32) -> DynamicImage {
73    let (w, h) = img.dimensions();
74
75    if h == target_height {
76        return img.clone();
77    }
78
79    let scale = target_height as f64 / h as f64;
80    let new_w = (w as f64 * scale).round() as u32;
81
82    fast_resize(img, new_w, target_height)
83}
84
85/// Fast image resizing using fast_image_resize
86/// Can pass DynamicImage directly when "image" feature is enabled
87fn fast_resize(img: &DynamicImage, new_w: u32, new_h: u32) -> DynamicImage {
88    use fast_image_resize::{images::Image, IntoImageView, PixelType, Resizer};
89
90    // Get source image pixel type
91    let pixel_type = img.pixel_type().unwrap_or(PixelType::U8x3);
92
93    // Create destination image container
94    let mut dst_image = Image::new(new_w, new_h, pixel_type);
95
96    // Resize using Resizer (pass DynamicImage directly, no manual conversion needed)
97    let mut resizer = Resizer::new();
98    resizer.resize(img, &mut dst_image, None).unwrap();
99
100    // Convert result back to DynamicImage
101    match pixel_type {
102        PixelType::U8x3 => {
103            DynamicImage::ImageRgb8(RgbImage::from_raw(new_w, new_h, dst_image.into_vec()).unwrap())
104        }
105        PixelType::U8x4 => DynamicImage::ImageRgba8(
106            image::RgbaImage::from_raw(new_w, new_h, dst_image.into_vec()).unwrap(),
107        ),
108        _ => {
109            // Convert other types to RGB
110            DynamicImage::ImageRgb8(RgbImage::from_raw(new_w, new_h, dst_image.into_vec()).unwrap())
111        }
112    }
113}
114
115/// Convert image to detection model input tensor
116///
117/// Output format: [1, 3, H, W] (NCHW)
118pub fn preprocess_for_det(
119    img: &DynamicImage,
120    params: &NormalizeParams,
121) -> ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>> {
122    let (w, h) = img.dimensions();
123    let pad_w = get_padded_size(w) as usize;
124    let pad_h = get_padded_size(h) as usize;
125
126    let mut input = Array4::<f32>::zeros((1, 3, pad_h, pad_w));
127    let rgb_img = img.to_rgb8();
128
129    // Normalize and pad
130    for y in 0..h as usize {
131        for x in 0..w as usize {
132            let pixel = rgb_img.get_pixel(x as u32, y as u32);
133            let [r, g, b] = pixel.0;
134
135            input[[0, 0, y, x]] = (r as f32 / 255.0 - params.mean[0]) / params.std[0];
136            input[[0, 1, y, x]] = (g as f32 / 255.0 - params.mean[1]) / params.std[1];
137            input[[0, 2, y, x]] = (b as f32 / 255.0 - params.mean[2]) / params.std[2];
138        }
139    }
140
141    input
142}
143
144/// Convert image to recognition model input tensor
145///
146/// Output format: [1, 3, H, W] (NCHW)
147/// Height is fixed at 48 (or specified value), width scaled proportionally
148pub fn preprocess_for_rec(
149    img: &DynamicImage,
150    target_height: u32,
151    params: &NormalizeParams,
152) -> ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>> {
153    let (w, h) = img.dimensions();
154
155    // Calculate scaled width
156    let scale = target_height as f64 / h as f64;
157    let target_width = (w as f64 * scale).round() as u32;
158
159    // Scale image
160    let resized = if h != target_height {
161        img.resize_exact(
162            target_width,
163            target_height,
164            image::imageops::FilterType::Lanczos3,
165        )
166    } else {
167        img.clone()
168    };
169
170    let rgb_img = resized.to_rgb8();
171    let (w, h) = (target_width as usize, target_height as usize);
172
173    let mut input = Array4::<f32>::zeros((1, 3, h, w));
174
175    for y in 0..h {
176        for x in 0..w {
177            let pixel = rgb_img.get_pixel(x as u32, y as u32);
178            let [r, g, b] = pixel.0;
179
180            input[[0, 0, y, x]] = (r as f32 / 255.0 - params.mean[0]) / params.std[0];
181            input[[0, 1, y, x]] = (g as f32 / 255.0 - params.mean[1]) / params.std[1];
182            input[[0, 2, y, x]] = (b as f32 / 255.0 - params.mean[2]) / params.std[2];
183        }
184    }
185
186    input
187}
188
189/// Batch preprocess recognition images
190///
191/// Process multiple images into batch tensor, all images padded to same width
192pub fn preprocess_batch_for_rec(
193    images: &[DynamicImage],
194    target_height: u32,
195    params: &NormalizeParams,
196) -> ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>> {
197    if images.is_empty() {
198        return Array4::<f32>::zeros((0, 3, target_height as usize, 0));
199    }
200
201    // Calculate scaled width for all images
202    let widths: Vec<u32> = images
203        .iter()
204        .map(|img| {
205            let (w, h) = img.dimensions();
206            let scale = target_height as f64 / h as f64;
207            (w as f64 * scale).round() as u32
208        })
209        .collect();
210
211    let max_width = *widths.iter().max().unwrap() as usize;
212    let batch_size = images.len();
213
214    let mut batch = Array4::<f32>::zeros((batch_size, 3, target_height as usize, max_width));
215
216    for (i, (img, &w)) in images.iter().zip(widths.iter()).enumerate() {
217        let resized = resize_to_height(img, target_height);
218        let rgb_img = resized.to_rgb8();
219
220        for y in 0..target_height as usize {
221            for x in 0..w as usize {
222                let pixel = rgb_img.get_pixel(x as u32, y as u32);
223                let [r, g, b] = pixel.0;
224
225                batch[[i, 0, y, x]] = (r as f32 / 255.0 - params.mean[0]) / params.std[0];
226                batch[[i, 1, y, x]] = (g as f32 / 255.0 - params.mean[1]) / params.std[1];
227                batch[[i, 2, y, x]] = (b as f32 / 255.0 - params.mean[2]) / params.std[2];
228            }
229        }
230    }
231
232    batch
233}
234
235/// Crop image region
236pub fn crop_image(img: &DynamicImage, x: u32, y: u32, width: u32, height: u32) -> DynamicImage {
237    img.crop_imm(x, y, width, height)
238}
239
240/// Split image into blocks (for high precision mode)
241///
242/// # Parameters
243/// - `img`: Input image
244/// - `block_size`: Block size
245/// - `overlap`: Overlap region size
246///
247/// # Returns
248/// List of block images and their positions in original image (x, y)
249pub fn split_into_blocks(
250    img: &DynamicImage,
251    block_size: u32,
252    overlap: u32,
253) -> Vec<(DynamicImage, u32, u32)> {
254    let (width, height) = img.dimensions();
255    let mut blocks = Vec::new();
256
257    let step = block_size - overlap;
258
259    let mut y = 0u32;
260    while y < height {
261        let mut x = 0u32;
262        while x < width {
263            let block_w = (block_size).min(width - x);
264            let block_h = (block_size).min(height - y);
265
266            let block = img.crop_imm(x, y, block_w, block_h);
267            blocks.push((block, x, y));
268
269            x += step;
270            if x + overlap >= width && x < width {
271                break;
272            }
273        }
274
275        y += step;
276        if y + overlap >= height && y < height {
277            break;
278        }
279    }
280
281    blocks
282}
283
284/// Convert grayscale mask to binary mask
285pub fn threshold_mask(mask: &[f32], threshold: f32) -> Vec<u8> {
286    mask.iter()
287        .map(|&v| if v > threshold { 255u8 } else { 0u8 })
288        .collect()
289}
290
291/// Create grayscale image
292pub fn create_gray_image(data: &[u8], width: u32, height: u32) -> image::GrayImage {
293    image::GrayImage::from_raw(width, height, data.to_vec())
294        .unwrap_or_else(|| image::GrayImage::new(width, height))
295}
296
297/// Convert image to RGB
298pub fn to_rgb(img: &DynamicImage) -> RgbImage {
299    img.to_rgb8()
300}
301
302/// Create image from RGB data
303pub fn rgb_to_image(data: &[u8], width: u32, height: u32) -> DynamicImage {
304    let rgb = RgbImage::from_raw(width, height, data.to_vec())
305        .unwrap_or_else(|| RgbImage::new(width, height));
306    DynamicImage::ImageRgb8(rgb)
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_padded_size() {
315        assert_eq!(get_padded_size(100), 128);
316        assert_eq!(get_padded_size(32), 32);
317        assert_eq!(get_padded_size(33), 64);
318        assert_eq!(get_padded_size(0), 0);
319        assert_eq!(get_padded_size(1), 32);
320        assert_eq!(get_padded_size(31), 32);
321        assert_eq!(get_padded_size(64), 64);
322        assert_eq!(get_padded_size(65), 96);
323    }
324
325    #[test]
326    fn test_normalize_params() {
327        let params = NormalizeParams::default();
328        assert_eq!(params.mean[0], 0.485);
329
330        let paddle = NormalizeParams::paddle_det();
331        assert_eq!(paddle.mean[0], 0.485);
332        assert_eq!(paddle.std[0], 0.229);
333    }
334
335    #[test]
336    fn test_normalize_params_paddle_rec() {
337        let params = NormalizeParams::paddle_rec();
338        assert_eq!(params.mean[0], 0.5);
339        assert_eq!(params.mean[1], 0.5);
340        assert_eq!(params.mean[2], 0.5);
341        assert_eq!(params.std[0], 0.5);
342        assert_eq!(params.std[1], 0.5);
343        assert_eq!(params.std[2], 0.5);
344    }
345
346    #[test]
347    fn test_resize_to_max_side_no_resize() {
348        let img = DynamicImage::new_rgb8(100, 50);
349        let resized = resize_to_max_side(&img, 200);
350
351        // 图像已经小于最大边,不应该缩放
352        assert_eq!(resized.width(), 100);
353        assert_eq!(resized.height(), 50);
354    }
355
356    #[test]
357    fn test_resize_to_max_side_width_limited() {
358        let img = DynamicImage::new_rgb8(1000, 500);
359        let resized = resize_to_max_side(&img, 500);
360
361        // 宽度是最大边,应该缩放到 500
362        assert_eq!(resized.width(), 500);
363        assert_eq!(resized.height(), 250);
364    }
365
366    #[test]
367    fn test_resize_to_max_side_height_limited() {
368        let img = DynamicImage::new_rgb8(500, 1000);
369        let resized = resize_to_max_side(&img, 500);
370
371        // 高度是最大边,应该缩放到 500
372        assert_eq!(resized.width(), 250);
373        assert_eq!(resized.height(), 500);
374    }
375
376    #[test]
377    fn test_resize_to_height() {
378        let img = DynamicImage::new_rgb8(200, 100);
379        let resized = resize_to_height(&img, 48);
380
381        assert_eq!(resized.height(), 48);
382        // 宽度应该按比例缩放: 200 * 48/100 = 96
383        assert_eq!(resized.width(), 96);
384    }
385
386    #[test]
387    fn test_resize_to_height_no_resize() {
388        let img = DynamicImage::new_rgb8(200, 48);
389        let resized = resize_to_height(&img, 48);
390
391        // 高度已经是目标高度,不应该缩放
392        assert_eq!(resized.height(), 48);
393        assert_eq!(resized.width(), 200);
394    }
395
396    #[test]
397    fn test_preprocess_for_det_shape() {
398        let img = DynamicImage::new_rgb8(100, 50);
399        let params = NormalizeParams::paddle_det();
400        let tensor = preprocess_for_det(&img, &params);
401
402        // 输出形状应该是 [1, 3, H, W],H 和 W 是 32 的倍数
403        assert_eq!(tensor.shape()[0], 1);
404        assert_eq!(tensor.shape()[1], 3);
405        assert_eq!(tensor.shape()[2], 64); // 50 向上取整到 64
406        assert_eq!(tensor.shape()[3], 128); // 100 向上取整到 128
407    }
408
409    #[test]
410    fn test_preprocess_for_rec_shape() {
411        let img = DynamicImage::new_rgb8(200, 100);
412        let params = NormalizeParams::paddle_rec();
413        let tensor = preprocess_for_rec(&img, 48, &params);
414
415        // 输出高度应该是 48
416        assert_eq!(tensor.shape()[0], 1);
417        assert_eq!(tensor.shape()[1], 3);
418        assert_eq!(tensor.shape()[2], 48);
419        // 宽度应该按比例缩放: 200 * 48/100 = 96
420        assert_eq!(tensor.shape()[3], 96);
421    }
422
423    #[test]
424    fn test_preprocess_batch_for_rec_empty() {
425        let images: Vec<DynamicImage> = vec![];
426        let params = NormalizeParams::paddle_rec();
427        let tensor = preprocess_batch_for_rec(&images, 48, &params);
428
429        assert_eq!(tensor.shape()[0], 0);
430    }
431
432    #[test]
433    fn test_preprocess_batch_for_rec_single() {
434        let images = vec![DynamicImage::new_rgb8(200, 100)];
435        let params = NormalizeParams::paddle_rec();
436        let tensor = preprocess_batch_for_rec(&images, 48, &params);
437
438        assert_eq!(tensor.shape()[0], 1);
439        assert_eq!(tensor.shape()[1], 3);
440        assert_eq!(tensor.shape()[2], 48);
441    }
442
443    #[test]
444    fn test_preprocess_batch_for_rec_multiple() {
445        let images = vec![
446            DynamicImage::new_rgb8(200, 100),
447            DynamicImage::new_rgb8(300, 100),
448        ];
449        let params = NormalizeParams::paddle_rec();
450        let tensor = preprocess_batch_for_rec(&images, 48, &params);
451
452        assert_eq!(tensor.shape()[0], 2);
453        assert_eq!(tensor.shape()[1], 3);
454        assert_eq!(tensor.shape()[2], 48);
455        // 宽度应该是最大宽度: max(96, 144) = 144
456        assert_eq!(tensor.shape()[3], 144);
457    }
458
459    #[test]
460    fn test_crop_image() {
461        let img = DynamicImage::new_rgb8(200, 100);
462        let cropped = crop_image(&img, 50, 25, 100, 50);
463
464        assert_eq!(cropped.width(), 100);
465        assert_eq!(cropped.height(), 50);
466    }
467
468    #[test]
469    fn test_split_into_blocks() {
470        let img = DynamicImage::new_rgb8(500, 500);
471        let blocks = split_into_blocks(&img, 200, 50);
472
473        // 应该有多个块
474        assert!(!blocks.is_empty());
475
476        // 每个块的位置应该记录正确
477        for (block, x, y) in &blocks {
478            assert!(block.width() <= 200);
479            assert!(block.height() <= 200);
480            assert!(*x < 500);
481            assert!(*y < 500);
482        }
483    }
484
485    #[test]
486    fn test_split_into_blocks_small_image() {
487        let img = DynamicImage::new_rgb8(100, 100);
488        let blocks = split_into_blocks(&img, 200, 50);
489
490        // 图像小于块大小,应该只有一个块
491        assert_eq!(blocks.len(), 1);
492        assert_eq!(blocks[0].1, 0); // x offset
493        assert_eq!(blocks[0].2, 0); // y offset
494    }
495
496    #[test]
497    fn test_threshold_mask() {
498        let mask = vec![0.1, 0.3, 0.5, 0.7, 0.9];
499        let binary = threshold_mask(&mask, 0.5);
500
501        assert_eq!(binary, vec![0, 0, 0, 255, 255]);
502    }
503
504    #[test]
505    fn test_threshold_mask_all_below() {
506        let mask = vec![0.1, 0.2, 0.3, 0.4];
507        let binary = threshold_mask(&mask, 0.5);
508
509        assert_eq!(binary, vec![0, 0, 0, 0]);
510    }
511
512    #[test]
513    fn test_threshold_mask_all_above() {
514        let mask = vec![0.6, 0.7, 0.8, 0.9];
515        let binary = threshold_mask(&mask, 0.5);
516
517        assert_eq!(binary, vec![255, 255, 255, 255]);
518    }
519
520    #[test]
521    fn test_create_gray_image() {
522        let data = vec![128u8; 100];
523        let gray = create_gray_image(&data, 10, 10);
524
525        assert_eq!(gray.width(), 10);
526        assert_eq!(gray.height(), 10);
527    }
528
529    #[test]
530    fn test_to_rgb() {
531        let img = DynamicImage::new_rgb8(100, 50);
532        let rgb = to_rgb(&img);
533
534        assert_eq!(rgb.width(), 100);
535        assert_eq!(rgb.height(), 50);
536    }
537
538    #[test]
539    fn test_rgb_to_image() {
540        let data = vec![128u8; 300]; // 10x10 RGB
541        let img = rgb_to_image(&data, 10, 10);
542
543        assert_eq!(img.width(), 10);
544        assert_eq!(img.height(), 10);
545    }
546}