1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
mod conv_2d;

pub use conv_2d::fft::Conv2DFftExt;
pub use conv_2d::Conv2DExt;

#[derive(Debug, Clone, Copy)]
pub enum PaddingSize<const N: usize> {
    Full,
    Same,
    Valid,
    // (pad, stride)
    Custom([usize; N], [usize; N]),
    // (pad, stride)
    Explicit([[usize; 2]; N], [usize; N]),
}

// padding mode. It can be either a single BorderType applied on all sides or a custom tuple of two BorderTypes for (H, W), respectively.
#[derive(Debug, Clone, Copy)]
pub enum PaddingMode<const N: usize, T: num::traits::NumAssign + Copy> {
    Zeros,
    Const(T),
    Reflect,
    Replicate,
    Circular,
    Custom([BorderType<T>; N]),
    Explicit([[BorderType<T>; 2]; N]),
}

// padding mode for single dim
#[derive(Debug, Clone, Copy)]
pub enum BorderType<T: num::traits::NumAssign + Copy> {
    Zeros,
    Const(T),
    Reflect,
    Replicate,
    Circular,
}

#[derive(Debug)]
pub(crate) struct ExplicitPadding<const N: usize> {
    pub pad: [[usize; 2]; N],
    pub stride: [usize; N],
}

#[derive(Debug)]
pub(crate) struct ExplictMode<const N: usize, T: num::traits::NumAssign + Copy>(
    pub [[BorderType<T>; 2]; N],
);

impl<const N: usize> PaddingSize<N> {
    pub(crate) fn unfold(self, kernel_size: &[usize; N]) -> ExplicitPadding<N> {
        match self {
            PaddingSize::Full => ExplicitPadding {
                pad: kernel_size.map(|kernel| [kernel - 1; 2]),
                stride: std::array::from_fn(|_| 1),
            },
            PaddingSize::Same => {
                let split = |k_size: usize| {
                    if k_size % 2 == 0 {
                        [(k_size - 1) / 2 + 1, (k_size - 1) / 2]
                    } else {
                        [(k_size - 1) / 2; 2]
                    }
                };

                ExplicitPadding {
                    pad: kernel_size.map(split),
                    stride: std::array::from_fn(|_| 1),
                }
            }
            PaddingSize::Valid => ExplicitPadding {
                pad: std::array::from_fn(|_| [0; 2]),
                stride: std::array::from_fn(|_| 1),
            },
            PaddingSize::Custom(pads, strides) => ExplicitPadding {
                pad: pads.map(|pad| [pad; 2]),
                stride: strides,
            },
            PaddingSize::Explicit(pad, stride) => ExplicitPadding { pad, stride },
        }
    }
}

impl<const N: usize, T: num::traits::NumAssign + Copy> PaddingMode<N, T> {
    pub(crate) fn unfold(self) -> ExplictMode<N, T> {
        match self {
            PaddingMode::Zeros => ExplictMode([[BorderType::Zeros; 2]; N]),
            PaddingMode::Const(num) => ExplictMode([[BorderType::Const(num); 2]; N]),
            PaddingMode::Reflect => ExplictMode([[BorderType::Reflect; 2]; N]),
            PaddingMode::Replicate => ExplictMode([[BorderType::Replicate; 2]; N]),
            PaddingMode::Circular => ExplictMode([[BorderType::Circular; 2]; N]),
            PaddingMode::Custom(borders) => ExplictMode(borders.map(|border| [border; 2])),
            PaddingMode::Explicit(borders) => ExplictMode(borders),
        }
    }
}