1use ndarray::{
4 ArrayBase, Data, Dim, Dimension, IntoDimension, Ix, RawData, SliceArg, SliceInfo, SliceInfoElem,
5};
6
7pub struct KernelWithDilation<'a, S: RawData, const N: usize> {
9 pub(crate) kernel: &'a ArrayBase<S, Dim<[Ix; N]>>,
10 pub(crate) dilation: [usize; N],
11 pub(crate) reverse: bool,
12}
13
14impl<'a, S: RawData, const N: usize, T> KernelWithDilation<'a, S, N>
15where
16 T: num::traits::NumAssign + Copy,
17 S: Data<Elem = T>,
18 Dim<[Ix; N]>: Dimension,
19 SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
20 SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
21{
22 pub fn gen_offset_list(&self, pds_strides: &[isize]) -> Vec<(isize, T)> {
35 let buffer_slice = self.kernel.slice(unsafe {
36 SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
37 start: 0,
38 end: Some(self.kernel.raw_dim()[i] as isize),
39 step: if self.reverse { -1 } else { 1 },
40 }))
41 .unwrap()
42 });
43
44 let strides: [isize; N] =
45 std::array::from_fn(|i| self.dilation[i] as isize * pds_strides[i]);
46
47 buffer_slice
48 .indexed_iter()
49 .filter(|(_, v)| **v != T::zero())
50 .map(|(index, v)| {
51 let index = index.into_dimension();
52 (
53 (0..N)
54 .map(|n| index[n] as isize * strides[n])
55 .sum::<isize>(),
56 *v,
57 )
58 })
59 .collect()
60 }
61}
62
63pub trait IntoDilation<const N: usize> {
65 fn into_dilation(self) -> [usize; N];
66}
67
68impl<const N: usize> IntoDilation<N> for usize {
69 #[inline]
70 fn into_dilation(self) -> [usize; N] {
71 [self; N]
72 }
73}
74
75impl<const N: usize> IntoDilation<N> for [usize; N] {
76 #[inline]
77 fn into_dilation(self) -> [usize; N] {
78 self
79 }
80}
81
82pub trait WithDilation<S: RawData, const N: usize> {
84 fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<S, N>;
85}
86
87impl<S: RawData, const N: usize> WithDilation<S, N> for ArrayBase<S, Dim<[Ix; N]>> {
88 #[inline]
89 fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<S, N> {
90 KernelWithDilation {
91 kernel: self,
92 dilation: dilation.into_dilation(),
93 reverse: true,
94 }
95 }
96}
97
98pub trait ReverseKernel<'a, S: RawData, const N: usize> {
99 fn reverse(self) -> KernelWithDilation<'a, S, N>;
100 fn no_reverse(self) -> KernelWithDilation<'a, S, N>;
101}
102
103impl<'a, S: RawData, K, const N: usize> ReverseKernel<'a, S, N> for K
104where
105 K: IntoKernelWithDilation<'a, S, N>,
106{
107 #[inline]
108 fn reverse(self) -> KernelWithDilation<'a, S, N> {
109 let mut kwd = self.into_kernel_with_dilation();
110
111 kwd.reverse = true;
112
113 kwd
114 }
115
116 #[inline]
117 fn no_reverse(self) -> KernelWithDilation<'a, S, N> {
118 let mut kwd = self.into_kernel_with_dilation();
119
120 kwd.reverse = false;
121
122 kwd
123 }
124}
125
126pub trait IntoKernelWithDilation<'a, S: RawData, const N: usize> {
128 fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N>;
129}
130
131impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
132 for &'a ArrayBase<S, Dim<[Ix; N]>>
133{
134 #[inline]
135 fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
136 self.with_dilation(1)
137 }
138}
139
140impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
141 for KernelWithDilation<'a, S, N>
142{
143 #[inline]
144 fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
145 self
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use ndarray::array;
152
153 use super::*;
154
155 #[test]
156 fn check_trait_impl() {
157 fn conv_example<'a, S: RawData + 'a, const N: usize>(
159 kernel: impl IntoKernelWithDilation<'a, S, N>,
160 ) {
161 let _ = kernel.into_kernel_with_dilation();
162 }
163
164 let kernel = array![1, 0, 1];
165
166 conv_example(&kernel);
167
168 let kernel = array![1, 0, 1];
169
170 conv_example(kernel.with_dilation(2));
171
172 let kernel = array![[1, 0, 1], [0, 1, 0]];
173
174 conv_example(kernel.with_dilation([1, 2]));
175
176 conv_example(&kernel);
178 conv_example(kernel.reverse());
180 conv_example(kernel.with_dilation(2).no_reverse());
182 }
183}