1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
mod conv_2d;
pub use conv_2d::fft::Conv2DFftExt;
pub use conv_2d::Conv2DExt;
#[derive(Debug, Clone, Copy)]
pub enum PaddingSize<const N: usize> {
Full,
Same,
Valid,
Custom([usize; N], [usize; N]),
Explicit([[usize; 2]; N], [usize; N]),
}
#[derive(Debug, Clone, Copy)]
pub enum PaddingMode<const N: usize, T: num::traits::NumAssign + Copy> {
Zeros,
Const(T),
Reflect,
Replicate,
Circular,
Custom([BorderType<T>; N]),
Explicit([[BorderType<T>; 2]; N]),
}
#[derive(Debug, Clone, Copy)]
pub enum BorderType<T: num::traits::NumAssign + Copy> {
Zeros,
Const(T),
Reflect,
Replicate,
Circular,
}
#[derive(Debug)]
pub(crate) struct ExplicitPadding<const N: usize> {
pub pad: [[usize; 2]; N],
pub stride: [usize; N],
}
#[derive(Debug)]
pub(crate) struct ExplictMode<const N: usize, T: num::traits::NumAssign + Copy>(
pub [[BorderType<T>; 2]; N],
);
impl<const N: usize> PaddingSize<N> {
pub(crate) fn unfold(self, kernel_size: &[usize; N]) -> ExplicitPadding<N> {
match self {
PaddingSize::Full => ExplicitPadding {
pad: kernel_size.map(|kernel| [kernel - 1; 2]),
stride: std::array::from_fn(|_| 1),
},
PaddingSize::Same => {
let split = |k_size: usize| {
if k_size % 2 == 0 {
[(k_size - 1) / 2 + 1, (k_size - 1) / 2]
} else {
[(k_size - 1) / 2; 2]
}
};
ExplicitPadding {
pad: kernel_size.map(split),
stride: std::array::from_fn(|_| 1),
}
}
PaddingSize::Valid => ExplicitPadding {
pad: std::array::from_fn(|_| [0; 2]),
stride: std::array::from_fn(|_| 1),
},
PaddingSize::Custom(pads, strides) => ExplicitPadding {
pad: pads.map(|pad| [pad; 2]),
stride: strides,
},
PaddingSize::Explicit(pad, stride) => ExplicitPadding { pad, stride },
}
}
}
impl<const N: usize, T: num::traits::NumAssign + Copy> PaddingMode<N, T> {
pub(crate) fn unfold(self) -> ExplictMode<N, T> {
match self {
PaddingMode::Zeros => ExplictMode([[BorderType::Zeros; 2]; N]),
PaddingMode::Const(num) => ExplictMode([[BorderType::Const(num); 2]; N]),
PaddingMode::Reflect => ExplictMode([[BorderType::Reflect; 2]; N]),
PaddingMode::Replicate => ExplictMode([[BorderType::Replicate; 2]; N]),
PaddingMode::Circular => ExplictMode([[BorderType::Circular; 2]; N]),
PaddingMode::Custom(borders) => ExplictMode(borders.map(|border| [border; 2])),
PaddingMode::Explicit(borders) => ExplictMode(borders),
}
}
}