burn_nn/modules/pool/
max_pool2d.rs1use crate::conv::checks::check_same_padding_support;
2use burn_core as burn;
3
4use crate::PaddingConfig2d;
5use burn::config::Config;
6use burn::module::{Content, DisplaySettings, ModuleDisplay};
7use burn::module::{Ignored, Module};
8use burn::tensor::Tensor;
9use burn::tensor::backend::Backend;
10
11use burn::tensor::module::max_pool2d;
12
13#[derive(Debug, Config)]
15pub struct MaxPool2dConfig {
16 pub kernel_size: [usize; 2],
18 #[config(default = "kernel_size")]
20 pub strides: [usize; 2],
21 #[config(default = "PaddingConfig2d::Valid")]
27 pub padding: PaddingConfig2d,
28 #[config(default = "[1, 1]")]
30 pub dilation: [usize; 2],
31 #[config(default = "false")]
33 pub ceil_mode: bool,
34}
35
36#[derive(Module, Clone, Debug)]
40#[module(custom_display)]
41pub struct MaxPool2d {
42 pub stride: [usize; 2],
44 pub kernel_size: [usize; 2],
46 pub padding: Ignored<PaddingConfig2d>,
48 pub dilation: [usize; 2],
50 pub ceil_mode: bool,
52}
53
54impl ModuleDisplay for MaxPool2d {
55 fn custom_settings(&self) -> Option<DisplaySettings> {
56 DisplaySettings::new()
57 .with_new_line_after_attribute(false)
58 .optional()
59 }
60
61 fn custom_content(&self, content: Content) -> Option<Content> {
62 content
63 .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
64 .add("stride", &alloc::format!("{:?}", &self.stride))
65 .add("padding", &self.padding)
66 .add("dilation", &alloc::format!("{:?}", &self.dilation))
67 .add("ceil_mode", &self.ceil_mode)
68 .optional()
69 }
70}
71
72impl MaxPool2dConfig {
73 pub fn init(&self) -> MaxPool2d {
75 if self.padding == PaddingConfig2d::Same {
76 check_same_padding_support(&self.kernel_size);
77 }
78 MaxPool2d {
79 stride: self.strides,
80 kernel_size: self.kernel_size,
81 padding: Ignored(self.padding.clone()),
82 dilation: self.dilation,
83 ceil_mode: self.ceil_mode,
84 }
85 }
86}
87
88impl MaxPool2d {
89 pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
98 let [_batch_size, _channels_in, height_in, width_in] = input.dims();
99 let padding =
100 self.padding
101 .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride);
102
103 max_pool2d(
104 input,
105 self.kernel_size,
106 self.stride,
107 padding,
108 self.dilation,
109 self.ceil_mode,
110 )
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use rstest::rstest;
118
119 #[test]
120 #[should_panic = "Same padding with an even kernel size is not supported"]
121 fn same_with_even_kernel_is_invalid() {
122 let config = MaxPool2dConfig::new([2, 2]).with_padding(PaddingConfig2d::Same);
123 let _ = config.init();
124 }
125
126 #[test]
127 fn display() {
128 let config = MaxPool2dConfig::new([3, 3]);
129
130 let layer = config.init();
131
132 assert_eq!(
133 alloc::format!("{layer}"),
134 "MaxPool2d {kernel_size: [3, 3], stride: [3, 3], padding: Valid, dilation: [1, 1], ceil_mode: false}"
135 );
136 }
137
138 #[rstest]
139 #[case([2, 2])]
140 #[case([1, 2])]
141 fn default_strides_match_kernel_size(#[case] kernel_size: [usize; 2]) {
142 let config = MaxPool2dConfig::new(kernel_size);
143
144 assert_eq!(
145 config.strides, kernel_size,
146 "Expected strides ({:?}) to match kernel size ({:?}) in default MaxPool2dConfig::new constructor",
147 config.strides, config.kernel_size
148 );
149 }
150}