use ndarray::{
ArrayBase, Data, Dim, Dimension, IntoDimension, Ix, RawData, SliceArg, SliceInfo, SliceInfoElem,
};
pub struct KernelWithDilation<'a, S: RawData, const N: usize> {
pub(crate) kernel: &'a ArrayBase<S, Dim<[Ix; N]>>,
pub(crate) dilation: [usize; N],
pub(crate) reverse: bool,
}
impl<'a, S: RawData, const N: usize, T> KernelWithDilation<'a, S, N>
where
T: num::traits::NumAssign + Copy,
S: Data<Elem = T>,
Dim<[Ix; N]>: Dimension,
SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
{
pub fn gen_offset_list(&self, pds_strides: &[isize]) -> Vec<(isize, T)> {
let buffer_slice = self.kernel.slice(unsafe {
SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
start: 0,
end: Some(self.kernel.raw_dim()[i] as isize),
step: if self.reverse { -1 } else { 1 },
}))
.unwrap()
});
let strides: [isize; N] =
std::array::from_fn(|i| self.dilation[i] as isize * pds_strides[i]);
buffer_slice
.indexed_iter()
.filter(|(_, v)| **v != T::zero())
.map(|(index, v)| {
let index = index.into_dimension();
(
(0..N)
.map(|n| index[n] as isize * strides[n])
.sum::<isize>(),
*v,
)
})
.collect()
}
}
pub trait IntoDilation<const N: usize> {
fn into_dilation(self) -> [usize; N];
}
impl<const N: usize> IntoDilation<N> for usize {
#[inline]
fn into_dilation(self) -> [usize; N] {
[self; N]
}
}
impl<const N: usize> IntoDilation<N> for [usize; N] {
#[inline]
fn into_dilation(self) -> [usize; N] {
self
}
}
pub trait WithDilation<S: RawData, const N: usize> {
fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<'_, S, N>;
}
impl<S: RawData, const N: usize> WithDilation<S, N> for ArrayBase<S, Dim<[Ix; N]>> {
#[inline]
fn with_dilation(&self, dilation: impl IntoDilation<N>) -> KernelWithDilation<'_, S, N> {
KernelWithDilation {
kernel: self,
dilation: dilation.into_dilation(),
reverse: true,
}
}
}
pub trait ReverseKernel<'a, S: RawData, const N: usize> {
fn reverse(self) -> KernelWithDilation<'a, S, N>;
fn no_reverse(self) -> KernelWithDilation<'a, S, N>;
}
impl<'a, S: RawData, K, const N: usize> ReverseKernel<'a, S, N> for K
where
K: IntoKernelWithDilation<'a, S, N>,
{
#[inline]
fn reverse(self) -> KernelWithDilation<'a, S, N> {
let mut kwd = self.into_kernel_with_dilation();
kwd.reverse = true;
kwd
}
#[inline]
fn no_reverse(self) -> KernelWithDilation<'a, S, N> {
let mut kwd = self.into_kernel_with_dilation();
kwd.reverse = false;
kwd
}
}
pub trait IntoKernelWithDilation<'a, S: RawData, const N: usize> {
fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N>;
}
impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
for &'a ArrayBase<S, Dim<[Ix; N]>>
{
#[inline]
fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
self.with_dilation(1)
}
}
impl<'a, S: RawData, const N: usize> IntoKernelWithDilation<'a, S, N>
for KernelWithDilation<'a, S, N>
{
#[inline]
fn into_kernel_with_dilation(self) -> KernelWithDilation<'a, S, N> {
self
}
}
#[cfg(test)]
mod tests {
use ndarray::array;
use super::*;
mod trait_implementation {
use super::*;
#[test]
fn check_trait_impl() {
fn conv_example<'a, S: RawData + 'a, const N: usize>(
kernel: impl IntoKernelWithDilation<'a, S, N>,
) {
let _ = kernel.into_kernel_with_dilation();
}
let kernel = array![1, 0, 1];
conv_example(&kernel);
let kernel = array![1, 0, 1];
conv_example(kernel.with_dilation(2));
let kernel = array![[1, 0, 1], [0, 1, 0]];
conv_example(kernel.with_dilation([1, 2]));
conv_example(&kernel);
conv_example(kernel.reverse());
conv_example(kernel.with_dilation(2).no_reverse());
}
}
mod basic_api {
use super::*;
#[test]
fn dilation_and_reverse_settings() {
let kernel = array![1, 2, 3];
assert_eq!(kernel.with_dilation(2).dilation, [2]);
assert_eq!(array![[1, 2]].with_dilation([2, 3]).dilation, [2, 3]);
assert_eq!(array![[[1]]].with_dilation([1, 2, 3]).dilation, [1, 2, 3]);
assert!(kernel.with_dilation(1).reverse);
assert!(!kernel.with_dilation(1).no_reverse().reverse);
assert!(kernel.with_dilation(1).no_reverse().reverse().reverse);
}
}
mod offset_generation {
use super::*;
#[test]
fn gen_offset_1d_no_dilation() {
let kernel = array![1.0, 2.0, 3.0];
let kwd = kernel.with_dilation(1);
let offsets = kwd.gen_offset_list(&[1]);
assert_eq!(offsets.len(), 3);
assert_eq!(offsets[0], (0, 3.0));
assert_eq!(offsets[1], (1, 2.0));
assert_eq!(offsets[2], (2, 1.0));
}
#[test]
fn gen_offset_1d_with_dilation() {
let kernel = array![1.0, 2.0, 3.0];
let kwd = kernel.with_dilation(2);
let offsets = kwd.gen_offset_list(&[1]);
assert_eq!(offsets.len(), 3);
assert_eq!(offsets[0], (0, 3.0));
assert_eq!(offsets[1], (2, 2.0));
assert_eq!(offsets[2], (4, 1.0));
}
#[test]
fn gen_offset_1d_no_reverse() {
let kernel = array![1.0, 2.0, 3.0];
let kwd = kernel.with_dilation(2).no_reverse();
let offsets = kwd.gen_offset_list(&[1]);
assert_eq!(offsets.len(), 3);
assert_eq!(offsets[0], (0, 1.0));
assert_eq!(offsets[1], (2, 2.0));
assert_eq!(offsets[2], (4, 3.0));
}
#[test]
fn gen_offset_2d_no_dilation() {
let kernel = array![[1.0, 2.0], [3.0, 4.0]];
let kwd = kernel.with_dilation(1);
let offsets = kwd.gen_offset_list(&[10, 1]);
assert_eq!(offsets.len(), 4);
assert_eq!(offsets[0], (0, 4.0));
assert_eq!(offsets[1], (1, 3.0));
assert_eq!(offsets[2], (10, 2.0));
assert_eq!(offsets[3], (11, 1.0));
}
#[test]
fn gen_offset_2d_with_dilation() {
let kernel = array![[1.0, 2.0], [3.0, 4.0]];
let kwd = kernel.with_dilation([2, 3]);
let offsets = kwd.gen_offset_list(&[10, 1]);
assert_eq!(offsets.len(), 4);
assert_eq!(offsets[0], (0, 4.0));
assert_eq!(offsets[1], (3, 3.0));
assert_eq!(offsets[2], (20, 2.0));
assert_eq!(offsets[3], (23, 1.0));
}
#[test]
fn gen_offset_filters_zeros() {
let kernel = array![1.0, 0.0, 2.0, 0.0, 3.0];
let kwd = kernel.with_dilation(1);
let offsets = kwd.gen_offset_list(&[1]);
assert_eq!(offsets.len(), 3);
}
}
mod edge_cases {
use super::*;
#[test]
fn single_element_kernel() {
let kernel = array![42.0];
let kwd = kernel.with_dilation(5);
assert_eq!(kwd.dilation, [5]);
let offsets = kwd.gen_offset_list(&[1]);
assert_eq!(offsets.len(), 1);
assert_eq!(offsets[0], (0, 42.0));
}
#[test]
fn all_zeros_kernel() {
let kernel = array![0.0, 0.0, 0.0];
let kwd = kernel.with_dilation(2);
let offsets = kwd.gen_offset_list(&[1]);
assert_eq!(offsets.len(), 0);
}
#[test]
fn large_dilation_value() {
let kernel = array![1, 2];
let kwd = kernel.with_dilation(100);
assert_eq!(kwd.dilation, [100]);
}
#[test]
fn asymmetric_2d_dilation() {
let kernel = array![[1, 2, 3], [4, 5, 6]];
let kwd = kernel.with_dilation([1, 5]);
assert_eq!(kwd.dilation, [1, 5]);
}
}
mod integration_with_padding {
use super::*;
#[test]
fn effective_kernel_size_calculation() {
let kernel = array![1, 2, 3];
let kwd1 = kernel.with_dilation(1);
let effective_size_1 = kernel.len() + (kernel.len() - 1) * (kwd1.dilation[0] - 1);
assert_eq!(effective_size_1, 3);
let kwd2 = kernel.with_dilation(2);
let effective_size_2 = kernel.len() + (kernel.len() - 1) * (kwd2.dilation[0] - 1);
assert_eq!(effective_size_2, 5);
let kwd3 = kernel.with_dilation(3);
let effective_size_3 = kernel.len() + (kernel.len() - 1) * (kwd3.dilation[0] - 1);
assert_eq!(effective_size_3, 7);
}
}
}