use image::{DynamicImage, imageops::FilterType};
use rten_tensor::Tensor;
use std::path::Path;
pub const IMG_HEIGHT: u32 = 80;
pub const IMG_WIDTH: u32 = 215;
const MEAN: [f32; 3] = [0.485, 0.456, 0.406];
const STD: [f32; 3] = [0.229, 0.224, 0.225];
#[must_use]
pub fn preprocess(img: &DynamicImage) -> Tensor<f32> {
let img = img.resize_exact(IMG_WIDTH, IMG_HEIGHT, FilterType::Triangle);
let img = img.to_rgb8();
let mut data = Vec::with_capacity((3 * IMG_HEIGHT * IMG_WIDTH) as usize);
for c in 0..3 {
for y in 0..IMG_HEIGHT {
for x in 0..IMG_WIDTH {
let pixel = img.get_pixel(x, y);
let pixel_value = f32::from(pixel[c]) / 255.0;
let normalized = (pixel_value - MEAN[c]) / STD[c];
data.push(normalized);
}
}
}
Tensor::from_data(&[1, 3, IMG_HEIGHT as usize, IMG_WIDTH as usize], data)
}
pub fn preprocess_file<P: AsRef<Path>>(path: P) -> crate::Result<Tensor<f32>> {
let img = image::open(path)?;
Ok(preprocess(&img))
}
#[cfg(test)]
mod tests {
use super::*;
use image::RgbImage;
use rten_tensor::prelude::*;
#[test]
fn test_preprocess_output_shape() {
let img = DynamicImage::ImageRgb8(RgbImage::new(100, 50));
let tensor = preprocess(&img);
assert_eq!(tensor.shape(), &[1, 3, 80, 215]);
}
#[test]
fn test_preprocess_normalization() {
let mut img = RgbImage::new(10, 10);
for pixel in img.pixels_mut() {
*pixel = image::Rgb([255, 255, 255]);
}
let img = DynamicImage::ImageRgb8(img);
let tensor = preprocess(&img);
let data = tensor.data().expect("should be contiguous");
let stride = (IMG_HEIGHT * IMG_WIDTH) as usize;
let first_pixel_r = data[0];
let first_pixel_g = data[stride];
let first_pixel_b = data[stride * 2];
assert!((first_pixel_r - 2.2489).abs() < 0.001);
assert!((first_pixel_g - 2.4286).abs() < 0.001);
assert!((first_pixel_b - 2.6400).abs() < 0.001);
}
}