1use crate::error::{VisionError, VisionResult};
7
8pub const IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
12
13pub const IMAGENET_STD: [f32; 3] = [0.229, 0.224, 0.225];
15
16pub fn normalize_chw(
35 img: &[f32],
36 channels: usize,
37 h: usize,
38 w: usize,
39 mean: &[f32],
40 std: &[f32],
41) -> VisionResult<Vec<f32>> {
42 if channels == 0 || h == 0 || w == 0 {
43 return Err(VisionError::InvalidImageSize {
44 height: h,
45 width: w,
46 channels,
47 });
48 }
49 let expected_len = channels * h * w;
50 if img.len() != expected_len {
51 return Err(VisionError::DimensionMismatch {
52 expected: expected_len,
53 got: img.len(),
54 });
55 }
56 if mean.len() != channels {
57 return Err(VisionError::ShapeMismatch {
58 lhs: vec![channels],
59 rhs: vec![mean.len()],
60 });
61 }
62 if std.len() != channels {
63 return Err(VisionError::ShapeMismatch {
64 lhs: vec![channels],
65 rhs: vec![std.len()],
66 });
67 }
68 for (c, &s) in std.iter().enumerate() {
70 if s <= 0.0 || !s.is_finite() {
71 return Err(VisionError::NonFinite(
72 if c == 0 {
75 "std[0] non-positive"
76 } else if c == 1 {
77 "std[1] non-positive"
78 } else if c == 2 {
79 "std[2] non-positive"
80 } else {
81 "std[c] non-positive"
82 },
83 ));
84 }
85 }
86
87 let hw = h * w;
88 let mut out = vec![0.0f32; expected_len];
89
90 for c in 0..channels {
91 let m = mean[c];
92 let s = std[c];
93 let inv_s = 1.0 / s;
94 let base = c * hw;
95 for i in 0..hw {
96 out[base + i] = (img[base + i] - m) * inv_s;
97 }
98 }
99
100 Ok(out)
101}
102
103#[cfg(test)]
106mod tests {
107 use super::*;
108
109 fn make_single_channel_img() -> (Vec<f32>, usize, usize, usize) {
111 let img = vec![1.0f32, 2.0, 3.0, 4.0];
112 (img, 1, 2, 2) }
114
115 #[test]
116 fn normalized_mean_approx_zero() {
117 let (img, channels, h, w) = make_single_channel_img();
120 let sample_mean = img.iter().sum::<f32>() / img.len() as f32; let variance =
122 img.iter().map(|&v| (v - sample_mean).powi(2)).sum::<f32>() / img.len() as f32;
123 let sample_std = variance.sqrt(); let out = normalize_chw(&img, channels, h, w, &[sample_mean], &[sample_std])
126 .expect("normalize_chw ok");
127
128 let out_mean = out.iter().sum::<f32>() / out.len() as f32;
129 assert!(
130 out_mean.abs() < 1e-5,
131 "expected near-zero mean after normalization, got {out_mean}"
132 );
133 }
134
135 #[test]
136 fn normalized_std_approx_one() {
137 let (img, channels, h, w) = make_single_channel_img();
138 let sample_mean = img.iter().sum::<f32>() / img.len() as f32;
139 let variance =
140 img.iter().map(|&v| (v - sample_mean).powi(2)).sum::<f32>() / img.len() as f32;
141 let sample_std = variance.sqrt();
142
143 let out = normalize_chw(&img, channels, h, w, &[sample_mean], &[sample_std])
144 .expect("normalize_chw ok");
145
146 let out_mean = out.iter().sum::<f32>() / out.len() as f32;
147 let out_var = out.iter().map(|&v| (v - out_mean).powi(2)).sum::<f32>() / out.len() as f32;
148 let out_std = out_var.sqrt();
149 assert!(
150 (out_std - 1.0).abs() < 1e-5,
151 "expected std ≈ 1.0 after normalization, got {out_std}"
152 );
153 }
154
155 #[test]
156 fn multi_channel_normalization_per_channel() {
157 let img = vec![10.0f32, 20.0, 30.0]; let mean = [5.0f32, 15.0, 25.0];
161 let std = [2.5f32, 2.5, 2.5];
162
163 let out = normalize_chw(&img, 3, 1, 1, &mean, &std).expect("ok");
164 assert!((out[0] - 2.0).abs() < 1e-6, "c0: {}", out[0]);
168 assert!((out[1] - 2.0).abs() < 1e-6, "c1: {}", out[1]);
169 assert!((out[2] - 2.0).abs() < 1e-6, "c2: {}", out[2]);
170 }
171
172 #[test]
173 fn error_on_zero_height() {
174 let img = vec![1.0f32; 3];
175 let r = normalize_chw(&img, 3, 0, 1, &[0.0, 0.0, 0.0], &[1.0, 1.0, 1.0]);
176 assert!(
177 matches!(r, Err(VisionError::InvalidImageSize { .. })),
178 "expected InvalidImageSize, got {:?}",
179 r
180 );
181 }
182
183 #[test]
184 fn error_on_zero_channels() {
185 let img: Vec<f32> = vec![];
186 let r = normalize_chw(&img, 0, 4, 4, &[], &[]);
187 assert!(matches!(r, Err(VisionError::InvalidImageSize { .. })));
188 }
189
190 #[test]
191 fn error_on_wrong_image_length() {
192 let img = vec![1.0f32; 10]; let r = normalize_chw(&img, 3, 4, 4, &[0.0, 0.0, 0.0], &[1.0, 1.0, 1.0]);
194 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
195 }
196
197 #[test]
198 fn error_on_mean_length_mismatch() {
199 let img = vec![0.0f32; 3 * 2 * 2];
200 let r = normalize_chw(&img, 3, 2, 2, &[0.0, 0.0], &[1.0, 1.0, 1.0]);
201 assert!(matches!(r, Err(VisionError::ShapeMismatch { .. })));
202 }
203
204 #[test]
205 fn error_on_std_length_mismatch() {
206 let img = vec![0.0f32; 3 * 2 * 2];
207 let r = normalize_chw(&img, 3, 2, 2, &[0.0, 0.0, 0.0], &[1.0, 1.0]);
208 assert!(matches!(r, Err(VisionError::ShapeMismatch { .. })));
209 }
210
211 #[test]
212 fn error_on_nonpositive_std() {
213 let img = vec![1.0f32; 2 * 2];
214 let r = normalize_chw(&img, 1, 2, 2, &[0.5], &[0.0]);
215 assert!(matches!(r, Err(VisionError::NonFinite(_))));
216 }
217
218 #[test]
219 fn error_on_negative_std() {
220 let img = vec![1.0f32; 2 * 2];
221 let r = normalize_chw(&img, 1, 2, 2, &[0.5], &[-0.5]);
222 assert!(matches!(r, Err(VisionError::NonFinite(_))));
223 }
224
225 #[test]
226 fn imagenet_constants_valid() {
227 assert_eq!(IMAGENET_MEAN.len(), 3);
228 assert_eq!(IMAGENET_STD.len(), 3);
229 assert!(IMAGENET_STD.iter().all(|&v| v > 0.0));
231 assert!(IMAGENET_MEAN.iter().all(|&v| (0.0..=1.0).contains(&v)));
233 }
234
235 #[test]
236 fn imagenet_normalization_output_finite() {
237 let img: Vec<f32> = (0..3 * 224 * 224)
239 .map(|i| ((i % 256) as f32) / 255.0)
240 .collect();
241 let out = normalize_chw(&img, 3, 224, 224, &IMAGENET_MEAN, &IMAGENET_STD)
242 .expect("imagenet normalize ok");
243 assert!(
244 out.iter().all(|v| v.is_finite()),
245 "non-finite after imagenet normalize"
246 );
247 }
248}