ndarray_conv/dilation/
mod.rs1use ndarray::{ArrayBase, Data, Dim, Dimension, IntoDimension, Ix, RawData};
4
5pub struct KernelWithDilation<'a, S: RawData, const N: usize> {
7 pub kernel: &'a ArrayBase<S, Dim<[Ix; N]>>,
8 pub dilation: [usize; N],
9}
10
11impl<'a, S: RawData, const N: usize, T> KernelWithDilation<'a, S, N>
12where
13 T: num::traits::NumAssign + Copy,
14 S: Data<Elem = T>,
15 Dim<[Ix; N]>: Dimension,
16{
17 pub fn gen_offset_list(&self, pds_strides: &[isize]) -> Vec<(isize, T)> {
30 let strides: [isize; N] =
31 std::array::from_fn(|i| self.dilation[i] as isize * pds_strides[i]);
32
33 self.kernel
34 .indexed_iter()
35 .filter(|(_, v)| **v != T::zero())
36 .map(|(index, v)| {
37 let index = index.into_dimension();
38 (
39 (0..N)
40 .map(|n| index[n] as isize * strides[n])
41 .sum::<isize>(),
42 *v,
43 )
44 })
45 .collect()
46
47 }
54}
55
56impl<'a, S: RawData, const N: usize> From<&'a ArrayBase<S, Dim<[Ix; N]>>>
57 for KernelWithDilation<'a, S, N>
58{
59 fn from(kernel: &'a ArrayBase<S, Dim<[Ix; N]>>) -> Self {
60 Self {
61 kernel,
62 dilation: [1; N],
63 }
64 }
65}
66
67pub trait IntoDilation<const N: usize> {
69 fn into_dilation(self) -> [usize; N];
70}
71
72impl<const N: usize> IntoDilation<N> for usize {
73 #[inline]
74 fn into_dilation(self) -> [usize; N] {
75 [self; N]
76 }
77}
78
79impl<const N: usize> IntoDilation<N> for [usize; N] {
80 #[inline]
81 fn into_dilation(self) -> [usize; N] {
82 self
83 }
84}
85
86pub trait WithDilation<S: RawData, const N: usize> {
88 fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<S, N>;
89}
90
91impl<S: RawData, const N: usize> WithDilation<S, N> for ArrayBase<S, Dim<[Ix; N]>> {
92 #[inline]
93 fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<S, N> {
94 KernelWithDilation {
95 kernel: self,
96 dilation: dilation.into_dilation(),
97 }
98 }
99}
100
101pub trait IntoKernelWithDilation<'a, S: RawData, const N: usize> {
103 fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N>;
104}
105
106impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
107 for &'a ArrayBase<S, Dim<[Ix; N]>>
108{
109 #[inline]
110 fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
111 self.with_dilation(1)
112 }
113}
114
115impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
116 for KernelWithDilation<'a, S, N>
117{
118 #[inline]
119 fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
120 self
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use ndarray::array;
127
128 use super::*;
129
130 #[test]
131 fn check_trait_impl() {
132 fn conv_example<'a, S: RawData + 'a, const N: usize>(
133 kernel: impl IntoKernelWithDilation<'a, S, N>,
134 ) {
135 let _ = kernel.into_kernel_with_dilation();
136 }
137
138 let kernel = array![1, 0, 1];
139
140 conv_example(&kernel);
141
142 let kernel = array![1, 0, 1];
143
144 conv_example(kernel.with_dilation(2));
145
146 let kernel = array![[1, 0, 1], [0, 1, 0]];
147
148 conv_example(kernel.with_dilation([1, 2]));
149 }
150
151 #[test]
152 fn check_ndarray_strides() {
153 }
161}