burn_core/nn/pool/
max_pool2d.rs1use crate as burn;
2use crate::nn::conv::checks::check_same_padding_support;
3
4use crate::config::Config;
5use crate::module::{Content, DisplaySettings, ModuleDisplay};
6use crate::module::{Ignored, Module};
7use crate::nn::PaddingConfig2d;
8use crate::tensor::Tensor;
9use crate::tensor::backend::Backend;
10
11use crate::tensor::module::max_pool2d;
12
13#[derive(Debug, Config)]
15pub struct MaxPool2dConfig {
16 pub kernel_size: [usize; 2],
18 #[config(default = "kernel_size")]
20 pub strides: [usize; 2],
21 #[config(default = "PaddingConfig2d::Valid")]
27 pub padding: PaddingConfig2d,
28 #[config(default = "[1, 1]")]
30 pub dilation: [usize; 2],
31}
32
33#[derive(Module, Clone, Debug)]
37#[module(custom_display)]
38pub struct MaxPool2d {
39 pub stride: [usize; 2],
41 pub kernel_size: [usize; 2],
43 pub padding: Ignored<PaddingConfig2d>,
45 pub dilation: [usize; 2],
47}
48
49impl ModuleDisplay for MaxPool2d {
50 fn custom_settings(&self) -> Option<DisplaySettings> {
51 DisplaySettings::new()
52 .with_new_line_after_attribute(false)
53 .optional()
54 }
55
56 fn custom_content(&self, content: Content) -> Option<Content> {
57 content
58 .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
59 .add("stride", &alloc::format!("{:?}", &self.stride))
60 .add("padding", &self.padding)
61 .add("dilation", &alloc::format!("{:?}", &self.dilation))
62 .optional()
63 }
64}
65
66impl MaxPool2dConfig {
67 pub fn init(&self) -> MaxPool2d {
69 if self.padding == PaddingConfig2d::Same {
70 check_same_padding_support(&self.kernel_size);
71 }
72 MaxPool2d {
73 stride: self.strides,
74 kernel_size: self.kernel_size,
75 padding: Ignored(self.padding.clone()),
76 dilation: self.dilation,
77 }
78 }
79}
80
81impl MaxPool2d {
82 pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
91 let [_batch_size, _channels_in, height_in, width_in] = input.dims();
92 let padding =
93 self.padding
94 .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride);
95
96 max_pool2d(input, self.kernel_size, self.stride, padding, self.dilation)
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use rstest::rstest;
104
105 #[test]
106 #[should_panic = "Same padding with an even kernel size is not supported"]
107 fn same_with_even_kernel_is_invalid() {
108 let config = MaxPool2dConfig::new([2, 2]).with_padding(PaddingConfig2d::Same);
109 let _ = config.init();
110 }
111
112 #[test]
113 fn display() {
114 let config = MaxPool2dConfig::new([3, 3]);
115
116 let layer = config.init();
117
118 assert_eq!(
119 alloc::format!("{}", layer),
120 "MaxPool2d {kernel_size: [3, 3], stride: [3, 3], padding: Valid, dilation: [1, 1]}"
121 );
122 }
123
124 #[rstest]
125 #[case([2, 2])]
126 #[case([1, 2])]
127 fn default_strides_match_kernel_size(#[case] kernel_size: [usize; 2]) {
128 let config = MaxPool2dConfig::new(kernel_size);
129
130 assert_eq!(
131 config.strides, kernel_size,
132 "Expected strides ({:?}) to match kernel size ({:?}) in default MaxPool2dConfig::new constructor",
133 config.strides, config.kernel_size
134 );
135 }
136}