1use rten_tensor::prelude::*;
19use rten_tensor::{NdTensor, NdTensorView};
20use std::fmt;
21
22pub const BLACK_VALUE: f32 = -0.5;
24
25#[derive(Clone, Debug, PartialEq, Eq)]
27pub enum ImageSourceError {
28 UnsupportedChannelCount,
29 InvalidDataLength,
30}
31
32impl fmt::Display for ImageSourceError {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 Self::UnsupportedChannelCount => f.write_str("channel count is not 1, 3 or 4"),
36 Self::InvalidDataLength => {
37 f.write_str("data length is not a multiple of width * height")
38 }
39 }
40 }
41}
42
43impl std::error::Error for ImageSourceError {}
44
45#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum DimOrder {
48 Hwc,
49 Chw,
50}
51
52enum ImagePixels<'a> {
53 #[allow(dead_code)]
54 Floats(NdTensorView<'a, f32, 3>),
55 Bytes(NdTensorView<'a, u8, 3>),
56 FloatsOwned(NdTensor<f32, 3>),
57}
58
59pub struct ImageSource<'a> {
61 data: ImagePixels<'a>,
62 order: DimOrder,
63}
64
65impl<'a> ImageSource<'a> {
66 pub fn from_bytes(bytes: &'a [u8], dimensions: (u32, u32)) -> Result<Self, ImageSourceError> {
68 let (width, height) = dimensions;
69 let channel_len = (width as usize).saturating_mul(height as usize);
70 if channel_len == 0 {
71 return Err(ImageSourceError::UnsupportedChannelCount);
72 }
73 if !bytes.len().is_multiple_of(channel_len) {
74 return Err(ImageSourceError::InvalidDataLength);
75 }
76 let chans = bytes.len() / channel_len;
77 if !matches!(chans, 1 | 3 | 4) {
78 return Err(ImageSourceError::UnsupportedChannelCount);
79 }
80 let view = NdTensorView::from_data([height as usize, width as usize, chans], bytes);
81 Ok(Self {
82 data: ImagePixels::Bytes(view),
83 order: DimOrder::Hwc,
84 })
85 }
86
87 pub fn from_tensor(
89 tensor: NdTensorView<'_, f32, 3>,
90 order: DimOrder,
91 ) -> Result<ImageSource<'static>, ImageSourceError> {
92 let chans = match order {
93 DimOrder::Hwc => tensor.size(2),
94 DimOrder::Chw => tensor.size(0),
95 };
96 if chans == 0 || !matches!(chans, 1 | 3 | 4) {
97 return Err(ImageSourceError::UnsupportedChannelCount);
98 }
99 let owned = NdTensor::from_data(tensor.shape(), tensor.to_vec());
100 Ok(ImageSource {
101 data: ImagePixels::FloatsOwned(owned),
102 order,
103 })
104 }
105}
106
107pub fn prepare_image(img: ImageSource<'_>) -> NdTensor<f32, 3> {
109 match (&img.data, img.order) {
110 (ImagePixels::Floats(f), DimOrder::Hwc) => prepare_floats::<true>(f.view()),
111 (ImagePixels::Floats(f), DimOrder::Chw) => prepare_floats::<false>(f.view()),
112 (ImagePixels::FloatsOwned(f), DimOrder::Hwc) => prepare_floats::<true>(f.view()),
113 (ImagePixels::FloatsOwned(f), DimOrder::Chw) => prepare_floats::<false>(f.view()),
114 (ImagePixels::Bytes(b), DimOrder::Hwc) => prepare_bytes::<true>(b.view()),
115 (ImagePixels::Bytes(b), DimOrder::Chw) => prepare_bytes::<false>(b.view()),
116 }
117}
118
119fn prepare_floats<const CHANS_LAST: bool>(floats: NdTensorView<'_, f32, 3>) -> NdTensor<f32, 3> {
120 const ITU: [f32; 3] = [0.299, 0.587, 0.114];
121 let n = if CHANS_LAST {
122 floats.shape()[2]
123 } else {
124 floats.shape()[0]
125 };
126 match n {
127 1 => convert_pixels::<f32, 1, 1, CHANS_LAST>(floats, [1.]),
128 3 => convert_pixels::<f32, 3, 3, CHANS_LAST>(floats, ITU),
129 4 => convert_pixels::<f32, 4, 3, CHANS_LAST>(floats, ITU),
130 _ => panic!("expected greyscale, RGB or RGBA input image"),
131 }
132}
133
134fn prepare_bytes<const CHANS_LAST: bool>(bytes: NdTensorView<'_, u8, 3>) -> NdTensor<f32, 3> {
135 const ITU: [f32; 3] = [0.299, 0.587, 0.114];
136 let weights = ITU.map(|w| w / 255.0);
137 let n = if CHANS_LAST {
138 bytes.shape()[2]
139 } else {
140 bytes.shape()[0]
141 };
142 match n {
143 1 => convert_pixels::<u8, 1, 1, CHANS_LAST>(bytes, [1. / 255.0]),
144 3 => convert_pixels::<u8, 3, 3, CHANS_LAST>(bytes, weights),
145 4 => convert_pixels::<u8, 4, 3, CHANS_LAST>(bytes, weights),
146 _ => panic!("expected greyscale, RGB or RGBA input image"),
147 }
148}
149
150fn convert_pixels<
151 T: Copy + Into<f32>,
152 const PIXEL_STRIDE: usize,
153 const CHANS: usize,
154 const CHANS_LAST: bool,
155>(
156 src: NdTensorView<'_, T, 3>,
157 chan_weights: [f32; CHANS],
158) -> NdTensor<f32, 3> {
159 let [height, width, chans] = if CHANS_LAST {
160 src.shape()
161 } else {
162 let [c, h, w] = src.shape();
163 [h, w, c]
164 };
165 assert_eq!(chans, PIXEL_STRIDE);
166
167 let mut out_pixels = Vec::with_capacity(height * width);
168 if CHANS_LAST {
169 let src = src.to_contiguous();
170 let mut iter = src.data().chunks_exact(PIXEL_STRIDE);
171 debug_assert!(iter.remainder().is_empty());
172 for in_pixel in iter.by_ref() {
173 let mut pixel = BLACK_VALUE;
174 for (c, &w) in chan_weights.iter().enumerate() {
175 pixel += in_pixel[c].into() * w;
176 }
177 out_pixels.push(pixel);
178 }
179 } else {
180 for y in 0..height {
181 out_pixels.extend((0..width).map(|x| {
182 let mut pixel = BLACK_VALUE;
183 for (c, &w) in chan_weights.iter().enumerate() {
184 pixel += src[[c, y, x]].into() * w;
185 }
186 pixel
187 }));
188 }
189 }
190 NdTensor::from_data([1, height, width], out_pixels)
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn preprocess_greyscale_bytes() {
199 let out = prepare_image(ImageSource::from_bytes(&[0, 128, 255, 64], (2, 2)).unwrap());
200 assert_eq!(out.shape(), [1, 2, 2]);
201 assert!((out[[0, 0, 0]] - BLACK_VALUE).abs() < 1e-5);
202 assert!((out[[0, 0, 1]] - (BLACK_VALUE + 128.0 / 255.0)).abs() < 1e-5);
203 }
204}