1use crate::error::{Error, Result};
4use image::{ImageBuffer, Pixel};
5use ndarray::{Array3, ArrayView3, ArrayViewMut, ArrayViewMut3};
6
7pub trait ImageArray<P: image::Pixel, ImageContainer> {
11 fn as_ndarray<'a>(&'a self) -> ArrayView3<'a, ImageContainer>;
22
23 fn as_ndarray_mut<'a>(&'a mut self) -> ArrayViewMut3<'a, ImageContainer>;
34
35 fn to_ndarray(self) -> Array3<ImageContainer>;
46
47 fn from_ndarray(array: Array3<ImageContainer>) -> Result<ImageBuffer<P, Vec<ImageContainer>>>;
58}
59
60impl<P, C> ImageArray<P, C> for ImageBuffer<P, Vec<C>>
61where
62 P: Pixel<Subpixel = C>,
63 C: Clone + Copy,
64{
65 fn as_ndarray<'a>(&'a self) -> ArrayView3<'a, C> {
66 let (width, height) = self.dimensions();
67 unsafe {
68 ArrayView3::from_shape_ptr(
69 (height as usize, width as usize, P::CHANNEL_COUNT as usize),
70 self.as_raw().as_ptr(),
71 )
72 }
73 }
74
75 fn to_ndarray(self) -> Array3<C> {
76 let (width, height) = self.dimensions();
77 unsafe {
78 Array3::from_shape_vec_unchecked(
79 (height as usize, width as usize, P::CHANNEL_COUNT as usize),
80 self.into_raw(),
81 )
82 }
83 }
84
85 fn from_ndarray(mut array: Array3<C>) -> Result<ImageBuffer<P, Vec<C>>> {
86 let (height, width, channels) = array.dim();
87
88 if channels != P::CHANNEL_COUNT.into() {
89 return Err(Error::ChannelMismatch);
90 }
91
92 let data = array.as_mut_ptr();
93
94 std::mem::forget(array);
95 let size = height * width * channels;
96
97 let vec_data = unsafe { Vec::from_raw_parts(data, size, size) };
98 Self::from_raw(width as u32, height as u32, vec_data).ok_or(Error::ImageConstructFailed)
99 }
100
101 fn as_ndarray_mut<'a>(&'a mut self) -> ArrayViewMut3<'a, C> {
102 let (width, height) = self.dimensions();
103
104 unsafe {
105 ArrayViewMut::from_shape_ptr(
106 (height as usize, width as usize, P::CHANNEL_COUNT as usize),
107 self.as_mut_ptr(),
108 )
109 }
110 }
111}
112
113#[cfg(test)]
114mod test {
115 use super::*;
116 use image::{Rgb32FImage, Rgba32FImage};
117
118 #[test]
119 fn test_as_ndarray() {
120 let (width, height, channels) = (256, 128, 4);
121 let data = create_test_data(width, height, channels);
122 let test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
123
124 let array = test_image.as_ndarray();
125
126 for ((y, x, channel), value) in array.indexed_iter() {
127 assert_eq!(test_image.get_pixel(x as u32, y as u32)[channel], *value);
128 }
129 }
130
131 #[test]
132 fn test_as_ndarray_mut() {
133 let (width, height, channels) = (256, 128, 4);
134 let data = create_test_data(width, height, channels);
135 let mut test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
136 let compare = test_image.clone();
137
138 let mut array = test_image.as_ndarray_mut();
139 array += 1.0;
140
141 for (x, y, pixel) in test_image.enumerate_pixels() {
142 let compare_pixel = compare.get_pixel(x, y);
143 for (channel, value) in pixel.channels().iter().enumerate() {
144 assert_eq!(*value, compare_pixel[channel] + 1.0);
145 }
146 }
147 }
148
149 #[test]
150 fn test_to_ndarray() {
151 let (width, height, channels) = (256, 128, 4);
152 let data = create_test_data(width, height, channels);
153 let test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
154
155 let mut array = test_image.clone().to_ndarray();
156
157 array += 1.0;
158 for ((y, x, channel), value) in array.indexed_iter() {
159 assert_eq!(
160 test_image.get_pixel(x as u32, y as u32)[channel] + 1.0,
161 *value
162 );
163 }
164 }
165
166 #[test]
167 fn test_from_ndarray() {
168 let (width, height, channels) = (256, 128, 4);
169 let data = create_test_data(width, height, channels);
170 let test_image = Array3::from_shape_vec((height, width, channels), data).unwrap();
171 let compare_data = test_image.clone();
172
173 let result = Rgba32FImage::from_ndarray(test_image).unwrap();
174
175 for (x, y, pixel) in result.enumerate_pixels() {
176 for (channel, value) in pixel.channels().iter().enumerate() {
177 assert_eq!(*value, compare_data[[y as usize, x as usize, channel]]);
178 }
179 }
180 }
181
182 fn create_test_data(width: usize, height: usize, channels: usize) -> Vec<f32> {
183 let total_elements = width * height * channels;
184 (0..total_elements).map(|x| (x + 1) as f32).collect()
185 }
186
187 #[test]
188 fn test_from_ndarray_with_invalid_channels() {
189 let channels = 4;
190 let (width, height) = (256.0, 128.0);
191 let total_elements = (width * height * 4.0) as usize;
192 let data: Vec<f32> = (0..total_elements).map(|x| (x + 1) as f32).collect();
193 let test_image =
194 Array3::from_shape_vec((height as usize, width as usize, channels), data).unwrap();
195
196 let result = Rgb32FImage::from_ndarray(test_image).err().unwrap();
197
198 assert_eq!(result, Error::ChannelMismatch);
199 }
200}