burn_core/nn/
padding.rs

1use crate as burn;
2
3use crate::tensor::ops::conv::calculate_conv_padding;
4
5use crate::config::Config;
6
7/// Padding configuration for 1D operators.
8#[derive(Config, Debug, PartialEq)]
9pub enum PaddingConfig1d {
10    /// Dynamically calculate the amount of padding necessary to ensure that the output size will be
11    /// the same as the input.
12    Same,
13    /// Same as no padding.
14    Valid,
15    /// Applies the specified amount of padding to all inputs.
16    Explicit(usize),
17}
18
19impl PaddingConfig1d {
20    pub(crate) fn calculate_padding_1d(
21        &self,
22        length: usize,
23        kernel_size: usize,
24        stride: usize,
25    ) -> usize {
26        let same_padding = || calculate_conv_padding(kernel_size, stride, length, length);
27        match self {
28            Self::Valid => 0,
29            Self::Same => same_padding(),
30            Self::Explicit(value) => *value,
31        }
32    }
33}
34
35/// Padding configuration for 2D operators.
36#[derive(Config, Debug, PartialEq)]
37pub enum PaddingConfig2d {
38    /// Dynamically calculate the amount of padding necessary to ensure that the output size will be
39    /// the same as the input.
40    Same,
41    /// Same as no padding.
42    Valid,
43    /// Applies the specified amount of padding to all inputs.
44    Explicit(usize, usize),
45}
46
47impl PaddingConfig2d {
48    pub(crate) fn calculate_padding_2d(
49        &self,
50        height: usize,
51        width: usize,
52        kernel_size: &[usize; 2],
53        stride: &[usize; 2],
54    ) -> [usize; 2] {
55        let same_padding = || {
56            let p1 = calculate_conv_padding(kernel_size[0], stride[0], height, height);
57            let p2 = calculate_conv_padding(kernel_size[1], stride[1], width, width);
58
59            [p1, p2]
60        };
61
62        match self {
63            Self::Same => same_padding(),
64            Self::Valid => [0, 0],
65            Self::Explicit(v1, v2) => [*v1, *v2],
66        }
67    }
68}
69
70/// Padding configuration for 3D operators.
71#[derive(Config, Debug, PartialEq)]
72pub enum PaddingConfig3d {
73    /// Dynamically calculate the amount of padding necessary to ensure that the output size will be
74    /// the same as the input.
75    Same,
76    /// Same as no padding.
77    Valid,
78    /// Applies the specified amount of padding to all inputs.
79    Explicit(usize, usize, usize),
80}
81
82impl PaddingConfig3d {
83    pub(crate) fn calculate_padding_3d(
84        &self,
85        depth: usize,
86        height: usize,
87        width: usize,
88        kernel_size: &[usize; 3],
89        stride: &[usize; 3],
90    ) -> [usize; 3] {
91        let same_padding = || {
92            let p1 = calculate_conv_padding(kernel_size[0], stride[0], depth, depth);
93            let p2 = calculate_conv_padding(kernel_size[1], stride[1], height, height);
94            let p3 = calculate_conv_padding(kernel_size[2], stride[2], width, width);
95
96            [p1, p2, p3]
97        };
98
99        match self {
100            Self::Same => same_padding(),
101            Self::Valid => [0, 0, 0],
102            Self::Explicit(v1, v2, v3) => [*v1, *v2, *v3],
103        }
104    }
105}