use std::error::Error;
use std::iter::zip;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut};
pub fn normalize_pixel(value: f32, channel: usize) -> f32 {
assert!(channel < 3, "channel index is invalid");
let imagenet_mean = [0.485, 0.456, 0.406];
let imagenet_std_dev = [0.229, 0.224, 0.225];
(value - imagenet_mean[channel]) / imagenet_std_dev[channel]
}
pub fn normalize_image(mut img: NdTensorViewMut<f32, 3>) {
for ([chan, _y, _x], pixel) in zip(img.indices(), img.iter_mut()) {
*pixel = normalize_pixel(*pixel, chan);
}
}
pub fn read_image(path: &str) -> Result<NdTensor<f32, 3>, Box<dyn Error>> {
let input_img = image::open(path)?;
let input_img = input_img.into_rgb8();
let (width, height) = input_img.dimensions();
let layout = input_img.sample_layout();
let chw_tensor = NdTensorView::from_data_with_strides(
[height as usize, width as usize, 3],
input_img.as_raw().as_slice(),
[
layout.height_stride,
layout.width_stride,
layout.channel_stride,
],
)?
.permuted([2, 0, 1]) .to_tensor() .map(|x| *x as f32 / 255.);
Ok(chw_tensor)
}
pub fn write_image(path: &str, img: NdTensorView<f32, 3>) -> Result<(), Box<dyn Error>> {
let [channels, height, width] = img.shape();
let color_type = match channels {
1 => image::ColorType::L8,
3 => image::ColorType::Rgb8,
4 => image::ColorType::Rgba8,
_ => return Err("Unsupported channel count".into()),
};
let hwc_img = img
.permuted([1, 2, 0]) .map(|x| (x.clamp(0., 1.) * 255.0) as u8)
.to_tensor();
image::save_buffer(
path,
hwc_img.data().unwrap(),
width as u32,
height as u32,
color_type,
)?;
Ok(())
}