1use burn_core as burn;
2
3use crate::PaddingConfig2d;
4use burn::config::Config;
5use burn::module::Module;
6use burn::module::{Content, DisplaySettings, ModuleDisplay};
7use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9use burn::tensor::ops::PadMode;
10
11use burn::tensor::module::max_pool2d;
12
13#[derive(Debug, Config)]
15pub struct MaxPool2dConfig {
16 pub kernel_size: [usize; 2],
18 #[config(default = "kernel_size")]
20 pub strides: [usize; 2],
21 #[config(default = "PaddingConfig2d::Valid")]
26 pub padding: PaddingConfig2d,
27 #[config(default = "[1, 1]")]
29 pub dilation: [usize; 2],
30 #[config(default = "false")]
32 pub ceil_mode: bool,
33}
34
35#[derive(Module, Clone, Debug)]
39#[module(custom_display)]
40pub struct MaxPool2d {
41 pub stride: [usize; 2],
43 pub kernel_size: [usize; 2],
45 pub padding: PaddingConfig2d,
47 pub dilation: [usize; 2],
49 pub ceil_mode: bool,
51}
52
53impl ModuleDisplay for MaxPool2d {
54 fn custom_settings(&self) -> Option<DisplaySettings> {
55 DisplaySettings::new()
56 .with_new_line_after_attribute(false)
57 .optional()
58 }
59
60 fn custom_content(&self, content: Content) -> Option<Content> {
61 content
62 .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
63 .add("stride", &alloc::format!("{:?}", &self.stride))
64 .add_debug_attribute("padding", &self.padding)
65 .add("dilation", &alloc::format!("{:?}", &self.dilation))
66 .add("ceil_mode", &self.ceil_mode)
67 .optional()
68 }
69}
70
71impl MaxPool2dConfig {
72 pub fn init(&self) -> MaxPool2d {
74 MaxPool2d {
75 stride: self.strides,
76 kernel_size: self.kernel_size,
77 padding: self.padding.clone(),
78 dilation: self.dilation,
79 ceil_mode: self.ceil_mode,
80 }
81 }
82}
83
84impl MaxPool2d {
85 pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
94 let [_batch_size, _channels_in, height_in, width_in] = input.dims();
95
96 let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs(
98 height_in,
99 width_in,
100 &self.kernel_size,
101 &self.stride,
102 );
103
104 if top != bottom || left != right {
108 let padded = input.pad(
111 (left, right, top, bottom),
112 PadMode::Constant(f32::NEG_INFINITY),
113 );
114 max_pool2d(
116 padded,
117 self.kernel_size,
118 self.stride,
119 [0, 0],
120 self.dilation,
121 self.ceil_mode,
122 )
123 } else {
124 max_pool2d(
126 input,
127 self.kernel_size,
128 self.stride,
129 [top, left],
130 self.dilation,
131 self.ceil_mode,
132 )
133 }
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::TestBackend;
141 use rstest::rstest;
142
143 #[test]
144 fn same_with_even_kernel_uses_asymmetric_padding() {
145 let device = Default::default();
146 let config = MaxPool2dConfig::new([2, 2])
147 .with_strides([1, 1])
148 .with_padding(PaddingConfig2d::Same);
149 let pool = config.init();
150
151 let input = Tensor::<TestBackend, 4>::ones([1, 2, 5, 5], &device);
153 let output = pool.forward(input);
154
155 assert_eq!(output.dims(), [1, 2, 5, 5]);
157 }
158
159 #[test]
160 fn display() {
161 let config = MaxPool2dConfig::new([3, 3]);
162
163 let layer = config.init();
164
165 assert_eq!(
166 alloc::format!("{layer}"),
167 "MaxPool2d {kernel_size: [3, 3], stride: [3, 3], padding: Valid, dilation: [1, 1], ceil_mode: false}"
168 );
169 }
170
171 #[rstest]
172 #[case([2, 2])]
173 #[case([1, 2])]
174 fn default_strides_match_kernel_size(#[case] kernel_size: [usize; 2]) {
175 let config = MaxPool2dConfig::new(kernel_size);
176
177 assert_eq!(
178 config.strides, kernel_size,
179 "Expected strides ({:?}) to match kernel size ({:?}) in default MaxPool2dConfig::new constructor",
180 config.strides, config.kernel_size
181 );
182 }
183
184 #[test]
185 fn asymmetric_padding_forward() {
186 let device = Default::default();
187 let config = MaxPool2dConfig::new([3, 3])
189 .with_strides([1, 1])
190 .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4));
191 let pool = config.init();
192
193 let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
195 let output = pool.forward(input);
196
197 assert_eq!(output.dims(), [1, 2, 6, 9]);
200 }
201
202 #[test]
203 fn symmetric_explicit_padding_forward() {
204 let device = Default::default();
205 let config = MaxPool2dConfig::new([3, 3])
207 .with_strides([1, 1])
208 .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2));
209 let pool = config.init();
210
211 let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
213 let output = pool.forward(input);
214
215 assert_eq!(output.dims(), [1, 2, 6, 7]);
218 }
219}