image_ndarray/
traits.rs

1//! Implementations for ndarray casting and conversions for the ImageBuffer
2
3use num_traits::{AsPrimitive, ToPrimitive};
4
5#[cfg(feature = "image")]
6use crate::error::{Error, Result};
7#[cfg(feature = "image")]
8use image::{ImageBuffer, Pixel};
9#[cfg(feature = "image")]
10use ndarray::{Array3, ArrayView3, ArrayViewMut, ArrayViewMut3};
11
12#[cfg(feature = "image")]
13/// Conversion methods for working with ndarrays.
14///
15/// All methods work without copying any data.
16pub trait ImageArray<P: image::Pixel, ImageContainer> {
17    /// Cast the ImageBuffer as an ArrayView3.
18    ///
19    /// * `Y` index is the row
20    /// * `X` index is the columns
21    /// * `Z` index is the channel
22    ///
23    /// So when referencing:
24    /// `array[[y, x, z]]`
25    ///
26    /// This does not copy the data, as it is a reference to the actual data in the buffer.
27    fn as_ndarray<'a>(&'a self) -> ArrayView3<'a, ImageContainer>;
28
29    /// Cast the ImageBuffer as an ArrayViewMut3.
30    ///
31    /// * `Y` index is the row
32    /// * `X` index is the columns
33    /// * `Z` index is the channel
34    ///
35    /// So when referencing:
36    /// `array[[y, x, z]]`
37    ///
38    /// This does not copy the data, as it is a reference to the actual data in the buffer.
39    fn as_ndarray_mut<'a>(&'a mut self) -> ArrayViewMut3<'a, ImageContainer>;
40
41    /// Interpret the ImageBuffer as an Array3.
42    ///
43    /// * `Y` index is the row
44    /// * `X` index is the columns
45    /// * `Z` index is the channel
46    ///
47    /// So when referencing:
48    /// `array[[y, x, z]]`
49    ///
50    /// This does not copy the data, but it does consume the buffer.
51    fn to_ndarray(self) -> Array3<ImageContainer>;
52
53    /// Convert the provided array into the ImageBuffer
54    ///
55    /// * `Y` index is the row
56    /// * `X` index is the columns
57    /// * `Z` index is the channel
58    ///
59    /// So when referencing:
60    /// `array[[y, x, z]]`
61    ///
62    /// This does not copy the data, but it does consume the buffer.
63    fn from_ndarray(array: Array3<ImageContainer>) -> Result<ImageBuffer<P, Vec<ImageContainer>>>;
64}
65
66#[cfg(feature = "image")]
67impl<P, C> ImageArray<P, C> for ImageBuffer<P, Vec<C>>
68where
69    P: Pixel<Subpixel = C>,
70    C: Clone + Copy,
71{
72    fn as_ndarray<'a>(&'a self) -> ArrayView3<'a, C> {
73        let (width, height) = self.dimensions();
74        unsafe {
75            ArrayView3::from_shape_ptr(
76                (height as usize, width as usize, P::CHANNEL_COUNT as usize),
77                self.as_raw().as_ptr(),
78            )
79        }
80    }
81
82    fn to_ndarray(self) -> Array3<C> {
83        let (width, height) = self.dimensions();
84        unsafe {
85            Array3::from_shape_vec_unchecked(
86                (height as usize, width as usize, P::CHANNEL_COUNT as usize),
87                self.into_raw(),
88            )
89        }
90    }
91
92    fn from_ndarray(mut array: Array3<C>) -> Result<ImageBuffer<P, Vec<C>>> {
93        let (height, width, channels) = array.dim();
94
95        if channels != P::CHANNEL_COUNT.into() {
96            return Err(Error::ChannelMismatch);
97        }
98
99        let data = array.as_mut_ptr();
100
101        std::mem::forget(array);
102        let size = height * width * channels;
103
104        let vec_data = unsafe { Vec::from_raw_parts(data, size, size) };
105        Self::from_raw(width as u32, height as u32, vec_data).ok_or(Error::ImageConstructFailed)
106    }
107
108    fn as_ndarray_mut<'a>(&'a mut self) -> ArrayViewMut3<'a, C> {
109        let (width, height) = self.dimensions();
110
111        unsafe {
112            ArrayViewMut::from_shape_ptr(
113                (height as usize, width as usize, P::CHANNEL_COUNT as usize),
114                self.as_mut_ptr(),
115            )
116        }
117    }
118}
119
120/// Trait for converting the provided value to a normalized float.
121///
122/// This is used for image processing where a lot of operations rely on floating values.
123pub trait NormalizedFloat<T>
124where
125    T: AsPrimitive<f32> + AsPrimitive<f64>,
126{
127    /// Convert the value to a 32 bit float.
128    ///
129    /// The value will be in a normalized range according to color depths.
130    ///
131    /// For example in u8, a value of 255 would be represented as 1.0.
132    ///
133    /// Returns None if it overflows and could not be represented.
134    fn to_f32_normalized(&self) -> Option<f32>;
135    /// Convert the value to a 64 bit float
136    ///
137    /// The value will be in a normalized range according to color depths.
138    ///
139    /// For example in u8, a value of 255 would be represented as 1.0.
140    ///
141    /// Returns None if it overflows and could not be represented.
142    fn to_f64_normalized(&self) -> Option<f64>;
143
144    /// Converts the f32 value to the provided type
145    ///
146    /// Returns None if it overflows and could not be represented.
147    fn from_f32_normalized(value: f32) -> Option<T>;
148
149    /// Converts the f64 value to the provided type
150    ///
151    /// Returns None if it overflows and could not be represented.
152    fn from_f64_normalized(value: f64) -> Option<T>;
153}
154
155impl NormalizedFloat<f32> for f32 {
156    fn to_f32_normalized(&self) -> Option<f32> {
157        Some(*self)
158    }
159
160    fn to_f64_normalized(&self) -> Option<f64> {
161        self.to_f64()
162    }
163    fn from_f32_normalized(value: f32) -> Option<f32> {
164        Some(value)
165    }
166
167    fn from_f64_normalized(value: f64) -> Option<f32> {
168        value.to_f32()
169    }
170}
171
172impl NormalizedFloat<f64> for f64 {
173    fn to_f32_normalized(&self) -> Option<f32> {
174        self.to_f32()
175    }
176
177    fn to_f64_normalized(&self) -> Option<f64> {
178        Some(*self)
179    }
180    fn from_f32_normalized(value: f32) -> Option<f64> {
181        value.to_f64()
182    }
183
184    fn from_f64_normalized(value: f64) -> Option<f64> {
185        Some(value)
186    }
187}
188
189#[macro_export]
190macro_rules! impl_as_float {
191    ($type:ty) => {
192        impl NormalizedFloat<$type> for $type {
193            fn to_f32_normalized(&self) -> Option<f32> {
194                self.to_f32()
195                    .map(|converted| converted / <$type>::MAX as f32)
196            }
197
198            fn to_f64_normalized(&self) -> Option<f64> {
199                self.to_f64()
200                    .map(|converted| converted / <$type>::MAX as f64)
201            }
202
203            fn from_f32_normalized(value: f32) -> Option<$type> {
204                Some((value * <$type>::MAX as f32).as_())
205            }
206
207            fn from_f64_normalized(value: f64) -> Option<$type> {
208                Some((value * <$type>::MAX as f64).as_())
209            }
210        }
211    };
212}
213
214impl_as_float!(i32);
215impl_as_float!(u32);
216impl_as_float!(i16);
217impl_as_float!(u16);
218impl_as_float!(i8);
219impl_as_float!(u8);
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use image::{Luma, Rgb32FImage, Rgba32FImage};
225    use rstest::*;
226
227    #[test]
228    fn test_as_ndarray_rgba() {
229        let (width, height, channels) = (256, 128, 4);
230        let data = create_test_data(width, height, channels);
231        let test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
232
233        let array = test_image.as_ndarray();
234
235        for ((y, x, channel), value) in array.indexed_iter() {
236            assert_eq!(test_image.get_pixel(x as u32, y as u32)[channel], *value);
237        }
238    }
239
240    #[test]
241    fn test_as_ndarray_luma() {
242        let (width, height, channels) = (256, 128, 1);
243        let data = create_test_data(width, height, channels);
244        let test_image: ImageBuffer<Luma<f32>, Vec<f32>> =
245            ImageBuffer::from_vec(256, 128, data).unwrap();
246
247        let array = test_image.as_ndarray();
248
249        for ((y, x, channel), value) in array.indexed_iter() {
250            assert_eq!(test_image.get_pixel(x as u32, y as u32)[channel], *value);
251        }
252    }
253
254    #[test]
255    fn test_as_ndarray_mut() {
256        let (width, height, channels) = (256, 128, 4);
257        let data = create_test_data(width, height, channels);
258        let mut test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
259        let compare = test_image.clone();
260
261        let mut array = test_image.as_ndarray_mut();
262        array += 1.0;
263
264        for (x, y, pixel) in test_image.enumerate_pixels() {
265            let compare_pixel = compare.get_pixel(x, y);
266            for (channel, value) in pixel.channels().iter().enumerate() {
267                assert_eq!(*value, compare_pixel[channel] + 1.0);
268            }
269        }
270    }
271
272    #[test]
273    fn test_to_ndarray() {
274        let (width, height, channels) = (256, 128, 4);
275        let data = create_test_data(width, height, channels);
276        let test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
277
278        let mut array = test_image.clone().to_ndarray();
279
280        array += 1.0;
281        for ((y, x, channel), value) in array.indexed_iter() {
282            assert_eq!(
283                test_image.get_pixel(x as u32, y as u32)[channel] + 1.0,
284                *value
285            );
286        }
287    }
288
289    #[test]
290    fn test_from_ndarray() {
291        let (width, height, channels) = (256, 128, 4);
292        let data = create_test_data(width, height, channels);
293        let test_image = Array3::from_shape_vec((height, width, channels), data).unwrap();
294        let compare_data = test_image.clone();
295
296        let result = Rgba32FImage::from_ndarray(test_image).unwrap();
297
298        for (x, y, pixel) in result.enumerate_pixels() {
299            for (channel, value) in pixel.channels().iter().enumerate() {
300                assert_eq!(*value, compare_data[[y as usize, x as usize, channel]]);
301            }
302        }
303    }
304
305    fn create_test_data(width: usize, height: usize, channels: usize) -> Vec<f32> {
306        let total_elements = width * height * channels;
307        (0..total_elements).map(|x| (x + 1) as f32).collect()
308    }
309
310    #[test]
311    fn test_from_ndarray_with_invalid_channels() {
312        let channels = 4;
313        let (width, height) = (256.0, 128.0);
314        let total_elements = (width * height * 4.0) as usize;
315        let data: Vec<f32> = (0..total_elements).map(|x| (x + 1) as f32).collect();
316        let test_image =
317            Array3::from_shape_vec((height as usize, width as usize, channels), data).unwrap();
318
319        let result = Rgb32FImage::from_ndarray(test_image).err().unwrap();
320
321        assert_eq!(result, Error::ChannelMismatch);
322    }
323
324    #[rstest]
325    #[case(1.0)]
326    #[case(255.0)]
327    #[case(0.5)]
328    #[case(-1.0)]
329    #[case(-255.0)]
330    fn test_f32(#[case] float: f32) {
331        assert_eq!(float.to_f32_normalized().unwrap(), float);
332        assert_eq!(f32::from_f32_normalized(float).unwrap(), float);
333
334        let float_64: f64 = float.as_();
335        assert_eq!(float_64.to_f64_normalized().unwrap(), float_64);
336        assert_eq!(f64::from_f64_normalized(float_64).unwrap(), float_64);
337
338        let converted_to_float64 = float.to_f64_normalized().unwrap();
339        assert_eq!(converted_to_float64, float as f64);
340
341        let converted_back_to_float32 = float_64.to_f32_normalized().unwrap();
342        assert_eq!(converted_back_to_float32, float);
343    }
344
345    #[macro_export]
346    macro_rules! test_unsigned_ints {
347        ($name:ident, $type:ty) => {
348            #[rstest]
349            #[case(0)]
350            #[case(1)]
351            #[case($type::MAX)]
352            #[case($type::MIN)]
353            fn $name(#[case] int: $type) {
354                let normalized_f32 = int.to_f32_normalized().unwrap();
355                let expected_normalized_f32 = int as f32 / <$type>::MAX as f32;
356                assert_eq!(normalized_f32, expected_normalized_f32);
357
358                let int_from_float32 =
359                    <$type>::from_f32_normalized(expected_normalized_f32).unwrap();
360                let expected_int_from_float32 =
361                    (expected_normalized_f32 * <$type>::MAX as f32) as $type;
362                assert_eq!(int_from_float32, expected_int_from_float32);
363
364                let normalized_f64 = int.to_f64_normalized().unwrap();
365                let expected_normalized_f64 = int as f64 / <$type>::MAX as f64;
366                assert_eq!(normalized_f64, expected_normalized_f64);
367
368                let int_from_float64 =
369                    <$type>::from_f64_normalized(expected_normalized_f64).unwrap();
370                let expected_int_from_float64 =
371                    (expected_normalized_f64 * <$type>::MAX as f64) as $type;
372                assert_eq!(int_from_float64, expected_int_from_float64);
373            }
374        };
375    }
376
377    #[macro_export]
378    macro_rules! test_signed_ints {
379        ($name:ident, $type:ty) => {
380            #[rstest]
381            #[case(0)]
382            #[case(1)]
383            #[case($type::MAX)]
384            #[case($type::MIN)]
385            #[case(-1)]
386            #[case(-$type::MAX)]
387            fn $name(#[case] int: $type) {
388                let normalized_f32 = int.to_f32_normalized().unwrap();
389                let expected_normalized_f32 = int as f32 / <$type>::MAX as f32;
390                assert_eq!(normalized_f32, expected_normalized_f32);
391
392                let int_from_float32 =
393                    <$type>::from_f32_normalized(expected_normalized_f32).unwrap();
394                let expected_int_from_float32 =
395                    (expected_normalized_f32 * <$type>::MAX as f32) as $type;
396                assert_eq!(int_from_float32, expected_int_from_float32);
397
398                let normalized_f64 = int.to_f64_normalized().unwrap();
399                let expected_normalized_f64 = int as f64 / <$type>::MAX as f64;
400                assert_eq!(normalized_f64, expected_normalized_f64);
401
402                let int_from_float64 =
403                    <$type>::from_f64_normalized(expected_normalized_f64).unwrap();
404                let expected_int_from_float64 =
405                    (expected_normalized_f64 * <$type>::MAX as f64) as $type;
406                assert_eq!(int_from_float64, expected_int_from_float64);
407            }
408        };
409    }
410    // Using the macro to generate tests for i32
411    test_signed_ints!(test_i32, i32);
412    test_signed_ints!(test_i16, i16);
413    test_signed_ints!(test_i8, i8);
414    test_unsigned_ints!(test_u32, u32);
415    test_unsigned_ints!(test_u16, u16);
416    test_unsigned_ints!(test_u8, u8);
417}