Skip to main content

ocr_rs/
preprocess.rs

1//! Image Preprocessing Utilities
2//!
3//! Provides various image preprocessing functions required for OCR
4
5use image::{DynamicImage, GenericImageView, RgbImage};
6use ndarray::{Array4, ArrayBase, Dim, OwnedRepr};
7
8use crate::error::{OcrError, OcrResult};
9
10/// Image normalization parameters
11#[derive(Debug, Clone)]
12pub struct NormalizeParams {
13    /// RGB channel means
14    pub mean: [f32; 3],
15    /// RGB channel standard deviations
16    pub std: [f32; 3],
17}
18
19impl Default for NormalizeParams {
20    fn default() -> Self {
21        // ImageNet normalization parameters
22        Self {
23            mean: [0.485, 0.456, 0.406],
24            std: [0.229, 0.224, 0.225],
25        }
26    }
27}
28
29impl NormalizeParams {
30    /// Normalization parameters for PaddleOCR detection model
31    pub fn paddle_det() -> Self {
32        Self {
33            mean: [0.485, 0.456, 0.406],
34            std: [0.229, 0.224, 0.225],
35        }
36    }
37
38    /// Normalization parameters for PaddleOCR recognition model
39    pub fn paddle_rec() -> Self {
40        Self {
41            mean: [0.5, 0.5, 0.5],
42            std: [0.5, 0.5, 0.5],
43        }
44    }
45}
46
47/// Calculate size to pad to (multiple of 32)
48#[inline]
49pub fn get_padded_size(size: u32) -> u32 {
50    ((size + 31) / 32) * 32
51}
52
53/// Scale image to specified maximum side length
54///
55/// Maintains aspect ratio, scales longest side to max_side_len
56pub fn resize_to_max_side(img: &DynamicImage, max_side_len: u32) -> OcrResult<DynamicImage> {
57    let (w, h) = img.dimensions();
58    let max_dim = w.max(h);
59
60    if max_dim <= max_side_len {
61        return Ok(img.clone());
62    }
63
64    let scale = max_side_len as f64 / max_dim as f64;
65    let new_w = (w as f64 * scale).round() as u32;
66    let new_h = (h as f64 * scale).round() as u32;
67
68    fast_resize(img, new_w, new_h)
69}
70
71/// Scale image to specified height (for recognition model)
72///
73/// Scales maintaining aspect ratio
74pub fn resize_to_height(img: &DynamicImage, target_height: u32) -> OcrResult<DynamicImage> {
75    let (w, h) = img.dimensions();
76
77    if h == target_height {
78        return Ok(img.clone());
79    }
80
81    let scale = target_height as f64 / h as f64;
82    let new_w = (w as f64 * scale).round() as u32;
83
84    fast_resize(img, new_w, target_height)
85}
86
87/// Fast image resizing using fast_image_resize
88/// Can pass DynamicImage directly when "image" feature is enabled
89fn fast_resize(img: &DynamicImage, new_w: u32, new_h: u32) -> OcrResult<DynamicImage> {
90    use fast_image_resize::{images::Image, IntoImageView, PixelType, Resizer};
91
92    // Only U8x3 (RGB) and U8x4 (RGBA) are handled end-to-end.
93    // Grayscale (U8), 16-bit, and other formats must be converted to RGB first;
94    // otherwise the output buffer byte count would not match the expected channel count.
95    let converted: DynamicImage;
96    let (src, pixel_type) = match img.pixel_type() {
97        Some(PixelType::U8x3) => (img, PixelType::U8x3),
98        Some(PixelType::U8x4) => (img, PixelType::U8x4),
99        _ => {
100            converted = DynamicImage::ImageRgb8(img.to_rgb8());
101            (&converted, PixelType::U8x3)
102        }
103    };
104
105    // Create destination image container
106    let mut dst_image = Image::new(new_w, new_h, pixel_type);
107
108    // Resize using Resizer
109    let mut resizer = Resizer::new();
110    resizer
111        .resize(src, &mut dst_image, None)
112        .map_err(|e| OcrError::PreprocessError(format!("Image resize failed: {e}")))?;
113
114    // Convert result back to DynamicImage
115    match pixel_type {
116        PixelType::U8x3 => RgbImage::from_raw(new_w, new_h, dst_image.into_vec())
117            .map(DynamicImage::ImageRgb8)
118            .ok_or_else(|| {
119                OcrError::PreprocessError("RGB buffer size mismatch after resize".into())
120            }),
121        PixelType::U8x4 => image::RgbaImage::from_raw(new_w, new_h, dst_image.into_vec())
122            .map(DynamicImage::ImageRgba8)
123            .ok_or_else(|| {
124                OcrError::PreprocessError("RGBA buffer size mismatch after resize".into())
125            }),
126        _ => unreachable!("pixel_type is constrained to U8x3 or U8x4 above"),
127    }
128}
129
130/// Convert image to detection model input tensor
131///
132/// Output format: [1, 3, H, W] (NCHW)
133pub fn preprocess_for_det(
134    img: &DynamicImage,
135    params: &NormalizeParams,
136) -> OcrResult<ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>>> {
137    let (w, h) = img.dimensions();
138    let pad_w = get_padded_size(w) as usize;
139    let pad_h = get_padded_size(h) as usize;
140
141    let mut input = Array4::<f32>::zeros((1, 3, pad_h, pad_w));
142    let rgb_img = img.to_rgb8();
143
144    // Normalize and pad
145    for y in 0..h as usize {
146        for x in 0..w as usize {
147            let pixel = rgb_img.get_pixel(x as u32, y as u32);
148            let [r, g, b] = pixel.0;
149
150            input[[0, 0, y, x]] = (r as f32 / 255.0 - params.mean[0]) / params.std[0];
151            input[[0, 1, y, x]] = (g as f32 / 255.0 - params.mean[1]) / params.std[1];
152            input[[0, 2, y, x]] = (b as f32 / 255.0 - params.mean[2]) / params.std[2];
153        }
154    }
155
156    Ok(input)
157}
158
159/// Convert image to recognition model input tensor
160///
161/// Output format: [1, 3, H, W] (NCHW)
162/// Height is fixed at 48 (or specified value), width scaled proportionally
163pub fn preprocess_for_rec(
164    img: &DynamicImage,
165    target_height: u32,
166    params: &NormalizeParams,
167) -> OcrResult<ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>>> {
168    let (w, h) = img.dimensions();
169
170    // Calculate scaled width
171    let scale = target_height as f64 / h as f64;
172    let target_width = (w as f64 * scale).round() as u32;
173
174    // Scale image
175    let resized = if h != target_height {
176        img.resize_exact(
177            target_width,
178            target_height,
179            image::imageops::FilterType::Lanczos3,
180        )
181    } else {
182        img.clone()
183    };
184
185    let rgb_img = resized.to_rgb8();
186    let (w, h) = (target_width as usize, target_height as usize);
187
188    let mut input = Array4::<f32>::zeros((1, 3, h, w));
189
190    for y in 0..h {
191        for x in 0..w {
192            let pixel = rgb_img.get_pixel(x as u32, y as u32);
193            let [r, g, b] = pixel.0;
194
195            input[[0, 0, y, x]] = (r as f32 / 255.0 - params.mean[0]) / params.std[0];
196            input[[0, 1, y, x]] = (g as f32 / 255.0 - params.mean[1]) / params.std[1];
197            input[[0, 2, y, x]] = (b as f32 / 255.0 - params.mean[2]) / params.std[2];
198        }
199    }
200
201    Ok(input)
202}
203
204/// Batch preprocess recognition images
205///
206/// Process multiple images into batch tensor, all images padded to same width
207pub fn preprocess_batch_for_rec(
208    images: &[DynamicImage],
209    target_height: u32,
210    params: &NormalizeParams,
211) -> OcrResult<ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>>> {
212    if images.is_empty() {
213        return Ok(Array4::<f32>::zeros((0, 3, target_height as usize, 0)));
214    }
215
216    // Calculate scaled width for all images
217    let widths: Vec<u32> = images
218        .iter()
219        .map(|img| {
220            let (w, h) = img.dimensions();
221            let scale = target_height as f64 / h as f64;
222            (w as f64 * scale).round() as u32
223        })
224        .collect();
225
226    // widths is non-empty because images is non-empty (checked above)
227    let max_width = *widths.iter().max().unwrap() as usize;
228    let batch_size = images.len();
229
230    let mut batch = Array4::<f32>::zeros((batch_size, 3, target_height as usize, max_width));
231
232    for (i, (img, &w)) in images.iter().zip(widths.iter()).enumerate() {
233        let resized = resize_to_height(img, target_height)?;
234        let rgb_img = resized.to_rgb8();
235
236        for y in 0..target_height as usize {
237            for x in 0..w as usize {
238                let pixel = rgb_img.get_pixel(x as u32, y as u32);
239                let [r, g, b] = pixel.0;
240
241                batch[[i, 0, y, x]] = (r as f32 / 255.0 - params.mean[0]) / params.std[0];
242                batch[[i, 1, y, x]] = (g as f32 / 255.0 - params.mean[1]) / params.std[1];
243                batch[[i, 2, y, x]] = (b as f32 / 255.0 - params.mean[2]) / params.std[2];
244            }
245        }
246    }
247
248    Ok(batch)
249}
250
251/// Crop image region
252pub fn crop_image(img: &DynamicImage, x: u32, y: u32, width: u32, height: u32) -> DynamicImage {
253    img.crop_imm(x, y, width, height)
254}
255
256/// Split image into blocks (for high precision mode)
257///
258/// # Parameters
259/// - `img`: Input image
260/// - `block_size`: Block size
261/// - `overlap`: Overlap region size
262///
263/// # Returns
264/// List of block images and their positions in original image (x, y)
265pub fn split_into_blocks(
266    img: &DynamicImage,
267    block_size: u32,
268    overlap: u32,
269) -> Vec<(DynamicImage, u32, u32)> {
270    let (width, height) = img.dimensions();
271    let mut blocks = Vec::new();
272
273    let step = block_size - overlap;
274
275    let mut y = 0u32;
276    while y < height {
277        let mut x = 0u32;
278        while x < width {
279            let block_w = (block_size).min(width - x);
280            let block_h = (block_size).min(height - y);
281
282            let block = img.crop_imm(x, y, block_w, block_h);
283            blocks.push((block, x, y));
284
285            x += step;
286            if x + overlap >= width && x < width {
287                break;
288            }
289        }
290
291        y += step;
292        if y + overlap >= height && y < height {
293            break;
294        }
295    }
296
297    blocks
298}
299
300/// Convert grayscale mask to binary mask
301pub fn threshold_mask(mask: &[f32], threshold: f32) -> Vec<u8> {
302    mask.iter()
303        .map(|&v| if v > threshold { 255u8 } else { 0u8 })
304        .collect()
305}
306
307/// Create grayscale image
308pub fn create_gray_image(data: &[u8], width: u32, height: u32) -> image::GrayImage {
309    image::GrayImage::from_raw(width, height, data.to_vec())
310        .unwrap_or_else(|| image::GrayImage::new(width, height))
311}
312
313/// Convert image to RGB
314pub fn to_rgb(img: &DynamicImage) -> RgbImage {
315    img.to_rgb8()
316}
317
318/// Create image from RGB data
319pub fn rgb_to_image(data: &[u8], width: u32, height: u32) -> DynamicImage {
320    let rgb = RgbImage::from_raw(width, height, data.to_vec())
321        .unwrap_or_else(|| RgbImage::new(width, height));
322    DynamicImage::ImageRgb8(rgb)
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn test_padded_size() {
331        assert_eq!(get_padded_size(100), 128);
332        assert_eq!(get_padded_size(32), 32);
333        assert_eq!(get_padded_size(33), 64);
334        assert_eq!(get_padded_size(0), 0);
335        assert_eq!(get_padded_size(1), 32);
336        assert_eq!(get_padded_size(31), 32);
337        assert_eq!(get_padded_size(64), 64);
338        assert_eq!(get_padded_size(65), 96);
339    }
340
341    #[test]
342    fn test_normalize_params() {
343        let params = NormalizeParams::default();
344        assert_eq!(params.mean[0], 0.485);
345
346        let paddle = NormalizeParams::paddle_det();
347        assert_eq!(paddle.mean[0], 0.485);
348        assert_eq!(paddle.std[0], 0.229);
349    }
350
351    #[test]
352    fn test_normalize_params_paddle_rec() {
353        let params = NormalizeParams::paddle_rec();
354        assert_eq!(params.mean[0], 0.5);
355        assert_eq!(params.mean[1], 0.5);
356        assert_eq!(params.mean[2], 0.5);
357        assert_eq!(params.std[0], 0.5);
358        assert_eq!(params.std[1], 0.5);
359        assert_eq!(params.std[2], 0.5);
360    }
361
362    #[test]
363    fn test_resize_to_max_side_no_resize() {
364        let img = DynamicImage::new_rgb8(100, 50);
365        let resized = resize_to_max_side(&img, 200).unwrap();
366
367        // 图像已经小于最大边,不应该缩放
368        assert_eq!(resized.width(), 100);
369        assert_eq!(resized.height(), 50);
370    }
371
372    #[test]
373    fn test_resize_to_max_side_width_limited() {
374        let img = DynamicImage::new_rgb8(1000, 500);
375        let resized = resize_to_max_side(&img, 500).unwrap();
376
377        // 宽度是最大边,应该缩放到 500
378        assert_eq!(resized.width(), 500);
379        assert_eq!(resized.height(), 250);
380    }
381
382    #[test]
383    fn test_resize_to_max_side_height_limited() {
384        let img = DynamicImage::new_rgb8(500, 1000);
385        let resized = resize_to_max_side(&img, 500).unwrap();
386
387        // 高度是最大边,应该缩放到 500
388        assert_eq!(resized.width(), 250);
389        assert_eq!(resized.height(), 500);
390    }
391
392    #[test]
393    fn test_resize_to_height() {
394        let img = DynamicImage::new_rgb8(200, 100);
395        let resized = resize_to_height(&img, 48).unwrap();
396
397        assert_eq!(resized.height(), 48);
398        // 宽度应该按比例缩放: 200 * 48/100 = 96
399        assert_eq!(resized.width(), 96);
400    }
401
402    #[test]
403    fn test_resize_to_height_no_resize() {
404        let img = DynamicImage::new_rgb8(200, 48);
405        let resized = resize_to_height(&img, 48).unwrap();
406
407        // 高度已经是目标高度,不应该缩放
408        assert_eq!(resized.height(), 48);
409        assert_eq!(resized.width(), 200);
410    }
411
412    #[test]
413    fn test_preprocess_for_det_shape() {
414        let img = DynamicImage::new_rgb8(100, 50);
415        let params = NormalizeParams::paddle_det();
416        let tensor = preprocess_for_det(&img, &params).unwrap();
417
418        // 输出形状应该是 [1, 3, H, W],H 和 W 是 32 的倍数
419        assert_eq!(tensor.shape()[0], 1);
420        assert_eq!(tensor.shape()[1], 3);
421        assert_eq!(tensor.shape()[2], 64); // 50 向上取整到 64
422        assert_eq!(tensor.shape()[3], 128); // 100 向上取整到 128
423    }
424
425    #[test]
426    fn test_preprocess_for_rec_shape() {
427        let img = DynamicImage::new_rgb8(200, 100);
428        let params = NormalizeParams::paddle_rec();
429        let tensor = preprocess_for_rec(&img, 48, &params).unwrap();
430
431        // 输出高度应该是 48
432        assert_eq!(tensor.shape()[0], 1);
433        assert_eq!(tensor.shape()[1], 3);
434        assert_eq!(tensor.shape()[2], 48);
435        // 宽度应该按比例缩放: 200 * 48/100 = 96
436        assert_eq!(tensor.shape()[3], 96);
437    }
438
439    #[test]
440    fn test_preprocess_batch_for_rec_empty() {
441        let images: Vec<DynamicImage> = vec![];
442        let params = NormalizeParams::paddle_rec();
443        let tensor = preprocess_batch_for_rec(&images, 48, &params).unwrap();
444
445        assert_eq!(tensor.shape()[0], 0);
446    }
447
448    #[test]
449    fn test_preprocess_batch_for_rec_single() {
450        let images = vec![DynamicImage::new_rgb8(200, 100)];
451        let params = NormalizeParams::paddle_rec();
452        let tensor = preprocess_batch_for_rec(&images, 48, &params).unwrap();
453
454        assert_eq!(tensor.shape()[0], 1);
455        assert_eq!(tensor.shape()[1], 3);
456        assert_eq!(tensor.shape()[2], 48);
457    }
458
459    #[test]
460    fn test_preprocess_batch_for_rec_multiple() {
461        let images = vec![
462            DynamicImage::new_rgb8(200, 100),
463            DynamicImage::new_rgb8(300, 100),
464        ];
465        let params = NormalizeParams::paddle_rec();
466        let tensor = preprocess_batch_for_rec(&images, 48, &params).unwrap();
467
468        assert_eq!(tensor.shape()[0], 2);
469        assert_eq!(tensor.shape()[1], 3);
470        assert_eq!(tensor.shape()[2], 48);
471        // 宽度应该是最大宽度: max(96, 144) = 144
472        assert_eq!(tensor.shape()[3], 144);
473    }
474
475    #[test]
476    fn test_crop_image() {
477        let img = DynamicImage::new_rgb8(200, 100);
478        let cropped = crop_image(&img, 50, 25, 100, 50);
479
480        assert_eq!(cropped.width(), 100);
481        assert_eq!(cropped.height(), 50);
482    }
483
484    #[test]
485    fn test_split_into_blocks() {
486        let img = DynamicImage::new_rgb8(500, 500);
487        let blocks = split_into_blocks(&img, 200, 50);
488
489        // 应该有多个块
490        assert!(!blocks.is_empty());
491
492        // 每个块的位置应该记录正确
493        for (block, x, y) in &blocks {
494            assert!(block.width() <= 200);
495            assert!(block.height() <= 200);
496            assert!(*x < 500);
497            assert!(*y < 500);
498        }
499    }
500
501    #[test]
502    fn test_split_into_blocks_small_image() {
503        let img = DynamicImage::new_rgb8(100, 100);
504        let blocks = split_into_blocks(&img, 200, 50);
505
506        // 图像小于块大小,应该只有一个块
507        assert_eq!(blocks.len(), 1);
508        assert_eq!(blocks[0].1, 0); // x offset
509        assert_eq!(blocks[0].2, 0); // y offset
510    }
511
512    #[test]
513    fn test_threshold_mask() {
514        let mask = vec![0.1, 0.3, 0.5, 0.7, 0.9];
515        let binary = threshold_mask(&mask, 0.5);
516
517        assert_eq!(binary, vec![0, 0, 0, 255, 255]);
518    }
519
520    #[test]
521    fn test_threshold_mask_all_below() {
522        let mask = vec![0.1, 0.2, 0.3, 0.4];
523        let binary = threshold_mask(&mask, 0.5);
524
525        assert_eq!(binary, vec![0, 0, 0, 0]);
526    }
527
528    #[test]
529    fn test_threshold_mask_all_above() {
530        let mask = vec![0.6, 0.7, 0.8, 0.9];
531        let binary = threshold_mask(&mask, 0.5);
532
533        assert_eq!(binary, vec![255, 255, 255, 255]);
534    }
535
536    #[test]
537    fn test_create_gray_image() {
538        let data = vec![128u8; 100];
539        let gray = create_gray_image(&data, 10, 10);
540
541        assert_eq!(gray.width(), 10);
542        assert_eq!(gray.height(), 10);
543    }
544
545    #[test]
546    fn test_to_rgb() {
547        let img = DynamicImage::new_rgb8(100, 50);
548        let rgb = to_rgb(&img);
549
550        assert_eq!(rgb.width(), 100);
551        assert_eq!(rgb.height(), 50);
552    }
553
554    #[test]
555    fn test_rgb_to_image() {
556        let data = vec![128u8; 300]; // 10x10 RGB
557        let img = rgb_to_image(&data, 10, 10);
558
559        assert_eq!(img.width(), 10);
560        assert_eq!(img.height(), 10);
561    }
562}