ndarray_conv/dilation/
mod.rs

1//! Provides functionality for kernel dilation.
2
3use ndarray::{ArrayBase, Data, Dim, Dimension, IntoDimension, Ix, RawData};
4
5/// Represents a kernel along with its dilation factors for each dimension.
6pub struct KernelWithDilation<'a, S: RawData, const N: usize> {
7    pub kernel: &'a ArrayBase<S, Dim<[Ix; N]>>,
8    pub dilation: [usize; N],
9}
10
11impl<'a, S: RawData, const N: usize, T> KernelWithDilation<'a, S, N>
12where
13    T: num::traits::NumAssign + Copy,
14    S: Data<Elem = T>,
15    Dim<[Ix; N]>: Dimension,
16{
17    /// Generates a list of offsets and corresponding kernel values for efficient convolution.
18    ///
19    /// This method calculates the offsets into the input array that need to be accessed
20    /// during the convolution operation, taking into account the kernel's dilation.
21    /// It filters out elements where the kernel value is zero to optimize the computation.
22    ///
23    /// # Arguments
24    ///
25    /// * `pds_strides`: The strides of the padded input array.
26    ///
27    /// # Returns
28    /// A `Vec` of tuples, where each tuple contains an offset and the corresponding kernel value.
29    pub fn gen_offset_list(&self, pds_strides: &[isize]) -> Vec<(isize, T)> {
30        let strides: [isize; N] =
31            std::array::from_fn(|i| self.dilation[i] as isize * pds_strides[i]);
32
33        self.kernel
34            .indexed_iter()
35            .filter(|(_, v)| **v != T::zero())
36            .map(|(index, v)| {
37                let index = index.into_dimension();
38                (
39                    (0..N)
40                        .map(|n| index[n] as isize * strides[n])
41                        .sum::<isize>(),
42                    *v,
43                )
44            })
45            .collect()
46
47        // let first = self.kernel.as_ptr();
48        // self.kernel
49        //     .iter()
50        //     .filter(|v| **v != T::zero())
51        //     .map(|v| (unsafe { (v as *const T).offset_from(first) }, *v))
52        //     .collect()
53    }
54}
55
56impl<'a, S: RawData, const N: usize> From<&'a ArrayBase<S, Dim<[Ix; N]>>>
57    for KernelWithDilation<'a, S, N>
58{
59    fn from(kernel: &'a ArrayBase<S, Dim<[Ix; N]>>) -> Self {
60        Self {
61            kernel,
62            dilation: [1; N],
63        }
64    }
65}
66
67/// Trait for converting a value into a dilation array.
68pub trait IntoDilation<const N: usize> {
69    fn into_dilation(self) -> [usize; N];
70}
71
72impl<const N: usize> IntoDilation<N> for usize {
73    #[inline]
74    fn into_dilation(self) -> [usize; N] {
75        [self; N]
76    }
77}
78
79impl<const N: usize> IntoDilation<N> for [usize; N] {
80    #[inline]
81    fn into_dilation(self) -> [usize; N] {
82        self
83    }
84}
85
86/// Trait for adding dilation information to a kernel.
87pub trait WithDilation<S: RawData, const N: usize> {
88    fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<S, N>;
89}
90
91impl<S: RawData, const N: usize> WithDilation<S, N> for ArrayBase<S, Dim<[Ix; N]>> {
92    #[inline]
93    fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<S, N> {
94        KernelWithDilation {
95            kernel: self,
96            dilation: dilation.into_dilation(),
97        }
98    }
99}
100
101/// Trait for converting a reference to a `KernelWithDilation`.
102pub trait IntoKernelWithDilation<'a, S: RawData, const N: usize> {
103    fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N>;
104}
105
106impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
107    for &'a ArrayBase<S, Dim<[Ix; N]>>
108{
109    #[inline]
110    fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
111        self.with_dilation(1)
112    }
113}
114
115impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
116    for KernelWithDilation<'a, S, N>
117{
118    #[inline]
119    fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
120        self
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use ndarray::array;
127
128    use super::*;
129
130    #[test]
131    fn check_trait_impl() {
132        fn conv_example<'a, S: RawData + 'a, const N: usize>(
133            kernel: impl IntoKernelWithDilation<'a, S, N>,
134        ) {
135            let _ = kernel.into_kernel_with_dilation();
136        }
137
138        let kernel = array![1, 0, 1];
139
140        conv_example(&kernel);
141
142        let kernel = array![1, 0, 1];
143
144        conv_example(kernel.with_dilation(2));
145
146        let kernel = array![[1, 0, 1], [0, 1, 0]];
147
148        conv_example(kernel.with_dilation([1, 2]));
149    }
150
151    #[test]
152    fn check_ndarray_strides() {
153        // let arr = array![[1, 1, 1], [1, 1, 1]];
154        // dbg!(&arr);
155
156        // dbg!(arr.with_dilation(2).gen_offset_list(arr.shape()));
157
158        // let arr = array![[[1, 1, 1], [1, 1, 1]]];
159        // dbg!(&arr);
160    }
161}