Skip to main content

oximedia_ml/
preprocess.rs

1//! Image preprocessing for ML inference.
2//!
3//! [`ImagePreprocessor`] is a builder for converting raw pixel data
4//! (RGB/BGR, u8 or f32) into a normalised `f32` tensor in either NCHW
5//! or NHWC layout. The design is deliberately minimal — it covers the
6//! overwhelmingly common case of ImageNet-style classifiers and shot-
7//! boundary detectors, where input is "u8 RGB, scaled to a fixed size,
8//! normalised by per-channel mean/std".
9//!
10//! Scaling is done with nearest-neighbour (Pure-Rust, zero deps); richer
11//! interpolation is intentionally left to `oximedia-scaling` so the
12//! default build stays small.
13//!
14//! ## Builder flow
15//!
16//! 1. Start with [`ImagePreprocessor::new`] to fix the output size.
17//! 2. Chain [`with_pixel_layout`](ImagePreprocessor::with_pixel_layout),
18//!    [`with_tensor_layout`](ImagePreprocessor::with_tensor_layout), and
19//!    [`with_input_range`](ImagePreprocessor::with_input_range) to match
20//!    the source data.
21//! 3. Apply per-channel normalisation via
22//!    [`with_mean`](ImagePreprocessor::with_mean) /
23//!    [`with_std`](ImagePreprocessor::with_std), or use the
24//!    [`with_imagenet_normalization`](ImagePreprocessor::with_imagenet_normalization)
25//!    shortcut.
26//! 4. Call [`process_u8_rgb`](ImagePreprocessor::process_u8_rgb) to get
27//!    a flattened `Vec<f32>` ready to feed into ONNX.
28//!
29//! ## Example
30//!
31//! ```
32//! use oximedia_ml::{ImagePreprocessor, TensorLayout};
33//!
34//! # fn main() -> oximedia_ml::MlResult<()> {
35//! let preproc = ImagePreprocessor::new(224, 224)
36//!     .with_tensor_layout(TensorLayout::Nchw)
37//!     .with_imagenet_normalization();
38//!
39//! assert_eq!(preproc.batch_shape(), vec![1, 3, 224, 224]);
40//!
41//! let white = vec![255_u8; 224 * 224 * 3];
42//! let tensor = preproc.process_u8_rgb(&white, 224, 224)?;
43//! assert_eq!(tensor.len(), 3 * 224 * 224);
44//! # Ok(())
45//! # }
46//! ```
47
48use crate::error::{MlError, MlResult};
49
50/// Pixel layout of the source image.
51#[derive(Clone, Copy, Debug, PartialEq, Eq)]
52pub enum PixelLayout {
53    /// Red, Green, Blue channel order (default).
54    Rgb,
55    /// Blue, Green, Red channel order (OpenCV convention).
56    Bgr,
57}
58
59/// Tensor memory layout.
60#[derive(Clone, Copy, Debug, PartialEq, Eq)]
61pub enum TensorLayout {
62    /// Batch, Channel, Height, Width.
63    Nchw,
64    /// Batch, Height, Width, Channel.
65    Nhwc,
66}
67
68/// Scalar range of the source image.
69#[derive(Clone, Copy, Debug, PartialEq, Eq)]
70pub enum InputRange {
71    /// Values are in `[0, 255]` u8 range.
72    U8,
73    /// Values are already in `[0.0, 1.0]` f32 range.
74    UnitFloat,
75}
76
77/// Builder for an image preprocessing pipeline.
78///
79/// See the [module-level docs][self] for the intended flow. An instance
80/// is cheap to clone so a single pre-built preprocessor can be shared
81/// across threads.
82#[derive(Clone, Debug)]
83pub struct ImagePreprocessor {
84    target_width: u32,
85    target_height: u32,
86    pixel_layout: PixelLayout,
87    tensor_layout: TensorLayout,
88    input_range: InputRange,
89    mean: [f32; 3],
90    std: [f32; 3],
91    swap_to_rgb: bool,
92}
93
94impl ImagePreprocessor {
95    /// Create a new preprocessor for a given output size.
96    #[must_use]
97    pub fn new(target_width: u32, target_height: u32) -> Self {
98        Self {
99            target_width,
100            target_height,
101            pixel_layout: PixelLayout::Rgb,
102            tensor_layout: TensorLayout::Nchw,
103            input_range: InputRange::U8,
104            mean: [0.0, 0.0, 0.0],
105            std: [1.0, 1.0, 1.0],
106            swap_to_rgb: false,
107        }
108    }
109
110    /// Set the input pixel layout. Default is [`PixelLayout::Rgb`].
111    #[must_use]
112    pub fn with_pixel_layout(mut self, layout: PixelLayout) -> Self {
113        self.pixel_layout = layout;
114        self.swap_to_rgb = layout == PixelLayout::Bgr;
115        self
116    }
117
118    /// Set the output tensor layout. Default is [`TensorLayout::Nchw`].
119    #[must_use]
120    pub fn with_tensor_layout(mut self, layout: TensorLayout) -> Self {
121        self.tensor_layout = layout;
122        self
123    }
124
125    /// Set the scalar range of the source image. Default is [`InputRange::U8`].
126    #[must_use]
127    pub fn with_input_range(mut self, range: InputRange) -> Self {
128        self.input_range = range;
129        self
130    }
131
132    /// Set the per-channel normalisation mean.
133    #[must_use]
134    pub fn with_mean(mut self, mean: [f32; 3]) -> Self {
135        self.mean = mean;
136        self
137    }
138
139    /// Set the per-channel normalisation std-dev.
140    #[must_use]
141    pub fn with_std(mut self, std: [f32; 3]) -> Self {
142        self.std = std;
143        self
144    }
145
146    /// Apply ImageNet mean `[0.485, 0.456, 0.406]` and std `[0.229, 0.224, 0.225]`.
147    #[must_use]
148    pub fn with_imagenet_normalization(self) -> Self {
149        self.with_mean([0.485, 0.456, 0.406])
150            .with_std([0.229, 0.224, 0.225])
151    }
152
153    /// Target width in pixels.
154    #[must_use]
155    pub fn target_width(&self) -> u32 {
156        self.target_width
157    }
158
159    /// Target height in pixels.
160    #[must_use]
161    pub fn target_height(&self) -> u32 {
162        self.target_height
163    }
164
165    /// Process a raw u8 RGB(3-channel) buffer of size `src_w * src_h * 3`.
166    ///
167    /// The output is a flattened `Vec<f32>` containing a single image
168    /// (no batch dimension). Call [`ImagePreprocessor::batch_shape`] to
169    /// learn the logical shape.
170    ///
171    /// # Errors
172    ///
173    /// Returns [`MlError::Preprocess`] if:
174    ///
175    /// * `pixels.len() != src_w * src_h * 3`,
176    /// * either dimension (source or target) is zero.
177    pub fn process_u8_rgb(&self, pixels: &[u8], src_w: u32, src_h: u32) -> MlResult<Vec<f32>> {
178        let expected = (src_w as usize) * (src_h as usize) * 3;
179        if pixels.len() != expected {
180            return Err(MlError::preprocess(format!(
181                "expected {expected} bytes for {src_w}x{src_h} RGB, got {}",
182                pixels.len()
183            )));
184        }
185        if src_w == 0 || src_h == 0 {
186            return Err(MlError::preprocess("source image has zero extent"));
187        }
188        if self.target_width == 0 || self.target_height == 0 {
189            return Err(MlError::preprocess("target size has zero extent"));
190        }
191
192        let tw = self.target_width as usize;
193        let th = self.target_height as usize;
194        let mut out = vec![0.0_f32; tw * th * 3];
195
196        let x_ratio = (src_w as f32) / (self.target_width as f32);
197        let y_ratio = (src_h as f32) / (self.target_height as f32);
198
199        for y in 0..th {
200            let src_y = ((y as f32) * y_ratio) as usize;
201            let src_y = src_y.min((src_h as usize).saturating_sub(1));
202            for x in 0..tw {
203                let src_x = ((x as f32) * x_ratio) as usize;
204                let src_x = src_x.min((src_w as usize).saturating_sub(1));
205                let src_idx = (src_y * (src_w as usize) + src_x) * 3;
206                let (r_src, g_src, b_src) =
207                    (pixels[src_idx], pixels[src_idx + 1], pixels[src_idx + 2]);
208                let (r_raw, g_raw, b_raw) = if self.swap_to_rgb {
209                    (b_src, g_src, r_src)
210                } else {
211                    (r_src, g_src, b_src)
212                };
213
214                let (r, g, b) = match self.input_range {
215                    InputRange::U8 => (
216                        (r_raw as f32) / 255.0,
217                        (g_raw as f32) / 255.0,
218                        (b_raw as f32) / 255.0,
219                    ),
220                    InputRange::UnitFloat => (r_raw as f32, g_raw as f32, b_raw as f32),
221                };
222
223                let r = (r - self.mean[0]) / self.std[0];
224                let g = (g - self.mean[1]) / self.std[1];
225                let b = (b - self.mean[2]) / self.std[2];
226
227                match self.tensor_layout {
228                    TensorLayout::Nhwc => {
229                        let dst = (y * tw + x) * 3;
230                        out[dst] = r;
231                        out[dst + 1] = g;
232                        out[dst + 2] = b;
233                    }
234                    TensorLayout::Nchw => {
235                        let plane = tw * th;
236                        let pixel = y * tw + x;
237                        out[pixel] = r;
238                        out[plane + pixel] = g;
239                        out[(plane * 2) + pixel] = b;
240                    }
241                }
242            }
243        }
244
245        Ok(out)
246    }
247
248    /// Return the logical shape of the output tensor with a leading
249    /// batch dim of 1. Matches the flat buffer returned by
250    /// [`ImagePreprocessor::process_u8_rgb`].
251    #[must_use]
252    pub fn batch_shape(&self) -> Vec<usize> {
253        let tw = self.target_width as usize;
254        let th = self.target_height as usize;
255        match self.tensor_layout {
256            TensorLayout::Nchw => vec![1, 3, th, tw],
257            TensorLayout::Nhwc => vec![1, th, tw, 3],
258        }
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn builder_defaults() {
268        let p = ImagePreprocessor::new(224, 224);
269        assert_eq!(p.target_width(), 224);
270        assert_eq!(p.target_height(), 224);
271        assert_eq!(p.batch_shape(), vec![1, 3, 224, 224]);
272    }
273
274    #[test]
275    fn nhwc_batch_shape() {
276        let p = ImagePreprocessor::new(64, 32).with_tensor_layout(TensorLayout::Nhwc);
277        assert_eq!(p.batch_shape(), vec![1, 32, 64, 3]);
278    }
279
280    #[test]
281    fn mismatched_buffer_errors() {
282        let p = ImagePreprocessor::new(4, 4);
283        let pixels = vec![0u8; 10];
284        let err = p.process_u8_rgb(&pixels, 2, 2).expect_err("must fail");
285        assert!(matches!(err, MlError::Preprocess(_)));
286    }
287
288    #[test]
289    fn zero_target_errors() {
290        let p = ImagePreprocessor::new(0, 4);
291        let pixels = vec![0u8; 4 * 4 * 3];
292        let err = p.process_u8_rgb(&pixels, 4, 4).expect_err("must fail");
293        assert!(matches!(err, MlError::Preprocess(_)));
294    }
295
296    #[test]
297    fn imagenet_white_pixel_is_normalized() {
298        // white pixel, 1×1 input, 1×1 target, ImageNet normalisation.
299        let p = ImagePreprocessor::new(1, 1).with_imagenet_normalization();
300        let pixels = vec![255u8, 255u8, 255u8];
301        let out = p.process_u8_rgb(&pixels, 1, 1).expect("ok");
302        assert_eq!(out.len(), 3);
303        let expected_r = (1.0 - 0.485) / 0.229;
304        let expected_g = (1.0 - 0.456) / 0.224;
305        let expected_b = (1.0 - 0.406) / 0.225;
306        assert!((out[0] - expected_r).abs() < 1e-5);
307        assert!((out[1] - expected_g).abs() < 1e-5);
308        assert!((out[2] - expected_b).abs() < 1e-5);
309    }
310
311    #[test]
312    fn bgr_swaps_to_rgb() {
313        let p = ImagePreprocessor::new(1, 1)
314            .with_pixel_layout(PixelLayout::Bgr)
315            .with_input_range(InputRange::U8);
316        let pixels = vec![10u8, 20u8, 30u8];
317        let out = p.process_u8_rgb(&pixels, 1, 1).expect("ok");
318        // BGR→RGB swap => R=30/255, G=20/255, B=10/255 (no mean/std).
319        assert!((out[0] - 30.0 / 255.0).abs() < 1e-5);
320        assert!((out[1] - 20.0 / 255.0).abs() < 1e-5);
321        assert!((out[2] - 10.0 / 255.0).abs() < 1e-5);
322    }
323
324    #[test]
325    fn nchw_layout_plane_major() {
326        let p = ImagePreprocessor::new(2, 1).with_input_range(InputRange::UnitFloat);
327        // 2×1 image: two pixels with distinct channel values.
328        // Pixel 0: (0.1, 0.2, 0.3); Pixel 1: (0.4, 0.5, 0.6)
329        // raw u8 encoding: values 0..=255 via .process_u8_rgb expects u8 in [0..=255]; switch to UnitFloat below.
330        // Hack: reinterpret u8 values as already-unit floats.
331        let pixels = vec![25u8, 51, 76, 102, 128, 153];
332        let out = p.process_u8_rgb(&pixels, 2, 1).expect("ok");
333        // With InputRange::UnitFloat, r/g/b treated as f32 directly (u8 → f32 cast).
334        assert_eq!(out.len(), 2 * 1 * 3);
335        // Plane 0 = R: [25, 102]
336        assert!((out[0] - 25.0).abs() < 1e-5);
337        assert!((out[1] - 102.0).abs() < 1e-5);
338        // Plane 1 = G: [51, 128]
339        assert!((out[2] - 51.0).abs() < 1e-5);
340        assert!((out[3] - 128.0).abs() < 1e-5);
341        // Plane 2 = B: [76, 153]
342        assert!((out[4] - 76.0).abs() < 1e-5);
343        assert!((out[5] - 153.0).abs() < 1e-5);
344    }
345}