image_ndarray/
traits.rs

1//! Implementations for ndarray casting and conversions for the ImageBuffer
2
3#[cfg(feature = "image")]
4use crate::error::{Error, Result};
5#[cfg(feature = "image")]
6use image::{ImageBuffer, Pixel};
7#[cfg(feature = "image")]
8use ndarray::{Array, Array3, ArrayView3, ArrayViewMut, ArrayViewMut3, Dimension};
9use num_traits::{AsPrimitive, ToPrimitive};
10
11#[cfg(feature = "image")]
12/// Conversion methods for working with ndarrays.
13///
14/// All methods work without copying any data.
15pub trait ImageArray<P: image::Pixel, ImageContainer> {
16    /// Cast the ImageBuffer as an ArrayView3.
17    ///
18    /// * `Y` index is the row
19    /// * `X` index is the columns
20    /// * `Z` index is the channel
21    ///
22    /// So when referencing:
23    /// `array[[y, x, z]]`
24    ///
25    /// This does not copy the data, as it is a reference to the actual data in the buffer.
26    fn as_ndarray<'a>(&'a self) -> ArrayView3<'a, ImageContainer>;
27
28    /// Cast the ImageBuffer as an ArrayViewMut3.
29    ///
30    /// * `Y` index is the row
31    /// * `X` index is the columns
32    /// * `Z` index is the channel
33    ///
34    /// So when referencing:
35    /// `array[[y, x, z]]`
36    ///
37    /// This does not copy the data, as it is a reference to the actual data in the buffer.
38    fn as_ndarray_mut<'a>(&'a mut self) -> ArrayViewMut3<'a, ImageContainer>;
39
40    /// Interpret the ImageBuffer as an Array3.
41    ///
42    /// * `Y` index is the row
43    /// * `X` index is the columns
44    /// * `Z` index is the channel
45    ///
46    /// So when referencing:
47    /// `array[[y, x, z]]`
48    ///
49    /// This does not copy the data, but it does consume the buffer.
50    fn to_ndarray(self) -> Array3<ImageContainer>;
51
52    /// Convert the provided array into the ImageBuffer
53    ///
54    /// * `Y` index is the row
55    /// * `X` index is the columns
56    /// * `Z` index is the channel
57    ///
58    /// So when referencing:
59    /// `array[[y, x, z]]`
60    ///
61    /// This does not copy the data, but it does consume the buffer.
62    fn from_ndarray<D: Dimension>(
63        array: Array<ImageContainer, D>,
64    ) -> Result<ImageBuffer<P, Vec<ImageContainer>>>;
65}
66
67#[cfg(feature = "image")]
68impl<P, C> ImageArray<P, C> for ImageBuffer<P, Vec<C>>
69where
70    P: Pixel<Subpixel = C>,
71    C: Clone + Copy,
72{
73    fn as_ndarray<'a>(&'a self) -> ArrayView3<'a, C> {
74        let (width, height) = self.dimensions();
75        unsafe {
76            ArrayView3::from_shape_ptr(
77                (height as usize, width as usize, P::CHANNEL_COUNT as usize),
78                self.as_raw().as_ptr(),
79            )
80        }
81    }
82
83    fn to_ndarray(self) -> Array3<C> {
84        let (width, height) = self.dimensions();
85        unsafe {
86            Array3::from_shape_vec_unchecked(
87                (height as usize, width as usize, P::CHANNEL_COUNT as usize),
88                self.into_raw(),
89            )
90        }
91    }
92    fn from_ndarray<D: Dimension>(mut array: Array<C, D>) -> Result<ImageBuffer<P, Vec<C>>> {
93        let shape = array.shape();
94        if shape.len() < 2 {
95            return Err(Error::Dimensions);
96        };
97
98        let (width, height) = (shape[1], shape[0]);
99        let channels = match shape.len() {
100            2 => 1,
101            3 => shape[2],
102            _ => return Err(Error::Dimensions),
103        };
104
105        if channels != P::CHANNEL_COUNT.into() {
106            return Err(Error::ChannelMismatch);
107        }
108
109        let data = array.as_mut_ptr();
110
111        std::mem::forget(array);
112        let size = height * width * channels;
113
114        let vec_data = unsafe { Vec::from_raw_parts(data, size, size) };
115        Self::from_raw(width as u32, height as u32, vec_data).ok_or(Error::ImageConstructFailed)
116    }
117
118    fn as_ndarray_mut<'a>(&'a mut self) -> ArrayViewMut3<'a, C> {
119        let (width, height) = self.dimensions();
120
121        unsafe {
122            ArrayViewMut::from_shape_ptr(
123                (height as usize, width as usize, P::CHANNEL_COUNT as usize),
124                self.as_mut_ptr(),
125            )
126        }
127    }
128}
129
130/// Trait for converting the provided value to a normalized float.
131///
132/// This is used for image processing where a lot of operations rely on floating values.
133pub trait NormalizedFloat<T>
134where
135    T: AsPrimitive<f32> + AsPrimitive<f64>,
136{
137    /// Convert the value to a 32 bit float.
138    ///
139    /// The value will be in a normalized range according to color depths.
140    ///
141    /// For example in u8, a value of 255 would be represented as 1.0.
142    ///
143    /// Returns None if it overflows and could not be represented.
144    fn to_f32_normalized(&self) -> Option<f32>;
145    /// Convert the value to a 64 bit float
146    ///
147    /// The value will be in a normalized range according to color depths.
148    ///
149    /// For example in u8, a value of 255 would be represented as 1.0.
150    ///
151    /// Returns None if it overflows and could not be represented.
152    fn to_f64_normalized(&self) -> Option<f64>;
153
154    /// Converts the f32 value to the provided type
155    ///
156    /// Returns None if it overflows and could not be represented.
157    fn from_f32_normalized(value: f32) -> Option<T>;
158
159    /// Converts the f64 value to the provided type
160    ///
161    /// Returns None if it overflows and could not be represented.
162    fn from_f64_normalized(value: f64) -> Option<T>;
163}
164
165impl NormalizedFloat<f32> for f32 {
166    fn to_f32_normalized(&self) -> Option<f32> {
167        Some(*self)
168    }
169
170    fn to_f64_normalized(&self) -> Option<f64> {
171        self.to_f64()
172    }
173    fn from_f32_normalized(value: f32) -> Option<f32> {
174        Some(value)
175    }
176
177    fn from_f64_normalized(value: f64) -> Option<f32> {
178        value.to_f32()
179    }
180}
181
182impl NormalizedFloat<f64> for f64 {
183    fn to_f32_normalized(&self) -> Option<f32> {
184        self.to_f32()
185    }
186
187    fn to_f64_normalized(&self) -> Option<f64> {
188        Some(*self)
189    }
190    fn from_f32_normalized(value: f32) -> Option<f64> {
191        value.to_f64()
192    }
193
194    fn from_f64_normalized(value: f64) -> Option<f64> {
195        Some(value)
196    }
197}
198
199#[macro_export]
200macro_rules! impl_as_float {
201    ($type:ty) => {
202        impl NormalizedFloat<$type> for $type {
203            fn to_f32_normalized(&self) -> Option<f32> {
204                self.to_f32()
205                    .map(|converted| converted / <$type>::MAX as f32)
206            }
207
208            fn to_f64_normalized(&self) -> Option<f64> {
209                self.to_f64()
210                    .map(|converted| converted / <$type>::MAX as f64)
211            }
212
213            fn from_f32_normalized(value: f32) -> Option<$type> {
214                Some((value * <$type>::MAX as f32).as_())
215            }
216
217            fn from_f64_normalized(value: f64) -> Option<$type> {
218                Some((value * <$type>::MAX as f64).as_())
219            }
220        }
221    };
222}
223
224impl_as_float!(i32);
225impl_as_float!(u32);
226impl_as_float!(i16);
227impl_as_float!(u16);
228impl_as_float!(i8);
229impl_as_float!(u8);
230
231#[cfg(feature = "image")]
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use image::{Luma, Rgb32FImage, Rgba32FImage};
236    use ndarray::Array2;
237    use rstest::*;
238
239    #[test]
240    fn test_as_ndarray_rgba() {
241        let (width, height, channels) = (256, 128, 4);
242        let data = create_test_data(width, height, channels);
243        let test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
244
245        let array = test_image.as_ndarray();
246
247        for ((y, x, channel), value) in array.indexed_iter() {
248            assert_eq!(test_image.get_pixel(x as u32, y as u32)[channel], *value);
249        }
250    }
251
252    #[test]
253    fn test_as_ndarray_luma() {
254        let (width, height, channels) = (256, 128, 1);
255        let data = create_test_data(width, height, channels);
256        let test_image: ImageBuffer<Luma<f32>, Vec<f32>> =
257            ImageBuffer::from_vec(256, 128, data).unwrap();
258
259        let array = test_image.as_ndarray();
260
261        for ((y, x, channel), value) in array.indexed_iter() {
262            assert_eq!(test_image.get_pixel(x as u32, y as u32)[channel], *value);
263        }
264    }
265
266    #[test]
267    fn test_as_ndarray_mut() {
268        let (width, height, channels) = (256, 128, 4);
269        let data = create_test_data(width, height, channels);
270        let mut test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
271        let compare = test_image.clone();
272
273        let mut array = test_image.as_ndarray_mut();
274        array += 1.0;
275
276        for (x, y, pixel) in test_image.enumerate_pixels() {
277            let compare_pixel = compare.get_pixel(x, y);
278            for (channel, value) in pixel.channels().iter().enumerate() {
279                assert_eq!(*value, compare_pixel[channel] + 1.0);
280            }
281        }
282    }
283
284    #[test]
285    fn test_to_ndarray() {
286        let (width, height, channels) = (256, 128, 4);
287        let data = create_test_data(width, height, channels);
288        let test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
289
290        let mut array = test_image.clone().to_ndarray();
291
292        array += 1.0;
293        for ((y, x, channel), value) in array.indexed_iter() {
294            assert_eq!(
295                test_image.get_pixel(x as u32, y as u32)[channel] + 1.0,
296                *value
297            );
298        }
299    }
300
301    #[test]
302    fn test_from_ndarray() {
303        let (width, height, channels) = (256, 128, 4);
304        let data = create_test_data(width, height, channels);
305        let test_image = Array3::from_shape_vec((height, width, channels), data).unwrap();
306        let compare_data = test_image.clone();
307
308        let result = Rgba32FImage::from_ndarray(test_image).unwrap();
309
310        for (x, y, pixel) in result.enumerate_pixels() {
311            for (channel, value) in pixel.channels().iter().enumerate() {
312                assert_eq!(*value, compare_data[[y as usize, x as usize, channel]]);
313            }
314        }
315    }
316
317    #[test]
318    fn test_from_ndarray_2d() {
319        let (width, height, channels) = (256, 128, 1);
320        let data = create_test_data(width, height, channels);
321        let test_image = Array2::from_shape_vec((height, width), data).unwrap();
322        println!("{}", test_image.shape().len());
323        let compare_data = test_image.clone();
324
325        let result = ImageBuffer::<Luma<f32>, Vec<f32>>::from_ndarray(test_image).unwrap();
326
327        for (x, y, pixel) in result.enumerate_pixels() {
328            for value in pixel.channels().iter() {
329                assert_eq!(*value, compare_data[[y as usize, x as usize]]);
330            }
331        }
332    }
333
334    fn create_test_data(width: usize, height: usize, channels: usize) -> Vec<f32> {
335        let total_elements = width * height * channels;
336        (0..total_elements).map(|x| (x + 1) as f32).collect()
337    }
338
339    #[test]
340    fn test_from_ndarray_with_invalid_channels() {
341        let channels = 4;
342        let (width, height) = (256.0, 128.0);
343        let total_elements = (width * height * 4.0) as usize;
344        let data: Vec<f32> = (0..total_elements).map(|x| (x + 1) as f32).collect();
345        let test_image =
346            Array3::from_shape_vec((height as usize, width as usize, channels), data).unwrap();
347
348        let result = Rgb32FImage::from_ndarray(test_image.into_dyn())
349            .err()
350            .unwrap();
351
352        assert_eq!(result, Error::ChannelMismatch);
353    }
354
355    #[rstest]
356    #[case(1.0)]
357    #[case(255.0)]
358    #[case(0.5)]
359    #[case(-1.0)]
360    #[case(-255.0)]
361    fn test_f32(#[case] float: f32) {
362        assert_eq!(float.to_f32_normalized().unwrap(), float);
363        assert_eq!(f32::from_f32_normalized(float).unwrap(), float);
364
365        let float_64: f64 = float.as_();
366        assert_eq!(float_64.to_f64_normalized().unwrap(), float_64);
367        assert_eq!(f64::from_f64_normalized(float_64).unwrap(), float_64);
368
369        let converted_to_float64 = float.to_f64_normalized().unwrap();
370        assert_eq!(converted_to_float64, float as f64);
371
372        let converted_back_to_float32 = float_64.to_f32_normalized().unwrap();
373        assert_eq!(converted_back_to_float32, float);
374    }
375
376    #[macro_export]
377    macro_rules! test_unsigned_ints {
378        ($name:ident, $type:ty) => {
379            #[rstest]
380            #[case(0)]
381            #[case(1)]
382            #[case($type::MAX)]
383            #[case($type::MIN)]
384            fn $name(#[case] int: $type) {
385                let normalized_f32 = int.to_f32_normalized().unwrap();
386                let expected_normalized_f32 = int as f32 / <$type>::MAX as f32;
387                assert_eq!(normalized_f32, expected_normalized_f32);
388
389                let int_from_float32 =
390                    <$type>::from_f32_normalized(expected_normalized_f32).unwrap();
391                let expected_int_from_float32 =
392                    (expected_normalized_f32 * <$type>::MAX as f32) as $type;
393                assert_eq!(int_from_float32, expected_int_from_float32);
394
395                let normalized_f64 = int.to_f64_normalized().unwrap();
396                let expected_normalized_f64 = int as f64 / <$type>::MAX as f64;
397                assert_eq!(normalized_f64, expected_normalized_f64);
398
399                let int_from_float64 =
400                    <$type>::from_f64_normalized(expected_normalized_f64).unwrap();
401                let expected_int_from_float64 =
402                    (expected_normalized_f64 * <$type>::MAX as f64) as $type;
403                assert_eq!(int_from_float64, expected_int_from_float64);
404            }
405        };
406    }
407
408    #[macro_export]
409    macro_rules! test_signed_ints {
410        ($name:ident, $type:ty) => {
411            #[rstest]
412            #[case(0)]
413            #[case(1)]
414            #[case($type::MAX)]
415            #[case($type::MIN)]
416            #[case(-1)]
417            #[case(-$type::MAX)]
418            fn $name(#[case] int: $type) {
419                let normalized_f32 = int.to_f32_normalized().unwrap();
420                let expected_normalized_f32 = int as f32 / <$type>::MAX as f32;
421                assert_eq!(normalized_f32, expected_normalized_f32);
422
423                let int_from_float32 =
424                    <$type>::from_f32_normalized(expected_normalized_f32).unwrap();
425                let expected_int_from_float32 =
426                    (expected_normalized_f32 * <$type>::MAX as f32) as $type;
427                assert_eq!(int_from_float32, expected_int_from_float32);
428
429                let normalized_f64 = int.to_f64_normalized().unwrap();
430                let expected_normalized_f64 = int as f64 / <$type>::MAX as f64;
431                assert_eq!(normalized_f64, expected_normalized_f64);
432
433                let int_from_float64 =
434                    <$type>::from_f64_normalized(expected_normalized_f64).unwrap();
435                let expected_int_from_float64 =
436                    (expected_normalized_f64 * <$type>::MAX as f64) as $type;
437                assert_eq!(int_from_float64, expected_int_from_float64);
438            }
439        };
440    }
441    // Using the macro to generate tests for i32
442    test_signed_ints!(test_i32, i32);
443    test_signed_ints!(test_i16, i16);
444    test_signed_ints!(test_i8, i8);
445    test_unsigned_ints!(test_u32, u32);
446    test_unsigned_ints!(test_u16, u16);
447    test_unsigned_ints!(test_u8, u8);
448}