1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
mod conv;
mod conv_fft;
mod dilation;
mod padding;

pub use conv::ConvExt;
pub use conv_fft::ConvFFTExt;
pub use padding::ExplicitPadding;

#[derive(Debug, Clone, Copy)]
pub enum ConvMode<const N: usize> {
    Full,
    Same,
    Valid,
    // (pad, stride)
    Custom {
        padding: [usize; N],
        strides: [usize; N],
    },
    // (pad, stride)
    Explicit {
        padding: [[usize; 2]; N],
        strides: [usize; N],
    },
}

// padding mode. It can be either a single BorderType applied on all sides or a custom tuple of two BorderTypes for (H, W), respectively.
#[derive(Debug, Clone, Copy)]
pub enum PaddingMode<const N: usize, T: num::traits::NumAssign + Copy> {
    Zeros,
    Const(T),
    Reflect,
    Replicate,
    Circular,
    Custom([BorderType<T>; N]),
    Explicit([[BorderType<T>; 2]; N]),
}

// padding mode for single dim
#[derive(Debug, Clone, Copy)]
pub enum BorderType<T: num::traits::NumAssign + Copy> {
    Zeros,
    Const(T),
    Reflect,
    Replicate,
    Circular,
}