use crate::error::{MlError, MlResult};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PixelLayout {
Rgb,
Bgr,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TensorLayout {
Nchw,
Nhwc,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum InputRange {
U8,
UnitFloat,
}
#[derive(Clone, Debug)]
pub struct ImagePreprocessor {
target_width: u32,
target_height: u32,
pixel_layout: PixelLayout,
tensor_layout: TensorLayout,
input_range: InputRange,
mean: [f32; 3],
std: [f32; 3],
swap_to_rgb: bool,
}
impl ImagePreprocessor {
#[must_use]
pub fn new(target_width: u32, target_height: u32) -> Self {
Self {
target_width,
target_height,
pixel_layout: PixelLayout::Rgb,
tensor_layout: TensorLayout::Nchw,
input_range: InputRange::U8,
mean: [0.0, 0.0, 0.0],
std: [1.0, 1.0, 1.0],
swap_to_rgb: false,
}
}
#[must_use]
pub fn with_pixel_layout(mut self, layout: PixelLayout) -> Self {
self.pixel_layout = layout;
self.swap_to_rgb = layout == PixelLayout::Bgr;
self
}
#[must_use]
pub fn with_tensor_layout(mut self, layout: TensorLayout) -> Self {
self.tensor_layout = layout;
self
}
#[must_use]
pub fn with_input_range(mut self, range: InputRange) -> Self {
self.input_range = range;
self
}
#[must_use]
pub fn with_mean(mut self, mean: [f32; 3]) -> Self {
self.mean = mean;
self
}
#[must_use]
pub fn with_std(mut self, std: [f32; 3]) -> Self {
self.std = std;
self
}
#[must_use]
pub fn with_imagenet_normalization(self) -> Self {
self.with_mean([0.485, 0.456, 0.406])
.with_std([0.229, 0.224, 0.225])
}
#[must_use]
pub fn target_width(&self) -> u32 {
self.target_width
}
#[must_use]
pub fn target_height(&self) -> u32 {
self.target_height
}
pub fn process_u8_rgb(&self, pixels: &[u8], src_w: u32, src_h: u32) -> MlResult<Vec<f32>> {
let expected = (src_w as usize) * (src_h as usize) * 3;
if pixels.len() != expected {
return Err(MlError::preprocess(format!(
"expected {expected} bytes for {src_w}x{src_h} RGB, got {}",
pixels.len()
)));
}
if src_w == 0 || src_h == 0 {
return Err(MlError::preprocess("source image has zero extent"));
}
if self.target_width == 0 || self.target_height == 0 {
return Err(MlError::preprocess("target size has zero extent"));
}
let tw = self.target_width as usize;
let th = self.target_height as usize;
let mut out = vec![0.0_f32; tw * th * 3];
let x_ratio = (src_w as f32) / (self.target_width as f32);
let y_ratio = (src_h as f32) / (self.target_height as f32);
for y in 0..th {
let src_y = ((y as f32) * y_ratio) as usize;
let src_y = src_y.min((src_h as usize).saturating_sub(1));
for x in 0..tw {
let src_x = ((x as f32) * x_ratio) as usize;
let src_x = src_x.min((src_w as usize).saturating_sub(1));
let src_idx = (src_y * (src_w as usize) + src_x) * 3;
let (r_src, g_src, b_src) =
(pixels[src_idx], pixels[src_idx + 1], pixels[src_idx + 2]);
let (r_raw, g_raw, b_raw) = if self.swap_to_rgb {
(b_src, g_src, r_src)
} else {
(r_src, g_src, b_src)
};
let (r, g, b) = match self.input_range {
InputRange::U8 => (
(r_raw as f32) / 255.0,
(g_raw as f32) / 255.0,
(b_raw as f32) / 255.0,
),
InputRange::UnitFloat => (r_raw as f32, g_raw as f32, b_raw as f32),
};
let r = (r - self.mean[0]) / self.std[0];
let g = (g - self.mean[1]) / self.std[1];
let b = (b - self.mean[2]) / self.std[2];
match self.tensor_layout {
TensorLayout::Nhwc => {
let dst = (y * tw + x) * 3;
out[dst] = r;
out[dst + 1] = g;
out[dst + 2] = b;
}
TensorLayout::Nchw => {
let plane = tw * th;
let pixel = y * tw + x;
out[pixel] = r;
out[plane + pixel] = g;
out[(plane * 2) + pixel] = b;
}
}
}
}
Ok(out)
}
#[must_use]
pub fn batch_shape(&self) -> Vec<usize> {
let tw = self.target_width as usize;
let th = self.target_height as usize;
match self.tensor_layout {
TensorLayout::Nchw => vec![1, 3, th, tw],
TensorLayout::Nhwc => vec![1, th, tw, 3],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_defaults() {
let p = ImagePreprocessor::new(224, 224);
assert_eq!(p.target_width(), 224);
assert_eq!(p.target_height(), 224);
assert_eq!(p.batch_shape(), vec![1, 3, 224, 224]);
}
#[test]
fn nhwc_batch_shape() {
let p = ImagePreprocessor::new(64, 32).with_tensor_layout(TensorLayout::Nhwc);
assert_eq!(p.batch_shape(), vec![1, 32, 64, 3]);
}
#[test]
fn mismatched_buffer_errors() {
let p = ImagePreprocessor::new(4, 4);
let pixels = vec![0u8; 10];
let err = p.process_u8_rgb(&pixels, 2, 2).expect_err("must fail");
assert!(matches!(err, MlError::Preprocess(_)));
}
#[test]
fn zero_target_errors() {
let p = ImagePreprocessor::new(0, 4);
let pixels = vec![0u8; 4 * 4 * 3];
let err = p.process_u8_rgb(&pixels, 4, 4).expect_err("must fail");
assert!(matches!(err, MlError::Preprocess(_)));
}
#[test]
fn imagenet_white_pixel_is_normalized() {
let p = ImagePreprocessor::new(1, 1).with_imagenet_normalization();
let pixels = vec![255u8, 255u8, 255u8];
let out = p.process_u8_rgb(&pixels, 1, 1).expect("ok");
assert_eq!(out.len(), 3);
let expected_r = (1.0 - 0.485) / 0.229;
let expected_g = (1.0 - 0.456) / 0.224;
let expected_b = (1.0 - 0.406) / 0.225;
assert!((out[0] - expected_r).abs() < 1e-5);
assert!((out[1] - expected_g).abs() < 1e-5);
assert!((out[2] - expected_b).abs() < 1e-5);
}
#[test]
fn bgr_swaps_to_rgb() {
let p = ImagePreprocessor::new(1, 1)
.with_pixel_layout(PixelLayout::Bgr)
.with_input_range(InputRange::U8);
let pixels = vec![10u8, 20u8, 30u8];
let out = p.process_u8_rgb(&pixels, 1, 1).expect("ok");
assert!((out[0] - 30.0 / 255.0).abs() < 1e-5);
assert!((out[1] - 20.0 / 255.0).abs() < 1e-5);
assert!((out[2] - 10.0 / 255.0).abs() < 1e-5);
}
#[test]
fn nchw_layout_plane_major() {
let p = ImagePreprocessor::new(2, 1).with_input_range(InputRange::UnitFloat);
let pixels = vec![25u8, 51, 76, 102, 128, 153];
let out = p.process_u8_rgb(&pixels, 2, 1).expect("ok");
assert_eq!(out.len(), 2 * 1 * 3);
assert!((out[0] - 25.0).abs() < 1e-5);
assert!((out[1] - 102.0).abs() < 1e-5);
assert!((out[2] - 51.0).abs() < 1e-5);
assert!((out[3] - 128.0).abs() < 1e-5);
assert!((out[4] - 76.0).abs() < 1e-5);
assert!((out[5] - 153.0).abs() < 1e-5);
}
}