1use std::error::Error;
9use std::path::Path;
10
11use rten_tensor::errors::FromDataError;
12use rten_tensor::prelude::*;
13use rten_tensor::{NdTensor, NdTensorView};
14
15#[derive(Debug)]
17pub enum ReadImageError {
18 ImageError(image::ImageError),
20 ConvertError(FromDataError),
22}
23
24impl std::fmt::Display for ReadImageError {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 match self {
27 ReadImageError::ImageError(e) => write!(f, "failed to read image: {}", e),
28 ReadImageError::ConvertError(e) => write!(f, "failed to create tensor: {}", e),
29 }
30 }
31}
32
33impl Error for ReadImageError {}
34
35pub fn image_to_tensor(image: image::DynamicImage) -> Result<NdTensor<f32, 3>, ReadImageError> {
38 let image = image.into_rgb8();
39 let (width, height) = image.dimensions();
40 let layout = image.sample_layout();
41
42 let chw_tensor = NdTensorView::from_data_with_strides(
43 [height as usize, width as usize, 3],
44 image.as_raw().as_slice(),
45 [
46 layout.height_stride,
47 layout.width_stride,
48 layout.channel_stride,
49 ],
50 )
51 .map_err(ReadImageError::ConvertError)?
52 .permuted([2, 0, 1]) .map(|x| *x as f32 / 255.); Ok(chw_tensor)
56}
57
58pub fn read_image<P: AsRef<Path>>(path: P) -> Result<NdTensor<f32, 3>, ReadImageError> {
63 image::open(path)
64 .map_err(ReadImageError::ImageError)
65 .and_then(image_to_tensor)
66}
67
68#[derive(Debug)]
70pub enum WriteImageError {
71 UnsupportedChannelCount,
73 ImageError(image::ImageError),
75}
76
77impl std::fmt::Display for WriteImageError {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 match self {
80 Self::ImageError(e) => write!(f, "failed to write image: {}", e),
81 Self::UnsupportedChannelCount => write!(f, "image has unsupported number of channels"),
82 }
83 }
84}
85
86impl Error for WriteImageError {}
87
88pub fn write_image(path: &str, img: NdTensorView<f32, 3>) -> Result<(), WriteImageError> {
90 let [channels, height, width] = img.shape();
91 let color_type = match channels {
92 1 => image::ColorType::L8,
93 3 => image::ColorType::Rgb8,
94 4 => image::ColorType::Rgba8,
95 _ => return Err(WriteImageError::UnsupportedChannelCount),
96 };
97
98 let hwc_img = img
99 .permuted([1, 2, 0]) .map(|x| (x.clamp(0., 1.) * 255.0) as u8);
101
102 image::save_buffer(
103 path,
104 hwc_img.data().unwrap(),
105 width as u32,
106 height as u32,
107 color_type,
108 )
109 .map_err(WriteImageError::ImageError)?;
110
111 Ok(())
112}