ndarray_conv/dilation/
mod.rs

1//! Provides functionality for kernel dilation.
2
3use ndarray::{
4    ArrayBase, Data, Dim, Dimension, IntoDimension, Ix, RawData, SliceArg, SliceInfo, SliceInfoElem,
5};
6
7/// Represents a kernel along with its dilation factors for each dimension.
8pub struct KernelWithDilation<'a, S: RawData, const N: usize> {
9    pub(crate) kernel: &'a ArrayBase<S, Dim<[Ix; N]>>,
10    pub(crate) dilation: [usize; N],
11    pub(crate) reverse: bool,
12}
13
14impl<'a, S: RawData, const N: usize, T> KernelWithDilation<'a, S, N>
15where
16    T: num::traits::NumAssign + Copy,
17    S: Data<Elem = T>,
18    Dim<[Ix; N]>: Dimension,
19    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
20        SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
21{
22    /// Generates a list of offsets and corresponding kernel values for efficient convolution.
23    ///
24    /// This method calculates the offsets into the input array that need to be accessed
25    /// during the convolution operation, taking into account the kernel's dilation.
26    /// It filters out elements where the kernel value is zero to optimize the computation.
27    ///
28    /// # Arguments
29    ///
30    /// * `pds_strides`: The strides of the padded input array.
31    ///
32    /// # Returns
33    /// A `Vec` of tuples, where each tuple contains an offset and the corresponding kernel value.
34    pub fn gen_offset_list(&self, pds_strides: &[isize]) -> Vec<(isize, T)> {
35        let buffer_slice = self.kernel.slice(unsafe {
36            SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
37                start: 0,
38                end: Some(self.kernel.raw_dim()[i] as isize),
39                step: if self.reverse { -1 } else { 1 },
40            }))
41            .unwrap()
42        });
43
44        let strides: [isize; N] =
45            std::array::from_fn(|i| self.dilation[i] as isize * pds_strides[i]);
46
47        buffer_slice
48            .indexed_iter()
49            .filter(|(_, v)| **v != T::zero())
50            .map(|(index, v)| {
51                let index = index.into_dimension();
52                (
53                    (0..N)
54                        .map(|n| index[n] as isize * strides[n])
55                        .sum::<isize>(),
56                    *v,
57                )
58            })
59            .collect()
60    }
61}
62
63/// Trait for converting a value into a dilation array.
64pub trait IntoDilation<const N: usize> {
65    fn into_dilation(self) -> [usize; N];
66}
67
68impl<const N: usize> IntoDilation<N> for usize {
69    #[inline]
70    fn into_dilation(self) -> [usize; N] {
71        [self; N]
72    }
73}
74
75impl<const N: usize> IntoDilation<N> for [usize; N] {
76    #[inline]
77    fn into_dilation(self) -> [usize; N] {
78        self
79    }
80}
81
82/// Trait for adding dilation information to a kernel.
83///
84/// Dilation is a parameter that controls the spacing between kernel elements
85/// during convolution. A dilation of 1 means no spacing (standard convolution),
86/// while larger values insert gaps between kernel elements.
87///
88/// # Example
89///
90/// ```rust
91/// use ndarray::array;
92/// use ndarray_conv::{WithDilation, ConvExt, ConvMode, PaddingMode};
93///
94/// let input = array![1, 2, 3, 4, 5];
95/// let kernel = array![1, 1, 1];
96///
97/// // Standard convolution (dilation = 1)
98/// let result1 = input.conv(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap();
99///
100/// // Dilated convolution (dilation = 2)
101/// let result2 = input.conv(kernel.with_dilation(2), ConvMode::Same, PaddingMode::Zeros).unwrap();
102/// ```
103pub trait WithDilation<S: RawData, const N: usize> {
104    /// Adds dilation information to the kernel.
105    ///
106    /// # Arguments
107    ///
108    /// * `dilation`: The dilation factor(s). Can be a single value (applied to all dimensions)
109    ///   or an array of values (one per dimension).
110    ///
111    /// # Returns
112    ///
113    /// A `KernelWithDilation` instance containing the kernel and dilation information.
114    fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<'_, S, N>;
115}
116
117impl<S: RawData, const N: usize> WithDilation<S, N> for ArrayBase<S, Dim<[Ix; N]>> {
118    #[inline]
119    fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<'_, S, N> {
120        KernelWithDilation {
121            kernel: self,
122            dilation: dilation.into_dilation(),
123            reverse: true,
124        }
125    }
126}
127
128/// Trait for controlling kernel reversal behavior in convolution operations.
129///
130/// In standard convolution, the kernel is reversed (flipped) along all axes.
131/// This trait allows you to control whether the kernel should be reversed or not.
132///
133/// # Convolution vs Cross-Correlation
134///
135/// * **Convolution** (default, `reverse()`): The kernel is reversed, which is the mathematical definition of convolution.
136/// * **Cross-correlation** (`no_reverse()`): The kernel is NOT reversed. This is commonly used in machine learning frameworks.
137///
138/// # Example
139///
140/// ```rust
141/// use ndarray::array;
142/// use ndarray_conv::{WithDilation, ReverseKernel, ConvExt, ConvMode, PaddingMode};
143///
144/// let input = array![1, 2, 3, 4, 5];
145/// let kernel = array![1, 2, 3];
146///
147/// // Standard convolution (kernel is reversed)
148/// let result1 = input.conv(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap();
149/// // Equivalent to:
150/// let result1_explicit = input.conv(kernel.reverse(), ConvMode::Same, PaddingMode::Zeros).unwrap();
151///
152/// // Cross-correlation (kernel is NOT reversed)
153/// let result2 = input.conv(kernel.no_reverse(), ConvMode::Same, PaddingMode::Zeros).unwrap();
154/// ```
155pub trait ReverseKernel<'a, S: RawData, const N: usize> {
156    /// Explicitly enables kernel reversal (standard convolution).
157    ///
158    /// This is the default behavior, so calling this method is usually not necessary.
159    fn reverse(self) -> KernelWithDilation<'a, S, N>;
160
161    /// Disables kernel reversal (cross-correlation).
162    ///
163    /// Use this when you want the kernel to be applied without flipping,
164    /// which is common in machine learning applications.
165    fn no_reverse(self) -> KernelWithDilation<'a, S, N>;
166}
167
168impl<'a, S: RawData, K, const N: usize> ReverseKernel<'a, S, N> for K
169where
170    K: IntoKernelWithDilation<'a, S, N>,
171{
172    #[inline]
173    fn reverse(self) -> KernelWithDilation<'a, S, N> {
174        let mut kwd = self.into_kernel_with_dilation();
175
176        kwd.reverse = true;
177
178        kwd
179    }
180
181    #[inline]
182    fn no_reverse(self) -> KernelWithDilation<'a, S, N> {
183        let mut kwd = self.into_kernel_with_dilation();
184
185        kwd.reverse = false;
186
187        kwd
188    }
189}
190
191/// Trait for converting a reference to a `KernelWithDilation`.
192pub trait IntoKernelWithDilation<'a, S: RawData, const N: usize> {
193    fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N>;
194}
195
196impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
197    for &'a ArrayBase<S, Dim<[Ix; N]>>
198{
199    #[inline]
200    fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
201        self.with_dilation(1)
202    }
203}
204
205impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
206    for KernelWithDilation<'a, S, N>
207{
208    #[inline]
209    fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
210        self
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use ndarray::array;
217
218    use super::*;
219
220    // ===== Trait Implementation Tests =====
221
222    mod trait_implementation {
223        use super::*;
224
225        #[test]
226        fn check_trait_impl() {
227            fn conv_example<'a, S: RawData + 'a, const N: usize>(
228                kernel: impl IntoKernelWithDilation<'a, S, N>,
229            ) {
230                let _ = kernel.into_kernel_with_dilation();
231            }
232
233            let kernel = array![1, 0, 1];
234            conv_example(&kernel);
235
236            let kernel = array![1, 0, 1];
237            conv_example(kernel.with_dilation(2));
238
239            let kernel = array![[1, 0, 1], [0, 1, 0]];
240            conv_example(kernel.with_dilation([1, 2]));
241
242            // for convolution (default)
243            conv_example(&kernel);
244            // for convolution (explicit)
245            conv_example(kernel.reverse());
246            // for cross-correlation
247            conv_example(kernel.with_dilation(2).no_reverse());
248        }
249    }
250
251    // ===== Basic API Tests =====
252
253    mod basic_api {
254        use super::*;
255
256        #[test]
257        fn dilation_and_reverse_settings() {
258            let kernel = array![1, 2, 3];
259
260            // Test dilation is set correctly for different dimensions
261            assert_eq!(kernel.with_dilation(2).dilation, [2]);
262            assert_eq!(array![[1, 2]].with_dilation([2, 3]).dilation, [2, 3]);
263            assert_eq!(array![[[1]]].with_dilation([1, 2, 3]).dilation, [1, 2, 3]);
264
265            // Test reverse behavior (default is true, can be toggled)
266            assert!(kernel.with_dilation(1).reverse);
267            assert!(!kernel.with_dilation(1).no_reverse().reverse);
268            assert!(kernel.with_dilation(1).no_reverse().reverse().reverse);
269        }
270    }
271
272    // ===== Offset Generation Tests =====
273
274    mod offset_generation {
275        use super::*;
276
277        #[test]
278        fn gen_offset_1d_no_dilation() {
279            let kernel = array![1.0, 2.0, 3.0];
280            let kwd = kernel.with_dilation(1);
281
282            // Stride = 1 for 1D
283            let offsets = kwd.gen_offset_list(&[1]);
284
285            // Should have 3 offsets (all kernel elements)
286            assert_eq!(offsets.len(), 3);
287
288            // With reverse=true, kernel is reversed: [3, 2, 1]
289            // Offsets: [0, 1, 2] * stride[1] = [0, 1, 2]
290            assert_eq!(offsets[0], (0, 3.0));
291            assert_eq!(offsets[1], (1, 2.0));
292            assert_eq!(offsets[2], (2, 1.0));
293        }
294
295        #[test]
296        fn gen_offset_1d_with_dilation() {
297            let kernel = array![1.0, 2.0, 3.0];
298            let kwd = kernel.with_dilation(2);
299
300            // Stride = 1, but dilation = 2
301            let offsets = kwd.gen_offset_list(&[1]);
302
303            assert_eq!(offsets.len(), 3);
304
305            // Effective kernel: [1, 0, 2, 0, 3]
306            // With reverse, indices with dilation: [0*2, 1*2, 2*2] = [0, 2, 4]
307            // But reversed: [3, 2, 1] at positions [0, 2, 4]
308            assert_eq!(offsets[0], (0, 3.0));
309            assert_eq!(offsets[1], (2, 2.0));
310            assert_eq!(offsets[2], (4, 1.0));
311        }
312
313        #[test]
314        fn gen_offset_1d_no_reverse() {
315            let kernel = array![1.0, 2.0, 3.0];
316            let kwd = kernel.with_dilation(2).no_reverse();
317
318            let offsets = kwd.gen_offset_list(&[1]);
319
320            assert_eq!(offsets.len(), 3);
321
322            // No reverse: [1, 2, 3] at positions [0, 2, 4]
323            assert_eq!(offsets[0], (0, 1.0));
324            assert_eq!(offsets[1], (2, 2.0));
325            assert_eq!(offsets[2], (4, 3.0));
326        }
327
328        #[test]
329        fn gen_offset_2d_no_dilation() {
330            let kernel = array![[1.0, 2.0], [3.0, 4.0]];
331            let kwd = kernel.with_dilation(1);
332
333            // Strides for 2D: [row_stride, col_stride]
334            let offsets = kwd.gen_offset_list(&[10, 1]);
335
336            assert_eq!(offsets.len(), 4);
337
338            // With reverse, kernel becomes [[4, 3], [2, 1]]
339            // Flattened in row-major order with reversed indices:
340            // (0,0)=4 at offset 0, (0,1)=3 at offset 1, (1,0)=2 at offset 10, (1,1)=1 at offset 11
341            assert_eq!(offsets[0], (0, 4.0));
342            assert_eq!(offsets[1], (1, 3.0));
343            assert_eq!(offsets[2], (10, 2.0));
344            assert_eq!(offsets[3], (11, 1.0));
345        }
346
347        #[test]
348        fn gen_offset_2d_with_dilation() {
349            let kernel = array![[1.0, 2.0], [3.0, 4.0]];
350            let kwd = kernel.with_dilation([2, 3]);
351
352            let offsets = kwd.gen_offset_list(&[10, 1]);
353
354            assert_eq!(offsets.len(), 4);
355
356            // Dilation [2, 3] means:
357            // - row spacing = 2 (kernel rows are 0 and 2*10=20 apart)
358            // - col spacing = 3 (kernel cols are 0 and 3*1=3 apart)
359            // With reverse, kernel [[4,3],[2,1]] at effective positions:
360            // (0,0)=4 at 0, (0,3)=3 at 3, (2,0)=2 at 20, (2,3)=1 at 23
361            assert_eq!(offsets[0], (0, 4.0));
362            assert_eq!(offsets[1], (3, 3.0));
363            assert_eq!(offsets[2], (20, 2.0));
364            assert_eq!(offsets[3], (23, 1.0));
365        }
366
367        #[test]
368        fn gen_offset_filters_zeros() {
369            let kernel = array![1.0, 0.0, 2.0, 0.0, 3.0];
370            let kwd = kernel.with_dilation(1);
371
372            let offsets = kwd.gen_offset_list(&[1]);
373
374            // Should only have 3 offsets (non-zero elements)
375            assert_eq!(offsets.len(), 3);
376        }
377    }
378
379    // ===== Edge Cases =====
380
381    mod edge_cases {
382        use super::*;
383
384        #[test]
385        fn single_element_kernel() {
386            let kernel = array![42.0];
387            let kwd = kernel.with_dilation(5);
388
389            assert_eq!(kwd.dilation, [5]);
390
391            let offsets = kwd.gen_offset_list(&[1]);
392            assert_eq!(offsets.len(), 1);
393            assert_eq!(offsets[0], (0, 42.0));
394        }
395
396        #[test]
397        fn all_zeros_kernel() {
398            let kernel = array![0.0, 0.0, 0.0];
399            let kwd = kernel.with_dilation(2);
400
401            let offsets = kwd.gen_offset_list(&[1]);
402            // Should filter out all zeros
403            assert_eq!(offsets.len(), 0);
404        }
405
406        #[test]
407        fn large_dilation_value() {
408            let kernel = array![1, 2];
409            let kwd = kernel.with_dilation(100);
410
411            assert_eq!(kwd.dilation, [100]);
412            // Effective size: 2 + (2-1)*99 = 101
413        }
414
415        #[test]
416        fn asymmetric_2d_dilation() {
417            let kernel = array![[1, 2, 3], [4, 5, 6]];
418            let kwd = kernel.with_dilation([1, 5]);
419
420            assert_eq!(kwd.dilation, [1, 5]);
421            // dim 0: no dilation (keeps 2 rows)
422            // dim 1: dilation=5 (3 + (3-1)*4 = 11 effective cols)
423        }
424    }
425
426    // ===== Integration Tests =====
427
428    mod integration_with_padding {
429        use super::*;
430
431        #[test]
432        fn effective_kernel_size_calculation() {
433            // This tests the concept used in padding calculations
434            let kernel = array![1, 2, 3];
435
436            // No dilation
437            let kwd1 = kernel.with_dilation(1);
438            let effective_size_1 = kernel.len() + (kernel.len() - 1) * (kwd1.dilation[0] - 1);
439            assert_eq!(effective_size_1, 3);
440
441            // Dilation = 2
442            let kwd2 = kernel.with_dilation(2);
443            let effective_size_2 = kernel.len() + (kernel.len() - 1) * (kwd2.dilation[0] - 1);
444            assert_eq!(effective_size_2, 5);
445
446            // Dilation = 3
447            let kwd3 = kernel.with_dilation(3);
448            let effective_size_3 = kernel.len() + (kernel.len() - 1) * (kwd3.dilation[0] - 1);
449            assert_eq!(effective_size_3, 7);
450        }
451    }
452}