Skip to main content

captcha_engine/
image_ops.rs

1//! Image preprocessing for the captcha model.
2
3use image::{DynamicImage, imageops::FilterType};
4use ndarray::Array4;
5use std::path::Path;
6
7/// Image dimensions for the custom finetuned model.
8/// Model expects: `[1, 3, 80, 215]` (NCHW format)
9pub const IMG_HEIGHT: u32 = 80;
10pub const IMG_WIDTH: u32 = 215;
11
12/// `ImageNet` normalization parameters (used during training)
13const MEAN: [f32; 3] = [0.485, 0.456, 0.406];
14const STD: [f32; 3] = [0.229, 0.224, 0.225];
15
16/// Preprocess an image for the captcha model.
17///
18/// Model expects: `[1, 3, 80, 215]` (NCHW format)
19#[must_use]
20pub fn preprocess(img: &DynamicImage) -> Array4<f32> {
21    // Resize to model input dimensions
22    let img = img.resize_exact(IMG_WIDTH, IMG_HEIGHT, FilterType::Triangle);
23
24    // Convert to RGB
25    let img = img.to_rgb8();
26
27    // Create tensor in NCHW format: [1, 3, height, width]
28    let mut tensor = Array4::<f32>::zeros((1, 3, IMG_HEIGHT as usize, IMG_WIDTH as usize));
29
30    // Fill tensor with normalized pixel values (ImageNet normalization)
31    for y in 0..IMG_HEIGHT {
32        for x in 0..IMG_WIDTH {
33            let pixel = img.get_pixel(x, y);
34            for c in 0..3 {
35                let pixel_value = f32::from(pixel[c]) / 255.0;
36                // ImageNet normalization: (value - mean) / std
37                let normalized = (pixel_value - MEAN[c]) / STD[c];
38                tensor[[0, c, y as usize, x as usize]] = normalized;
39            }
40        }
41    }
42
43    tensor
44}
45
46/// Load and preprocess an image from a file path.
47///
48/// # Errors
49///
50/// Returns an error if the image cannot be loaded.
51pub fn preprocess_file<P: AsRef<Path>>(path: P) -> crate::Result<Array4<f32>> {
52    let img = image::open(path)?;
53    Ok(preprocess(&img))
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59    use image::RgbImage;
60
61    #[test]
62    fn test_preprocess_output_shape() {
63        // Create a dummy image - 100x50
64        let img = DynamicImage::ImageRgb8(RgbImage::new(100, 50));
65        let tensor = preprocess(&img);
66
67        // Check shape: [1, 3, 80, 215]
68        assert_eq!(tensor.shape(), &[1, 3, 80, 215]);
69    }
70
71    #[test]
72    fn test_preprocess_normalization() {
73        // Create an all-white image (255, 255, 255)
74        let mut img = RgbImage::new(10, 10);
75        for pixel in img.pixels_mut() {
76            *pixel = image::Rgb([255, 255, 255]);
77        }
78        let img = DynamicImage::ImageRgb8(img);
79        let tensor = preprocess(&img);
80
81        // Expected value for 1.0 (255/255)
82        // Red: (1.0 - 0.485) / 0.229 ≈ 2.2489
83        // Green: (1.0 - 0.456) / 0.224 ≈ 2.4286
84        // Blue: (1.0 - 0.406) / 0.225 ≈ 2.6400
85
86        let first_pixel_r = tensor[[0, 0, 0, 0]];
87        let first_pixel_g = tensor[[0, 1, 0, 0]];
88        let first_pixel_b = tensor[[0, 2, 0, 0]];
89
90        assert!((first_pixel_r - 2.2489).abs() < 0.001);
91        assert!((first_pixel_g - 2.4286).abs() < 0.001);
92        assert!((first_pixel_b - 2.6400).abs() < 0.001);
93    }
94}