rten_imageio/
lib.rs

1//! Provides utilities for loading, saving and preprocessing images for use with
2//! [RTen](https://github.com/robertknight/rten).
3//!
4//! The APIs are limited to keep them simple for the most common use cases.
5//! If you need more flexibility from a function, copy and adjust the
6//! implementation.
7
8use std::error::Error;
9use std::path::Path;
10
11use rten_tensor::errors::FromDataError;
12use rten_tensor::prelude::*;
13use rten_tensor::{NdTensor, NdTensorView};
14
15/// Errors reported when creating a tensor from an image.
16#[derive(Debug)]
17pub enum ReadImageError {
18    /// The image could not be loaded.
19    ImageError(image::ImageError),
20    /// The loaded image could not be converted to a tensor.
21    ConvertError(FromDataError),
22}
23
24impl std::fmt::Display for ReadImageError {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        match self {
27            ReadImageError::ImageError(e) => write!(f, "failed to read image: {}", e),
28            ReadImageError::ConvertError(e) => write!(f, "failed to create tensor: {}", e),
29        }
30    }
31}
32
33impl Error for ReadImageError {}
34
35/// Convert an image into a CHW tensor with 3 channels and values in the range
36/// [0, 1].
37pub fn image_to_tensor(image: image::DynamicImage) -> Result<NdTensor<f32, 3>, ReadImageError> {
38    let image = image.into_rgb8();
39    let (width, height) = image.dimensions();
40    let layout = image.sample_layout();
41
42    let chw_tensor = NdTensorView::from_data_with_strides(
43        [height as usize, width as usize, 3],
44        image.as_raw().as_slice(),
45        [
46            layout.height_stride,
47            layout.width_stride,
48            layout.channel_stride,
49        ],
50    )
51    .map_err(ReadImageError::ConvertError)?
52    .permuted([2, 0, 1]) // HWC => CHW
53    .map(|x| *x as f32 / 255.); // Rescale from [0, 255] to [0, 1]
54
55    Ok(chw_tensor)
56}
57
58/// Read an image from a file into a CHW tensor.
59///
60/// To load an image from a byte buffer or other source, use [`image::open`]
61/// and pass the result to [`image_to_tensor`].
62pub fn read_image<P: AsRef<Path>>(path: P) -> Result<NdTensor<f32, 3>, ReadImageError> {
63    image::open(path)
64        .map_err(ReadImageError::ImageError)
65        .and_then(image_to_tensor)
66}
67
68/// Errors returned when writing a tensor to an image.
69#[derive(Debug)]
70pub enum WriteImageError {
71    /// The number of channels in the image tensor is unsupported.
72    UnsupportedChannelCount,
73    /// The image could not be written.
74    ImageError(image::ImageError),
75}
76
77impl std::fmt::Display for WriteImageError {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        match self {
80            Self::ImageError(e) => write!(f, "failed to write image: {}", e),
81            Self::UnsupportedChannelCount => write!(f, "image has unsupported number of channels"),
82        }
83    }
84}
85
86impl Error for WriteImageError {}
87
88/// Convert a CHW tensor to an image and write it to a PNG file.
89pub fn write_image(path: &str, img: NdTensorView<f32, 3>) -> Result<(), WriteImageError> {
90    let [channels, height, width] = img.shape();
91    let color_type = match channels {
92        1 => image::ColorType::L8,
93        3 => image::ColorType::Rgb8,
94        4 => image::ColorType::Rgba8,
95        _ => return Err(WriteImageError::UnsupportedChannelCount),
96    };
97
98    let hwc_img = img
99        .permuted([1, 2, 0]) // CHW => HWC
100        .map(|x| (x.clamp(0., 1.) * 255.0) as u8);
101
102    image::save_buffer(
103        path,
104        hwc_img.data().unwrap(),
105        width as u32,
106        height as u32,
107        color_type,
108    )
109    .map_err(WriteImageError::ImageError)?;
110
111    Ok(())
112}