burn_core/nn/
padding.rs

1use crate as burn;
2
3use crate::tensor::ops::conv::calculate_conv_padding;
4
5use crate::config::Config;
6
7/// Padding configuration for 1D operators.
8#[derive(Config, Debug, PartialEq)]
9pub enum PaddingConfig1d {
10    /// Dynamically calculates padding to ensure output size matches input size.
11    Same,
12    /// No padding applied.
13    Valid,
14    /// Applies a specific amount of padding to all inputs.
15    Explicit(usize),
16}
17
18impl PaddingConfig1d {
19    pub(crate) fn calculate_padding_1d(
20        &self,
21        length: usize,
22        kernel_size: usize,
23        stride: usize,
24    ) -> usize {
25        let same_padding = || calculate_conv_padding(kernel_size, stride, length, length);
26        match self {
27            Self::Valid => 0,
28            Self::Same => same_padding(),
29            Self::Explicit(value) => *value,
30        }
31    }
32}
33
34/// Padding configuration for 2D operators.
35#[derive(Config, Debug, PartialEq)]
36pub enum PaddingConfig2d {
37    /// Dynamically calculates padding to preserve input dimensions in output.
38    Same,
39    /// No padding applied.
40    Valid,
41    /// Applies specified padding values to height and width dimensions.
42    Explicit(usize, usize),
43}
44
45impl PaddingConfig2d {
46    pub(crate) fn calculate_padding_2d(
47        &self,
48        height: usize,
49        width: usize,
50        kernel_size: &[usize; 2],
51        stride: &[usize; 2],
52    ) -> [usize; 2] {
53        let same_padding = || {
54            let p1 = calculate_conv_padding(kernel_size[0], stride[0], height, height);
55            let p2 = calculate_conv_padding(kernel_size[1], stride[1], width, width);
56
57            [p1, p2]
58        };
59
60        match self {
61            Self::Same => same_padding(),
62            Self::Valid => [0, 0],
63            Self::Explicit(v1, v2) => [*v1, *v2],
64        }
65    }
66}
67
68/// Padding configuration for 3D operators.
69#[derive(Config, Debug, PartialEq)]
70pub enum PaddingConfig3d {
71    /// Dynamically calculates padding to preserve input dimensions in output.
72    Same,
73    /// No padding applied.
74    Valid,
75    /// Applies specified padding values to depth, height, and width dimensions.
76    Explicit(usize, usize, usize),
77}
78
79impl PaddingConfig3d {
80    pub(crate) fn calculate_padding_3d(
81        &self,
82        depth: usize,
83        height: usize,
84        width: usize,
85        kernel_size: &[usize; 3],
86        stride: &[usize; 3],
87    ) -> [usize; 3] {
88        let same_padding = || {
89            let p1 = calculate_conv_padding(kernel_size[0], stride[0], depth, depth);
90            let p2 = calculate_conv_padding(kernel_size[1], stride[1], height, height);
91            let p3 = calculate_conv_padding(kernel_size[2], stride[2], width, width);
92
93            [p1, p2, p3]
94        };
95
96        match self {
97            Self::Same => same_padding(),
98            Self::Valid => [0, 0, 0],
99            Self::Explicit(v1, v2, v3) => [*v1, *v2, *v3],
100        }
101    }
102}