burn_core/nn/pool/
max_pool2d.rs1use crate as burn;
2use crate::nn::conv::checks::check_same_padding_support;
3
4use crate::config::Config;
5use crate::module::{Content, DisplaySettings, ModuleDisplay};
6use crate::module::{Ignored, Module};
7use crate::nn::PaddingConfig2d;
8use crate::tensor::Tensor;
9use crate::tensor::backend::Backend;
10
11use crate::tensor::module::max_pool2d;
12
13#[derive(Debug, Config)]
15pub struct MaxPool2dConfig {
16 pub kernel_size: [usize; 2],
18 #[config(default = "[1, 1]")]
20 pub strides: [usize; 2],
21 #[config(default = "PaddingConfig2d::Valid")]
27 pub padding: PaddingConfig2d,
28 #[config(default = "[1, 1]")]
30 pub dilation: [usize; 2],
31}
32
33#[derive(Module, Clone, Debug)]
37#[module(custom_display)]
38pub struct MaxPool2d {
39 pub stride: [usize; 2],
41 pub kernel_size: [usize; 2],
43 pub padding: Ignored<PaddingConfig2d>,
45 pub dilation: [usize; 2],
47}
48
49impl ModuleDisplay for MaxPool2d {
50 fn custom_settings(&self) -> Option<DisplaySettings> {
51 DisplaySettings::new()
52 .with_new_line_after_attribute(false)
53 .optional()
54 }
55
56 fn custom_content(&self, content: Content) -> Option<Content> {
57 content
58 .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
59 .add("stride", &alloc::format!("{:?}", &self.stride))
60 .add("padding", &self.padding)
61 .add("dilation", &alloc::format!("{:?}", &self.dilation))
62 .optional()
63 }
64}
65
66impl MaxPool2dConfig {
67 pub fn init(&self) -> MaxPool2d {
69 if self.padding == PaddingConfig2d::Same {
70 check_same_padding_support(&self.kernel_size);
71 }
72 MaxPool2d {
73 stride: self.strides,
74 kernel_size: self.kernel_size,
75 padding: Ignored(self.padding.clone()),
76 dilation: self.dilation,
77 }
78 }
79}
80
81impl MaxPool2d {
82 pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
91 let [_batch_size, _channels_in, height_in, width_in] = input.dims();
92 let padding =
93 self.padding
94 .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride);
95
96 max_pool2d(input, self.kernel_size, self.stride, padding, self.dilation)
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 #[should_panic = "Same padding with an even kernel size is not supported"]
106 fn same_with_even_kernel_is_invalid() {
107 let config = MaxPool2dConfig::new([2, 2]).with_padding(PaddingConfig2d::Same);
108 let _ = config.init();
109 }
110
111 #[test]
112 fn display() {
113 let config = MaxPool2dConfig::new([3, 3]);
114
115 let layer = config.init();
116
117 assert_eq!(
118 alloc::format!("{}", layer),
119 "MaxPool2d {kernel_size: [3, 3], stride: [1, 1], padding: Valid, dilation: [1, 1]}"
120 );
121 }
122}