1use burn_core as burn;
2
3use burn::config::Config;
4
5fn calculate_same_padding(kernel_size: usize, stride: usize, size_in: usize) -> (usize, usize) {
9 let size_out = size_in.div_ceil(stride); let total_padding = if size_out > 0 {
11 let needed = (size_out - 1) * stride + kernel_size;
12 needed.saturating_sub(size_in)
13 } else {
14 0
15 };
16 let pad_start = total_padding / 2;
17 let pad_end = total_padding - pad_start;
18 (pad_start, pad_end)
19}
20
21#[derive(Config, Debug, PartialEq)]
23pub enum PaddingConfig1d {
24 Same,
26 Valid,
28 Explicit(usize, usize),
32}
33
34impl PaddingConfig1d {
35 pub(crate) fn calculate_padding_1d_pair(
38 &self,
39 length: usize,
40 kernel_size: usize,
41 stride: usize,
42 ) -> (usize, usize) {
43 match self {
44 Self::Valid => (0, 0),
45 Self::Same => calculate_same_padding(kernel_size, stride, length),
46 Self::Explicit(left, right) => (*left, *right),
47 }
48 }
49}
50
51#[derive(Config, Debug, PartialEq)]
53pub enum PaddingConfig2d {
54 Same,
56 Valid,
58 Explicit(usize, usize, usize, usize),
62}
63
64impl PaddingConfig2d {
65 pub(crate) fn calculate_padding_2d_pairs(
68 &self,
69 height: usize,
70 width: usize,
71 kernel_size: &[usize; 2],
72 stride: &[usize; 2],
73 ) -> ((usize, usize), (usize, usize)) {
74 match self {
75 Self::Valid => ((0, 0), (0, 0)),
76 Self::Same => {
77 let (top, bottom) = calculate_same_padding(kernel_size[0], stride[0], height);
78 let (left, right) = calculate_same_padding(kernel_size[1], stride[1], width);
79 ((top, bottom), (left, right))
80 }
81 Self::Explicit(top, left, bottom, right) => ((*top, *bottom), (*left, *right)),
82 }
83 }
84
85 pub(crate) fn calculate_padding_2d(
89 &self,
90 height: usize,
91 width: usize,
92 kernel_size: &[usize; 2],
93 stride: &[usize; 2],
94 ) -> [usize; 2] {
95 let ((top, bottom), (left, right)) =
96 self.calculate_padding_2d_pairs(height, width, kernel_size, stride);
97 if top != bottom || left != right {
98 panic!("Asymmetric padding should be handled via calculate_padding_2d_pairs()")
99 }
100 [top, left]
101 }
102}
103
104#[derive(Config, Debug, PartialEq)]
106pub enum PaddingConfig3d {
107 Same,
109 Valid,
111 Explicit(usize, usize, usize),
114}
115
116impl PaddingConfig3d {
117 pub(crate) fn calculate_padding_3d(
120 &self,
121 depth: usize,
122 height: usize,
123 width: usize,
124 kernel_size: &[usize; 3],
125 stride: &[usize; 3],
126 ) -> [usize; 3] {
127 match self {
128 Self::Valid => [0, 0, 0],
129 Self::Same => {
130 let (front, back) = calculate_same_padding(kernel_size[0], stride[0], depth);
131 let (top, bottom) = calculate_same_padding(kernel_size[1], stride[1], height);
132 let (left, right) = calculate_same_padding(kernel_size[2], stride[2], width);
133 if front != back || top != bottom || left != right {
134 panic!(
135 "Asymmetric 3D 'Same' padding is not supported. \
136 Use odd kernel sizes for symmetric padding."
137 )
138 }
139 [front, top, left]
140 }
141 Self::Explicit(depth, height, width) => [*depth, *height, *width],
142 }
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
153 fn test_padding_config_1d_calculate_pair_valid() {
154 let padding = PaddingConfig1d::Valid;
155 assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (0, 0));
156 }
157
158 #[test]
159 fn test_padding_config_1d_calculate_pair_explicit() {
160 let padding = PaddingConfig1d::Explicit(1, 2);
161 assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (1, 2));
162 }
163
164 #[test]
165 fn test_padding_config_1d_calculate_pair_same() {
166 let padding = PaddingConfig1d::Same;
167 assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (1, 1));
169 }
170
171 #[test]
174 fn test_padding_config_2d_calculate_pairs_valid() {
175 let padding = PaddingConfig2d::Valid;
176 assert_eq!(
177 padding.calculate_padding_2d_pairs(10, 10, &[3, 3], &[1, 1]),
178 ((0, 0), (0, 0))
179 );
180 }
181
182 #[test]
183 fn test_padding_config_2d_calculate_pairs_explicit() {
184 let padding = PaddingConfig2d::Explicit(1, 2, 3, 4);
185 assert_eq!(
186 padding.calculate_padding_2d_pairs(10, 10, &[3, 3], &[1, 1]),
187 ((1, 3), (2, 4))
188 );
189 }
190
191 #[test]
192 fn test_padding_config_2d_calculate_symmetric_valid() {
193 let padding = PaddingConfig2d::Valid;
194 assert_eq!(
195 padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]),
196 [0, 0]
197 );
198 }
199
200 #[test]
201 fn test_padding_config_2d_calculate_symmetric_explicit() {
202 let padding = PaddingConfig2d::Explicit(2, 3, 2, 3);
203 assert_eq!(
204 padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]),
205 [2, 3]
206 );
207 }
208
209 #[test]
210 #[should_panic(
211 expected = "Asymmetric padding should be handled via calculate_padding_2d_pairs"
212 )]
213 fn test_padding_config_2d_calculate_symmetric_asymmetric_panics() {
214 let padding = PaddingConfig2d::Explicit(1, 2, 3, 4);
215 let _ = padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]);
216 }
217
218 #[test]
221 fn test_padding_config_3d_calculate_valid() {
222 let padding = PaddingConfig3d::Valid;
223 assert_eq!(
224 padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]),
225 [0, 0, 0]
226 );
227 }
228
229 #[test]
230 fn test_padding_config_3d_calculate_explicit() {
231 let padding = PaddingConfig3d::Explicit(1, 2, 3);
232 assert_eq!(
233 padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]),
234 [1, 2, 3]
235 );
236 }
237
238 #[test]
239 fn test_padding_config_3d_calculate_same_odd_kernel() {
240 let padding = PaddingConfig3d::Same;
241 assert_eq!(
243 padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]),
244 [1, 1, 1]
245 );
246 }
247}