use std::error::Error;
use std::iter::zip;
use std::path::Path;
use rten_tensor::errors::FromDataError;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut};
pub fn normalize_pixel(value: f32, channel: usize) -> f32 {
assert!(channel < 3, "channel index is invalid");
let imagenet_mean = [0.485, 0.456, 0.406];
let imagenet_std_dev = [0.229, 0.224, 0.225];
(value - imagenet_mean[channel]) / imagenet_std_dev[channel]
}
pub fn normalize_image(mut img: NdTensorViewMut<f32, 3>) {
for ([chan, _y, _x], pixel) in zip(img.indices(), img.iter_mut()) {
*pixel = normalize_pixel(*pixel, chan);
}
}
#[derive(Debug)]
pub enum ReadImageError {
ImageError(image::ImageError),
ConvertError(FromDataError),
}
impl std::fmt::Display for ReadImageError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ReadImageError::ImageError(e) => write!(f, "failed to read image: {}", e),
ReadImageError::ConvertError(e) => write!(f, "failed to create tensor: {}", e),
}
}
}
impl Error for ReadImageError {}
pub fn image_to_tensor(image: image::DynamicImage) -> Result<NdTensor<f32, 3>, ReadImageError> {
let image = image.into_rgb8();
let (width, height) = image.dimensions();
let layout = image.sample_layout();
let chw_tensor = NdTensorView::from_data_with_strides(
[height as usize, width as usize, 3],
image.as_raw().as_slice(),
[
layout.height_stride,
layout.width_stride,
layout.channel_stride,
],
)
.map_err(ReadImageError::ConvertError)?
.permuted([2, 0, 1]) .map(|x| *x as f32 / 255.);
Ok(chw_tensor)
}
pub fn read_image<P: AsRef<Path>>(path: P) -> Result<NdTensor<f32, 3>, ReadImageError> {
image::open(path)
.map_err(ReadImageError::ImageError)
.and_then(image_to_tensor)
}
#[derive(Debug)]
pub enum WriteImageError {
UnsupportedChannelCount,
ImageError(image::ImageError),
}
impl std::fmt::Display for WriteImageError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ImageError(e) => write!(f, "failed to write image: {}", e),
Self::UnsupportedChannelCount => write!(f, "image has unsupported number of channels"),
}
}
}
impl Error for WriteImageError {}
pub fn write_image(path: &str, img: NdTensorView<f32, 3>) -> Result<(), WriteImageError> {
let [channels, height, width] = img.shape();
let color_type = match channels {
1 => image::ColorType::L8,
3 => image::ColorType::Rgb8,
4 => image::ColorType::Rgba8,
_ => return Err(WriteImageError::UnsupportedChannelCount),
};
let hwc_img = img
.permuted([1, 2, 0]) .map(|x| (x.clamp(0., 1.) * 255.0) as u8);
image::save_buffer(
path,
hwc_img.data().unwrap(),
width as u32,
height as u32,
color_type,
)
.map_err(WriteImageError::ImageError)?;
Ok(())
}