Skip to main content

captcha_engine/
image_ops.rs

1//! Image preprocessing for the captcha model.
2
3use image::{DynamicImage, imageops::FilterType};
4use rten_tensor::Tensor;
5use std::path::Path;
6
7/// Image dimensions for the custom finetuned model.
8/// Model expects: `[1, 3, 80, 215]` (NCHW format)
9pub const IMG_HEIGHT: u32 = 80;
10pub const IMG_WIDTH: u32 = 215;
11
12/// `ImageNet` normalization parameters (used during training)
13const MEAN: [f32; 3] = [0.485, 0.456, 0.406];
14const STD: [f32; 3] = [0.229, 0.224, 0.225];
15
16/// Preprocess an image for the captcha model.
17///
18/// Model expects: `[1, 3, 80, 215]` (NCHW format)
19#[must_use]
20pub fn preprocess(img: &DynamicImage) -> Tensor<f32> {
21    // Resize to model input dimensions
22    let img = img.resize_exact(IMG_WIDTH, IMG_HEIGHT, FilterType::Triangle);
23
24    // Convert to RGB
25    let img = img.to_rgb8();
26
27    // Create tensor data in NCHW format: [1, 3, height, width]
28    let mut data = Vec::with_capacity((3 * IMG_HEIGHT * IMG_WIDTH) as usize);
29
30    // Iterating by channel first to be cache friendly?
31    // No, standard image iteration is y,x then c.
32    // NCHW means we need all Red, then all Green, then all Blue.
33
34    // We can pre-calculate normalized values or better yet, construct plane by plane.
35    for c in 0..3 {
36        for y in 0..IMG_HEIGHT {
37            for x in 0..IMG_WIDTH {
38                let pixel = img.get_pixel(x, y);
39                let pixel_value = f32::from(pixel[c]) / 255.0;
40                // ImageNet normalization: (value - mean) / std
41                let normalized = (pixel_value - MEAN[c]) / STD[c];
42                data.push(normalized);
43            }
44        }
45    }
46
47    Tensor::from_data(&[1, 3, IMG_HEIGHT as usize, IMG_WIDTH as usize], data)
48}
49
50/// Load and preprocess an image from a file path.
51///
52/// # Errors
53///
54/// Returns an error if the image cannot be loaded.
55pub fn preprocess_file<P: AsRef<Path>>(path: P) -> crate::Result<Tensor<f32>> {
56    let img = image::open(path)?;
57    Ok(preprocess(&img))
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63    use image::RgbImage;
64    use rten_tensor::prelude::*;
65
66    #[test]
67    fn test_preprocess_output_shape() {
68        // Create a dummy image - 100x50
69        let img = DynamicImage::ImageRgb8(RgbImage::new(100, 50));
70        let tensor = preprocess(&img);
71
72        // Check shape: [1, 3, 80, 215]
73        assert_eq!(tensor.shape(), &[1, 3, 80, 215]);
74    }
75
76    #[test]
77    fn test_preprocess_normalization() {
78        // Create an all-white image (255, 255, 255)
79        let mut img = RgbImage::new(10, 10);
80        for pixel in img.pixels_mut() {
81            *pixel = image::Rgb([255, 255, 255]);
82        }
83        let img = DynamicImage::ImageRgb8(img);
84        let tensor = preprocess(&img);
85
86        // Expected value for 1.0 (255/255)
87        // Red: (1.0 - 0.485) / 0.229 ≈ 2.2489
88        // Green: (1.0 - 0.456) / 0.224 ≈ 2.4286
89        // Blue: (1.0 - 0.406) / 0.225 ≈ 2.6400
90
91        // Index manually into the flattened data or use NCHW logic
92        // Tensor is [1, 3, 80, 215]
93        // data layout is [C0... C1... C2...]
94        // so first pixel of C0 is at index 0
95        // first pixel of C1 is at index 80*215
96        // first pixel of C2 is at index 2*80*215
97
98        // rten Tensor allows indexing? It implements Index but maybe by slice.
99        // Let's verify by iterating or just checking specific elements if possible.
100        // We can just get `data()` slice.
101
102        let data = tensor.data().expect("should be contiguous");
103        let stride = (IMG_HEIGHT * IMG_WIDTH) as usize;
104
105        let first_pixel_r = data[0];
106        let first_pixel_g = data[stride];
107        let first_pixel_b = data[stride * 2];
108
109        assert!((first_pixel_r - 2.2489).abs() < 0.001);
110        assert!((first_pixel_g - 2.4286).abs() < 0.001);
111        assert!((first_pixel_b - 2.6400).abs() < 0.001);
112    }
113}