1use image::{DynamicImage, GenericImageView, RgbImage};
6use ndarray::{Array4, ArrayBase, Dim, OwnedRepr};
7
8use crate::error::{OcrError, OcrResult};
9
10#[derive(Debug, Clone)]
12pub struct NormalizeParams {
13 pub mean: [f32; 3],
15 pub std: [f32; 3],
17}
18
19impl Default for NormalizeParams {
20 fn default() -> Self {
21 Self {
23 mean: [0.485, 0.456, 0.406],
24 std: [0.229, 0.224, 0.225],
25 }
26 }
27}
28
29impl NormalizeParams {
30 pub fn paddle_det() -> Self {
32 Self {
33 mean: [0.485, 0.456, 0.406],
34 std: [0.229, 0.224, 0.225],
35 }
36 }
37
38 pub fn paddle_rec() -> Self {
40 Self {
41 mean: [0.5, 0.5, 0.5],
42 std: [0.5, 0.5, 0.5],
43 }
44 }
45}
46
47#[inline]
49pub fn get_padded_size(size: u32) -> u32 {
50 ((size + 31) / 32) * 32
51}
52
53pub fn resize_to_max_side(img: &DynamicImage, max_side_len: u32) -> OcrResult<DynamicImage> {
57 let (w, h) = img.dimensions();
58 let max_dim = w.max(h);
59
60 if max_dim <= max_side_len {
61 return Ok(img.clone());
62 }
63
64 let scale = max_side_len as f64 / max_dim as f64;
65 let new_w = (w as f64 * scale).round() as u32;
66 let new_h = (h as f64 * scale).round() as u32;
67
68 fast_resize(img, new_w, new_h)
69}
70
71pub fn resize_to_height(img: &DynamicImage, target_height: u32) -> OcrResult<DynamicImage> {
75 let (w, h) = img.dimensions();
76
77 if h == target_height {
78 return Ok(img.clone());
79 }
80
81 let scale = target_height as f64 / h as f64;
82 let new_w = (w as f64 * scale).round() as u32;
83
84 fast_resize(img, new_w, target_height)
85}
86
87fn fast_resize(img: &DynamicImage, new_w: u32, new_h: u32) -> OcrResult<DynamicImage> {
90 use fast_image_resize::{images::Image, IntoImageView, PixelType, Resizer};
91
92 let converted: DynamicImage;
96 let (src, pixel_type) = match img.pixel_type() {
97 Some(PixelType::U8x3) => (img, PixelType::U8x3),
98 Some(PixelType::U8x4) => (img, PixelType::U8x4),
99 _ => {
100 converted = DynamicImage::ImageRgb8(img.to_rgb8());
101 (&converted, PixelType::U8x3)
102 }
103 };
104
105 let mut dst_image = Image::new(new_w, new_h, pixel_type);
107
108 let mut resizer = Resizer::new();
110 resizer
111 .resize(src, &mut dst_image, None)
112 .map_err(|e| OcrError::PreprocessError(format!("Image resize failed: {e}")))?;
113
114 match pixel_type {
116 PixelType::U8x3 => RgbImage::from_raw(new_w, new_h, dst_image.into_vec())
117 .map(DynamicImage::ImageRgb8)
118 .ok_or_else(|| {
119 OcrError::PreprocessError("RGB buffer size mismatch after resize".into())
120 }),
121 PixelType::U8x4 => image::RgbaImage::from_raw(new_w, new_h, dst_image.into_vec())
122 .map(DynamicImage::ImageRgba8)
123 .ok_or_else(|| {
124 OcrError::PreprocessError("RGBA buffer size mismatch after resize".into())
125 }),
126 _ => unreachable!("pixel_type is constrained to U8x3 or U8x4 above"),
127 }
128}
129
130pub fn preprocess_for_det(
134 img: &DynamicImage,
135 params: &NormalizeParams,
136) -> OcrResult<ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>>> {
137 let (w, h) = img.dimensions();
138 let pad_w = get_padded_size(w) as usize;
139 let pad_h = get_padded_size(h) as usize;
140
141 let mut input = Array4::<f32>::zeros((1, 3, pad_h, pad_w));
142 let rgb_img = img.to_rgb8();
143
144 for y in 0..h as usize {
146 for x in 0..w as usize {
147 let pixel = rgb_img.get_pixel(x as u32, y as u32);
148 let [r, g, b] = pixel.0;
149
150 input[[0, 0, y, x]] = (r as f32 / 255.0 - params.mean[0]) / params.std[0];
151 input[[0, 1, y, x]] = (g as f32 / 255.0 - params.mean[1]) / params.std[1];
152 input[[0, 2, y, x]] = (b as f32 / 255.0 - params.mean[2]) / params.std[2];
153 }
154 }
155
156 Ok(input)
157}
158
159pub fn preprocess_for_rec(
164 img: &DynamicImage,
165 target_height: u32,
166 params: &NormalizeParams,
167) -> OcrResult<ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>>> {
168 let (w, h) = img.dimensions();
169
170 let scale = target_height as f64 / h as f64;
172 let target_width = (w as f64 * scale).round() as u32;
173
174 let resized = if h != target_height {
176 img.resize_exact(
177 target_width,
178 target_height,
179 image::imageops::FilterType::Lanczos3,
180 )
181 } else {
182 img.clone()
183 };
184
185 let rgb_img = resized.to_rgb8();
186 let (w, h) = (target_width as usize, target_height as usize);
187
188 let mut input = Array4::<f32>::zeros((1, 3, h, w));
189
190 for y in 0..h {
191 for x in 0..w {
192 let pixel = rgb_img.get_pixel(x as u32, y as u32);
193 let [r, g, b] = pixel.0;
194
195 input[[0, 0, y, x]] = (r as f32 / 255.0 - params.mean[0]) / params.std[0];
196 input[[0, 1, y, x]] = (g as f32 / 255.0 - params.mean[1]) / params.std[1];
197 input[[0, 2, y, x]] = (b as f32 / 255.0 - params.mean[2]) / params.std[2];
198 }
199 }
200
201 Ok(input)
202}
203
204pub fn preprocess_batch_for_rec(
208 images: &[DynamicImage],
209 target_height: u32,
210 params: &NormalizeParams,
211) -> OcrResult<ArrayBase<OwnedRepr<f32>, Dim<[usize; 4]>>> {
212 if images.is_empty() {
213 return Ok(Array4::<f32>::zeros((0, 3, target_height as usize, 0)));
214 }
215
216 let widths: Vec<u32> = images
218 .iter()
219 .map(|img| {
220 let (w, h) = img.dimensions();
221 let scale = target_height as f64 / h as f64;
222 (w as f64 * scale).round() as u32
223 })
224 .collect();
225
226 let max_width = *widths.iter().max().unwrap() as usize;
228 let batch_size = images.len();
229
230 let mut batch = Array4::<f32>::zeros((batch_size, 3, target_height as usize, max_width));
231
232 for (i, (img, &w)) in images.iter().zip(widths.iter()).enumerate() {
233 let resized = resize_to_height(img, target_height)?;
234 let rgb_img = resized.to_rgb8();
235
236 for y in 0..target_height as usize {
237 for x in 0..w as usize {
238 let pixel = rgb_img.get_pixel(x as u32, y as u32);
239 let [r, g, b] = pixel.0;
240
241 batch[[i, 0, y, x]] = (r as f32 / 255.0 - params.mean[0]) / params.std[0];
242 batch[[i, 1, y, x]] = (g as f32 / 255.0 - params.mean[1]) / params.std[1];
243 batch[[i, 2, y, x]] = (b as f32 / 255.0 - params.mean[2]) / params.std[2];
244 }
245 }
246 }
247
248 Ok(batch)
249}
250
251pub fn crop_image(img: &DynamicImage, x: u32, y: u32, width: u32, height: u32) -> DynamicImage {
253 img.crop_imm(x, y, width, height)
254}
255
256pub fn split_into_blocks(
266 img: &DynamicImage,
267 block_size: u32,
268 overlap: u32,
269) -> Vec<(DynamicImage, u32, u32)> {
270 let (width, height) = img.dimensions();
271 let mut blocks = Vec::new();
272
273 let step = block_size - overlap;
274
275 let mut y = 0u32;
276 while y < height {
277 let mut x = 0u32;
278 while x < width {
279 let block_w = (block_size).min(width - x);
280 let block_h = (block_size).min(height - y);
281
282 let block = img.crop_imm(x, y, block_w, block_h);
283 blocks.push((block, x, y));
284
285 x += step;
286 if x + overlap >= width && x < width {
287 break;
288 }
289 }
290
291 y += step;
292 if y + overlap >= height && y < height {
293 break;
294 }
295 }
296
297 blocks
298}
299
300pub fn threshold_mask(mask: &[f32], threshold: f32) -> Vec<u8> {
302 mask.iter()
303 .map(|&v| if v > threshold { 255u8 } else { 0u8 })
304 .collect()
305}
306
307pub fn create_gray_image(data: &[u8], width: u32, height: u32) -> image::GrayImage {
309 image::GrayImage::from_raw(width, height, data.to_vec())
310 .unwrap_or_else(|| image::GrayImage::new(width, height))
311}
312
313pub fn to_rgb(img: &DynamicImage) -> RgbImage {
315 img.to_rgb8()
316}
317
318pub fn rgb_to_image(data: &[u8], width: u32, height: u32) -> DynamicImage {
320 let rgb = RgbImage::from_raw(width, height, data.to_vec())
321 .unwrap_or_else(|| RgbImage::new(width, height));
322 DynamicImage::ImageRgb8(rgb)
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_padded_size() {
331 assert_eq!(get_padded_size(100), 128);
332 assert_eq!(get_padded_size(32), 32);
333 assert_eq!(get_padded_size(33), 64);
334 assert_eq!(get_padded_size(0), 0);
335 assert_eq!(get_padded_size(1), 32);
336 assert_eq!(get_padded_size(31), 32);
337 assert_eq!(get_padded_size(64), 64);
338 assert_eq!(get_padded_size(65), 96);
339 }
340
341 #[test]
342 fn test_normalize_params() {
343 let params = NormalizeParams::default();
344 assert_eq!(params.mean[0], 0.485);
345
346 let paddle = NormalizeParams::paddle_det();
347 assert_eq!(paddle.mean[0], 0.485);
348 assert_eq!(paddle.std[0], 0.229);
349 }
350
351 #[test]
352 fn test_normalize_params_paddle_rec() {
353 let params = NormalizeParams::paddle_rec();
354 assert_eq!(params.mean[0], 0.5);
355 assert_eq!(params.mean[1], 0.5);
356 assert_eq!(params.mean[2], 0.5);
357 assert_eq!(params.std[0], 0.5);
358 assert_eq!(params.std[1], 0.5);
359 assert_eq!(params.std[2], 0.5);
360 }
361
362 #[test]
363 fn test_resize_to_max_side_no_resize() {
364 let img = DynamicImage::new_rgb8(100, 50);
365 let resized = resize_to_max_side(&img, 200).unwrap();
366
367 assert_eq!(resized.width(), 100);
369 assert_eq!(resized.height(), 50);
370 }
371
372 #[test]
373 fn test_resize_to_max_side_width_limited() {
374 let img = DynamicImage::new_rgb8(1000, 500);
375 let resized = resize_to_max_side(&img, 500).unwrap();
376
377 assert_eq!(resized.width(), 500);
379 assert_eq!(resized.height(), 250);
380 }
381
382 #[test]
383 fn test_resize_to_max_side_height_limited() {
384 let img = DynamicImage::new_rgb8(500, 1000);
385 let resized = resize_to_max_side(&img, 500).unwrap();
386
387 assert_eq!(resized.width(), 250);
389 assert_eq!(resized.height(), 500);
390 }
391
392 #[test]
393 fn test_resize_to_height() {
394 let img = DynamicImage::new_rgb8(200, 100);
395 let resized = resize_to_height(&img, 48).unwrap();
396
397 assert_eq!(resized.height(), 48);
398 assert_eq!(resized.width(), 96);
400 }
401
402 #[test]
403 fn test_resize_to_height_no_resize() {
404 let img = DynamicImage::new_rgb8(200, 48);
405 let resized = resize_to_height(&img, 48).unwrap();
406
407 assert_eq!(resized.height(), 48);
409 assert_eq!(resized.width(), 200);
410 }
411
412 #[test]
413 fn test_preprocess_for_det_shape() {
414 let img = DynamicImage::new_rgb8(100, 50);
415 let params = NormalizeParams::paddle_det();
416 let tensor = preprocess_for_det(&img, ¶ms).unwrap();
417
418 assert_eq!(tensor.shape()[0], 1);
420 assert_eq!(tensor.shape()[1], 3);
421 assert_eq!(tensor.shape()[2], 64); assert_eq!(tensor.shape()[3], 128); }
424
425 #[test]
426 fn test_preprocess_for_rec_shape() {
427 let img = DynamicImage::new_rgb8(200, 100);
428 let params = NormalizeParams::paddle_rec();
429 let tensor = preprocess_for_rec(&img, 48, ¶ms).unwrap();
430
431 assert_eq!(tensor.shape()[0], 1);
433 assert_eq!(tensor.shape()[1], 3);
434 assert_eq!(tensor.shape()[2], 48);
435 assert_eq!(tensor.shape()[3], 96);
437 }
438
439 #[test]
440 fn test_preprocess_batch_for_rec_empty() {
441 let images: Vec<DynamicImage> = vec![];
442 let params = NormalizeParams::paddle_rec();
443 let tensor = preprocess_batch_for_rec(&images, 48, ¶ms).unwrap();
444
445 assert_eq!(tensor.shape()[0], 0);
446 }
447
448 #[test]
449 fn test_preprocess_batch_for_rec_single() {
450 let images = vec![DynamicImage::new_rgb8(200, 100)];
451 let params = NormalizeParams::paddle_rec();
452 let tensor = preprocess_batch_for_rec(&images, 48, ¶ms).unwrap();
453
454 assert_eq!(tensor.shape()[0], 1);
455 assert_eq!(tensor.shape()[1], 3);
456 assert_eq!(tensor.shape()[2], 48);
457 }
458
459 #[test]
460 fn test_preprocess_batch_for_rec_multiple() {
461 let images = vec![
462 DynamicImage::new_rgb8(200, 100),
463 DynamicImage::new_rgb8(300, 100),
464 ];
465 let params = NormalizeParams::paddle_rec();
466 let tensor = preprocess_batch_for_rec(&images, 48, ¶ms).unwrap();
467
468 assert_eq!(tensor.shape()[0], 2);
469 assert_eq!(tensor.shape()[1], 3);
470 assert_eq!(tensor.shape()[2], 48);
471 assert_eq!(tensor.shape()[3], 144);
473 }
474
475 #[test]
476 fn test_crop_image() {
477 let img = DynamicImage::new_rgb8(200, 100);
478 let cropped = crop_image(&img, 50, 25, 100, 50);
479
480 assert_eq!(cropped.width(), 100);
481 assert_eq!(cropped.height(), 50);
482 }
483
484 #[test]
485 fn test_split_into_blocks() {
486 let img = DynamicImage::new_rgb8(500, 500);
487 let blocks = split_into_blocks(&img, 200, 50);
488
489 assert!(!blocks.is_empty());
491
492 for (block, x, y) in &blocks {
494 assert!(block.width() <= 200);
495 assert!(block.height() <= 200);
496 assert!(*x < 500);
497 assert!(*y < 500);
498 }
499 }
500
501 #[test]
502 fn test_split_into_blocks_small_image() {
503 let img = DynamicImage::new_rgb8(100, 100);
504 let blocks = split_into_blocks(&img, 200, 50);
505
506 assert_eq!(blocks.len(), 1);
508 assert_eq!(blocks[0].1, 0); assert_eq!(blocks[0].2, 0); }
511
512 #[test]
513 fn test_threshold_mask() {
514 let mask = vec![0.1, 0.3, 0.5, 0.7, 0.9];
515 let binary = threshold_mask(&mask, 0.5);
516
517 assert_eq!(binary, vec![0, 0, 0, 255, 255]);
518 }
519
520 #[test]
521 fn test_threshold_mask_all_below() {
522 let mask = vec![0.1, 0.2, 0.3, 0.4];
523 let binary = threshold_mask(&mask, 0.5);
524
525 assert_eq!(binary, vec![0, 0, 0, 0]);
526 }
527
528 #[test]
529 fn test_threshold_mask_all_above() {
530 let mask = vec![0.6, 0.7, 0.8, 0.9];
531 let binary = threshold_mask(&mask, 0.5);
532
533 assert_eq!(binary, vec![255, 255, 255, 255]);
534 }
535
536 #[test]
537 fn test_create_gray_image() {
538 let data = vec![128u8; 100];
539 let gray = create_gray_image(&data, 10, 10);
540
541 assert_eq!(gray.width(), 10);
542 assert_eq!(gray.height(), 10);
543 }
544
545 #[test]
546 fn test_to_rgb() {
547 let img = DynamicImage::new_rgb8(100, 50);
548 let rgb = to_rgb(&img);
549
550 assert_eq!(rgb.width(), 100);
551 assert_eq!(rgb.height(), 50);
552 }
553
554 #[test]
555 fn test_rgb_to_image() {
556 let data = vec![128u8; 300]; let img = rgb_to_image(&data, 10, 10);
558
559 assert_eq!(img.width(), 10);
560 assert_eq!(img.height(), 10);
561 }
562}