1use ndarray::{
4 ArrayBase, Data, Dim, Dimension, IntoDimension, Ix, RawData, SliceArg, SliceInfo, SliceInfoElem,
5};
6
7pub struct KernelWithDilation<'a, S: RawData, const N: usize> {
9 pub(crate) kernel: &'a ArrayBase<S, Dim<[Ix; N]>>,
10 pub(crate) dilation: [usize; N],
11 pub(crate) reverse: bool,
12}
13
14impl<'a, S: RawData, const N: usize, T> KernelWithDilation<'a, S, N>
15where
16 T: num::traits::NumAssign + Copy,
17 S: Data<Elem = T>,
18 Dim<[Ix; N]>: Dimension,
19 SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
20 SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
21{
22 pub fn gen_offset_list(&self, pds_strides: &[isize]) -> Vec<(isize, T)> {
35 let buffer_slice = self.kernel.slice(unsafe {
36 SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
37 start: 0,
38 end: Some(self.kernel.raw_dim()[i] as isize),
39 step: if self.reverse { -1 } else { 1 },
40 }))
41 .unwrap()
42 });
43
44 let strides: [isize; N] =
45 std::array::from_fn(|i| self.dilation[i] as isize * pds_strides[i]);
46
47 buffer_slice
48 .indexed_iter()
49 .filter(|(_, v)| **v != T::zero())
50 .map(|(index, v)| {
51 let index = index.into_dimension();
52 (
53 (0..N)
54 .map(|n| index[n] as isize * strides[n])
55 .sum::<isize>(),
56 *v,
57 )
58 })
59 .collect()
60 }
61}
62
63pub trait IntoDilation<const N: usize> {
65 fn into_dilation(self) -> [usize; N];
66}
67
68impl<const N: usize> IntoDilation<N> for usize {
69 #[inline]
70 fn into_dilation(self) -> [usize; N] {
71 [self; N]
72 }
73}
74
75impl<const N: usize> IntoDilation<N> for [usize; N] {
76 #[inline]
77 fn into_dilation(self) -> [usize; N] {
78 self
79 }
80}
81
82pub trait WithDilation<S: RawData, const N: usize> {
84 fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<'_, S, N>;
85}
86
87impl<S: RawData, const N: usize> WithDilation<S, N> for ArrayBase<S, Dim<[Ix; N]>> {
88 #[inline]
89 fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<'_, S, N> {
90 KernelWithDilation {
91 kernel: self,
92 dilation: dilation.into_dilation(),
93 reverse: true,
94 }
95 }
96}
97
98pub trait ReverseKernel<'a, S: RawData, const N: usize> {
99 fn reverse(self) -> KernelWithDilation<'a, S, N>;
100 fn no_reverse(self) -> KernelWithDilation<'a, S, N>;
101}
102
103impl<'a, S: RawData, K, const N: usize> ReverseKernel<'a, S, N> for K
104where
105 K: IntoKernelWithDilation<'a, S, N>,
106{
107 #[inline]
108 fn reverse(self) -> KernelWithDilation<'a, S, N> {
109 let mut kwd = self.into_kernel_with_dilation();
110
111 kwd.reverse = true;
112
113 kwd
114 }
115
116 #[inline]
117 fn no_reverse(self) -> KernelWithDilation<'a, S, N> {
118 let mut kwd = self.into_kernel_with_dilation();
119
120 kwd.reverse = false;
121
122 kwd
123 }
124}
125
126pub trait IntoKernelWithDilation<'a, S: RawData, const N: usize> {
128 fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N>;
129}
130
131impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
132 for &'a ArrayBase<S, Dim<[Ix; N]>>
133{
134 #[inline]
135 fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
136 self.with_dilation(1)
137 }
138}
139
140impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
141 for KernelWithDilation<'a, S, N>
142{
143 #[inline]
144 fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
145 self
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use ndarray::array;
152
153 use super::*;
154
155 mod trait_implementation {
158 use super::*;
159
160 #[test]
161 fn check_trait_impl() {
162 fn conv_example<'a, S: RawData + 'a, const N: usize>(
163 kernel: impl IntoKernelWithDilation<'a, S, N>,
164 ) {
165 let _ = kernel.into_kernel_with_dilation();
166 }
167
168 let kernel = array![1, 0, 1];
169 conv_example(&kernel);
170
171 let kernel = array![1, 0, 1];
172 conv_example(kernel.with_dilation(2));
173
174 let kernel = array![[1, 0, 1], [0, 1, 0]];
175 conv_example(kernel.with_dilation([1, 2]));
176
177 conv_example(&kernel);
179 conv_example(kernel.reverse());
181 conv_example(kernel.with_dilation(2).no_reverse());
183 }
184 }
185
186 mod basic_api {
189 use super::*;
190
191 #[test]
192 fn dilation_and_reverse_settings() {
193 let kernel = array![1, 2, 3];
194
195 assert_eq!(kernel.with_dilation(2).dilation, [2]);
197 assert_eq!(array![[1, 2]].with_dilation([2, 3]).dilation, [2, 3]);
198 assert_eq!(array![[[1]]].with_dilation([1, 2, 3]).dilation, [1, 2, 3]);
199
200 assert!(kernel.with_dilation(1).reverse);
202 assert!(!kernel.with_dilation(1).no_reverse().reverse);
203 assert!(kernel.with_dilation(1).no_reverse().reverse().reverse);
204 }
205 }
206
207 mod offset_generation {
210 use super::*;
211
212 #[test]
213 fn gen_offset_1d_no_dilation() {
214 let kernel = array![1.0, 2.0, 3.0];
215 let kwd = kernel.with_dilation(1);
216
217 let offsets = kwd.gen_offset_list(&[1]);
219
220 assert_eq!(offsets.len(), 3);
222
223 assert_eq!(offsets[0], (0, 3.0));
226 assert_eq!(offsets[1], (1, 2.0));
227 assert_eq!(offsets[2], (2, 1.0));
228 }
229
230 #[test]
231 fn gen_offset_1d_with_dilation() {
232 let kernel = array![1.0, 2.0, 3.0];
233 let kwd = kernel.with_dilation(2);
234
235 let offsets = kwd.gen_offset_list(&[1]);
237
238 assert_eq!(offsets.len(), 3);
239
240 assert_eq!(offsets[0], (0, 3.0));
244 assert_eq!(offsets[1], (2, 2.0));
245 assert_eq!(offsets[2], (4, 1.0));
246 }
247
248 #[test]
249 fn gen_offset_1d_no_reverse() {
250 let kernel = array![1.0, 2.0, 3.0];
251 let kwd = kernel.with_dilation(2).no_reverse();
252
253 let offsets = kwd.gen_offset_list(&[1]);
254
255 assert_eq!(offsets.len(), 3);
256
257 assert_eq!(offsets[0], (0, 1.0));
259 assert_eq!(offsets[1], (2, 2.0));
260 assert_eq!(offsets[2], (4, 3.0));
261 }
262
263 #[test]
264 fn gen_offset_2d_no_dilation() {
265 let kernel = array![[1.0, 2.0], [3.0, 4.0]];
266 let kwd = kernel.with_dilation(1);
267
268 let offsets = kwd.gen_offset_list(&[10, 1]);
270
271 assert_eq!(offsets.len(), 4);
272
273 assert_eq!(offsets[0], (0, 4.0));
277 assert_eq!(offsets[1], (1, 3.0));
278 assert_eq!(offsets[2], (10, 2.0));
279 assert_eq!(offsets[3], (11, 1.0));
280 }
281
282 #[test]
283 fn gen_offset_2d_with_dilation() {
284 let kernel = array![[1.0, 2.0], [3.0, 4.0]];
285 let kwd = kernel.with_dilation([2, 3]);
286
287 let offsets = kwd.gen_offset_list(&[10, 1]);
288
289 assert_eq!(offsets.len(), 4);
290
291 assert_eq!(offsets[0], (0, 4.0));
297 assert_eq!(offsets[1], (3, 3.0));
298 assert_eq!(offsets[2], (20, 2.0));
299 assert_eq!(offsets[3], (23, 1.0));
300 }
301
302 #[test]
303 fn gen_offset_filters_zeros() {
304 let kernel = array![1.0, 0.0, 2.0, 0.0, 3.0];
305 let kwd = kernel.with_dilation(1);
306
307 let offsets = kwd.gen_offset_list(&[1]);
308
309 assert_eq!(offsets.len(), 3);
311 }
312 }
313
314 mod edge_cases {
317 use super::*;
318
319 #[test]
320 fn single_element_kernel() {
321 let kernel = array![42.0];
322 let kwd = kernel.with_dilation(5);
323
324 assert_eq!(kwd.dilation, [5]);
325
326 let offsets = kwd.gen_offset_list(&[1]);
327 assert_eq!(offsets.len(), 1);
328 assert_eq!(offsets[0], (0, 42.0));
329 }
330
331 #[test]
332 fn all_zeros_kernel() {
333 let kernel = array![0.0, 0.0, 0.0];
334 let kwd = kernel.with_dilation(2);
335
336 let offsets = kwd.gen_offset_list(&[1]);
337 assert_eq!(offsets.len(), 0);
339 }
340
341 #[test]
342 fn large_dilation_value() {
343 let kernel = array![1, 2];
344 let kwd = kernel.with_dilation(100);
345
346 assert_eq!(kwd.dilation, [100]);
347 }
349
350 #[test]
351 fn asymmetric_2d_dilation() {
352 let kernel = array![[1, 2, 3], [4, 5, 6]];
353 let kwd = kernel.with_dilation([1, 5]);
354
355 assert_eq!(kwd.dilation, [1, 5]);
356 }
359 }
360
361 mod integration_with_padding {
364 use super::*;
365
366 #[test]
367 fn effective_kernel_size_calculation() {
368 let kernel = array![1, 2, 3];
370
371 let kwd1 = kernel.with_dilation(1);
373 let effective_size_1 = kernel.len() + (kernel.len() - 1) * (kwd1.dilation[0] - 1);
374 assert_eq!(effective_size_1, 3);
375
376 let kwd2 = kernel.with_dilation(2);
378 let effective_size_2 = kernel.len() + (kernel.len() - 1) * (kwd2.dilation[0] - 1);
379 assert_eq!(effective_size_2, 5);
380
381 let kwd3 = kernel.with_dilation(3);
383 let effective_size_3 = kernel.len() + (kernel.len() - 1) * (kwd3.dilation[0] - 1);
384 assert_eq!(effective_size_3, 7);
385 }
386 }
387}