1use crate as burn;
2
3use crate::tensor::ops::conv::calculate_conv_padding;
4
5use crate::config::Config;
6
7#[derive(Config, Debug, PartialEq)]
9pub enum PaddingConfig1d {
10 Same,
12 Valid,
14 Explicit(usize),
16}
17
18impl PaddingConfig1d {
19 pub(crate) fn calculate_padding_1d(
20 &self,
21 length: usize,
22 kernel_size: usize,
23 stride: usize,
24 ) -> usize {
25 let same_padding = || calculate_conv_padding(kernel_size, stride, length, length);
26 match self {
27 Self::Valid => 0,
28 Self::Same => same_padding(),
29 Self::Explicit(value) => *value,
30 }
31 }
32}
33
34#[derive(Config, Debug, PartialEq)]
36pub enum PaddingConfig2d {
37 Same,
39 Valid,
41 Explicit(usize, usize),
43}
44
45impl PaddingConfig2d {
46 pub(crate) fn calculate_padding_2d(
47 &self,
48 height: usize,
49 width: usize,
50 kernel_size: &[usize; 2],
51 stride: &[usize; 2],
52 ) -> [usize; 2] {
53 let same_padding = || {
54 let p1 = calculate_conv_padding(kernel_size[0], stride[0], height, height);
55 let p2 = calculate_conv_padding(kernel_size[1], stride[1], width, width);
56
57 [p1, p2]
58 };
59
60 match self {
61 Self::Same => same_padding(),
62 Self::Valid => [0, 0],
63 Self::Explicit(v1, v2) => [*v1, *v2],
64 }
65 }
66}
67
68#[derive(Config, Debug, PartialEq)]
70pub enum PaddingConfig3d {
71 Same,
73 Valid,
75 Explicit(usize, usize, usize),
77}
78
79impl PaddingConfig3d {
80 pub(crate) fn calculate_padding_3d(
81 &self,
82 depth: usize,
83 height: usize,
84 width: usize,
85 kernel_size: &[usize; 3],
86 stride: &[usize; 3],
87 ) -> [usize; 3] {
88 let same_padding = || {
89 let p1 = calculate_conv_padding(kernel_size[0], stride[0], depth, depth);
90 let p2 = calculate_conv_padding(kernel_size[1], stride[1], height, height);
91 let p3 = calculate_conv_padding(kernel_size[2], stride[2], width, width);
92
93 [p1, p2, p3]
94 };
95
96 match self {
97 Self::Same => same_padding(),
98 Self::Valid => [0, 0, 0],
99 Self::Explicit(v1, v2, v3) => [*v1, *v2, *v3],
100 }
101 }
102}