1use crate as burn;
2
3use crate::tensor::ops::conv::calculate_conv_padding;
4
5use crate::config::Config;
6
7#[derive(Config, Debug, PartialEq)]
9pub enum PaddingConfig1d {
10 Same,
13 Valid,
15 Explicit(usize),
17}
18
19impl PaddingConfig1d {
20 pub(crate) fn calculate_padding_1d(
21 &self,
22 length: usize,
23 kernel_size: usize,
24 stride: usize,
25 ) -> usize {
26 let same_padding = || calculate_conv_padding(kernel_size, stride, length, length);
27 match self {
28 Self::Valid => 0,
29 Self::Same => same_padding(),
30 Self::Explicit(value) => *value,
31 }
32 }
33}
34
35#[derive(Config, Debug, PartialEq)]
37pub enum PaddingConfig2d {
38 Same,
41 Valid,
43 Explicit(usize, usize),
45}
46
47impl PaddingConfig2d {
48 pub(crate) fn calculate_padding_2d(
49 &self,
50 height: usize,
51 width: usize,
52 kernel_size: &[usize; 2],
53 stride: &[usize; 2],
54 ) -> [usize; 2] {
55 let same_padding = || {
56 let p1 = calculate_conv_padding(kernel_size[0], stride[0], height, height);
57 let p2 = calculate_conv_padding(kernel_size[1], stride[1], width, width);
58
59 [p1, p2]
60 };
61
62 match self {
63 Self::Same => same_padding(),
64 Self::Valid => [0, 0],
65 Self::Explicit(v1, v2) => [*v1, *v2],
66 }
67 }
68}
69
70#[derive(Config, Debug, PartialEq)]
72pub enum PaddingConfig3d {
73 Same,
76 Valid,
78 Explicit(usize, usize, usize),
80}
81
82impl PaddingConfig3d {
83 pub(crate) fn calculate_padding_3d(
84 &self,
85 depth: usize,
86 height: usize,
87 width: usize,
88 kernel_size: &[usize; 3],
89 stride: &[usize; 3],
90 ) -> [usize; 3] {
91 let same_padding = || {
92 let p1 = calculate_conv_padding(kernel_size[0], stride[0], depth, depth);
93 let p2 = calculate_conv_padding(kernel_size[1], stride[1], height, height);
94 let p3 = calculate_conv_padding(kernel_size[2], stride[2], width, width);
95
96 [p1, p2, p3]
97 };
98
99 match self {
100 Self::Same => same_padding(),
101 Self::Valid => [0, 0, 0],
102 Self::Explicit(v1, v2, v3) => [*v1, *v2, *v3],
103 }
104 }
105}