1use crate::core::padding::*;
2use crate::core::{kernel_centre, ColourModel, Image, ImageBase};
3use crate::processing::Error;
4use core::mem::MaybeUninit;
5use ndarray::prelude::*;
6use ndarray::{Data, DataMut, Zip};
7use num_traits::{Num, NumAssignOps};
8use std::marker::PhantomData;
9use std::marker::Sized;
10
11pub trait ConvolutionExt<T: Copy>
13where
14 Self: Sized,
15{
16 type Output;
18
19 fn conv2d<U: Data<Elem = T>>(&self, kernel: ArrayBase<U, Ix3>) -> Result<Self::Output, Error>;
22 fn conv2d_inplace<U: Data<Elem = T>>(&mut self, kernel: ArrayBase<U, Ix3>)
25 -> Result<(), Error>;
26 fn conv2d_with_padding<U: Data<Elem = T>>(
29 &self,
30 kernel: ArrayBase<U, Ix3>,
31 strategy: &impl PaddingStrategy<T>,
32 ) -> Result<Self::Output, Error>;
33 fn conv2d_inplace_with_padding<U: Data<Elem = T>>(
36 &mut self,
37 kernel: ArrayBase<U, Ix3>,
38 strategy: &impl PaddingStrategy<T>,
39 ) -> Result<(), Error>;
40}
41
42fn apply_edge_convolution<T>(
43 array: ArrayView3<T>,
44 kernel: ArrayView3<T>,
45 coord: (usize, usize),
46 strategy: &impl PaddingStrategy<T>,
47) -> Vec<T>
48where
49 T: Copy + Num + NumAssignOps,
50{
51 let out_of_bounds =
52 |r, c| r < 0 || c < 0 || r >= array.dim().0 as isize || c >= array.dim().1 as isize;
53 let (row_offset, col_offset) = kernel_centre(kernel.dim().0, kernel.dim().1);
54
55 let top = coord.0 as isize - row_offset as isize;
56 let bottom = (coord.0 + row_offset + 1) as isize;
57 let left = coord.1 as isize - col_offset as isize;
58 let right = (coord.1 + col_offset + 1) as isize;
59 let channels = array.dim().2;
60 let mut res = vec![T::zero(); channels];
61 'processing: for (kr, r) in (top..bottom).enumerate() {
62 for (kc, c) in (left..right).enumerate() {
63 let oob = out_of_bounds(r, c);
64 if oob && !strategy.will_pad(Some((r, c))) {
65 for chan in 0..channels {
66 res[chan] = array[[coord.0, coord.1, chan]];
67 }
68 break 'processing;
69 }
70 for chan in 0..channels {
71 if oob {
73 if let Some(val) = strategy.get_value(array, (r, c, chan)) {
74 res[chan] += kernel[[kr, kc, chan]] * val;
75 } else {
76 unreachable!()
77 }
78 } else {
79 res[chan] += kernel[[kr, kc, chan]] * array[[r as usize, c as usize, chan]];
80 }
81 }
82 }
83 }
84 res
85}
86
87impl<T, U> ConvolutionExt<T> for ArrayBase<U, Ix3>
88where
89 U: DataMut<Elem = T>,
90 T: Copy + Clone + Num + NumAssignOps,
91{
92 type Output = Array<T, Ix3>;
93
94 fn conv2d<B: Data<Elem = T>>(&self, kernel: ArrayBase<B, Ix3>) -> Result<Self::Output, Error> {
95 self.conv2d_with_padding(kernel, &NoPadding {})
96 }
97
98 fn conv2d_inplace<B: Data<Elem = T>>(
99 &mut self,
100 kernel: ArrayBase<B, Ix3>,
101 ) -> Result<(), Error> {
102 self.assign(&self.conv2d_with_padding(kernel, &NoPadding {})?);
103 Ok(())
104 }
105
106 #[inline]
107 fn conv2d_with_padding<B: Data<Elem = T>>(
108 &self,
109 kernel: ArrayBase<B, Ix3>,
110 strategy: &impl PaddingStrategy<T>,
111 ) -> Result<Self::Output, Error> {
112 if self.shape()[2] != kernel.shape()[2] {
113 Err(Error::ChannelDimensionMismatch)
114 } else {
115 let k_s = kernel.shape();
116 let (row_offset, col_offset) = kernel_centre(k_s[0], k_s[1]);
119 let shape = (self.shape()[0], self.shape()[1], self.shape()[2]);
120
121 if shape.0 > 0 && shape.1 > 0 {
122 let mut result = Self::Output::uninit(shape);
123
124 Zip::indexed(self.windows(kernel.dim())).for_each(|(i, j, _), window| {
125 let mut temp;
126 for channel in 0..k_s[2] {
127 temp = T::zero();
128 for r in 0..k_s[0] {
129 for c in 0..k_s[1] {
130 temp += window[[r, c, channel]] * kernel[[r, c, channel]];
131 }
132 }
133 unsafe {
134 *result.uget_mut([i + row_offset, j + col_offset, channel]) =
135 MaybeUninit::new(temp);
136 }
137 }
138 });
139 for c in 0..shape.1 {
140 for r in 0..row_offset {
141 let pixel =
142 apply_edge_convolution(self.view(), kernel.view(), (r, c), strategy);
143 for chan in 0..k_s[2] {
144 unsafe {
145 *result.uget_mut([r, c, chan]) = MaybeUninit::new(pixel[chan]);
146 }
147 }
148 let bottom = shape.0 - r - 1;
149 let pixel = apply_edge_convolution(
150 self.view(),
151 kernel.view(),
152 (bottom, c),
153 strategy,
154 );
155 for chan in 0..k_s[2] {
156 unsafe {
157 *result.uget_mut([bottom, c, chan]) = MaybeUninit::new(pixel[chan]);
158 }
159 }
160 }
161 }
162 for r in (row_offset)..(shape.0 - row_offset) {
163 for c in 0..col_offset {
164 let pixel =
165 apply_edge_convolution(self.view(), kernel.view(), (r, c), strategy);
166 for chan in 0..k_s[2] {
167 unsafe {
168 *result.uget_mut([r, c, chan]) = MaybeUninit::new(pixel[chan]);
169 }
170 }
171 let right = shape.1 - c - 1;
172 let pixel = apply_edge_convolution(
173 self.view(),
174 kernel.view(),
175 (r, right),
176 strategy,
177 );
178 for chan in 0..k_s[2] {
179 unsafe {
180 *result.uget_mut([r, right, chan]) = MaybeUninit::new(pixel[chan]);
181 }
182 }
183 }
184 }
185 Ok(unsafe { result.assume_init() })
186 } else {
187 Err(Error::InvalidDimensions)
188 }
189 }
190 }
191
192 fn conv2d_inplace_with_padding<B: Data<Elem = T>>(
193 &mut self,
194 kernel: ArrayBase<B, Ix3>,
195 strategy: &impl PaddingStrategy<T>,
196 ) -> Result<(), Error> {
197 self.assign(&self.conv2d_with_padding(kernel, strategy)?);
198 Ok(())
199 }
200}
201
202impl<T, U, C> ConvolutionExt<T> for ImageBase<U, C>
203where
204 U: DataMut<Elem = T>,
205 T: Copy + Clone + Num + NumAssignOps,
206 C: ColourModel,
207{
208 type Output = Image<T, C>;
209
210 fn conv2d<B: Data<Elem = T>>(&self, kernel: ArrayBase<B, Ix3>) -> Result<Self::Output, Error> {
211 let data = self.data.conv2d(kernel)?;
212 Ok(Self::Output {
213 data,
214 model: PhantomData,
215 })
216 }
217
218 fn conv2d_inplace<B: Data<Elem = T>>(
219 &mut self,
220 kernel: ArrayBase<B, Ix3>,
221 ) -> Result<(), Error> {
222 self.data.conv2d_inplace(kernel)
223 }
224
225 fn conv2d_with_padding<B: Data<Elem = T>>(
226 &self,
227 kernel: ArrayBase<B, Ix3>,
228 strategy: &impl PaddingStrategy<T>,
229 ) -> Result<Self::Output, Error> {
230 let data = self.data.conv2d_with_padding(kernel, strategy)?;
231 Ok(Self::Output {
232 data,
233 model: PhantomData,
234 })
235 }
236
237 fn conv2d_inplace_with_padding<B: Data<Elem = T>>(
238 &mut self,
239 kernel: ArrayBase<B, Ix3>,
240 strategy: &impl PaddingStrategy<T>,
241 ) -> Result<(), Error> {
242 self.data.conv2d_inplace_with_padding(kernel, strategy)
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use crate::core::colour_models::{Gray, RGB};
250 use ndarray::arr3;
251
252 #[test]
253 fn bad_dimensions() {
254 let error = Err(Error::ChannelDimensionMismatch);
255 let error2 = Err(Error::ChannelDimensionMismatch);
256
257 let mut i = Image::<f64, RGB>::new(5, 5);
258 let bad_kern = Array3::<f64>::zeros((2, 2, 2));
259 assert_eq!(i.conv2d(bad_kern.view()), error);
260
261 let data_clone = i.data.clone();
262 let res = i.conv2d_inplace(bad_kern.view());
263 assert_eq!(res, error2);
264 assert_eq!(i.data, data_clone);
265
266 let good_kern = Array3::<f64>::zeros((2, 2, RGB::channels()));
267 assert!(i.conv2d(good_kern.view()).is_ok());
268 assert!(i.conv2d_inplace(good_kern.view()).is_ok());
269 }
270
271 #[test]
272 #[rustfmt::skip]
273 fn basic_conv() {
274 let input_pixels = vec![
275 1, 1, 1, 0, 0,
276 0, 1, 1, 1, 0,
277 0, 0, 1, 1, 1,
278 0, 0, 1, 1, 0,
279 0, 1, 1, 0, 0,
280 ];
281 let output_pixels = vec![
282 1, 1, 1, 0, 0,
283 0, 4, 3, 4, 0,
284 0, 2, 4, 3, 1,
285 0, 2, 3, 4, 0,
286 0, 1, 1, 0, 0,
287 ];
288
289 let kern = arr3(
290 &[
291 [[1], [0], [1]],
292 [[0], [1], [0]],
293 [[1], [0], [1]]
294 ]);
295
296 let input = Image::<u8, Gray>::from_shape_data(5, 5, input_pixels);
297 let expected = Image::<u8, Gray>::from_shape_data(5, 5, output_pixels);
298
299 assert_eq!(Ok(expected), input.conv2d(kern.view()));
300 }
301
302 #[test]
303 #[rustfmt::skip]
304 fn basic_conv_inplace() {
305 let input_pixels = vec![
306 1, 1, 1, 0, 0,
307 0, 1, 1, 1, 0,
308 0, 0, 1, 1, 1,
309 0, 0, 1, 1, 0,
310 0, 1, 1, 0, 0,
311 ];
312
313 let output_pixels = vec![
314 2, 2, 3, 1, 1,
315 1, 4, 3, 4, 1,
316 1, 2, 4, 3, 3,
317 1, 2, 3, 4, 1,
318 0, 2, 2, 1, 1,
319 ];
320
321 let kern = arr3(
322 &[
323 [[1], [0], [1]],
324 [[0], [1], [0]],
325 [[1], [0], [1]]
326 ]);
327
328 let mut input = Image::<u8, Gray>::from_shape_data(5, 5, input_pixels);
329 let expected = Image::<u8, Gray>::from_shape_data(5, 5, output_pixels);
330 let padding = ZeroPadding {};
331 input.conv2d_inplace_with_padding(kern.view(), &padding).unwrap();
332
333 assert_eq!(expected, input);
334 }
335}