1use crate::error::{MlError, MlResult};
49
50#[derive(Clone, Copy, Debug, PartialEq, Eq)]
52pub enum PixelLayout {
53 Rgb,
55 Bgr,
57}
58
59#[derive(Clone, Copy, Debug, PartialEq, Eq)]
61pub enum TensorLayout {
62 Nchw,
64 Nhwc,
66}
67
68#[derive(Clone, Copy, Debug, PartialEq, Eq)]
70pub enum InputRange {
71 U8,
73 UnitFloat,
75}
76
77#[derive(Clone, Debug)]
83pub struct ImagePreprocessor {
84 target_width: u32,
85 target_height: u32,
86 pixel_layout: PixelLayout,
87 tensor_layout: TensorLayout,
88 input_range: InputRange,
89 mean: [f32; 3],
90 std: [f32; 3],
91 swap_to_rgb: bool,
92}
93
94impl ImagePreprocessor {
95 #[must_use]
97 pub fn new(target_width: u32, target_height: u32) -> Self {
98 Self {
99 target_width,
100 target_height,
101 pixel_layout: PixelLayout::Rgb,
102 tensor_layout: TensorLayout::Nchw,
103 input_range: InputRange::U8,
104 mean: [0.0, 0.0, 0.0],
105 std: [1.0, 1.0, 1.0],
106 swap_to_rgb: false,
107 }
108 }
109
110 #[must_use]
112 pub fn with_pixel_layout(mut self, layout: PixelLayout) -> Self {
113 self.pixel_layout = layout;
114 self.swap_to_rgb = layout == PixelLayout::Bgr;
115 self
116 }
117
118 #[must_use]
120 pub fn with_tensor_layout(mut self, layout: TensorLayout) -> Self {
121 self.tensor_layout = layout;
122 self
123 }
124
125 #[must_use]
127 pub fn with_input_range(mut self, range: InputRange) -> Self {
128 self.input_range = range;
129 self
130 }
131
132 #[must_use]
134 pub fn with_mean(mut self, mean: [f32; 3]) -> Self {
135 self.mean = mean;
136 self
137 }
138
139 #[must_use]
141 pub fn with_std(mut self, std: [f32; 3]) -> Self {
142 self.std = std;
143 self
144 }
145
146 #[must_use]
148 pub fn with_imagenet_normalization(self) -> Self {
149 self.with_mean([0.485, 0.456, 0.406])
150 .with_std([0.229, 0.224, 0.225])
151 }
152
153 #[must_use]
155 pub fn target_width(&self) -> u32 {
156 self.target_width
157 }
158
159 #[must_use]
161 pub fn target_height(&self) -> u32 {
162 self.target_height
163 }
164
165 pub fn process_u8_rgb(&self, pixels: &[u8], src_w: u32, src_h: u32) -> MlResult<Vec<f32>> {
178 let expected = (src_w as usize) * (src_h as usize) * 3;
179 if pixels.len() != expected {
180 return Err(MlError::preprocess(format!(
181 "expected {expected} bytes for {src_w}x{src_h} RGB, got {}",
182 pixels.len()
183 )));
184 }
185 if src_w == 0 || src_h == 0 {
186 return Err(MlError::preprocess("source image has zero extent"));
187 }
188 if self.target_width == 0 || self.target_height == 0 {
189 return Err(MlError::preprocess("target size has zero extent"));
190 }
191
192 let tw = self.target_width as usize;
193 let th = self.target_height as usize;
194 let mut out = vec![0.0_f32; tw * th * 3];
195
196 let x_ratio = (src_w as f32) / (self.target_width as f32);
197 let y_ratio = (src_h as f32) / (self.target_height as f32);
198
199 for y in 0..th {
200 let src_y = ((y as f32) * y_ratio) as usize;
201 let src_y = src_y.min((src_h as usize).saturating_sub(1));
202 for x in 0..tw {
203 let src_x = ((x as f32) * x_ratio) as usize;
204 let src_x = src_x.min((src_w as usize).saturating_sub(1));
205 let src_idx = (src_y * (src_w as usize) + src_x) * 3;
206 let (r_src, g_src, b_src) =
207 (pixels[src_idx], pixels[src_idx + 1], pixels[src_idx + 2]);
208 let (r_raw, g_raw, b_raw) = if self.swap_to_rgb {
209 (b_src, g_src, r_src)
210 } else {
211 (r_src, g_src, b_src)
212 };
213
214 let (r, g, b) = match self.input_range {
215 InputRange::U8 => (
216 (r_raw as f32) / 255.0,
217 (g_raw as f32) / 255.0,
218 (b_raw as f32) / 255.0,
219 ),
220 InputRange::UnitFloat => (r_raw as f32, g_raw as f32, b_raw as f32),
221 };
222
223 let r = (r - self.mean[0]) / self.std[0];
224 let g = (g - self.mean[1]) / self.std[1];
225 let b = (b - self.mean[2]) / self.std[2];
226
227 match self.tensor_layout {
228 TensorLayout::Nhwc => {
229 let dst = (y * tw + x) * 3;
230 out[dst] = r;
231 out[dst + 1] = g;
232 out[dst + 2] = b;
233 }
234 TensorLayout::Nchw => {
235 let plane = tw * th;
236 let pixel = y * tw + x;
237 out[pixel] = r;
238 out[plane + pixel] = g;
239 out[(plane * 2) + pixel] = b;
240 }
241 }
242 }
243 }
244
245 Ok(out)
246 }
247
248 #[must_use]
252 pub fn batch_shape(&self) -> Vec<usize> {
253 let tw = self.target_width as usize;
254 let th = self.target_height as usize;
255 match self.tensor_layout {
256 TensorLayout::Nchw => vec![1, 3, th, tw],
257 TensorLayout::Nhwc => vec![1, th, tw, 3],
258 }
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn builder_defaults() {
268 let p = ImagePreprocessor::new(224, 224);
269 assert_eq!(p.target_width(), 224);
270 assert_eq!(p.target_height(), 224);
271 assert_eq!(p.batch_shape(), vec![1, 3, 224, 224]);
272 }
273
274 #[test]
275 fn nhwc_batch_shape() {
276 let p = ImagePreprocessor::new(64, 32).with_tensor_layout(TensorLayout::Nhwc);
277 assert_eq!(p.batch_shape(), vec![1, 32, 64, 3]);
278 }
279
280 #[test]
281 fn mismatched_buffer_errors() {
282 let p = ImagePreprocessor::new(4, 4);
283 let pixels = vec![0u8; 10];
284 let err = p.process_u8_rgb(&pixels, 2, 2).expect_err("must fail");
285 assert!(matches!(err, MlError::Preprocess(_)));
286 }
287
288 #[test]
289 fn zero_target_errors() {
290 let p = ImagePreprocessor::new(0, 4);
291 let pixels = vec![0u8; 4 * 4 * 3];
292 let err = p.process_u8_rgb(&pixels, 4, 4).expect_err("must fail");
293 assert!(matches!(err, MlError::Preprocess(_)));
294 }
295
296 #[test]
297 fn imagenet_white_pixel_is_normalized() {
298 let p = ImagePreprocessor::new(1, 1).with_imagenet_normalization();
300 let pixels = vec![255u8, 255u8, 255u8];
301 let out = p.process_u8_rgb(&pixels, 1, 1).expect("ok");
302 assert_eq!(out.len(), 3);
303 let expected_r = (1.0 - 0.485) / 0.229;
304 let expected_g = (1.0 - 0.456) / 0.224;
305 let expected_b = (1.0 - 0.406) / 0.225;
306 assert!((out[0] - expected_r).abs() < 1e-5);
307 assert!((out[1] - expected_g).abs() < 1e-5);
308 assert!((out[2] - expected_b).abs() < 1e-5);
309 }
310
311 #[test]
312 fn bgr_swaps_to_rgb() {
313 let p = ImagePreprocessor::new(1, 1)
314 .with_pixel_layout(PixelLayout::Bgr)
315 .with_input_range(InputRange::U8);
316 let pixels = vec![10u8, 20u8, 30u8];
317 let out = p.process_u8_rgb(&pixels, 1, 1).expect("ok");
318 assert!((out[0] - 30.0 / 255.0).abs() < 1e-5);
320 assert!((out[1] - 20.0 / 255.0).abs() < 1e-5);
321 assert!((out[2] - 10.0 / 255.0).abs() < 1e-5);
322 }
323
324 #[test]
325 fn nchw_layout_plane_major() {
326 let p = ImagePreprocessor::new(2, 1).with_input_range(InputRange::UnitFloat);
327 let pixels = vec![25u8, 51, 76, 102, 128, 153];
332 let out = p.process_u8_rgb(&pixels, 2, 1).expect("ok");
333 assert_eq!(out.len(), 2 * 1 * 3);
335 assert!((out[0] - 25.0).abs() < 1e-5);
337 assert!((out[1] - 102.0).abs() < 1e-5);
338 assert!((out[2] - 51.0).abs() < 1e-5);
340 assert!((out[3] - 128.0).abs() < 1e-5);
341 assert!((out[4] - 76.0).abs() < 1e-5);
343 assert!((out[5] - 153.0).abs() < 1e-5);
344 }
345}