ndarray_conv/dilation/
mod.rs

1//! Provides functionality for kernel dilation.
2
3use ndarray::{
4    ArrayBase, Data, Dim, Dimension, IntoDimension, Ix, RawData, SliceArg, SliceInfo, SliceInfoElem,
5};
6
7/// Represents a kernel along with its dilation factors for each dimension.
8pub struct KernelWithDilation<'a, S: RawData, const N: usize> {
9    pub(crate) kernel: &'a ArrayBase<S, Dim<[Ix; N]>>,
10    pub(crate) dilation: [usize; N],
11    pub(crate) reverse: bool,
12}
13
14impl<'a, S: RawData, const N: usize, T> KernelWithDilation<'a, S, N>
15where
16    T: num::traits::NumAssign + Copy,
17    S: Data<Elem = T>,
18    Dim<[Ix; N]>: Dimension,
19    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
20        SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
21{
22    /// Generates a list of offsets and corresponding kernel values for efficient convolution.
23    ///
24    /// This method calculates the offsets into the input array that need to be accessed
25    /// during the convolution operation, taking into account the kernel's dilation.
26    /// It filters out elements where the kernel value is zero to optimize the computation.
27    ///
28    /// # Arguments
29    ///
30    /// * `pds_strides`: The strides of the padded input array.
31    ///
32    /// # Returns
33    /// A `Vec` of tuples, where each tuple contains an offset and the corresponding kernel value.
34    pub fn gen_offset_list(&self, pds_strides: &[isize]) -> Vec<(isize, T)> {
35        let buffer_slice = self.kernel.slice(unsafe {
36            SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
37                start: 0,
38                end: Some(self.kernel.raw_dim()[i] as isize),
39                step: if self.reverse { -1 } else { 1 },
40            }))
41            .unwrap()
42        });
43
44        let strides: [isize; N] =
45            std::array::from_fn(|i| self.dilation[i] as isize * pds_strides[i]);
46
47        buffer_slice
48            .indexed_iter()
49            .filter(|(_, v)| **v != T::zero())
50            .map(|(index, v)| {
51                let index = index.into_dimension();
52                (
53                    (0..N)
54                        .map(|n| index[n] as isize * strides[n])
55                        .sum::<isize>(),
56                    *v,
57                )
58            })
59            .collect()
60    }
61}
62
63/// Trait for converting a value into a dilation array.
64pub trait IntoDilation<const N: usize> {
65    fn into_dilation(self) -> [usize; N];
66}
67
68impl<const N: usize> IntoDilation<N> for usize {
69    #[inline]
70    fn into_dilation(self) -> [usize; N] {
71        [self; N]
72    }
73}
74
75impl<const N: usize> IntoDilation<N> for [usize; N] {
76    #[inline]
77    fn into_dilation(self) -> [usize; N] {
78        self
79    }
80}
81
82/// Trait for adding dilation information to a kernel.
83pub trait WithDilation<S: RawData, const N: usize> {
84    fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<'_, S, N>;
85}
86
87impl<S: RawData, const N: usize> WithDilation<S, N> for ArrayBase<S, Dim<[Ix; N]>> {
88    #[inline]
89    fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<'_, S, N> {
90        KernelWithDilation {
91            kernel: self,
92            dilation: dilation.into_dilation(),
93            reverse: true,
94        }
95    }
96}
97
98pub trait ReverseKernel<'a, S: RawData, const N: usize> {
99    fn reverse(self) -> KernelWithDilation<'a, S, N>;
100    fn no_reverse(self) -> KernelWithDilation<'a, S, N>;
101}
102
103impl<'a, S: RawData, K, const N: usize> ReverseKernel<'a, S, N> for K
104where
105    K: IntoKernelWithDilation<'a, S, N>,
106{
107    #[inline]
108    fn reverse(self) -> KernelWithDilation<'a, S, N> {
109        let mut kwd = self.into_kernel_with_dilation();
110
111        kwd.reverse = true;
112
113        kwd
114    }
115
116    #[inline]
117    fn no_reverse(self) -> KernelWithDilation<'a, S, N> {
118        let mut kwd = self.into_kernel_with_dilation();
119
120        kwd.reverse = false;
121
122        kwd
123    }
124}
125
126/// Trait for converting a reference to a `KernelWithDilation`.
127pub trait IntoKernelWithDilation<'a, S: RawData, const N: usize> {
128    fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N>;
129}
130
131impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
132    for &'a ArrayBase<S, Dim<[Ix; N]>>
133{
134    #[inline]
135    fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
136        self.with_dilation(1)
137    }
138}
139
140impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
141    for KernelWithDilation<'a, S, N>
142{
143    #[inline]
144    fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
145        self
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use ndarray::array;
152
153    use super::*;
154
155    // ===== Trait Implementation Tests =====
156
157    mod trait_implementation {
158        use super::*;
159
160        #[test]
161        fn check_trait_impl() {
162            fn conv_example<'a, S: RawData + 'a, const N: usize>(
163                kernel: impl IntoKernelWithDilation<'a, S, N>,
164            ) {
165                let _ = kernel.into_kernel_with_dilation();
166            }
167
168            let kernel = array![1, 0, 1];
169            conv_example(&kernel);
170
171            let kernel = array![1, 0, 1];
172            conv_example(kernel.with_dilation(2));
173
174            let kernel = array![[1, 0, 1], [0, 1, 0]];
175            conv_example(kernel.with_dilation([1, 2]));
176
177            // for convolution (default)
178            conv_example(&kernel);
179            // for convolution (explicit)
180            conv_example(kernel.reverse());
181            // for cross-correlation
182            conv_example(kernel.with_dilation(2).no_reverse());
183        }
184    }
185
186    // ===== Basic API Tests =====
187
188    mod basic_api {
189        use super::*;
190
191        #[test]
192        fn dilation_and_reverse_settings() {
193            let kernel = array![1, 2, 3];
194
195            // Test dilation is set correctly for different dimensions
196            assert_eq!(kernel.with_dilation(2).dilation, [2]);
197            assert_eq!(array![[1, 2]].with_dilation([2, 3]).dilation, [2, 3]);
198            assert_eq!(array![[[1]]].with_dilation([1, 2, 3]).dilation, [1, 2, 3]);
199
200            // Test reverse behavior (default is true, can be toggled)
201            assert!(kernel.with_dilation(1).reverse);
202            assert!(!kernel.with_dilation(1).no_reverse().reverse);
203            assert!(kernel.with_dilation(1).no_reverse().reverse().reverse);
204        }
205    }
206
207    // ===== Offset Generation Tests =====
208
209    mod offset_generation {
210        use super::*;
211
212        #[test]
213        fn gen_offset_1d_no_dilation() {
214            let kernel = array![1.0, 2.0, 3.0];
215            let kwd = kernel.with_dilation(1);
216
217            // Stride = 1 for 1D
218            let offsets = kwd.gen_offset_list(&[1]);
219
220            // Should have 3 offsets (all kernel elements)
221            assert_eq!(offsets.len(), 3);
222
223            // With reverse=true, kernel is reversed: [3, 2, 1]
224            // Offsets: [0, 1, 2] * stride[1] = [0, 1, 2]
225            assert_eq!(offsets[0], (0, 3.0));
226            assert_eq!(offsets[1], (1, 2.0));
227            assert_eq!(offsets[2], (2, 1.0));
228        }
229
230        #[test]
231        fn gen_offset_1d_with_dilation() {
232            let kernel = array![1.0, 2.0, 3.0];
233            let kwd = kernel.with_dilation(2);
234
235            // Stride = 1, but dilation = 2
236            let offsets = kwd.gen_offset_list(&[1]);
237
238            assert_eq!(offsets.len(), 3);
239
240            // Effective kernel: [1, 0, 2, 0, 3]
241            // With reverse, indices with dilation: [0*2, 1*2, 2*2] = [0, 2, 4]
242            // But reversed: [3, 2, 1] at positions [0, 2, 4]
243            assert_eq!(offsets[0], (0, 3.0));
244            assert_eq!(offsets[1], (2, 2.0));
245            assert_eq!(offsets[2], (4, 1.0));
246        }
247
248        #[test]
249        fn gen_offset_1d_no_reverse() {
250            let kernel = array![1.0, 2.0, 3.0];
251            let kwd = kernel.with_dilation(2).no_reverse();
252
253            let offsets = kwd.gen_offset_list(&[1]);
254
255            assert_eq!(offsets.len(), 3);
256
257            // No reverse: [1, 2, 3] at positions [0, 2, 4]
258            assert_eq!(offsets[0], (0, 1.0));
259            assert_eq!(offsets[1], (2, 2.0));
260            assert_eq!(offsets[2], (4, 3.0));
261        }
262
263        #[test]
264        fn gen_offset_2d_no_dilation() {
265            let kernel = array![[1.0, 2.0], [3.0, 4.0]];
266            let kwd = kernel.with_dilation(1);
267
268            // Strides for 2D: [row_stride, col_stride]
269            let offsets = kwd.gen_offset_list(&[10, 1]);
270
271            assert_eq!(offsets.len(), 4);
272
273            // With reverse, kernel becomes [[4, 3], [2, 1]]
274            // Flattened in row-major order with reversed indices:
275            // (0,0)=4 at offset 0, (0,1)=3 at offset 1, (1,0)=2 at offset 10, (1,1)=1 at offset 11
276            assert_eq!(offsets[0], (0, 4.0));
277            assert_eq!(offsets[1], (1, 3.0));
278            assert_eq!(offsets[2], (10, 2.0));
279            assert_eq!(offsets[3], (11, 1.0));
280        }
281
282        #[test]
283        fn gen_offset_2d_with_dilation() {
284            let kernel = array![[1.0, 2.0], [3.0, 4.0]];
285            let kwd = kernel.with_dilation([2, 3]);
286
287            let offsets = kwd.gen_offset_list(&[10, 1]);
288
289            assert_eq!(offsets.len(), 4);
290
291            // Dilation [2, 3] means:
292            // - row spacing = 2 (kernel rows are 0 and 2*10=20 apart)
293            // - col spacing = 3 (kernel cols are 0 and 3*1=3 apart)
294            // With reverse, kernel [[4,3],[2,1]] at effective positions:
295            // (0,0)=4 at 0, (0,3)=3 at 3, (2,0)=2 at 20, (2,3)=1 at 23
296            assert_eq!(offsets[0], (0, 4.0));
297            assert_eq!(offsets[1], (3, 3.0));
298            assert_eq!(offsets[2], (20, 2.0));
299            assert_eq!(offsets[3], (23, 1.0));
300        }
301
302        #[test]
303        fn gen_offset_filters_zeros() {
304            let kernel = array![1.0, 0.0, 2.0, 0.0, 3.0];
305            let kwd = kernel.with_dilation(1);
306
307            let offsets = kwd.gen_offset_list(&[1]);
308
309            // Should only have 3 offsets (non-zero elements)
310            assert_eq!(offsets.len(), 3);
311        }
312    }
313
314    // ===== Edge Cases =====
315
316    mod edge_cases {
317        use super::*;
318
319        #[test]
320        fn single_element_kernel() {
321            let kernel = array![42.0];
322            let kwd = kernel.with_dilation(5);
323
324            assert_eq!(kwd.dilation, [5]);
325
326            let offsets = kwd.gen_offset_list(&[1]);
327            assert_eq!(offsets.len(), 1);
328            assert_eq!(offsets[0], (0, 42.0));
329        }
330
331        #[test]
332        fn all_zeros_kernel() {
333            let kernel = array![0.0, 0.0, 0.0];
334            let kwd = kernel.with_dilation(2);
335
336            let offsets = kwd.gen_offset_list(&[1]);
337            // Should filter out all zeros
338            assert_eq!(offsets.len(), 0);
339        }
340
341        #[test]
342        fn large_dilation_value() {
343            let kernel = array![1, 2];
344            let kwd = kernel.with_dilation(100);
345
346            assert_eq!(kwd.dilation, [100]);
347            // Effective size: 2 + (2-1)*99 = 101
348        }
349
350        #[test]
351        fn asymmetric_2d_dilation() {
352            let kernel = array![[1, 2, 3], [4, 5, 6]];
353            let kwd = kernel.with_dilation([1, 5]);
354
355            assert_eq!(kwd.dilation, [1, 5]);
356            // dim 0: no dilation (keeps 2 rows)
357            // dim 1: dilation=5 (3 + (3-1)*4 = 11 effective cols)
358        }
359    }
360
361    // ===== Integration Tests =====
362
363    mod integration_with_padding {
364        use super::*;
365
366        #[test]
367        fn effective_kernel_size_calculation() {
368            // This tests the concept used in padding calculations
369            let kernel = array![1, 2, 3];
370
371            // No dilation
372            let kwd1 = kernel.with_dilation(1);
373            let effective_size_1 = kernel.len() + (kernel.len() - 1) * (kwd1.dilation[0] - 1);
374            assert_eq!(effective_size_1, 3);
375
376            // Dilation = 2
377            let kwd2 = kernel.with_dilation(2);
378            let effective_size_2 = kernel.len() + (kernel.len() - 1) * (kwd2.dilation[0] - 1);
379            assert_eq!(effective_size_2, 5);
380
381            // Dilation = 3
382            let kwd3 = kernel.with_dilation(3);
383            let effective_size_3 = kernel.len() + (kernel.len() - 1) * (kwd3.dilation[0] - 1);
384            assert_eq!(effective_size_3, 7);
385        }
386    }
387}