1use image::{ImageBuffer, Pixel};
4use ndarray::{Array3, ArrayView3, ArrayViewMut, ArrayViewMut3};
5use crate::error::{Error, Result};
6
7
8pub trait ImageArray<P: image::Pixel, ImageContainer> {
12 fn as_ndarray<'a>(&'a self) -> ArrayView3<'a, ImageContainer>;
23
24
25 fn as_mut_ndarray<'a>(&'a mut self) -> ArrayViewMut3<'a, ImageContainer>;
36
37 fn to_ndarray(self) -> Array3<ImageContainer>;
48
49
50 fn from_ndarray(array: Array3<ImageContainer>) -> Result<ImageBuffer<P, Vec<ImageContainer>>>;
61
62
63
64
65
66}
67
68impl<P, C> ImageArray<P, C> for ImageBuffer<P, Vec<C>>
69where
70 P: Pixel<Subpixel = C>,
71 C: Clone + Copy,
72{
73 fn as_ndarray<'a>(&'a self) -> ArrayView3<'a, C> {
74 let (width, height) = self.dimensions();
75 unsafe {
76 ArrayView3::from_shape_ptr(
77 (height as usize, width as usize, P::CHANNEL_COUNT as usize),
78 self.as_raw().as_ptr(),
79 )
80 }
81 }
82
83 fn to_ndarray(self) -> Array3<C>{
84 let (width, height) = self.dimensions();
85 unsafe {
86 Array3::from_shape_vec_unchecked(
87 (height as usize, width as usize, P::CHANNEL_COUNT as usize),
88 self.into_raw(),
89 )
90 }
91 }
92
93 fn from_ndarray(mut array: Array3<C>) -> Result<ImageBuffer<P, Vec<C>>> {
94 let (height, width, channels) = array.dim();
95
96 if channels != P::CHANNEL_COUNT.into() {
97 return Err(Error::ChannelMismatch);
98 }
99
100 let data = array.as_mut_ptr();
101
102 std::mem::forget(array);
103 let size = height * width * channels;
104
105 let vec_data = unsafe {
106 Vec::from_raw_parts(data, size, size)
107 };
108 Self::from_raw(width as u32, height as u32, vec_data).ok_or(Error::ImageConstructFailed)
109 }
110
111 fn as_mut_ndarray<'a>(&'a mut self) -> ArrayViewMut3<'a, C> {
112 let (width, height) = self.dimensions();
113
114 unsafe {
115 ArrayViewMut::from_shape_ptr(
116 (height as usize, width as usize, P::CHANNEL_COUNT as usize),
117 self.as_mut_ptr(),
118 )
119 }
120 }
121}
122
123
124#[cfg(test)]
125mod test{
126 use image::{Rgb32FImage, Rgba32FImage};
127 use super::*;
128
129 #[test]
130 fn test_as_ndarray() {
131 let (width, height, channels) = (256, 128, 4);
132 let data = create_test_data(width, height, channels);
133 let test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
134
135 let array = test_image.as_ndarray();
136
137 for ((y, x, channel), value) in array.indexed_iter() {
138 assert_eq!(test_image.get_pixel(x as u32, y as u32)[channel], *value);
139 }
140 }
141
142
143 #[test]
144 fn test_as_mut_ndarray() {
145 let (width, height, channels) = (256, 128, 4);
146 let data = create_test_data(width, height, channels);
147 let mut test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
148 let compare = test_image.clone();
149
150
151 let mut array = test_image.as_mut_ndarray();
152 array += 1.0;
153
154
155 for (x, y, pixel) in test_image.enumerate_pixels() {
156 let compare_pixel = compare.get_pixel(x, y);
157 for (channel, value) in pixel.channels().iter().enumerate() {
158 assert_eq!(*value, compare_pixel[channel] + 1.0);
159 }
160 }
161 }
162
163
164
165 #[test]
166 fn test_to_ndarray() {
167 let (width, height, channels) = (256, 128, 4);
168 let data = create_test_data(width, height, channels);
169 let test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
170
171 let mut array = test_image.clone().to_ndarray();
172
173 array += 1.0;
174 for ((y, x, channel), value) in array.indexed_iter() {
175 assert_eq!(test_image.get_pixel(x as u32, y as u32)[channel] + 1.0, *value);
176 }
177 }
178
179
180 #[test]
181 fn test_from_ndarray() {
182 let (width, height, channels) = (256, 128, 4);
183 let data = create_test_data(width, height, channels);
184 let test_image = Array3::from_shape_vec((height, width, channels), data).unwrap();
185 let compare_data = test_image.clone();
186
187 let result = Rgba32FImage::from_ndarray(test_image).unwrap();
188
189 for (x, y, pixel) in result.enumerate_pixels() {
190 for (channel, value) in pixel.channels().iter().enumerate(){
191 assert_eq!(*value, compare_data[[y as usize, x as usize, channel]]);
192 }
193 }
194 }
195
196 fn create_test_data(width: usize, height: usize, channels: usize) -> Vec<f32> {
197 let total_elements = width * height * channels;
198 (0..total_elements).map(|x| (x + 1) as f32).collect()
199 }
200
201 #[test]
202 fn test_from_ndarray_with_invalid_channels() {
203 let channels = 4;
204 let (width, height) = (256.0, 128.0);
205 let total_elements = (width * height * 4.0) as usize;
206 let data: Vec<f32> = (0..total_elements).map(|x| (x + 1) as f32).collect();
207 let test_image = Array3::from_shape_vec((height as usize, width as usize, channels), data).unwrap();
208
209 let result = Rgb32FImage::from_ndarray(test_image).err().unwrap();
210
211 assert_eq!(result, Error::ChannelMismatch);
212 }
213}