image_ndarray/
prelude.rs

1//! Implementations for ndarray casting and conversions for the ImageBuffer
2
3use crate::error::{Error, Result};
4use image::{ImageBuffer, Pixel};
5use ndarray::{Array3, ArrayView3, ArrayViewMut, ArrayViewMut3};
6
7/// Conversion methods for working with ndarrays.
8///
9/// All methods work without copying any data.
10pub trait ImageArray<P: image::Pixel, ImageContainer> {
11    /// Cast the ImageBuffer as an ArrayView3.
12    ///
13    /// * `Y` index is the row
14    /// * `X` index is the columns
15    /// * `Z` index is the channel
16    ///
17    /// So when referencing:
18    /// `array[[y, x, z]]`
19    ///
20    /// This does not copy the data, as it is a reference to the actual data in the buffer.
21    fn as_ndarray<'a>(&'a self) -> ArrayView3<'a, ImageContainer>;
22
23    /// Cast the ImageBuffer as an ArrayViewMut3.
24    ///
25    /// * `Y` index is the row
26    /// * `X` index is the columns
27    /// * `Z` index is the channel
28    ///
29    /// So when referencing:
30    /// `array[[y, x, z]]`
31    ///
32    /// This does not copy the data, as it is a reference to the actual data in the buffer.
33    fn as_ndarray_mut<'a>(&'a mut self) -> ArrayViewMut3<'a, ImageContainer>;
34
35    /// Interpret the ImageBuffer as an Array3.
36    ///
37    /// * `Y` index is the row
38    /// * `X` index is the columns
39    /// * `Z` index is the channel
40    ///
41    /// So when referencing:
42    /// `array[[y, x, z]]`
43    ///
44    /// This does not copy the data, but it does consume the buffer.
45    fn to_ndarray(self) -> Array3<ImageContainer>;
46
47    /// Convert the provided array into the ImageBuffer
48    ///
49    /// * `Y` index is the row
50    /// * `X` index is the columns
51    /// * `Z` index is the channel
52    ///
53    /// So when referencing:
54    /// `array[[y, x, z]]`
55    ///
56    /// This does not copy the data, but it does consume the buffer.
57    fn from_ndarray(array: Array3<ImageContainer>) -> Result<ImageBuffer<P, Vec<ImageContainer>>>;
58}
59
60impl<P, C> ImageArray<P, C> for ImageBuffer<P, Vec<C>>
61where
62    P: Pixel<Subpixel = C>,
63    C: Clone + Copy,
64{
65    fn as_ndarray<'a>(&'a self) -> ArrayView3<'a, C> {
66        let (width, height) = self.dimensions();
67        unsafe {
68            ArrayView3::from_shape_ptr(
69                (height as usize, width as usize, P::CHANNEL_COUNT as usize),
70                self.as_raw().as_ptr(),
71            )
72        }
73    }
74
75    fn to_ndarray(self) -> Array3<C> {
76        let (width, height) = self.dimensions();
77        unsafe {
78            Array3::from_shape_vec_unchecked(
79                (height as usize, width as usize, P::CHANNEL_COUNT as usize),
80                self.into_raw(),
81            )
82        }
83    }
84
85    fn from_ndarray(mut array: Array3<C>) -> Result<ImageBuffer<P, Vec<C>>> {
86        let (height, width, channels) = array.dim();
87
88        if channels != P::CHANNEL_COUNT.into() {
89            return Err(Error::ChannelMismatch);
90        }
91
92        let data = array.as_mut_ptr();
93
94        std::mem::forget(array);
95        let size = height * width * channels;
96
97        let vec_data = unsafe { Vec::from_raw_parts(data, size, size) };
98        Self::from_raw(width as u32, height as u32, vec_data).ok_or(Error::ImageConstructFailed)
99    }
100
101    fn as_ndarray_mut<'a>(&'a mut self) -> ArrayViewMut3<'a, C> {
102        let (width, height) = self.dimensions();
103
104        unsafe {
105            ArrayViewMut::from_shape_ptr(
106                (height as usize, width as usize, P::CHANNEL_COUNT as usize),
107                self.as_mut_ptr(),
108            )
109        }
110    }
111}
112
113#[cfg(test)]
114mod test {
115    use super::*;
116    use image::{Rgb32FImage, Rgba32FImage};
117
118    #[test]
119    fn test_as_ndarray() {
120        let (width, height, channels) = (256, 128, 4);
121        let data = create_test_data(width, height, channels);
122        let test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
123
124        let array = test_image.as_ndarray();
125
126        for ((y, x, channel), value) in array.indexed_iter() {
127            assert_eq!(test_image.get_pixel(x as u32, y as u32)[channel], *value);
128        }
129    }
130
131    #[test]
132    fn test_as_ndarray_mut() {
133        let (width, height, channels) = (256, 128, 4);
134        let data = create_test_data(width, height, channels);
135        let mut test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
136        let compare = test_image.clone();
137
138        let mut array = test_image.as_ndarray_mut();
139        array += 1.0;
140
141        for (x, y, pixel) in test_image.enumerate_pixels() {
142            let compare_pixel = compare.get_pixel(x, y);
143            for (channel, value) in pixel.channels().iter().enumerate() {
144                assert_eq!(*value, compare_pixel[channel] + 1.0);
145            }
146        }
147    }
148
149    #[test]
150    fn test_to_ndarray() {
151        let (width, height, channels) = (256, 128, 4);
152        let data = create_test_data(width, height, channels);
153        let test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
154
155        let mut array = test_image.clone().to_ndarray();
156
157        array += 1.0;
158        for ((y, x, channel), value) in array.indexed_iter() {
159            assert_eq!(
160                test_image.get_pixel(x as u32, y as u32)[channel] + 1.0,
161                *value
162            );
163        }
164    }
165
166    #[test]
167    fn test_from_ndarray() {
168        let (width, height, channels) = (256, 128, 4);
169        let data = create_test_data(width, height, channels);
170        let test_image = Array3::from_shape_vec((height, width, channels), data).unwrap();
171        let compare_data = test_image.clone();
172
173        let result = Rgba32FImage::from_ndarray(test_image).unwrap();
174
175        for (x, y, pixel) in result.enumerate_pixels() {
176            for (channel, value) in pixel.channels().iter().enumerate() {
177                assert_eq!(*value, compare_data[[y as usize, x as usize, channel]]);
178            }
179        }
180    }
181
182    fn create_test_data(width: usize, height: usize, channels: usize) -> Vec<f32> {
183        let total_elements = width * height * channels;
184        (0..total_elements).map(|x| (x + 1) as f32).collect()
185    }
186
187    #[test]
188    fn test_from_ndarray_with_invalid_channels() {
189        let channels = 4;
190        let (width, height) = (256.0, 128.0);
191        let total_elements = (width * height * 4.0) as usize;
192        let data: Vec<f32> = (0..total_elements).map(|x| (x + 1) as f32).collect();
193        let test_image =
194            Array3::from_shape_vec((height as usize, width as usize, channels), data).unwrap();
195
196        let result = Rgb32FImage::from_ndarray(test_image).err().unwrap();
197
198        assert_eq!(result, Error::ChannelMismatch);
199    }
200}