Skip to main content

burn_nn/modules/pool/
max_pool2d.rs

1use burn_core as burn;
2
3use crate::PaddingConfig2d;
4use burn::config::Config;
5use burn::module::Module;
6use burn::module::{Content, DisplaySettings, ModuleDisplay};
7use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9use burn::tensor::ops::PadMode;
10
11use burn::tensor::module::max_pool2d;
12
13/// Configuration to create a [2D max pooling](MaxPool2d) layer using the [init function](MaxPool2dConfig::init).
14#[derive(Debug, Config)]
15pub struct MaxPool2dConfig {
16    /// The size of the kernel.
17    pub kernel_size: [usize; 2],
18    /// The strides.
19    #[config(default = "kernel_size")]
20    pub strides: [usize; 2],
21    /// The padding configuration.
22    ///
23    /// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes
24    /// will automatically use asymmetric padding to preserve input dimensions.
25    #[config(default = "PaddingConfig2d::Valid")]
26    pub padding: PaddingConfig2d,
27    /// The dilation.
28    #[config(default = "[1, 1]")]
29    pub dilation: [usize; 2],
30    /// If true, use ceiling instead of floor for output size calculation.
31    #[config(default = "false")]
32    pub ceil_mode: bool,
33}
34
35/// Applies a 2D max pooling over input tensors.
36///
37/// Should be created with [MaxPool2dConfig](MaxPool2dConfig).
38#[derive(Module, Clone, Debug)]
39#[module(custom_display)]
40pub struct MaxPool2d {
41    /// The strides.
42    pub stride: [usize; 2],
43    /// The size of the kernel.
44    pub kernel_size: [usize; 2],
45    /// The padding configuration.
46    pub padding: PaddingConfig2d,
47    /// The dilation.
48    pub dilation: [usize; 2],
49    /// If true, use ceiling instead of floor for output size calculation.
50    pub ceil_mode: bool,
51}
52
53impl ModuleDisplay for MaxPool2d {
54    fn custom_settings(&self) -> Option<DisplaySettings> {
55        DisplaySettings::new()
56            .with_new_line_after_attribute(false)
57            .optional()
58    }
59
60    fn custom_content(&self, content: Content) -> Option<Content> {
61        content
62            .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
63            .add("stride", &alloc::format!("{:?}", &self.stride))
64            .add_debug_attribute("padding", &self.padding)
65            .add("dilation", &alloc::format!("{:?}", &self.dilation))
66            .add("ceil_mode", &self.ceil_mode)
67            .optional()
68    }
69}
70
71impl MaxPool2dConfig {
72    /// Initialize a new [max pool 2d](MaxPool2d) module.
73    pub fn init(&self) -> MaxPool2d {
74        MaxPool2d {
75            stride: self.strides,
76            kernel_size: self.kernel_size,
77            padding: self.padding.clone(),
78            dilation: self.dilation,
79            ceil_mode: self.ceil_mode,
80        }
81    }
82}
83
84impl MaxPool2d {
85    /// Applies the forward pass on the input tensor.
86    ///
87    /// See [max_pool2d](burn::tensor::module::max_pool2d) for more information.
88    ///
89    /// # Shapes
90    ///
91    /// - input: `[batch_size, channels, height_in, width_in]`
92    /// - output: `[batch_size, channels, height_out, width_out]`
93    pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
94        let [_batch_size, _channels_in, height_in, width_in] = input.dims();
95
96        // Calculate padding as pairs - handles Same, Valid, and Explicit uniformly
97        let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs(
98            height_in,
99            width_in,
100            &self.kernel_size,
101            &self.stride,
102        );
103
104        // TODO: Move asymmetric padding to functional level via PoolOptions
105        // See: https://github.com/tracel-ai/burn/issues/4362
106        // Handle asymmetric padding by applying explicit pad operation first
107        if top != bottom || left != right {
108            // Burn's pad takes (left, right, top, bottom) for the last two dimensions
109            // Use -inf for max pooling so padded values don't affect the max
110            let padded = input.pad(
111                (left, right, top, bottom),
112                PadMode::Constant(f32::NEG_INFINITY),
113            );
114            // Use zero padding for the pool operation since we already padded
115            max_pool2d(
116                padded,
117                self.kernel_size,
118                self.stride,
119                [0, 0],
120                self.dilation,
121                self.ceil_mode,
122            )
123        } else {
124            // Symmetric padding
125            max_pool2d(
126                input,
127                self.kernel_size,
128                self.stride,
129                [top, left],
130                self.dilation,
131                self.ceil_mode,
132            )
133        }
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use crate::TestBackend;
141    use rstest::rstest;
142
143    #[test]
144    fn same_with_even_kernel_uses_asymmetric_padding() {
145        let device = Default::default();
146        let config = MaxPool2dConfig::new([2, 2])
147            .with_strides([1, 1])
148            .with_padding(PaddingConfig2d::Same);
149        let pool = config.init();
150
151        // Input: [batch=1, channels=2, height=5, width=5]
152        let input = Tensor::<TestBackend, 4>::ones([1, 2, 5, 5], &device);
153        let output = pool.forward(input);
154
155        // Same padding should preserve spatial dimensions
156        assert_eq!(output.dims(), [1, 2, 5, 5]);
157    }
158
159    #[test]
160    fn display() {
161        let config = MaxPool2dConfig::new([3, 3]);
162
163        let layer = config.init();
164
165        assert_eq!(
166            alloc::format!("{layer}"),
167            "MaxPool2d {kernel_size: [3, 3], stride: [3, 3], padding: Valid, dilation: [1, 1], ceil_mode: false}"
168        );
169    }
170
171    #[rstest]
172    #[case([2, 2])]
173    #[case([1, 2])]
174    fn default_strides_match_kernel_size(#[case] kernel_size: [usize; 2]) {
175        let config = MaxPool2dConfig::new(kernel_size);
176
177        assert_eq!(
178            config.strides, kernel_size,
179            "Expected strides ({:?}) to match kernel size ({:?}) in default MaxPool2dConfig::new constructor",
180            config.strides, config.kernel_size
181        );
182    }
183
184    #[test]
185    fn asymmetric_padding_forward() {
186        let device = Default::default();
187        // Create max pool with asymmetric padding: top=1, left=2, bottom=3, right=4
188        let config = MaxPool2dConfig::new([3, 3])
189            .with_strides([1, 1])
190            .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4));
191        let pool = config.init();
192
193        // Input: [batch=1, channels=2, height=4, width=5]
194        let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
195        let output = pool.forward(input);
196
197        // Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6
198        // Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9
199        assert_eq!(output.dims(), [1, 2, 6, 9]);
200    }
201
202    #[test]
203    fn symmetric_explicit_padding_forward() {
204        let device = Default::default();
205        // Create max pool with symmetric explicit padding: top=2, left=2, bottom=2, right=2
206        let config = MaxPool2dConfig::new([3, 3])
207            .with_strides([1, 1])
208            .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2));
209        let pool = config.init();
210
211        // Input: [batch=1, channels=2, height=4, width=5]
212        let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
213        let output = pool.forward(input);
214
215        // Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6
216        // Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7
217        assert_eq!(output.dims(), [1, 2, 6, 7]);
218    }
219}