image_ndarray/
traits.rs

1//! Implementations for ndarray casting and conversions for the ImageBuffer
2
3#[cfg(feature = "image")]
4use crate::error::{Error, Result};
5#[cfg(feature = "image")]
6use image::{ImageBuffer, Pixel};
7#[cfg(feature = "image")]
8use ndarray::{Array3, ArrayView3, ArrayViewMut, ArrayViewMut3};
9use num_traits::{AsPrimitive, ToPrimitive};
10
11#[cfg(feature = "image")]
12/// Conversion methods for working with ndarrays.
13///
14/// All methods work without copying any data.
15pub trait ImageArray<P: image::Pixel, ImageContainer> {
16    /// Cast the ImageBuffer as an ArrayView3.
17    ///
18    /// * `Y` index is the row
19    /// * `X` index is the columns
20    /// * `Z` index is the channel
21    ///
22    /// So when referencing:
23    /// `array[[y, x, z]]`
24    ///
25    /// This does not copy the data, as it is a reference to the actual data in the buffer.
26    fn as_ndarray<'a>(&'a self) -> ArrayView3<'a, ImageContainer>;
27
28    /// Cast the ImageBuffer as an ArrayViewMut3.
29    ///
30    /// * `Y` index is the row
31    /// * `X` index is the columns
32    /// * `Z` index is the channel
33    ///
34    /// So when referencing:
35    /// `array[[y, x, z]]`
36    ///
37    /// This does not copy the data, as it is a reference to the actual data in the buffer.
38    fn as_ndarray_mut<'a>(&'a mut self) -> ArrayViewMut3<'a, ImageContainer>;
39
40    /// Interpret the ImageBuffer as an Array3.
41    ///
42    /// * `Y` index is the row
43    /// * `X` index is the columns
44    /// * `Z` index is the channel
45    ///
46    /// So when referencing:
47    /// `array[[y, x, z]]`
48    ///
49    /// This does not copy the data, but it does consume the buffer.
50    fn to_ndarray(self) -> Array3<ImageContainer>;
51
52    /// Convert the provided array into the ImageBuffer
53    ///
54    /// * `Y` index is the row
55    /// * `X` index is the columns
56    /// * `Z` index is the channel
57    ///
58    /// So when referencing:
59    /// `array[[y, x, z]]`
60    ///
61    /// This does not copy the data, but it does consume the buffer.
62    fn from_ndarray(array: Array3<ImageContainer>) -> Result<ImageBuffer<P, Vec<ImageContainer>>>;
63}
64
65#[cfg(feature = "image")]
66impl<P, C> ImageArray<P, C> for ImageBuffer<P, Vec<C>>
67where
68    P: Pixel<Subpixel = C>,
69    C: Clone + Copy,
70{
71    fn as_ndarray<'a>(&'a self) -> ArrayView3<'a, C> {
72        let (width, height) = self.dimensions();
73        unsafe {
74            ArrayView3::from_shape_ptr(
75                (height as usize, width as usize, P::CHANNEL_COUNT as usize),
76                self.as_raw().as_ptr(),
77            )
78        }
79    }
80
81    fn to_ndarray(self) -> Array3<C> {
82        let (width, height) = self.dimensions();
83        unsafe {
84            Array3::from_shape_vec_unchecked(
85                (height as usize, width as usize, P::CHANNEL_COUNT as usize),
86                self.into_raw(),
87            )
88        }
89    }
90
91    fn from_ndarray(mut array: Array3<C>) -> Result<ImageBuffer<P, Vec<C>>> {
92        let (height, width, channels) = array.dim();
93
94        if channels != P::CHANNEL_COUNT.into() {
95            return Err(Error::ChannelMismatch);
96        }
97
98        let data = array.as_mut_ptr();
99
100        std::mem::forget(array);
101        let size = height * width * channels;
102
103        let vec_data = unsafe { Vec::from_raw_parts(data, size, size) };
104        Self::from_raw(width as u32, height as u32, vec_data).ok_or(Error::ImageConstructFailed)
105    }
106
107    fn as_ndarray_mut<'a>(&'a mut self) -> ArrayViewMut3<'a, C> {
108        let (width, height) = self.dimensions();
109
110        unsafe {
111            ArrayViewMut::from_shape_ptr(
112                (height as usize, width as usize, P::CHANNEL_COUNT as usize),
113                self.as_mut_ptr(),
114            )
115        }
116    }
117}
118
119/// Trait for converting the provided value to a normalized float.
120///
121/// This is used for image processing where a lot of operations rely on floating values.
122pub trait NormalizedFloat<T>
123where
124    T: AsPrimitive<f32> + AsPrimitive<f64>,
125{
126    /// Convert the value to a 32 bit float.
127    ///
128    /// The value will be in a normalized range according to color depths.
129    ///
130    /// For example in u8, a value of 255 would be represented as 1.0.
131    ///
132    /// Returns None if it overflows and could not be represented.
133    fn to_f32_normalized(&self) -> Option<f32>;
134    /// Convert the value to a 64 bit float
135    ///
136    /// The value will be in a normalized range according to color depths.
137    ///
138    /// For example in u8, a value of 255 would be represented as 1.0.
139    ///
140    /// Returns None if it overflows and could not be represented.
141    fn to_f64_normalized(&self) -> Option<f64>;
142
143    /// Converts the f32 value to the provided type
144    ///
145    /// Returns None if it overflows and could not be represented.
146    fn from_f32_normalized(value: f32) -> Option<T>;
147
148    /// Converts the f64 value to the provided type
149    ///
150    /// Returns None if it overflows and could not be represented.
151    fn from_f64_normalized(value: f64) -> Option<T>;
152}
153
154impl NormalizedFloat<f32> for f32 {
155    fn to_f32_normalized(&self) -> Option<f32> {
156        Some(*self)
157    }
158
159    fn to_f64_normalized(&self) -> Option<f64> {
160        self.to_f64()
161    }
162    fn from_f32_normalized(value: f32) -> Option<f32> {
163        Some(value)
164    }
165
166    fn from_f64_normalized(value: f64) -> Option<f32> {
167        value.to_f32()
168    }
169}
170
171impl NormalizedFloat<f64> for f64 {
172    fn to_f32_normalized(&self) -> Option<f32> {
173        self.to_f32()
174    }
175
176    fn to_f64_normalized(&self) -> Option<f64> {
177        Some(*self)
178    }
179    fn from_f32_normalized(value: f32) -> Option<f64> {
180        value.to_f64()
181    }
182
183    fn from_f64_normalized(value: f64) -> Option<f64> {
184        Some(value)
185    }
186}
187
188#[macro_export]
189macro_rules! impl_as_float {
190    ($type:ty) => {
191        impl NormalizedFloat<$type> for $type {
192            fn to_f32_normalized(&self) -> Option<f32> {
193                self.to_f32()
194                    .map(|converted| converted / <$type>::MAX as f32)
195            }
196
197            fn to_f64_normalized(&self) -> Option<f64> {
198                self.to_f64()
199                    .map(|converted| converted / <$type>::MAX as f64)
200            }
201
202            fn from_f32_normalized(value: f32) -> Option<$type> {
203                Some((value * <$type>::MAX as f32).as_())
204            }
205
206            fn from_f64_normalized(value: f64) -> Option<$type> {
207                Some((value * <$type>::MAX as f64).as_())
208            }
209        }
210    };
211}
212
213impl_as_float!(i32);
214impl_as_float!(u32);
215impl_as_float!(i16);
216impl_as_float!(u16);
217impl_as_float!(i8);
218impl_as_float!(u8);
219
220#[cfg(feature = "image")]
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use image::{Luma, Rgb32FImage, Rgba32FImage};
225    use rstest::*;
226
227    #[test]
228    fn test_as_ndarray_rgba() {
229        let (width, height, channels) = (256, 128, 4);
230        let data = create_test_data(width, height, channels);
231        let test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
232
233        let array = test_image.as_ndarray();
234
235        for ((y, x, channel), value) in array.indexed_iter() {
236            assert_eq!(test_image.get_pixel(x as u32, y as u32)[channel], *value);
237        }
238    }
239
240    #[test]
241    fn test_as_ndarray_luma() {
242        let (width, height, channels) = (256, 128, 1);
243        let data = create_test_data(width, height, channels);
244        let test_image: ImageBuffer<Luma<f32>, Vec<f32>> =
245            ImageBuffer::from_vec(256, 128, data).unwrap();
246
247        let array = test_image.as_ndarray();
248
249        for ((y, x, channel), value) in array.indexed_iter() {
250            assert_eq!(test_image.get_pixel(x as u32, y as u32)[channel], *value);
251        }
252    }
253
254    #[test]
255    fn test_as_ndarray_mut() {
256        let (width, height, channels) = (256, 128, 4);
257        let data = create_test_data(width, height, channels);
258        let mut test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
259        let compare = test_image.clone();
260
261        let mut array = test_image.as_ndarray_mut();
262        array += 1.0;
263
264        for (x, y, pixel) in test_image.enumerate_pixels() {
265            let compare_pixel = compare.get_pixel(x, y);
266            for (channel, value) in pixel.channels().iter().enumerate() {
267                assert_eq!(*value, compare_pixel[channel] + 1.0);
268            }
269        }
270    }
271
272    #[test]
273    fn test_to_ndarray() {
274        let (width, height, channels) = (256, 128, 4);
275        let data = create_test_data(width, height, channels);
276        let test_image = Rgba32FImage::from_vec(256, 128, data).unwrap();
277
278        let mut array = test_image.clone().to_ndarray();
279
280        array += 1.0;
281        for ((y, x, channel), value) in array.indexed_iter() {
282            assert_eq!(
283                test_image.get_pixel(x as u32, y as u32)[channel] + 1.0,
284                *value
285            );
286        }
287    }
288
289    #[test]
290    fn test_from_ndarray() {
291        let (width, height, channels) = (256, 128, 4);
292        let data = create_test_data(width, height, channels);
293        let test_image = Array3::from_shape_vec((height, width, channels), data).unwrap();
294        let compare_data = test_image.clone();
295
296        let result = Rgba32FImage::from_ndarray(test_image).unwrap();
297
298        for (x, y, pixel) in result.enumerate_pixels() {
299            for (channel, value) in pixel.channels().iter().enumerate() {
300                assert_eq!(*value, compare_data[[y as usize, x as usize, channel]]);
301            }
302        }
303    }
304
305    fn create_test_data(width: usize, height: usize, channels: usize) -> Vec<f32> {
306        let total_elements = width * height * channels;
307        (0..total_elements).map(|x| (x + 1) as f32).collect()
308    }
309
310    #[test]
311    fn test_from_ndarray_with_invalid_channels() {
312        let channels = 4;
313        let (width, height) = (256.0, 128.0);
314        let total_elements = (width * height * 4.0) as usize;
315        let data: Vec<f32> = (0..total_elements).map(|x| (x + 1) as f32).collect();
316        let test_image =
317            Array3::from_shape_vec((height as usize, width as usize, channels), data).unwrap();
318
319        let result = Rgb32FImage::from_ndarray(test_image).err().unwrap();
320
321        assert_eq!(result, Error::ChannelMismatch);
322    }
323
324    #[rstest]
325    #[case(1.0)]
326    #[case(255.0)]
327    #[case(0.5)]
328    #[case(-1.0)]
329    #[case(-255.0)]
330    fn test_f32(#[case] float: f32) {
331        assert_eq!(float.to_f32_normalized().unwrap(), float);
332        assert_eq!(f32::from_f32_normalized(float).unwrap(), float);
333
334        let float_64: f64 = float.as_();
335        assert_eq!(float_64.to_f64_normalized().unwrap(), float_64);
336        assert_eq!(f64::from_f64_normalized(float_64).unwrap(), float_64);
337
338        let converted_to_float64 = float.to_f64_normalized().unwrap();
339        assert_eq!(converted_to_float64, float as f64);
340
341        let converted_back_to_float32 = float_64.to_f32_normalized().unwrap();
342        assert_eq!(converted_back_to_float32, float);
343    }
344
345    #[macro_export]
346    macro_rules! test_unsigned_ints {
347        ($name:ident, $type:ty) => {
348            #[rstest]
349            #[case(0)]
350            #[case(1)]
351            #[case($type::MAX)]
352            #[case($type::MIN)]
353            fn $name(#[case] int: $type) {
354                let normalized_f32 = int.to_f32_normalized().unwrap();
355                let expected_normalized_f32 = int as f32 / <$type>::MAX as f32;
356                assert_eq!(normalized_f32, expected_normalized_f32);
357
358                let int_from_float32 =
359                    <$type>::from_f32_normalized(expected_normalized_f32).unwrap();
360                let expected_int_from_float32 =
361                    (expected_normalized_f32 * <$type>::MAX as f32) as $type;
362                assert_eq!(int_from_float32, expected_int_from_float32);
363
364                let normalized_f64 = int.to_f64_normalized().unwrap();
365                let expected_normalized_f64 = int as f64 / <$type>::MAX as f64;
366                assert_eq!(normalized_f64, expected_normalized_f64);
367
368                let int_from_float64 =
369                    <$type>::from_f64_normalized(expected_normalized_f64).unwrap();
370                let expected_int_from_float64 =
371                    (expected_normalized_f64 * <$type>::MAX as f64) as $type;
372                assert_eq!(int_from_float64, expected_int_from_float64);
373            }
374        };
375    }
376
377    #[macro_export]
378    macro_rules! test_signed_ints {
379        ($name:ident, $type:ty) => {
380            #[rstest]
381            #[case(0)]
382            #[case(1)]
383            #[case($type::MAX)]
384            #[case($type::MIN)]
385            #[case(-1)]
386            #[case(-$type::MAX)]
387            fn $name(#[case] int: $type) {
388                let normalized_f32 = int.to_f32_normalized().unwrap();
389                let expected_normalized_f32 = int as f32 / <$type>::MAX as f32;
390                assert_eq!(normalized_f32, expected_normalized_f32);
391
392                let int_from_float32 =
393                    <$type>::from_f32_normalized(expected_normalized_f32).unwrap();
394                let expected_int_from_float32 =
395                    (expected_normalized_f32 * <$type>::MAX as f32) as $type;
396                assert_eq!(int_from_float32, expected_int_from_float32);
397
398                let normalized_f64 = int.to_f64_normalized().unwrap();
399                let expected_normalized_f64 = int as f64 / <$type>::MAX as f64;
400                assert_eq!(normalized_f64, expected_normalized_f64);
401
402                let int_from_float64 =
403                    <$type>::from_f64_normalized(expected_normalized_f64).unwrap();
404                let expected_int_from_float64 =
405                    (expected_normalized_f64 * <$type>::MAX as f64) as $type;
406                assert_eq!(int_from_float64, expected_int_from_float64);
407            }
408        };
409    }
410    // Using the macro to generate tests for i32
411    test_signed_ints!(test_i32, i32);
412    test_signed_ints!(test_i16, i16);
413    test_signed_ints!(test_i8, i8);
414    test_unsigned_ints!(test_u32, u32);
415    test_unsigned_ints!(test_u16, u16);
416    test_unsigned_ints!(test_u8, u8);
417}