1use ndarray::{
4 ArrayBase, Data, Dim, Dimension, IntoDimension, Ix, RawData, SliceArg, SliceInfo, SliceInfoElem,
5};
6
7pub struct KernelWithDilation<'a, S: RawData, const N: usize> {
9 pub(crate) kernel: &'a ArrayBase<S, Dim<[Ix; N]>>,
10 pub(crate) dilation: [usize; N],
11 pub(crate) reverse: bool,
12}
13
14impl<'a, S: RawData, const N: usize, T> KernelWithDilation<'a, S, N>
15where
16 T: num::traits::NumAssign + Copy,
17 S: Data<Elem = T>,
18 Dim<[Ix; N]>: Dimension,
19 SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
20 SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
21{
22 pub fn gen_offset_list(&self, pds_strides: &[isize]) -> Vec<(isize, T)> {
35 let buffer_slice = self.kernel.slice(unsafe {
36 SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
37 start: 0,
38 end: Some(self.kernel.raw_dim()[i] as isize),
39 step: if self.reverse { -1 } else { 1 },
40 }))
41 .unwrap()
42 });
43
44 let strides: [isize; N] =
45 std::array::from_fn(|i| self.dilation[i] as isize * pds_strides[i]);
46
47 buffer_slice
48 .indexed_iter()
49 .filter(|(_, v)| **v != T::zero())
50 .map(|(index, v)| {
51 let index = index.into_dimension();
52 (
53 (0..N)
54 .map(|n| index[n] as isize * strides[n])
55 .sum::<isize>(),
56 *v,
57 )
58 })
59 .collect()
60 }
61}
62
63pub trait IntoDilation<const N: usize> {
65 fn into_dilation(self) -> [usize; N];
66}
67
68impl<const N: usize> IntoDilation<N> for usize {
69 #[inline]
70 fn into_dilation(self) -> [usize; N] {
71 [self; N]
72 }
73}
74
75impl<const N: usize> IntoDilation<N> for [usize; N] {
76 #[inline]
77 fn into_dilation(self) -> [usize; N] {
78 self
79 }
80}
81
82pub trait WithDilation<S: RawData, const N: usize> {
104 fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<'_, S, N>;
115}
116
117impl<S: RawData, const N: usize> WithDilation<S, N> for ArrayBase<S, Dim<[Ix; N]>> {
118 #[inline]
119 fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<'_, S, N> {
120 KernelWithDilation {
121 kernel: self,
122 dilation: dilation.into_dilation(),
123 reverse: true,
124 }
125 }
126}
127
128pub trait ReverseKernel<'a, S: RawData, const N: usize> {
156 fn reverse(self) -> KernelWithDilation<'a, S, N>;
160
161 fn no_reverse(self) -> KernelWithDilation<'a, S, N>;
166}
167
168impl<'a, S: RawData, K, const N: usize> ReverseKernel<'a, S, N> for K
169where
170 K: IntoKernelWithDilation<'a, S, N>,
171{
172 #[inline]
173 fn reverse(self) -> KernelWithDilation<'a, S, N> {
174 let mut kwd = self.into_kernel_with_dilation();
175
176 kwd.reverse = true;
177
178 kwd
179 }
180
181 #[inline]
182 fn no_reverse(self) -> KernelWithDilation<'a, S, N> {
183 let mut kwd = self.into_kernel_with_dilation();
184
185 kwd.reverse = false;
186
187 kwd
188 }
189}
190
191pub trait IntoKernelWithDilation<'a, S: RawData, const N: usize> {
193 fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N>;
194}
195
196impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
197 for &'a ArrayBase<S, Dim<[Ix; N]>>
198{
199 #[inline]
200 fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
201 self.with_dilation(1)
202 }
203}
204
205impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
206 for KernelWithDilation<'a, S, N>
207{
208 #[inline]
209 fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
210 self
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use ndarray::array;
217
218 use super::*;
219
220 mod trait_implementation {
223 use super::*;
224
225 #[test]
226 fn check_trait_impl() {
227 fn conv_example<'a, S: RawData + 'a, const N: usize>(
228 kernel: impl IntoKernelWithDilation<'a, S, N>,
229 ) {
230 let _ = kernel.into_kernel_with_dilation();
231 }
232
233 let kernel = array![1, 0, 1];
234 conv_example(&kernel);
235
236 let kernel = array![1, 0, 1];
237 conv_example(kernel.with_dilation(2));
238
239 let kernel = array![[1, 0, 1], [0, 1, 0]];
240 conv_example(kernel.with_dilation([1, 2]));
241
242 conv_example(&kernel);
244 conv_example(kernel.reverse());
246 conv_example(kernel.with_dilation(2).no_reverse());
248 }
249 }
250
251 mod basic_api {
254 use super::*;
255
256 #[test]
257 fn dilation_and_reverse_settings() {
258 let kernel = array![1, 2, 3];
259
260 assert_eq!(kernel.with_dilation(2).dilation, [2]);
262 assert_eq!(array![[1, 2]].with_dilation([2, 3]).dilation, [2, 3]);
263 assert_eq!(array![[[1]]].with_dilation([1, 2, 3]).dilation, [1, 2, 3]);
264
265 assert!(kernel.with_dilation(1).reverse);
267 assert!(!kernel.with_dilation(1).no_reverse().reverse);
268 assert!(kernel.with_dilation(1).no_reverse().reverse().reverse);
269 }
270 }
271
272 mod offset_generation {
275 use super::*;
276
277 #[test]
278 fn gen_offset_1d_no_dilation() {
279 let kernel = array![1.0, 2.0, 3.0];
280 let kwd = kernel.with_dilation(1);
281
282 let offsets = kwd.gen_offset_list(&[1]);
284
285 assert_eq!(offsets.len(), 3);
287
288 assert_eq!(offsets[0], (0, 3.0));
291 assert_eq!(offsets[1], (1, 2.0));
292 assert_eq!(offsets[2], (2, 1.0));
293 }
294
295 #[test]
296 fn gen_offset_1d_with_dilation() {
297 let kernel = array![1.0, 2.0, 3.0];
298 let kwd = kernel.with_dilation(2);
299
300 let offsets = kwd.gen_offset_list(&[1]);
302
303 assert_eq!(offsets.len(), 3);
304
305 assert_eq!(offsets[0], (0, 3.0));
309 assert_eq!(offsets[1], (2, 2.0));
310 assert_eq!(offsets[2], (4, 1.0));
311 }
312
313 #[test]
314 fn gen_offset_1d_no_reverse() {
315 let kernel = array![1.0, 2.0, 3.0];
316 let kwd = kernel.with_dilation(2).no_reverse();
317
318 let offsets = kwd.gen_offset_list(&[1]);
319
320 assert_eq!(offsets.len(), 3);
321
322 assert_eq!(offsets[0], (0, 1.0));
324 assert_eq!(offsets[1], (2, 2.0));
325 assert_eq!(offsets[2], (4, 3.0));
326 }
327
328 #[test]
329 fn gen_offset_2d_no_dilation() {
330 let kernel = array![[1.0, 2.0], [3.0, 4.0]];
331 let kwd = kernel.with_dilation(1);
332
333 let offsets = kwd.gen_offset_list(&[10, 1]);
335
336 assert_eq!(offsets.len(), 4);
337
338 assert_eq!(offsets[0], (0, 4.0));
342 assert_eq!(offsets[1], (1, 3.0));
343 assert_eq!(offsets[2], (10, 2.0));
344 assert_eq!(offsets[3], (11, 1.0));
345 }
346
347 #[test]
348 fn gen_offset_2d_with_dilation() {
349 let kernel = array![[1.0, 2.0], [3.0, 4.0]];
350 let kwd = kernel.with_dilation([2, 3]);
351
352 let offsets = kwd.gen_offset_list(&[10, 1]);
353
354 assert_eq!(offsets.len(), 4);
355
356 assert_eq!(offsets[0], (0, 4.0));
362 assert_eq!(offsets[1], (3, 3.0));
363 assert_eq!(offsets[2], (20, 2.0));
364 assert_eq!(offsets[3], (23, 1.0));
365 }
366
367 #[test]
368 fn gen_offset_filters_zeros() {
369 let kernel = array![1.0, 0.0, 2.0, 0.0, 3.0];
370 let kwd = kernel.with_dilation(1);
371
372 let offsets = kwd.gen_offset_list(&[1]);
373
374 assert_eq!(offsets.len(), 3);
376 }
377 }
378
379 mod edge_cases {
382 use super::*;
383
384 #[test]
385 fn single_element_kernel() {
386 let kernel = array![42.0];
387 let kwd = kernel.with_dilation(5);
388
389 assert_eq!(kwd.dilation, [5]);
390
391 let offsets = kwd.gen_offset_list(&[1]);
392 assert_eq!(offsets.len(), 1);
393 assert_eq!(offsets[0], (0, 42.0));
394 }
395
396 #[test]
397 fn all_zeros_kernel() {
398 let kernel = array![0.0, 0.0, 0.0];
399 let kwd = kernel.with_dilation(2);
400
401 let offsets = kwd.gen_offset_list(&[1]);
402 assert_eq!(offsets.len(), 0);
404 }
405
406 #[test]
407 fn large_dilation_value() {
408 let kernel = array![1, 2];
409 let kwd = kernel.with_dilation(100);
410
411 assert_eq!(kwd.dilation, [100]);
412 }
414
415 #[test]
416 fn asymmetric_2d_dilation() {
417 let kernel = array![[1, 2, 3], [4, 5, 6]];
418 let kwd = kernel.with_dilation([1, 5]);
419
420 assert_eq!(kwd.dilation, [1, 5]);
421 }
424 }
425
426 mod integration_with_padding {
429 use super::*;
430
431 #[test]
432 fn effective_kernel_size_calculation() {
433 let kernel = array![1, 2, 3];
435
436 let kwd1 = kernel.with_dilation(1);
438 let effective_size_1 = kernel.len() + (kernel.len() - 1) * (kwd1.dilation[0] - 1);
439 assert_eq!(effective_size_1, 3);
440
441 let kwd2 = kernel.with_dilation(2);
443 let effective_size_2 = kernel.len() + (kernel.len() - 1) * (kwd2.dilation[0] - 1);
444 assert_eq!(effective_size_2, 5);
445
446 let kwd3 = kernel.with_dilation(3);
448 let effective_size_3 = kernel.len() + (kernel.len() - 1) * (kwd3.dilation[0] - 1);
449 assert_eq!(effective_size_3, 7);
450 }
451 }
452}