ndarray_vision/processing/
conv.rs

1use crate::core::padding::*;
2use crate::core::{kernel_centre, ColourModel, Image, ImageBase};
3use crate::processing::Error;
4use core::mem::MaybeUninit;
5use ndarray::prelude::*;
6use ndarray::{Data, DataMut, Zip};
7use num_traits::{Num, NumAssignOps};
8use std::marker::PhantomData;
9use std::marker::Sized;
10
11/// Perform image convolutions
12pub trait ConvolutionExt<T: Copy>
13where
14    Self: Sized,
15{
16    /// Type for the output as data will have to be allocated
17    type Output;
18
19    /// Perform a convolution returning the resultant data
20    /// applies the default padding of zero padding
21    fn conv2d<U: Data<Elem = T>>(&self, kernel: ArrayBase<U, Ix3>) -> Result<Self::Output, Error>;
22    /// Performs the convolution inplace mutating the containers data
23    /// applies the default padding of zero padding
24    fn conv2d_inplace<U: Data<Elem = T>>(&mut self, kernel: ArrayBase<U, Ix3>)
25        -> Result<(), Error>;
26    /// Perform a convolution returning the resultant data
27    /// applies the default padding of zero padding
28    fn conv2d_with_padding<U: Data<Elem = T>>(
29        &self,
30        kernel: ArrayBase<U, Ix3>,
31        strategy: &impl PaddingStrategy<T>,
32    ) -> Result<Self::Output, Error>;
33    /// Performs the convolution inplace mutating the containers data
34    /// applies the default padding of zero padding
35    fn conv2d_inplace_with_padding<U: Data<Elem = T>>(
36        &mut self,
37        kernel: ArrayBase<U, Ix3>,
38        strategy: &impl PaddingStrategy<T>,
39    ) -> Result<(), Error>;
40}
41
42fn apply_edge_convolution<T>(
43    array: ArrayView3<T>,
44    kernel: ArrayView3<T>,
45    coord: (usize, usize),
46    strategy: &impl PaddingStrategy<T>,
47) -> Vec<T>
48where
49    T: Copy + Num + NumAssignOps,
50{
51    let out_of_bounds =
52        |r, c| r < 0 || c < 0 || r >= array.dim().0 as isize || c >= array.dim().1 as isize;
53    let (row_offset, col_offset) = kernel_centre(kernel.dim().0, kernel.dim().1);
54
55    let top = coord.0 as isize - row_offset as isize;
56    let bottom = (coord.0 + row_offset + 1) as isize;
57    let left = coord.1 as isize - col_offset as isize;
58    let right = (coord.1 + col_offset + 1) as isize;
59    let channels = array.dim().2;
60    let mut res = vec![T::zero(); channels];
61    'processing: for (kr, r) in (top..bottom).enumerate() {
62        for (kc, c) in (left..right).enumerate() {
63            let oob = out_of_bounds(r, c);
64            if oob && !strategy.will_pad(Some((r, c))) {
65                for chan in 0..channels {
66                    res[chan] = array[[coord.0, coord.1, chan]];
67                }
68                break 'processing;
69            }
70            for chan in 0..channels {
71                // TODO this doesn't work on no padding
72                if oob {
73                    if let Some(val) = strategy.get_value(array, (r, c, chan)) {
74                        res[chan] += kernel[[kr, kc, chan]] * val;
75                    } else {
76                        unreachable!()
77                    }
78                } else {
79                    res[chan] += kernel[[kr, kc, chan]] * array[[r as usize, c as usize, chan]];
80                }
81            }
82        }
83    }
84    res
85}
86
87impl<T, U> ConvolutionExt<T> for ArrayBase<U, Ix3>
88where
89    U: DataMut<Elem = T>,
90    T: Copy + Clone + Num + NumAssignOps,
91{
92    type Output = Array<T, Ix3>;
93
94    fn conv2d<B: Data<Elem = T>>(&self, kernel: ArrayBase<B, Ix3>) -> Result<Self::Output, Error> {
95        self.conv2d_with_padding(kernel, &NoPadding {})
96    }
97
98    fn conv2d_inplace<B: Data<Elem = T>>(
99        &mut self,
100        kernel: ArrayBase<B, Ix3>,
101    ) -> Result<(), Error> {
102        self.assign(&self.conv2d_with_padding(kernel, &NoPadding {})?);
103        Ok(())
104    }
105
106    #[inline]
107    fn conv2d_with_padding<B: Data<Elem = T>>(
108        &self,
109        kernel: ArrayBase<B, Ix3>,
110        strategy: &impl PaddingStrategy<T>,
111    ) -> Result<Self::Output, Error> {
112        if self.shape()[2] != kernel.shape()[2] {
113            Err(Error::ChannelDimensionMismatch)
114        } else {
115            let k_s = kernel.shape();
116            // Bit icky but handles fact that uncentred convolutions will cross the bounds
117            // otherwise
118            let (row_offset, col_offset) = kernel_centre(k_s[0], k_s[1]);
119            let shape = (self.shape()[0], self.shape()[1], self.shape()[2]);
120
121            if shape.0 > 0 && shape.1 > 0 {
122                let mut result = Self::Output::uninit(shape);
123
124                Zip::indexed(self.windows(kernel.dim())).for_each(|(i, j, _), window| {
125                    let mut temp;
126                    for channel in 0..k_s[2] {
127                        temp = T::zero();
128                        for r in 0..k_s[0] {
129                            for c in 0..k_s[1] {
130                                temp += window[[r, c, channel]] * kernel[[r, c, channel]];
131                            }
132                        }
133                        unsafe {
134                            *result.uget_mut([i + row_offset, j + col_offset, channel]) =
135                                MaybeUninit::new(temp);
136                        }
137                    }
138                });
139                for c in 0..shape.1 {
140                    for r in 0..row_offset {
141                        let pixel =
142                            apply_edge_convolution(self.view(), kernel.view(), (r, c), strategy);
143                        for chan in 0..k_s[2] {
144                            unsafe {
145                                *result.uget_mut([r, c, chan]) = MaybeUninit::new(pixel[chan]);
146                            }
147                        }
148                        let bottom = shape.0 - r - 1;
149                        let pixel = apply_edge_convolution(
150                            self.view(),
151                            kernel.view(),
152                            (bottom, c),
153                            strategy,
154                        );
155                        for chan in 0..k_s[2] {
156                            unsafe {
157                                *result.uget_mut([bottom, c, chan]) = MaybeUninit::new(pixel[chan]);
158                            }
159                        }
160                    }
161                }
162                for r in (row_offset)..(shape.0 - row_offset) {
163                    for c in 0..col_offset {
164                        let pixel =
165                            apply_edge_convolution(self.view(), kernel.view(), (r, c), strategy);
166                        for chan in 0..k_s[2] {
167                            unsafe {
168                                *result.uget_mut([r, c, chan]) = MaybeUninit::new(pixel[chan]);
169                            }
170                        }
171                        let right = shape.1 - c - 1;
172                        let pixel = apply_edge_convolution(
173                            self.view(),
174                            kernel.view(),
175                            (r, right),
176                            strategy,
177                        );
178                        for chan in 0..k_s[2] {
179                            unsafe {
180                                *result.uget_mut([r, right, chan]) = MaybeUninit::new(pixel[chan]);
181                            }
182                        }
183                    }
184                }
185                Ok(unsafe { result.assume_init() })
186            } else {
187                Err(Error::InvalidDimensions)
188            }
189        }
190    }
191
192    fn conv2d_inplace_with_padding<B: Data<Elem = T>>(
193        &mut self,
194        kernel: ArrayBase<B, Ix3>,
195        strategy: &impl PaddingStrategy<T>,
196    ) -> Result<(), Error> {
197        self.assign(&self.conv2d_with_padding(kernel, strategy)?);
198        Ok(())
199    }
200}
201
202impl<T, U, C> ConvolutionExt<T> for ImageBase<U, C>
203where
204    U: DataMut<Elem = T>,
205    T: Copy + Clone + Num + NumAssignOps,
206    C: ColourModel,
207{
208    type Output = Image<T, C>;
209
210    fn conv2d<B: Data<Elem = T>>(&self, kernel: ArrayBase<B, Ix3>) -> Result<Self::Output, Error> {
211        let data = self.data.conv2d(kernel)?;
212        Ok(Self::Output {
213            data,
214            model: PhantomData,
215        })
216    }
217
218    fn conv2d_inplace<B: Data<Elem = T>>(
219        &mut self,
220        kernel: ArrayBase<B, Ix3>,
221    ) -> Result<(), Error> {
222        self.data.conv2d_inplace(kernel)
223    }
224
225    fn conv2d_with_padding<B: Data<Elem = T>>(
226        &self,
227        kernel: ArrayBase<B, Ix3>,
228        strategy: &impl PaddingStrategy<T>,
229    ) -> Result<Self::Output, Error> {
230        let data = self.data.conv2d_with_padding(kernel, strategy)?;
231        Ok(Self::Output {
232            data,
233            model: PhantomData,
234        })
235    }
236
237    fn conv2d_inplace_with_padding<B: Data<Elem = T>>(
238        &mut self,
239        kernel: ArrayBase<B, Ix3>,
240        strategy: &impl PaddingStrategy<T>,
241    ) -> Result<(), Error> {
242        self.data.conv2d_inplace_with_padding(kernel, strategy)
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use crate::core::colour_models::{Gray, RGB};
250    use ndarray::arr3;
251
252    #[test]
253    fn bad_dimensions() {
254        let error = Err(Error::ChannelDimensionMismatch);
255        let error2 = Err(Error::ChannelDimensionMismatch);
256
257        let mut i = Image::<f64, RGB>::new(5, 5);
258        let bad_kern = Array3::<f64>::zeros((2, 2, 2));
259        assert_eq!(i.conv2d(bad_kern.view()), error);
260
261        let data_clone = i.data.clone();
262        let res = i.conv2d_inplace(bad_kern.view());
263        assert_eq!(res, error2);
264        assert_eq!(i.data, data_clone);
265
266        let good_kern = Array3::<f64>::zeros((2, 2, RGB::channels()));
267        assert!(i.conv2d(good_kern.view()).is_ok());
268        assert!(i.conv2d_inplace(good_kern.view()).is_ok());
269    }
270
271    #[test]
272    #[rustfmt::skip]
273    fn basic_conv() {
274        let input_pixels = vec![
275            1, 1, 1, 0, 0,
276            0, 1, 1, 1, 0,
277            0, 0, 1, 1, 1,
278            0, 0, 1, 1, 0,
279            0, 1, 1, 0, 0,
280        ];
281        let output_pixels = vec![
282            1, 1, 1, 0, 0,
283            0, 4, 3, 4, 0,
284            0, 2, 4, 3, 1,
285            0, 2, 3, 4, 0,
286            0, 1, 1, 0, 0, 
287        ];
288
289        let kern = arr3(
290            &[
291                [[1], [0], [1]],
292                [[0], [1], [0]],
293                [[1], [0], [1]]
294            ]);
295
296        let input = Image::<u8, Gray>::from_shape_data(5, 5, input_pixels);
297        let expected = Image::<u8, Gray>::from_shape_data(5, 5, output_pixels);
298
299        assert_eq!(Ok(expected), input.conv2d(kern.view()));
300    }
301
302    #[test]
303    #[rustfmt::skip]
304    fn basic_conv_inplace() {
305        let input_pixels = vec![
306            1, 1, 1, 0, 0,
307            0, 1, 1, 1, 0,
308            0, 0, 1, 1, 1,
309            0, 0, 1, 1, 0,
310            0, 1, 1, 0, 0,
311        ];
312
313        let output_pixels = vec![
314            2, 2, 3, 1, 1,
315            1, 4, 3, 4, 1,
316            1, 2, 4, 3, 3,
317            1, 2, 3, 4, 1,
318            0, 2, 2, 1, 1,
319        ];
320
321        let kern = arr3(
322            &[
323                [[1], [0], [1]],
324                [[0], [1], [0]],
325                [[1], [0], [1]]
326            ]);
327
328        let mut input = Image::<u8, Gray>::from_shape_data(5, 5, input_pixels);
329        let expected = Image::<u8, Gray>::from_shape_data(5, 5, output_pixels);
330        let padding = ZeroPadding {};
331        input.conv2d_inplace_with_padding(kern.view(), &padding).unwrap();
332
333        assert_eq!(expected, input);
334    }
335}