1use num_traits::{AsPrimitive, ToPrimitive};
4
5#[cfg(feature = "image")]
6use crate::error::{Error, Result};
7#[cfg(feature = "image")]
8use image::{ImageBuffer, Pixel};
9#[cfg(feature = "image")]
10use ndarray::{Array3, ArrayView3, ArrayViewMut, ArrayViewMut3};
11
12#[cfg(feature = "image")]
13pub trait ImageArray<P: image::Pixel, ImageContainer> {
17 fn as_ndarray<'a>(&'a self) -> ArrayView3<'a, ImageContainer>;
28
29 fn as_ndarray_mut<'a>(&'a mut self) -> ArrayViewMut3<'a, ImageContainer>;
40
41 fn to_ndarray(self) -> Array3<ImageContainer>;
52
53 fn from_ndarray(array: Array3<ImageContainer>) -> Result<ImageBuffer<P, Vec<ImageContainer>>>;
64}
65
66#[cfg(feature = "image")]
67impl<P, C> ImageArray<P, C> for ImageBuffer<P, Vec<C>>
68where
69 P: Pixel<Subpixel = C>,
70 C: Clone + Copy,
71{
72 fn as_ndarray<'a>(&'a self) -> ArrayView3<'a, C> {
73 let (width, height) = self.dimensions();
74 unsafe {
75 ArrayView3::from_shape_ptr(
76 (height as usize, width as usize, P::CHANNEL_COUNT as usize),
77 self.as_raw().as_ptr(),
78 )
79 }
80 }
81
82 fn to_ndarray(self) -> Array3<C> {
83 let (width, height) = self.dimensions();
84 unsafe {
85 Array3::from_shape_vec_unchecked(
86 (height as usize, width as usize, P::CHANNEL_COUNT as usize),
87 self.into_raw(),
88 )
89 }
90 }
91
92 fn from_ndarray(mut array: Array3<C>) -> Result<ImageBuffer<P, Vec<C>>> {
93 let (height, width, channels) = array.dim();
94
95 if channels != P::CHANNEL_COUNT.into() {
96 return Err(Error::ChannelMismatch);
97 }
98
99 let data = array.as_mut_ptr();
100
101 std::mem::forget(array);
102 let size = height * width * channels;
103
104 let vec_data = unsafe { Vec::from_raw_parts(data, size, size) };
105 Self::from_raw(width as u32, height as u32, vec_data).ok_or(Error::ImageConstructFailed)
106 }
107
108 fn as_ndarray_mut<'a>(&'a mut self) -> ArrayViewMut3<'a, C> {
109 let (width, height) = self.dimensions();
110
111 unsafe {
112 ArrayViewMut::from_shape_ptr(
113 (height as usize, width as usize, P::CHANNEL_COUNT as usize),
114 self.as_mut_ptr(),
115 )
116 }
117 }
118}
119
120pub trait NormalizedFloat<T>
124where
125 T: AsPrimitive<f32> + AsPrimitive<f64>,
126{
127 fn to_f32_normalized(&self) -> Option<f32>;
135 fn to_f64_normalized(&self) -> Option<f64>;
143
144 fn from_f32_normalized(value: f32) -> Option<T>;
148
149 fn from_f64_normalized(value: f64) -> Option<T>;
153}
154
155impl NormalizedFloat<f32> for f32 {
156 fn to_f32_normalized(&self) -> Option<f32> {
157 Some(*self)
158 }
159
160 fn to_f64_normalized(&self) -> Option<f64> {
161 self.to_f64()
162 }
163 fn from_f32_normalized(value: f32) -> Option<f32> {
164 Some(value)
165 }
166
167 fn from_f64_normalized(value: f64) -> Option<f32> {
168 value.to_f32()
169 }
170}
171
172impl NormalizedFloat<f64> for f64 {
173 fn to_f32_normalized(&self) -> Option<f32> {
174 self.to_f32()
175 }
176
177 fn to_f64_normalized(&self) -> Option<f64> {
178 Some(*self)
179 }
180 fn from_f32_normalized(value: f32) -> Option<f64> {
181 value.to_f64()
182 }
183
184 fn from_f64_normalized(value: f64) -> Option<f64> {
185 Some(value)
186 }
187}
188
189#[macro_export]
190macro_rules! impl_as_float {
191 ($type:ty) => {
192 impl NormalizedFloat<$type> for $type {
193 fn to_f32_normalized(&self) -> Option<f32> {
194 self.to_f32()
195 .map(|converted| converted / <$type>::MAX as f32)
196 }
197
198 fn to_f64_normalized(&self) -> Option<f64> {
199 self.to_f64()
200 .map(|converted| converted / <$type>::MAX as f64)
201 }
202
203 fn from_f32_normalized(value: f32) -> Option<$type> {
204 Some((value * <$type>::MAX as f32).as_())
205 }
206
207 fn from_f64_normalized(value: f64) -> Option<$type> {
208 Some((value * <$type>::MAX as f64).as_())
209 }
210 }
211 };
212}
213
214impl_as_float!(i32);
215impl_as_float!(u32);
216impl_as_float!(i16);
217impl_as_float!(u16);
218impl_as_float!(i8);
219impl_as_float!(u8);
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use image::{Luma, Rgb32FImage, Rgba32FImage};
225 use rstest::*;
226
227 #[test]
228 fn test_as_ndarray_rgba() {
229 let (width, height, channels) = (256, 128, 4);
230 let data = create_test_data(width, height, channels);
231 let test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
232
233 let array = test_image.as_ndarray();
234
235 for ((y, x, channel), value) in array.indexed_iter() {
236 assert_eq!(test_image.get_pixel(x as u32, y as u32)[channel], *value);
237 }
238 }
239
240 #[test]
241 fn test_as_ndarray_luma() {
242 let (width, height, channels) = (256, 128, 1);
243 let data = create_test_data(width, height, channels);
244 let test_image: ImageBuffer<Luma<f32>, Vec<f32>> =
245 ImageBuffer::from_vec(256, 128, data).unwrap();
246
247 let array = test_image.as_ndarray();
248
249 for ((y, x, channel), value) in array.indexed_iter() {
250 assert_eq!(test_image.get_pixel(x as u32, y as u32)[channel], *value);
251 }
252 }
253
254 #[test]
255 fn test_as_ndarray_mut() {
256 let (width, height, channels) = (256, 128, 4);
257 let data = create_test_data(width, height, channels);
258 let mut test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
259 let compare = test_image.clone();
260
261 let mut array = test_image.as_ndarray_mut();
262 array += 1.0;
263
264 for (x, y, pixel) in test_image.enumerate_pixels() {
265 let compare_pixel = compare.get_pixel(x, y);
266 for (channel, value) in pixel.channels().iter().enumerate() {
267 assert_eq!(*value, compare_pixel[channel] + 1.0);
268 }
269 }
270 }
271
272 #[test]
273 fn test_to_ndarray() {
274 let (width, height, channels) = (256, 128, 4);
275 let data = create_test_data(width, height, channels);
276 let test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
277
278 let mut array = test_image.clone().to_ndarray();
279
280 array += 1.0;
281 for ((y, x, channel), value) in array.indexed_iter() {
282 assert_eq!(
283 test_image.get_pixel(x as u32, y as u32)[channel] + 1.0,
284 *value
285 );
286 }
287 }
288
289 #[test]
290 fn test_from_ndarray() {
291 let (width, height, channels) = (256, 128, 4);
292 let data = create_test_data(width, height, channels);
293 let test_image = Array3::from_shape_vec((height, width, channels), data).unwrap();
294 let compare_data = test_image.clone();
295
296 let result = Rgba32FImage::from_ndarray(test_image).unwrap();
297
298 for (x, y, pixel) in result.enumerate_pixels() {
299 for (channel, value) in pixel.channels().iter().enumerate() {
300 assert_eq!(*value, compare_data[[y as usize, x as usize, channel]]);
301 }
302 }
303 }
304
305 fn create_test_data(width: usize, height: usize, channels: usize) -> Vec<f32> {
306 let total_elements = width * height * channels;
307 (0..total_elements).map(|x| (x + 1) as f32).collect()
308 }
309
310 #[test]
311 fn test_from_ndarray_with_invalid_channels() {
312 let channels = 4;
313 let (width, height) = (256.0, 128.0);
314 let total_elements = (width * height * 4.0) as usize;
315 let data: Vec<f32> = (0..total_elements).map(|x| (x + 1) as f32).collect();
316 let test_image =
317 Array3::from_shape_vec((height as usize, width as usize, channels), data).unwrap();
318
319 let result = Rgb32FImage::from_ndarray(test_image).err().unwrap();
320
321 assert_eq!(result, Error::ChannelMismatch);
322 }
323
324 #[rstest]
325 #[case(1.0)]
326 #[case(255.0)]
327 #[case(0.5)]
328 #[case(-1.0)]
329 #[case(-255.0)]
330 fn test_f32(#[case] float: f32) {
331 assert_eq!(float.to_f32_normalized().unwrap(), float);
332 assert_eq!(f32::from_f32_normalized(float).unwrap(), float);
333
334 let float_64: f64 = float.as_();
335 assert_eq!(float_64.to_f64_normalized().unwrap(), float_64);
336 assert_eq!(f64::from_f64_normalized(float_64).unwrap(), float_64);
337
338 let converted_to_float64 = float.to_f64_normalized().unwrap();
339 assert_eq!(converted_to_float64, float as f64);
340
341 let converted_back_to_float32 = float_64.to_f32_normalized().unwrap();
342 assert_eq!(converted_back_to_float32, float);
343 }
344
345 #[macro_export]
346 macro_rules! test_unsigned_ints {
347 ($name:ident, $type:ty) => {
348 #[rstest]
349 #[case(0)]
350 #[case(1)]
351 #[case($type::MAX)]
352 #[case($type::MIN)]
353 fn $name(#[case] int: $type) {
354 let normalized_f32 = int.to_f32_normalized().unwrap();
355 let expected_normalized_f32 = int as f32 / <$type>::MAX as f32;
356 assert_eq!(normalized_f32, expected_normalized_f32);
357
358 let int_from_float32 =
359 <$type>::from_f32_normalized(expected_normalized_f32).unwrap();
360 let expected_int_from_float32 =
361 (expected_normalized_f32 * <$type>::MAX as f32) as $type;
362 assert_eq!(int_from_float32, expected_int_from_float32);
363
364 let normalized_f64 = int.to_f64_normalized().unwrap();
365 let expected_normalized_f64 = int as f64 / <$type>::MAX as f64;
366 assert_eq!(normalized_f64, expected_normalized_f64);
367
368 let int_from_float64 =
369 <$type>::from_f64_normalized(expected_normalized_f64).unwrap();
370 let expected_int_from_float64 =
371 (expected_normalized_f64 * <$type>::MAX as f64) as $type;
372 assert_eq!(int_from_float64, expected_int_from_float64);
373 }
374 };
375 }
376
377 #[macro_export]
378 macro_rules! test_signed_ints {
379 ($name:ident, $type:ty) => {
380 #[rstest]
381 #[case(0)]
382 #[case(1)]
383 #[case($type::MAX)]
384 #[case($type::MIN)]
385 #[case(-1)]
386 #[case(-$type::MAX)]
387 fn $name(#[case] int: $type) {
388 let normalized_f32 = int.to_f32_normalized().unwrap();
389 let expected_normalized_f32 = int as f32 / <$type>::MAX as f32;
390 assert_eq!(normalized_f32, expected_normalized_f32);
391
392 let int_from_float32 =
393 <$type>::from_f32_normalized(expected_normalized_f32).unwrap();
394 let expected_int_from_float32 =
395 (expected_normalized_f32 * <$type>::MAX as f32) as $type;
396 assert_eq!(int_from_float32, expected_int_from_float32);
397
398 let normalized_f64 = int.to_f64_normalized().unwrap();
399 let expected_normalized_f64 = int as f64 / <$type>::MAX as f64;
400 assert_eq!(normalized_f64, expected_normalized_f64);
401
402 let int_from_float64 =
403 <$type>::from_f64_normalized(expected_normalized_f64).unwrap();
404 let expected_int_from_float64 =
405 (expected_normalized_f64 * <$type>::MAX as f64) as $type;
406 assert_eq!(int_from_float64, expected_int_from_float64);
407 }
408 };
409 }
410 test_signed_ints!(test_i32, i32);
412 test_signed_ints!(test_i16, i16);
413 test_signed_ints!(test_i8, i8);
414 test_unsigned_ints!(test_u32, u32);
415 test_unsigned_ints!(test_u16, u16);
416 test_unsigned_ints!(test_u8, u8);
417}