image_ndarray/
prelude.rs

1//! Implementations for ndarray casting and conversions for the ImageBuffer
2
3use image::{ImageBuffer, Pixel};
4use ndarray::{Array3, ArrayView3, ArrayViewMut, ArrayViewMut3};
5use crate::error::{Error, Result};
6
7
8/// Conversion methods for working with ndarrays.
9/// 
10/// All methods work without copying any data.
11pub trait ImageArray<P: image::Pixel, ImageContainer> {
12    /// Cast the ImageBuffer as an ArrayView3.
13    /// 
14    /// * `Y` index is the row
15    /// * `X` index is the columns
16    /// * `Z` index is the channel 
17    /// 
18    /// So when referencing:
19    /// `array[[y, x, z]]`
20    /// 
21    /// This does not copy the data, as it is a reference to the actual data in the buffer.
22    fn as_ndarray<'a>(&'a self) -> ArrayView3<'a, ImageContainer>;
23
24
25    /// Cast the ImageBuffer as an ArrayViewMut3.
26    /// 
27    /// * `Y` index is the row
28    /// * `X` index is the columns
29    /// * `Z` index is the channel 
30    /// 
31    /// So when referencing:
32    /// `array[[y, x, z]]`
33    /// 
34    /// This does not copy the data, as it is a reference to the actual data in the buffer.
35    fn as_mut_ndarray<'a>(&'a mut self) -> ArrayViewMut3<'a, ImageContainer>;
36
37    /// Interpret the ImageBuffer as an Array3.
38    /// 
39    /// * `Y` index is the row
40    /// * `X` index is the columns
41    /// * `Z` index is the channel 
42    /// 
43    /// So when referencing:
44    /// `array[[y, x, z]]`
45    /// 
46    /// This does not copy the data, but it does consume the buffer.
47    fn to_ndarray(self) -> Array3<ImageContainer>;
48
49
50    /// Convert the provided array into the ImageBuffer
51    /// 
52    /// * `Y` index is the row
53    /// * `X` index is the columns
54    /// * `Z` index is the channel 
55    /// 
56    /// So when referencing:
57    /// `array[[y, x, z]]`
58    /// 
59    /// This does not copy the data, but it does consume the buffer.
60    fn from_ndarray(array: Array3<ImageContainer>) -> Result<ImageBuffer<P, Vec<ImageContainer>>>;
61
62
63
64
65
66}
67
68impl<P, C> ImageArray<P, C> for ImageBuffer<P, Vec<C>>
69where
70    P: Pixel<Subpixel = C>,
71    C: Clone + Copy,
72{
73    fn as_ndarray<'a>(&'a self) -> ArrayView3<'a, C> {
74        let (width, height) = self.dimensions();
75        unsafe {
76            ArrayView3::from_shape_ptr(
77                (height as usize, width as usize, P::CHANNEL_COUNT as usize),
78                self.as_raw().as_ptr(),
79            )
80        }
81    }
82
83    fn to_ndarray(self) -> Array3<C>{
84        let (width, height) = self.dimensions();
85        unsafe {
86            Array3::from_shape_vec_unchecked(
87                (height as usize, width as usize, P::CHANNEL_COUNT as usize),
88                self.into_raw(),
89            )
90        }
91    }
92
93    fn from_ndarray(mut array: Array3<C>) -> Result<ImageBuffer<P, Vec<C>>> {
94        let (height, width, channels) = array.dim();
95
96        if channels != P::CHANNEL_COUNT.into() {
97            return Err(Error::ChannelMismatch);
98        }
99
100        let data = array.as_mut_ptr();
101         
102        std::mem::forget(array);
103        let size = height * width * channels;
104
105        let vec_data = unsafe {
106             Vec::from_raw_parts(data, size, size)
107        };
108        Self::from_raw(width as u32, height as u32, vec_data).ok_or(Error::ImageConstructFailed)
109    }
110
111    fn as_mut_ndarray<'a>(&'a mut self) -> ArrayViewMut3<'a, C> {
112        let (width, height) = self.dimensions();
113        
114        unsafe {
115            ArrayViewMut::from_shape_ptr(
116                (height as usize, width as usize, P::CHANNEL_COUNT as usize),
117                self.as_mut_ptr(),
118            )
119        }
120    }
121}
122
123
124#[cfg(test)]
125mod test{
126    use image::{Rgb32FImage, Rgba32FImage};
127    use super::*;
128
129    #[test]
130    fn test_as_ndarray() {
131        let (width, height, channels) = (256, 128, 4);
132        let data = create_test_data(width, height, channels);
133        let test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
134
135        let array  = test_image.as_ndarray();
136
137        for ((y, x, channel),  value) in array.indexed_iter() {
138            assert_eq!(test_image.get_pixel(x as u32, y as u32)[channel], *value);
139        }
140    }
141
142
143    #[test]
144    fn test_as_mut_ndarray() {
145        let (width, height, channels) = (256, 128, 4);
146        let data = create_test_data(width, height, channels);
147        let mut test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
148        let compare = test_image.clone();
149
150        
151        let mut array  = test_image.as_mut_ndarray();
152        array += 1.0;
153        
154
155        for (x, y, pixel) in test_image.enumerate_pixels() {
156            let compare_pixel = compare.get_pixel(x, y);
157            for (channel, value) in pixel.channels().iter().enumerate() {
158                assert_eq!(*value, compare_pixel[channel] + 1.0);
159            }
160        }
161    }
162
163
164
165    #[test]
166    fn test_to_ndarray() {
167        let (width, height, channels) = (256, 128, 4);
168        let data = create_test_data(width, height, channels);
169        let test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
170        
171        let mut array  = test_image.clone().to_ndarray();
172
173        array += 1.0;
174        for ((y, x, channel),  value) in array.indexed_iter() {
175            assert_eq!(test_image.get_pixel(x as u32, y as u32)[channel] + 1.0, *value);
176        }
177    }
178
179
180    #[test]
181    fn test_from_ndarray() {
182        let (width, height, channels) = (256, 128, 4);
183        let data = create_test_data(width, height, channels);
184        let test_image = Array3::from_shape_vec((height, width, channels), data).unwrap();
185        let compare_data = test_image.clone();
186
187        let result = Rgba32FImage::from_ndarray(test_image).unwrap();
188
189        for (x, y, pixel) in result.enumerate_pixels() {
190            for (channel, value) in pixel.channels().iter().enumerate(){
191                assert_eq!(*value, compare_data[[y as usize, x as usize, channel]]);
192            }
193        }
194    }
195
196    fn create_test_data(width: usize, height: usize, channels: usize) -> Vec<f32> {
197        let total_elements = width * height * channels;
198        (0..total_elements).map(|x| (x + 1) as f32).collect()
199    }
200    
201    #[test]
202    fn test_from_ndarray_with_invalid_channels() {
203        let channels = 4;
204        let (width, height) = (256.0, 128.0);
205        let total_elements = (width * height * 4.0) as usize;
206        let data: Vec<f32> = (0..total_elements).map(|x| (x + 1) as f32).collect();
207        let test_image = Array3::from_shape_vec((height as usize, width as usize, channels), data).unwrap();
208
209        let result = Rgb32FImage::from_ndarray(test_image).err().unwrap();
210        
211        assert_eq!(result, Error::ChannelMismatch);
212    }
213}