burn_nn/modules/pool/
avg_pool1d.rs1use burn_core as burn;
2
3use crate::PaddingConfig1d;
4use burn::config::Config;
5use burn::module::Module;
6use burn::module::{Content, DisplaySettings, ModuleDisplay};
7use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9use burn::tensor::ops::PadMode;
10
11use burn::tensor::module::avg_pool1d;
12
13#[derive(Config, Debug)]
15pub struct AvgPool1dConfig {
16 pub kernel_size: usize,
18 #[config(default = "kernel_size")]
20 pub stride: usize,
21 #[config(default = "PaddingConfig1d::Valid")]
26 pub padding: PaddingConfig1d,
27 #[config(default = "true")]
29 pub count_include_pad: bool,
30 #[config(default = "false")]
32 pub ceil_mode: bool,
33}
34
35#[derive(Module, Clone, Debug)]
47#[module(custom_display)]
48pub struct AvgPool1d {
49 pub stride: usize,
51 pub kernel_size: usize,
53 pub padding: PaddingConfig1d,
55 pub count_include_pad: bool,
57 pub ceil_mode: bool,
59}
60
61impl ModuleDisplay for AvgPool1d {
62 fn custom_settings(&self) -> Option<DisplaySettings> {
63 DisplaySettings::new()
64 .with_new_line_after_attribute(false)
65 .optional()
66 }
67
68 fn custom_content(&self, content: Content) -> Option<Content> {
69 content
70 .add("kernel_size", &self.kernel_size)
71 .add("stride", &self.stride)
72 .add_debug_attribute("padding", &self.padding)
73 .add("count_include_pad", &self.count_include_pad)
74 .add("ceil_mode", &self.ceil_mode)
75 .optional()
76 }
77}
78
79impl AvgPool1dConfig {
80 pub fn init(&self) -> AvgPool1d {
82 AvgPool1d {
83 stride: self.stride,
84 kernel_size: self.kernel_size,
85 padding: self.padding.clone(),
86 count_include_pad: self.count_include_pad,
87 ceil_mode: self.ceil_mode,
88 }
89 }
90}
91
92impl AvgPool1d {
93 pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
102 let [_batch_size, _channels, length] = input.dims();
103
104 let (left, right) =
106 self.padding
107 .calculate_padding_1d_pair(length, self.kernel_size, self.stride);
108
109 if left != right {
113 let padded = input.pad((left, right, 0, 0), PadMode::Constant(0.0));
116 avg_pool1d(
118 padded,
119 self.kernel_size,
120 self.stride,
121 0,
122 self.count_include_pad,
123 self.ceil_mode,
124 )
125 } else {
126 avg_pool1d(
128 input,
129 self.kernel_size,
130 self.stride,
131 left,
132 self.count_include_pad,
133 self.ceil_mode,
134 )
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use crate::TestBackend;
143 use rstest::rstest;
144
145 #[test]
146 fn same_with_even_kernel_uses_asymmetric_padding() {
147 let device = Default::default();
148 let config = AvgPool1dConfig::new(2)
149 .with_stride(1)
150 .with_padding(PaddingConfig1d::Same);
151 let pool = config.init();
152
153 let input = Tensor::<TestBackend, 3>::ones([1, 2, 5], &device);
155 let output = pool.forward(input);
156
157 assert_eq!(output.dims(), [1, 2, 5]);
159 }
160
161 #[test]
162 fn display() {
163 let config = AvgPool1dConfig::new(3);
164 let layer = config.init();
165
166 assert_eq!(
167 alloc::format!("{layer}"),
168 "AvgPool1d {kernel_size: 3, stride: 3, padding: Valid, count_include_pad: true, ceil_mode: false}"
169 );
170 }
171
172 #[rstest]
173 #[case(1)]
174 #[case(2)]
175 fn default_strides_match_kernel_size(#[case] kernel_size: usize) {
176 let config = AvgPool1dConfig::new(kernel_size);
177
178 assert_eq!(
179 config.stride, kernel_size,
180 "Expected stride ({:?}) to match kernel size ({:?}) in default AvgPool1dConfig::new constructor",
181 config.stride, config.kernel_size
182 );
183 }
184
185 #[test]
186 fn asymmetric_padding_forward() {
187 let device = Default::default();
188 let config = AvgPool1dConfig::new(3)
190 .with_stride(1)
191 .with_padding(PaddingConfig1d::Explicit(1, 2));
192 let pool = config.init();
193
194 let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
196 let output = pool.forward(input);
197
198 assert_eq!(output.dims(), [1, 2, 5]);
201 }
202
203 #[test]
204 fn symmetric_explicit_padding_forward() {
205 let device = Default::default();
206 let config = AvgPool1dConfig::new(3)
208 .with_stride(1)
209 .with_padding(PaddingConfig1d::Explicit(2, 2));
210 let pool = config.init();
211
212 let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
214 let output = pool.forward(input);
215
216 assert_eq!(output.dims(), [1, 2, 6]);
219 }
220}