Skip to main content

burn_nn/
padding.rs

1use burn_core as burn;
2
3use burn::config::Config;
4
5/// Calculate asymmetric padding for "same" convolution.
6/// Returns (start_padding, end_padding) where start is applied first (top/left).
7/// For odd total padding, the extra pad goes to the end (bottom/right) following ONNX convention.
8fn calculate_same_padding(kernel_size: usize, stride: usize, size_in: usize) -> (usize, usize) {
9    let size_out = size_in.div_ceil(stride); // ceil division for same padding
10    let total_padding = if size_out > 0 {
11        let needed = (size_out - 1) * stride + kernel_size;
12        needed.saturating_sub(size_in)
13    } else {
14        0
15    };
16    let pad_start = total_padding / 2;
17    let pad_end = total_padding - pad_start;
18    (pad_start, pad_end)
19}
20
21/// Padding configuration for 1D operators.
22#[derive(Config, Debug, PartialEq)]
23pub enum PaddingConfig1d {
24    /// Dynamically calculates padding to ensure output size matches input size.
25    Same,
26    /// No padding applied.
27    Valid,
28    /// Applies explicit padding values.
29    /// Format: (left, right)
30    /// For symmetric padding, use the same value for both (e.g., `Explicit(1, 1)`).
31    Explicit(usize, usize),
32}
33
34impl PaddingConfig1d {
35    /// Calculate padding as (left, right) pair for 1D operations.
36    /// For `Same` padding, this computes the actual asymmetric padding if needed.
37    pub(crate) fn calculate_padding_1d_pair(
38        &self,
39        length: usize,
40        kernel_size: usize,
41        stride: usize,
42    ) -> (usize, usize) {
43        match self {
44            Self::Valid => (0, 0),
45            Self::Same => calculate_same_padding(kernel_size, stride, length),
46            Self::Explicit(left, right) => (*left, *right),
47        }
48    }
49}
50
51/// Padding configuration for 2D operators.
52#[derive(Config, Debug, PartialEq)]
53pub enum PaddingConfig2d {
54    /// Dynamically calculates padding to preserve input dimensions in output.
55    Same,
56    /// No padding applied.
57    Valid,
58    /// Applies explicit padding values.
59    /// Format: (top, left, bottom, right)
60    /// For symmetric padding, use matching values (e.g., `Explicit(1, 1, 1, 1)`).
61    Explicit(usize, usize, usize, usize),
62}
63
64impl PaddingConfig2d {
65    /// Calculate padding as ((top, bottom), (left, right)) pairs for 2D operations.
66    /// For `Same` padding, this computes the actual asymmetric padding if needed.
67    pub(crate) fn calculate_padding_2d_pairs(
68        &self,
69        height: usize,
70        width: usize,
71        kernel_size: &[usize; 2],
72        stride: &[usize; 2],
73    ) -> ((usize, usize), (usize, usize)) {
74        match self {
75            Self::Valid => ((0, 0), (0, 0)),
76            Self::Same => {
77                let (top, bottom) = calculate_same_padding(kernel_size[0], stride[0], height);
78                let (left, right) = calculate_same_padding(kernel_size[1], stride[1], width);
79                ((top, bottom), (left, right))
80            }
81            Self::Explicit(top, left, bottom, right) => ((*top, *bottom), (*left, *right)),
82        }
83    }
84
85    /// Calculate symmetric padding for 2D operations.
86    /// Returns padding values [height, width] (same for both sides).
87    /// Panics if asymmetric padding is detected.
88    pub(crate) fn calculate_padding_2d(
89        &self,
90        height: usize,
91        width: usize,
92        kernel_size: &[usize; 2],
93        stride: &[usize; 2],
94    ) -> [usize; 2] {
95        let ((top, bottom), (left, right)) =
96            self.calculate_padding_2d_pairs(height, width, kernel_size, stride);
97        if top != bottom || left != right {
98            panic!("Asymmetric padding should be handled via calculate_padding_2d_pairs()")
99        }
100        [top, left]
101    }
102}
103
104/// Padding configuration for 3D operators.
105#[derive(Config, Debug, PartialEq)]
106pub enum PaddingConfig3d {
107    /// Dynamically calculates padding to preserve input dimensions in output.
108    Same,
109    /// No padding applied.
110    Valid,
111    /// Applies explicit symmetric padding values.
112    /// Format: (depth, height, width) — same padding on both sides of each dimension.
113    Explicit(usize, usize, usize),
114}
115
116impl PaddingConfig3d {
117    /// Calculate symmetric padding for 3D operations.
118    /// Returns padding values [depth, height, width] (same for both sides).
119    pub(crate) fn calculate_padding_3d(
120        &self,
121        depth: usize,
122        height: usize,
123        width: usize,
124        kernel_size: &[usize; 3],
125        stride: &[usize; 3],
126    ) -> [usize; 3] {
127        match self {
128            Self::Valid => [0, 0, 0],
129            Self::Same => {
130                let (front, back) = calculate_same_padding(kernel_size[0], stride[0], depth);
131                let (top, bottom) = calculate_same_padding(kernel_size[1], stride[1], height);
132                let (left, right) = calculate_same_padding(kernel_size[2], stride[2], width);
133                if front != back || top != bottom || left != right {
134                    panic!(
135                        "Asymmetric 3D 'Same' padding is not supported. \
136                        Use odd kernel sizes for symmetric padding."
137                    )
138                }
139                [front, top, left]
140            }
141            Self::Explicit(depth, height, width) => [*depth, *height, *width],
142        }
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    // ==================== PaddingConfig1d Tests ====================
151
152    #[test]
153    fn test_padding_config_1d_calculate_pair_valid() {
154        let padding = PaddingConfig1d::Valid;
155        assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (0, 0));
156    }
157
158    #[test]
159    fn test_padding_config_1d_calculate_pair_explicit() {
160        let padding = PaddingConfig1d::Explicit(1, 2);
161        assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (1, 2));
162    }
163
164    #[test]
165    fn test_padding_config_1d_calculate_pair_same() {
166        let padding = PaddingConfig1d::Same;
167        // kernel=3, stride=1, length=10: total=2, start=1, end=1
168        assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (1, 1));
169    }
170
171    // ==================== PaddingConfig2d Tests ====================
172
173    #[test]
174    fn test_padding_config_2d_calculate_pairs_valid() {
175        let padding = PaddingConfig2d::Valid;
176        assert_eq!(
177            padding.calculate_padding_2d_pairs(10, 10, &[3, 3], &[1, 1]),
178            ((0, 0), (0, 0))
179        );
180    }
181
182    #[test]
183    fn test_padding_config_2d_calculate_pairs_explicit() {
184        let padding = PaddingConfig2d::Explicit(1, 2, 3, 4);
185        assert_eq!(
186            padding.calculate_padding_2d_pairs(10, 10, &[3, 3], &[1, 1]),
187            ((1, 3), (2, 4))
188        );
189    }
190
191    #[test]
192    fn test_padding_config_2d_calculate_symmetric_valid() {
193        let padding = PaddingConfig2d::Valid;
194        assert_eq!(
195            padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]),
196            [0, 0]
197        );
198    }
199
200    #[test]
201    fn test_padding_config_2d_calculate_symmetric_explicit() {
202        let padding = PaddingConfig2d::Explicit(2, 3, 2, 3);
203        assert_eq!(
204            padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]),
205            [2, 3]
206        );
207    }
208
209    #[test]
210    #[should_panic(
211        expected = "Asymmetric padding should be handled via calculate_padding_2d_pairs"
212    )]
213    fn test_padding_config_2d_calculate_symmetric_asymmetric_panics() {
214        let padding = PaddingConfig2d::Explicit(1, 2, 3, 4);
215        let _ = padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]);
216    }
217
218    // ==================== PaddingConfig3d Tests ====================
219
220    #[test]
221    fn test_padding_config_3d_calculate_valid() {
222        let padding = PaddingConfig3d::Valid;
223        assert_eq!(
224            padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]),
225            [0, 0, 0]
226        );
227    }
228
229    #[test]
230    fn test_padding_config_3d_calculate_explicit() {
231        let padding = PaddingConfig3d::Explicit(1, 2, 3);
232        assert_eq!(
233            padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]),
234            [1, 2, 3]
235        );
236    }
237
238    #[test]
239    fn test_padding_config_3d_calculate_same_odd_kernel() {
240        let padding = PaddingConfig3d::Same;
241        // kernel=3, stride=1: total=2, symmetric (1,1) per dim
242        assert_eq!(
243            padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]),
244            [1, 1, 1]
245        );
246    }
247}