Skip to main content

burn_nn/
padding.rs

1use burn_core as burn;
2
3use burn::config::Config;
4use burn::tensor::ops::conv::calculate_conv_padding;
5
6/// Padding configuration for 1D operators.
7#[derive(Config, Debug, PartialEq)]
8pub enum PaddingConfig1d {
9    /// Dynamically calculates padding to ensure output size matches input size.
10    Same,
11    /// No padding applied.
12    Valid,
13    /// Applies a specific amount of padding to all inputs.
14    Explicit(usize),
15}
16
17impl PaddingConfig1d {
18    pub(crate) fn calculate_padding_1d(
19        &self,
20        length: usize,
21        kernel_size: usize,
22        stride: usize,
23    ) -> usize {
24        let same_padding = || calculate_conv_padding(kernel_size, stride, length, length);
25        match self {
26            Self::Valid => 0,
27            Self::Same => same_padding(),
28            Self::Explicit(value) => *value,
29        }
30    }
31}
32
33/// Padding configuration for 2D operators.
34#[derive(Config, Debug, PartialEq)]
35pub enum PaddingConfig2d {
36    /// Dynamically calculates padding to preserve input dimensions in output.
37    Same,
38    /// No padding applied.
39    Valid,
40    /// Applies specified padding values to height and width dimensions.
41    Explicit(usize, usize),
42}
43
44impl PaddingConfig2d {
45    pub(crate) fn calculate_padding_2d(
46        &self,
47        height: usize,
48        width: usize,
49        kernel_size: &[usize; 2],
50        stride: &[usize; 2],
51    ) -> [usize; 2] {
52        let same_padding = || {
53            let p1 = calculate_conv_padding(kernel_size[0], stride[0], height, height);
54            let p2 = calculate_conv_padding(kernel_size[1], stride[1], width, width);
55
56            [p1, p2]
57        };
58
59        match self {
60            Self::Same => same_padding(),
61            Self::Valid => [0, 0],
62            Self::Explicit(v1, v2) => [*v1, *v2],
63        }
64    }
65}
66
67/// Padding configuration for 3D operators.
68#[derive(Config, Debug, PartialEq)]
69pub enum PaddingConfig3d {
70    /// Dynamically calculates padding to preserve input dimensions in output.
71    Same,
72    /// No padding applied.
73    Valid,
74    /// Applies specified padding values to depth, height, and width dimensions.
75    Explicit(usize, usize, usize),
76}
77
78impl PaddingConfig3d {
79    pub(crate) fn calculate_padding_3d(
80        &self,
81        depth: usize,
82        height: usize,
83        width: usize,
84        kernel_size: &[usize; 3],
85        stride: &[usize; 3],
86    ) -> [usize; 3] {
87        let same_padding = || {
88            let p1 = calculate_conv_padding(kernel_size[0], stride[0], depth, depth);
89            let p2 = calculate_conv_padding(kernel_size[1], stride[1], height, height);
90            let p3 = calculate_conv_padding(kernel_size[2], stride[2], width, width);
91
92            [p1, p2, p3]
93        };
94
95        match self {
96            Self::Same => same_padding(),
97            Self::Valid => [0, 0, 0],
98            Self::Explicit(v1, v2, v3) => [*v1, *v2, *v3],
99        }
100    }
101}