1use burn_core as burn;
2
3use burn::config::Config;
4use burn::tensor::ops::conv::calculate_conv_padding;
5
6#[derive(Config, Debug, PartialEq)]
8pub enum PaddingConfig1d {
9 Same,
11 Valid,
13 Explicit(usize),
15}
16
17impl PaddingConfig1d {
18 pub(crate) fn calculate_padding_1d(
19 &self,
20 length: usize,
21 kernel_size: usize,
22 stride: usize,
23 ) -> usize {
24 let same_padding = || calculate_conv_padding(kernel_size, stride, length, length);
25 match self {
26 Self::Valid => 0,
27 Self::Same => same_padding(),
28 Self::Explicit(value) => *value,
29 }
30 }
31}
32
33#[derive(Config, Debug, PartialEq)]
35pub enum PaddingConfig2d {
36 Same,
38 Valid,
40 Explicit(usize, usize),
42}
43
44impl PaddingConfig2d {
45 pub(crate) fn calculate_padding_2d(
46 &self,
47 height: usize,
48 width: usize,
49 kernel_size: &[usize; 2],
50 stride: &[usize; 2],
51 ) -> [usize; 2] {
52 let same_padding = || {
53 let p1 = calculate_conv_padding(kernel_size[0], stride[0], height, height);
54 let p2 = calculate_conv_padding(kernel_size[1], stride[1], width, width);
55
56 [p1, p2]
57 };
58
59 match self {
60 Self::Same => same_padding(),
61 Self::Valid => [0, 0],
62 Self::Explicit(v1, v2) => [*v1, *v2],
63 }
64 }
65}
66
67#[derive(Config, Debug, PartialEq)]
69pub enum PaddingConfig3d {
70 Same,
72 Valid,
74 Explicit(usize, usize, usize),
76}
77
78impl PaddingConfig3d {
79 pub(crate) fn calculate_padding_3d(
80 &self,
81 depth: usize,
82 height: usize,
83 width: usize,
84 kernel_size: &[usize; 3],
85 stride: &[usize; 3],
86 ) -> [usize; 3] {
87 let same_padding = || {
88 let p1 = calculate_conv_padding(kernel_size[0], stride[0], depth, depth);
89 let p2 = calculate_conv_padding(kernel_size[1], stride[1], height, height);
90 let p3 = calculate_conv_padding(kernel_size[2], stride[2], width, width);
91
92 [p1, p2, p3]
93 };
94
95 match self {
96 Self::Same => same_padding(),
97 Self::Valid => [0, 0, 0],
98 Self::Explicit(v1, v2, v3) => [*v1, *v2, *v3],
99 }
100 }
101}