1use std::fmt::Debug;
2
3use rten_tensor::prelude::*;
4use rten_tensor::{NdTensor, NdTensorView};
5use thiserror::Error;
6
7pub enum ImagePixels<'a> {
10 Floats(NdTensorView<'a, f32, 3>),
12 Bytes(NdTensorView<'a, u8, 3>),
14}
15
16impl<'a> From<NdTensorView<'a, f32, 3>> for ImagePixels<'a> {
17 fn from(value: NdTensorView<'a, f32, 3>) -> Self {
18 ImagePixels::Floats(value)
19 }
20}
21
22impl<'a> From<NdTensorView<'a, u8, 3>> for ImagePixels<'a> {
23 fn from(value: NdTensorView<'a, u8, 3>) -> Self {
24 ImagePixels::Bytes(value)
25 }
26}
27
28impl ImagePixels<'_> {
29 fn shape(&self) -> [usize; 3] {
30 match self {
31 ImagePixels::Floats(f) => f.shape(),
32 ImagePixels::Bytes(b) => b.shape(),
33 }
34 }
35}
36
37#[derive(Error, Clone, Debug, PartialEq)]
39pub enum ImageSourceError {
40 #[error("channel count is not 1, 3 or 4")]
42 UnsupportedChannelCount,
43 #[error("data length is not a multiple of `width * height`")]
45 InvalidDataLength,
46}
47
48#[derive(Copy, Clone, Debug, PartialEq)]
50pub enum DimOrder {
51 Hwc,
54 Chw,
57}
58
59pub struct ImageSource<'a> {
62 data: ImagePixels<'a>,
63 order: DimOrder,
64}
65
66impl<'a> ImageSource<'a> {
67 pub fn from_bytes(
82 bytes: &'a [u8],
83 dimensions: (u32, u32),
84 ) -> Result<ImageSource<'a>, ImageSourceError> {
85 let (width, height) = dimensions;
86 let channel_len = (width * height) as usize;
87
88 if channel_len == 0 {
89 return Err(ImageSourceError::UnsupportedChannelCount);
90 }
91
92 if !bytes.len().is_multiple_of(channel_len) {
93 return Err(ImageSourceError::InvalidDataLength);
94 }
95
96 let channels = bytes.len() / channel_len;
97 Self::from_tensor(
98 NdTensorView::from_data([height as usize, width as usize, channels], bytes),
99 DimOrder::Hwc,
100 )
101 }
102
103 pub fn from_tensor<T>(
106 data: NdTensorView<'a, T, 3>,
107 order: DimOrder,
108 ) -> Result<ImageSource<'a>, ImageSourceError>
109 where
110 NdTensorView<'a, T, 3>: Into<ImagePixels<'a>>,
111 {
112 let channels = match order {
113 DimOrder::Hwc => data.size(2),
114 DimOrder::Chw => data.size(0),
115 };
116 match channels {
117 1 | 3 | 4 => Ok(ImageSource {
118 data: data.into(),
119 order,
120 }),
121 _ => Err(ImageSourceError::UnsupportedChannelCount),
122 }
123 }
124}
125
126pub const BLACK_VALUE: f32 = -0.5;
129
130enum Channels {
132 Grey,
133 Rgb,
134 Rgba,
135}
136
137pub fn prepare_image(img: ImageSource) -> NdTensor<f32, 3> {
150 match img.order {
151 DimOrder::Hwc => prepare_image_impl::<true>(img.data),
152 DimOrder::Chw => prepare_image_impl::<false>(img.data),
153 }
154}
155
156fn prepare_image_impl<const CHANS_LAST: bool>(pixels: ImagePixels) -> NdTensor<f32, 3> {
157 let n_chans = if CHANS_LAST {
158 pixels.shape()[2]
159 } else {
160 pixels.shape()[0]
161 };
162 let src_chans = match n_chans {
163 1 => Channels::Grey,
164 3 => Channels::Rgb,
165 4 => Channels::Rgba,
166 _ => panic!("expected greyscale, RGB or RGBA input image"),
167 };
168
169 const ITU_WEIGHTS: [f32; 3] = [0.299, 0.587, 0.114];
172
173 match pixels {
174 ImagePixels::Floats(floats) => match src_chans {
175 Channels::Grey => convert_pixels::<_, 1, _, CHANS_LAST>(floats.view(), [1.]),
176 Channels::Rgb => convert_pixels::<_, 3, _, CHANS_LAST>(floats.view(), ITU_WEIGHTS),
177 Channels::Rgba => convert_pixels::<_, 4, _, CHANS_LAST>(floats.view(), ITU_WEIGHTS),
178 },
179 ImagePixels::Bytes(bytes) => {
180 let weights = ITU_WEIGHTS.map(|w| w / 255.0);
183 match src_chans {
184 Channels::Grey => convert_pixels::<_, 1, _, CHANS_LAST>(bytes.view(), [1. / 255.]),
185 Channels::Rgb => convert_pixels::<_, 3, _, CHANS_LAST>(bytes.view(), weights),
186 Channels::Rgba => convert_pixels::<_, 4, _, CHANS_LAST>(bytes.view(), weights),
187 }
188 }
189 }
190}
191
192fn convert_pixels<
202 T: AsF32,
203 const PIXEL_STRIDE: usize,
204 const CHANS: usize,
205 const CHANS_LAST: bool,
206>(
207 src: NdTensorView<T, 3>,
208 chan_weights: [f32; CHANS],
209) -> NdTensor<f32, 3> {
210 let [height, width, chans] = if CHANS_LAST {
211 src.shape()
212 } else {
213 let [c, h, w] = src.shape();
214 [h, w, c]
215 };
216 assert_eq!(chans, PIXEL_STRIDE);
217 let mut out_pixels = Vec::with_capacity(height * width);
218
219 if CHANS_LAST {
220 let src = src.to_contiguous();
225 let (src_pixels, remainder) = src.data().as_chunks::<PIXEL_STRIDE>();
226 debug_assert!(remainder.is_empty());
227
228 out_pixels.extend(src_pixels.iter().map(|in_pixel| {
229 let mut pixel = BLACK_VALUE;
230 for c in 0..chan_weights.len() {
231 pixel += in_pixel[c].as_f32() * chan_weights[c]
232 }
233 pixel
234 }));
235 } else {
236 for y in 0..height {
237 out_pixels.extend((0..width).map(|x| {
238 let mut pixel = BLACK_VALUE;
239 for c in 0..chan_weights.len() {
240 pixel += src[[c, y, x]].as_f32() * chan_weights[c]
241 }
242 pixel
243 }));
244 }
245 }
246
247 NdTensor::from_data([1, height, width], out_pixels)
248}
249
250trait AsF32: Copy {
252 fn as_f32(self) -> f32;
253}
254
255impl AsF32 for f32 {
256 fn as_f32(self) -> f32 {
257 self
258 }
259}
260
261impl AsF32 for u8 {
262 fn as_f32(self) -> f32 {
263 self as f32
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use rten_tensor::prelude::*;
270 use rten_tensor::NdTensor;
271
272 use super::{prepare_image, DimOrder, ImageSource, ImageSourceError, BLACK_VALUE};
273
274 #[test]
275 fn test_image_source_from_bytes() {
276 struct Case {
277 len: usize,
278 width: u32,
279 height: u32,
280 error: Option<ImageSourceError>,
281 }
282
283 let cases = [
284 Case {
285 len: 100,
286 width: 10,
287 height: 10,
288 error: None,
289 },
290 Case {
291 len: 50,
292 width: 10,
293 height: 10,
294 error: Some(ImageSourceError::InvalidDataLength),
295 },
296 Case {
297 len: 8 * 8 * 2,
298 width: 8,
299 height: 8,
300 error: Some(ImageSourceError::UnsupportedChannelCount),
301 },
302 Case {
303 len: 0,
304 width: 0,
305 height: 10,
306 error: Some(ImageSourceError::UnsupportedChannelCount),
307 },
308 ];
309
310 for Case {
311 len,
312 width,
313 height,
314 error,
315 } in cases
316 {
317 let data: Vec<u8> = (0u8..len as u8).collect();
318 let source = ImageSource::from_bytes(&data, (width, height));
319 assert_eq!(source.as_ref().err(), error.as_ref());
320 }
321 }
322
323 #[test]
324 fn test_image_source_from_data() {
325 struct Case {
326 shape: [usize; 3],
327 error: Option<ImageSourceError>,
328 order: DimOrder,
329 }
330
331 let cases = [
332 Case {
333 shape: [1, 5, 5],
334 error: None,
335 order: DimOrder::Chw,
336 },
337 Case {
338 shape: [1, 5, 5],
339 error: Some(ImageSourceError::UnsupportedChannelCount),
340 order: DimOrder::Hwc,
341 },
342 Case {
343 shape: [0, 5, 5],
344 error: Some(ImageSourceError::UnsupportedChannelCount),
345 order: DimOrder::Chw,
346 },
347 ];
348
349 for Case {
350 shape,
351 error,
352 order,
353 } in cases
354 {
355 let len: usize = shape.iter().product();
356 let tensor = NdTensor::<u8, 1>::arange(0, len as u8, None).into_shape(shape);
357 let source = ImageSource::from_tensor(tensor.view(), order);
358 assert_eq!(source.as_ref().err(), error.as_ref());
359 }
360 }
361
362 const ITU_WEIGHTS: [f32; 3] = [0.299, 0.587, 0.114];
364
365 fn expected_grey_from_rgb(r: f32, g: f32, b: f32) -> f32 {
367 BLACK_VALUE + r * ITU_WEIGHTS[0] + g * ITU_WEIGHTS[1] + b * ITU_WEIGHTS[2]
368 }
369
370 #[track_caller]
371 fn assert_close(actual: f32, expected: f32) {
372 assert!(
373 (actual - expected).abs() < 1e-5,
374 "expected {expected}, got {actual}"
375 );
376 }
377
378 #[test]
379 fn test_prepare_image_greyscale_u8() {
380 struct Case {
381 shape: [usize; 3],
382 order: DimOrder,
383 }
384
385 let cases = [
386 Case {
387 shape: [2, 2, 1],
388 order: DimOrder::Hwc,
389 },
390 Case {
391 shape: [1, 2, 2],
392 order: DimOrder::Chw,
393 },
394 ];
395
396 for Case { shape, order } in cases {
397 let data: Vec<u8> = vec![0, 128, 255, 64];
398 let tensor = NdTensor::from_data(shape, data);
399 let source = ImageSource::from_tensor(tensor.view(), order).unwrap();
400
401 let result = prepare_image(source);
402
403 assert_eq!(result.shape(), [1, 2, 2]);
404 assert_close(result[[0, 0, 0]], BLACK_VALUE + 0.0);
405 assert_close(result[[0, 0, 1]], BLACK_VALUE + 128.0 / 255.0);
406 assert_close(result[[0, 1, 0]], BLACK_VALUE + 1.0);
407 assert_close(result[[0, 1, 1]], BLACK_VALUE + 64.0 / 255.0);
408 }
409 }
410
411 #[test]
412 fn test_prepare_image_greyscale_f32() {
413 struct Case {
414 shape: [usize; 3],
415 order: DimOrder,
416 }
417
418 let cases = [
419 Case {
420 shape: [2, 2, 1],
421 order: DimOrder::Hwc,
422 },
423 Case {
424 shape: [1, 2, 2],
425 order: DimOrder::Chw,
426 },
427 ];
428
429 for Case { shape, order } in cases {
430 let data: Vec<f32> = vec![0.0, 0.5, 1.0, 0.25];
431 let tensor = NdTensor::from_data(shape, data);
432 let source = ImageSource::from_tensor(tensor.view(), order).unwrap();
433
434 let result = prepare_image(source);
435
436 assert_eq!(result.shape(), [1, 2, 2]);
437 assert_close(result[[0, 0, 0]], BLACK_VALUE + 0.0);
438 assert_close(result[[0, 0, 1]], BLACK_VALUE + 0.5);
439 assert_close(result[[0, 1, 0]], BLACK_VALUE + 1.0);
440 assert_close(result[[0, 1, 1]], BLACK_VALUE + 0.25);
441 }
442 }
443
444 #[test]
445 fn test_prepare_image_rgb_rgba_u8() {
446 struct Case {
447 data: Vec<u8>,
448 shape: [usize; 3],
449 order: DimOrder,
450 rgb: [u8; 3],
451 }
452
453 let cases = [
454 Case {
456 data: vec![100, 150, 200],
457 shape: [1, 1, 3],
458 order: DimOrder::Hwc,
459 rgb: [100, 150, 200],
460 },
461 Case {
463 data: vec![100, 150, 200],
464 shape: [3, 1, 1],
465 order: DimOrder::Chw,
466 rgb: [100, 150, 200],
467 },
468 Case {
470 data: vec![50, 100, 150, 255],
471 shape: [1, 1, 4],
472 order: DimOrder::Hwc,
473 rgb: [50, 100, 150],
474 },
475 Case {
477 data: vec![50, 100, 150, 255],
478 shape: [4, 1, 1],
479 order: DimOrder::Chw,
480 rgb: [50, 100, 150],
481 },
482 ];
483
484 for Case {
485 data,
486 shape,
487 order,
488 rgb: [r, g, b],
489 } in cases
490 {
491 let tensor = NdTensor::from_data(shape, data);
492 let source = ImageSource::from_tensor(tensor.view(), order).unwrap();
493
494 let result = prepare_image(source);
495
496 assert_eq!(result.shape(), [1, 1, 1]);
497 let expected =
498 expected_grey_from_rgb(r as f32 / 255.0, g as f32 / 255.0, b as f32 / 255.0);
499 assert_close(result[[0, 0, 0]], expected);
500 }
501 }
502
503 #[test]
504 fn test_prepare_image_rgb_f32() {
505 struct Case {
506 shape: [usize; 3],
507 order: DimOrder,
508 }
509
510 let cases = [
511 Case {
512 shape: [1, 1, 3],
513 order: DimOrder::Hwc,
514 },
515 Case {
516 shape: [3, 1, 1],
517 order: DimOrder::Chw,
518 },
519 ];
520
521 let (r, g, b) = (0.4, 0.6, 0.8);
522
523 for Case { shape, order } in cases {
524 let data: Vec<f32> = vec![r, g, b];
525 let tensor = NdTensor::from_data(shape, data);
526 let source = ImageSource::from_tensor(tensor.view(), order).unwrap();
527
528 let result = prepare_image(source);
529
530 assert_eq!(result.shape(), [1, 1, 1]);
531 let expected = expected_grey_from_rgb(r, g, b);
532 assert_close(result[[0, 0, 0]], expected);
533 }
534 }
535
536 #[test]
537 fn test_prepare_image_multi_pixel_rgb() {
538 struct Case {
540 data: Vec<u8>,
541 shape: [usize; 3],
542 order: DimOrder,
543 }
544
545 let cases = [
546 Case {
548 #[rustfmt::skip]
549 data: vec![
550 255, 0, 0, 0, 255, 0, 0, 0, 255, 128, 128, 128 ],
555 shape: [2, 2, 3],
556 order: DimOrder::Hwc,
557 },
558 Case {
560 #[rustfmt::skip]
561 data: vec![
562 255, 0,
564 0, 128,
565 0, 255,
567 0, 128,
568 0, 0,
570 255, 128,
571 ],
572 shape: [3, 2, 2],
573 order: DimOrder::Chw,
574 },
575 ];
576
577 let expected_red = expected_grey_from_rgb(1.0, 0.0, 0.0);
578 let expected_green = expected_grey_from_rgb(0.0, 1.0, 0.0);
579 let expected_blue = expected_grey_from_rgb(0.0, 0.0, 1.0);
580 let expected_grey = expected_grey_from_rgb(128.0 / 255.0, 128.0 / 255.0, 128.0 / 255.0);
581
582 for Case { data, shape, order } in cases {
583 let tensor = NdTensor::from_data(shape, data);
584 let source = ImageSource::from_tensor(tensor.view(), order).unwrap();
585
586 let result = prepare_image(source);
587
588 assert_eq!(result.shape(), [1, 2, 2]);
589 assert_close(result[[0, 0, 0]], expected_red);
590 assert_close(result[[0, 0, 1]], expected_green);
591 assert_close(result[[0, 1, 0]], expected_blue);
592 assert_close(result[[0, 1, 1]], expected_grey);
593 }
594 }
595}