Skip to main content

oxicuda_vision/augment/
normalize.rs

1//! Channel-wise normalisation for CHW image tensors.
2//!
3//! Provides the standard `(x - mean) / std` transformation used before
4//! feeding images to neural networks, applied independently per channel.
5
6use crate::error::{VisionError, VisionResult};
7
8// ─── Constants ───────────────────────────────────────────────────────────────
9
10/// ImageNet per-channel mean (RGB order, values pre-scaled to [0, 1]).
11pub const IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
12
13/// ImageNet per-channel standard deviation (RGB order).
14pub const IMAGENET_STD: [f32; 3] = [0.229, 0.224, 0.225];
15
16// ─── normalize_chw ───────────────────────────────────────────────────────────
17
18/// Normalize a CHW image channel-wise: `output[c, h, w] = (input[c, h, w] - mean[c]) / std[c]`.
19///
20/// # Parameters
21/// - `img`: flat `[channels × h × w]` input buffer.
22/// - `channels`: number of channels; must equal `mean.len()` and `std.len()`.
23/// - `h`: image height in pixels.
24/// - `w`: image width in pixels.
25/// - `mean`: per-channel mean values; length must equal `channels`.
26/// - `std`: per-channel standard deviation values; length must equal `channels`.
27///   Each element must be positive.
28///
29/// # Errors
30/// Returns [`VisionError::InvalidImageSize`] if any dimension is zero.
31/// Returns [`VisionError::DimensionMismatch`] if `img.len() != channels * h * w`.
32/// Returns [`VisionError::ShapeMismatch`] if `mean.len() != channels` or `std.len() != channels`.
33/// Returns [`VisionError::NonFinite`] if any `std[c] <= 0` (would produce NaN/Inf).
34pub fn normalize_chw(
35    img: &[f32],
36    channels: usize,
37    h: usize,
38    w: usize,
39    mean: &[f32],
40    std: &[f32],
41) -> VisionResult<Vec<f32>> {
42    if channels == 0 || h == 0 || w == 0 {
43        return Err(VisionError::InvalidImageSize {
44            height: h,
45            width: w,
46            channels,
47        });
48    }
49    let expected_len = channels * h * w;
50    if img.len() != expected_len {
51        return Err(VisionError::DimensionMismatch {
52            expected: expected_len,
53            got: img.len(),
54        });
55    }
56    if mean.len() != channels {
57        return Err(VisionError::ShapeMismatch {
58            lhs: vec![channels],
59            rhs: vec![mean.len()],
60        });
61    }
62    if std.len() != channels {
63        return Err(VisionError::ShapeMismatch {
64            lhs: vec![channels],
65            rhs: vec![std.len()],
66        });
67    }
68    // Validate std values before proceeding (avoid silently producing NaN/Inf).
69    for (c, &s) in std.iter().enumerate() {
70        if s <= 0.0 || !s.is_finite() {
71            return Err(VisionError::NonFinite(
72                // We use a single static string; the channel index is
73                // implicit (detailed validation error).
74                if c == 0 {
75                    "std[0] non-positive"
76                } else if c == 1 {
77                    "std[1] non-positive"
78                } else if c == 2 {
79                    "std[2] non-positive"
80                } else {
81                    "std[c] non-positive"
82                },
83            ));
84        }
85    }
86
87    let hw = h * w;
88    let mut out = vec![0.0f32; expected_len];
89
90    for c in 0..channels {
91        let m = mean[c];
92        let s = std[c];
93        let inv_s = 1.0 / s;
94        let base = c * hw;
95        for i in 0..hw {
96            out[base + i] = (img[base + i] - m) * inv_s;
97        }
98    }
99
100    Ok(out)
101}
102
103// ─── Tests ───────────────────────────────────────────────────────────────────
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    /// Build a simple single-channel, 4-pixel image: values [1, 2, 3, 4].
110    fn make_single_channel_img() -> (Vec<f32>, usize, usize, usize) {
111        let img = vec![1.0f32, 2.0, 3.0, 4.0];
112        (img, 1, 2, 2) // (data, channels, h, w)
113    }
114
115    #[test]
116    fn normalized_mean_approx_zero() {
117        // For a channel with values [1,2,3,4], mean=2.5, std=1.118...
118        // After normalization, sample mean ≈ 0.
119        let (img, channels, h, w) = make_single_channel_img();
120        let sample_mean = img.iter().sum::<f32>() / img.len() as f32; // 2.5
121        let variance =
122            img.iter().map(|&v| (v - sample_mean).powi(2)).sum::<f32>() / img.len() as f32;
123        let sample_std = variance.sqrt(); // ~1.118
124
125        let out = normalize_chw(&img, channels, h, w, &[sample_mean], &[sample_std])
126            .expect("normalize_chw ok");
127
128        let out_mean = out.iter().sum::<f32>() / out.len() as f32;
129        assert!(
130            out_mean.abs() < 1e-5,
131            "expected near-zero mean after normalization, got {out_mean}"
132        );
133    }
134
135    #[test]
136    fn normalized_std_approx_one() {
137        let (img, channels, h, w) = make_single_channel_img();
138        let sample_mean = img.iter().sum::<f32>() / img.len() as f32;
139        let variance =
140            img.iter().map(|&v| (v - sample_mean).powi(2)).sum::<f32>() / img.len() as f32;
141        let sample_std = variance.sqrt();
142
143        let out = normalize_chw(&img, channels, h, w, &[sample_mean], &[sample_std])
144            .expect("normalize_chw ok");
145
146        let out_mean = out.iter().sum::<f32>() / out.len() as f32;
147        let out_var = out.iter().map(|&v| (v - out_mean).powi(2)).sum::<f32>() / out.len() as f32;
148        let out_std = out_var.sqrt();
149        assert!(
150            (out_std - 1.0).abs() < 1e-5,
151            "expected std ≈ 1.0 after normalization, got {out_std}"
152        );
153    }
154
155    #[test]
156    fn multi_channel_normalization_per_channel() {
157        // 3 channels, 1×1 spatial (trivial sizes to verify math).
158        // Channel c contains value (c as f32 + 1) * 10.
159        let img = vec![10.0f32, 20.0, 30.0]; // 3 × 1 × 1
160        let mean = [5.0f32, 15.0, 25.0];
161        let std = [2.5f32, 2.5, 2.5];
162
163        let out = normalize_chw(&img, 3, 1, 1, &mean, &std).expect("ok");
164        // channel 0: (10 - 5) / 2.5 = 2.0
165        // channel 1: (20 - 15) / 2.5 = 2.0
166        // channel 2: (30 - 25) / 2.5 = 2.0
167        assert!((out[0] - 2.0).abs() < 1e-6, "c0: {}", out[0]);
168        assert!((out[1] - 2.0).abs() < 1e-6, "c1: {}", out[1]);
169        assert!((out[2] - 2.0).abs() < 1e-6, "c2: {}", out[2]);
170    }
171
172    #[test]
173    fn error_on_zero_height() {
174        let img = vec![1.0f32; 3];
175        let r = normalize_chw(&img, 3, 0, 1, &[0.0, 0.0, 0.0], &[1.0, 1.0, 1.0]);
176        assert!(
177            matches!(r, Err(VisionError::InvalidImageSize { .. })),
178            "expected InvalidImageSize, got {:?}",
179            r
180        );
181    }
182
183    #[test]
184    fn error_on_zero_channels() {
185        let img: Vec<f32> = vec![];
186        let r = normalize_chw(&img, 0, 4, 4, &[], &[]);
187        assert!(matches!(r, Err(VisionError::InvalidImageSize { .. })));
188    }
189
190    #[test]
191    fn error_on_wrong_image_length() {
192        let img = vec![1.0f32; 10]; // should be 3*4*4=48
193        let r = normalize_chw(&img, 3, 4, 4, &[0.0, 0.0, 0.0], &[1.0, 1.0, 1.0]);
194        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
195    }
196
197    #[test]
198    fn error_on_mean_length_mismatch() {
199        let img = vec![0.0f32; 3 * 2 * 2];
200        let r = normalize_chw(&img, 3, 2, 2, &[0.0, 0.0], &[1.0, 1.0, 1.0]);
201        assert!(matches!(r, Err(VisionError::ShapeMismatch { .. })));
202    }
203
204    #[test]
205    fn error_on_std_length_mismatch() {
206        let img = vec![0.0f32; 3 * 2 * 2];
207        let r = normalize_chw(&img, 3, 2, 2, &[0.0, 0.0, 0.0], &[1.0, 1.0]);
208        assert!(matches!(r, Err(VisionError::ShapeMismatch { .. })));
209    }
210
211    #[test]
212    fn error_on_nonpositive_std() {
213        let img = vec![1.0f32; 2 * 2];
214        let r = normalize_chw(&img, 1, 2, 2, &[0.5], &[0.0]);
215        assert!(matches!(r, Err(VisionError::NonFinite(_))));
216    }
217
218    #[test]
219    fn error_on_negative_std() {
220        let img = vec![1.0f32; 2 * 2];
221        let r = normalize_chw(&img, 1, 2, 2, &[0.5], &[-0.5]);
222        assert!(matches!(r, Err(VisionError::NonFinite(_))));
223    }
224
225    #[test]
226    fn imagenet_constants_valid() {
227        assert_eq!(IMAGENET_MEAN.len(), 3);
228        assert_eq!(IMAGENET_STD.len(), 3);
229        // All positive
230        assert!(IMAGENET_STD.iter().all(|&v| v > 0.0));
231        // Mean in [0, 1]
232        assert!(IMAGENET_MEAN.iter().all(|&v| (0.0..=1.0).contains(&v)));
233    }
234
235    #[test]
236    fn imagenet_normalization_output_finite() {
237        // Typical ImageNet input after /255 rescaling.
238        let img: Vec<f32> = (0..3 * 224 * 224)
239            .map(|i| ((i % 256) as f32) / 255.0)
240            .collect();
241        let out = normalize_chw(&img, 3, 224, 224, &IMAGENET_MEAN, &IMAGENET_STD)
242            .expect("imagenet normalize ok");
243        assert!(
244            out.iter().all(|v| v.is_finite()),
245            "non-finite after imagenet normalize"
246        );
247    }
248}