ndarray_conv/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
mod conv;
mod conv_fft;
mod dilation;
mod padding;

pub(crate) use padding::ExplicitPadding;

pub use conv::ConvExt;
pub use conv_fft::{ConvFFTExt, Processor as FftProcessor};
pub use dilation::WithDilation;

#[derive(Debug, Clone, Copy)]
pub enum ConvMode<const N: usize> {
    Full,
    Same,
    Valid,
    // (pad, stride)
    Custom {
        padding: [usize; N],
        strides: [usize; N],
    },
    // (pad, stride)
    Explicit {
        padding: [[usize; 2]; N],
        strides: [usize; N],
    },
}

// padding mode. It can be either a single BorderType applied on all sides or a custom tuple of two BorderTypes for (H, W), respectively.
#[derive(Debug, Clone, Copy)]
pub enum PaddingMode<const N: usize, T: num::traits::NumAssign + Copy> {
    Zeros,
    Const(T),
    Reflect,
    Replicate,
    Circular,
    Custom([BorderType<T>; N]),
    Explicit([[BorderType<T>; 2]; N]),
}

// padding mode for single dim
#[derive(Debug, Clone, Copy)]
pub enum BorderType<T: num::traits::NumAssign + Copy> {
    Zeros,
    Const(T),
    Reflect,
    Replicate,
    Circular,
}

use thiserror::Error;

#[derive(Error, Debug)]
pub enum Error<const N: usize> {
    #[error("Data shape shouldn't have ZERO. {0:?}")]
    DataShape(ndarray::Dim<[ndarray::Ix; N]>),
    #[error("Kernel shape shouldn't have ZERO. {0:?}")]
    KernelShape(ndarray::Dim<[ndarray::Ix; N]>),
    #[error("ConvMode {0:?} does not match KernelWithDilation Size {1:?}")]
    MismatchShape(ConvMode<N>, [ndarray::Ix; N]),
}