1use burn_core as burn;
2
3use crate::PaddingConfig2d;
4use burn::config::Config;
5use burn::module::Module;
6use burn::module::{Content, DisplaySettings, ModuleDisplay};
7use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9use burn::tensor::ops::PadMode;
10
11use burn::tensor::module::avg_pool2d;
12
13#[derive(Config, Debug)]
15pub struct AvgPool2dConfig {
16 pub kernel_size: [usize; 2],
18 #[config(default = "kernel_size")]
20 pub strides: [usize; 2],
21 #[config(default = "PaddingConfig2d::Valid")]
26 pub padding: PaddingConfig2d,
27 #[config(default = "true")]
29 pub count_include_pad: bool,
30 #[config(default = "false")]
32 pub ceil_mode: bool,
33}
34
35#[derive(Module, Clone, Debug)]
47#[module(custom_display)]
48pub struct AvgPool2d {
49 pub stride: [usize; 2],
51 pub kernel_size: [usize; 2],
53 pub padding: PaddingConfig2d,
55 pub count_include_pad: bool,
57 pub ceil_mode: bool,
59}
60
61impl ModuleDisplay for AvgPool2d {
62 fn custom_settings(&self) -> Option<DisplaySettings> {
63 DisplaySettings::new()
64 .with_new_line_after_attribute(false)
65 .optional()
66 }
67
68 fn custom_content(&self, content: Content) -> Option<Content> {
69 content
70 .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
71 .add("stride", &alloc::format!("{:?}", &self.stride))
72 .add_debug_attribute("padding", &self.padding)
73 .add("count_include_pad", &self.count_include_pad)
74 .add("ceil_mode", &self.ceil_mode)
75 .optional()
76 }
77}
78
79impl AvgPool2dConfig {
80 pub fn init(&self) -> AvgPool2d {
82 AvgPool2d {
83 stride: self.strides,
84 kernel_size: self.kernel_size,
85 padding: self.padding.clone(),
86 count_include_pad: self.count_include_pad,
87 ceil_mode: self.ceil_mode,
88 }
89 }
90}
91
92impl AvgPool2d {
93 pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
102 let [_batch_size, _channels_in, height_in, width_in] = input.dims();
103
104 let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs(
106 height_in,
107 width_in,
108 &self.kernel_size,
109 &self.stride,
110 );
111
112 if top != bottom || left != right {
116 let padded = input.pad((left, right, top, bottom), PadMode::Constant(0.0));
118 avg_pool2d(
120 padded,
121 self.kernel_size,
122 self.stride,
123 [0, 0],
124 self.count_include_pad,
125 self.ceil_mode,
126 )
127 } else {
128 avg_pool2d(
130 input,
131 self.kernel_size,
132 self.stride,
133 [top, left],
134 self.count_include_pad,
135 self.ceil_mode,
136 )
137 }
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use crate::TestBackend;
145 use rstest::rstest;
146
147 #[test]
148 fn same_with_even_kernel_uses_asymmetric_padding() {
149 let device = Default::default();
150 let config = AvgPool2dConfig::new([2, 2])
151 .with_strides([1, 1])
152 .with_padding(PaddingConfig2d::Same);
153 let pool = config.init();
154
155 let input = Tensor::<TestBackend, 4>::ones([1, 2, 5, 5], &device);
157 let output = pool.forward(input);
158
159 assert_eq!(output.dims(), [1, 2, 5, 5]);
161 }
162
163 #[test]
164 fn display() {
165 let config = AvgPool2dConfig::new([3, 3]);
166
167 let layer = config.init();
168
169 assert_eq!(
170 alloc::format!("{layer}"),
171 "AvgPool2d {kernel_size: [3, 3], stride: [3, 3], padding: Valid, count_include_pad: true, ceil_mode: false}"
172 );
173 }
174
175 #[rstest]
176 #[case([2, 2])]
177 #[case([1, 2])]
178 fn default_strides_match_kernel_size(#[case] kernel_size: [usize; 2]) {
179 let config = AvgPool2dConfig::new(kernel_size);
180
181 assert_eq!(
182 config.strides, kernel_size,
183 "Expected strides ({:?}) to match kernel size ({:?}) in default AvgPool2dConfig::new constructor",
184 config.strides, config.kernel_size
185 );
186 }
187
188 #[test]
189 fn asymmetric_padding_forward() {
190 let device = Default::default();
191 let config = AvgPool2dConfig::new([3, 3])
193 .with_strides([1, 1])
194 .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4));
195 let pool = config.init();
196
197 let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
199 let output = pool.forward(input);
200
201 assert_eq!(output.dims(), [1, 2, 6, 9]);
204 }
205
206 #[test]
207 fn symmetric_explicit_padding_forward() {
208 let device = Default::default();
209 let config = AvgPool2dConfig::new([3, 3])
211 .with_strides([1, 1])
212 .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2));
213 let pool = config.init();
214
215 let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);
217 let output = pool.forward(input);
218
219 assert_eq!(output.dims(), [1, 2, 6, 7]);
222 }
223}