ndarray_conv/dilation/
mod.rs

1//! Provides functionality for kernel dilation.
2
3use ndarray::{
4    ArrayBase, Data, Dim, Dimension, IntoDimension, Ix, RawData, SliceArg, SliceInfo, SliceInfoElem,
5};
6
7/// Represents a kernel along with its dilation factors for each dimension.
8pub struct KernelWithDilation<'a, S: RawData, const N: usize> {
9    pub(crate) kernel: &'a ArrayBase<S, Dim<[Ix; N]>>,
10    pub(crate) dilation: [usize; N],
11    pub(crate) reverse: bool,
12}
13
14impl<'a, S: RawData, const N: usize, T> KernelWithDilation<'a, S, N>
15where
16    T: num::traits::NumAssign + Copy,
17    S: Data<Elem = T>,
18    Dim<[Ix; N]>: Dimension,
19    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
20        SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
21{
22    /// Generates a list of offsets and corresponding kernel values for efficient convolution.
23    ///
24    /// This method calculates the offsets into the input array that need to be accessed
25    /// during the convolution operation, taking into account the kernel's dilation.
26    /// It filters out elements where the kernel value is zero to optimize the computation.
27    ///
28    /// # Arguments
29    ///
30    /// * `pds_strides`: The strides of the padded input array.
31    ///
32    /// # Returns
33    /// A `Vec` of tuples, where each tuple contains an offset and the corresponding kernel value.
34    pub fn gen_offset_list(&self, pds_strides: &[isize]) -> Vec<(isize, T)> {
35        let buffer_slice = self.kernel.slice(unsafe {
36            SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
37                start: 0,
38                end: Some(self.kernel.raw_dim()[i] as isize),
39                step: if self.reverse { -1 } else { 1 },
40            }))
41            .unwrap()
42        });
43
44        let strides: [isize; N] =
45            std::array::from_fn(|i| self.dilation[i] as isize * pds_strides[i]);
46
47        buffer_slice
48            .indexed_iter()
49            .filter(|(_, v)| **v != T::zero())
50            .map(|(index, v)| {
51                let index = index.into_dimension();
52                (
53                    (0..N)
54                        .map(|n| index[n] as isize * strides[n])
55                        .sum::<isize>(),
56                    *v,
57                )
58            })
59            .collect()
60    }
61}
62
63/// Trait for converting a value into a dilation array.
64pub trait IntoDilation<const N: usize> {
65    fn into_dilation(self) -> [usize; N];
66}
67
68impl<const N: usize> IntoDilation<N> for usize {
69    #[inline]
70    fn into_dilation(self) -> [usize; N] {
71        [self; N]
72    }
73}
74
75impl<const N: usize> IntoDilation<N> for [usize; N] {
76    #[inline]
77    fn into_dilation(self) -> [usize; N] {
78        self
79    }
80}
81
82/// Trait for adding dilation information to a kernel.
83pub trait WithDilation<S: RawData, const N: usize> {
84    fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<S, N>;
85}
86
87impl<S: RawData, const N: usize> WithDilation<S, N> for ArrayBase<S, Dim<[Ix; N]>> {
88    #[inline]
89    fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<S, N> {
90        KernelWithDilation {
91            kernel: self,
92            dilation: dilation.into_dilation(),
93            reverse: true,
94        }
95    }
96}
97
98pub trait ReverseKernel<'a, S: RawData, const N: usize> {
99    fn reverse(self) -> KernelWithDilation<'a, S, N>;
100    fn no_reverse(self) -> KernelWithDilation<'a, S, N>;
101}
102
103impl<'a, S: RawData, K, const N: usize> ReverseKernel<'a, S, N> for K
104where
105    K: IntoKernelWithDilation<'a, S, N>,
106{
107    #[inline]
108    fn reverse(self) -> KernelWithDilation<'a, S, N> {
109        let mut kwd = self.into_kernel_with_dilation();
110
111        kwd.reverse = true;
112
113        kwd
114    }
115
116    #[inline]
117    fn no_reverse(self) -> KernelWithDilation<'a, S, N> {
118        let mut kwd = self.into_kernel_with_dilation();
119
120        kwd.reverse = false;
121
122        kwd
123    }
124}
125
126/// Trait for converting a reference to a `KernelWithDilation`.
127pub trait IntoKernelWithDilation<'a, S: RawData, const N: usize> {
128    fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N>;
129}
130
131impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
132    for &'a ArrayBase<S, Dim<[Ix; N]>>
133{
134    #[inline]
135    fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
136        self.with_dilation(1)
137    }
138}
139
140impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
141    for KernelWithDilation<'a, S, N>
142{
143    #[inline]
144    fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
145        self
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use ndarray::array;
152
153    use super::*;
154
155    #[test]
156    fn check_trait_impl() {
157        // #[deprecated(since = "0.4.2", note = "test")]
158        fn conv_example<'a, S: RawData + 'a, const N: usize>(
159            kernel: impl IntoKernelWithDilation<'a, S, N>,
160        ) {
161            let _ = kernel.into_kernel_with_dilation();
162        }
163
164        let kernel = array![1, 0, 1];
165
166        conv_example(&kernel);
167
168        let kernel = array![1, 0, 1];
169
170        conv_example(kernel.with_dilation(2));
171
172        let kernel = array![[1, 0, 1], [0, 1, 0]];
173
174        conv_example(kernel.with_dilation([1, 2]));
175
176        // for convolution (default)
177        conv_example(&kernel);
178        // for convolution (explicit)
179        conv_example(kernel.reverse());
180        // for cross-correlation
181        conv_example(kernel.with_dilation(2).no_reverse());
182    }
183}