Skip to main content

burn_nn/modules/pool/
avg_pool1d.rs

1use burn_core as burn;
2
3use crate::PaddingConfig1d;
4use burn::config::Config;
5use burn::module::Module;
6use burn::module::{Content, DisplaySettings, ModuleDisplay};
7use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9use burn::tensor::ops::PadMode;
10
11use burn::tensor::module::avg_pool1d;
12
13/// Configuration to create a [1D avg pooling](AvgPool1d) layer using the [init function](AvgPool1dConfig::init).
14#[derive(Config, Debug)]
15pub struct AvgPool1dConfig {
16    /// The size of the kernel.
17    pub kernel_size: usize,
18    /// The stride.
19    #[config(default = "kernel_size")]
20    pub stride: usize,
21    /// The padding configuration.
22    ///
23    /// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes
24    /// will automatically use asymmetric padding to preserve input dimensions.
25    #[config(default = "PaddingConfig1d::Valid")]
26    pub padding: PaddingConfig1d,
27    /// If the padding is counted in the denominator when computing the average.
28    #[config(default = "true")]
29    pub count_include_pad: bool,
30    /// If true, use ceiling instead of floor for output size calculation.
31    #[config(default = "false")]
32    pub ceil_mode: bool,
33}
34
35/// Applies a 1D avg pooling over input tensors.
36///
37/// Should be created with [AvgPool1dConfig](AvgPool1dConfig).
38///
39/// # Remarks
40///
41/// The zero-padding values will be included in the calculation
42/// of the average. This means that the zeros are counted as
43/// legitimate values, and they contribute to the denominator
44/// when calculating the average. This is equivalent to
45/// `torch.nn.AvgPool2d` with `count_include_pad=True`.
46#[derive(Module, Clone, Debug)]
47#[module(custom_display)]
48pub struct AvgPool1d {
49    /// The stride.
50    pub stride: usize,
51    /// The size of the kernel.
52    pub kernel_size: usize,
53    /// The padding configuration.
54    pub padding: PaddingConfig1d,
55    /// If the padding is counted in the denominator when computing the average.
56    pub count_include_pad: bool,
57    /// If true, use ceiling instead of floor for output size calculation.
58    pub ceil_mode: bool,
59}
60
61impl ModuleDisplay for AvgPool1d {
62    fn custom_settings(&self) -> Option<DisplaySettings> {
63        DisplaySettings::new()
64            .with_new_line_after_attribute(false)
65            .optional()
66    }
67
68    fn custom_content(&self, content: Content) -> Option<Content> {
69        content
70            .add("kernel_size", &self.kernel_size)
71            .add("stride", &self.stride)
72            .add_debug_attribute("padding", &self.padding)
73            .add("count_include_pad", &self.count_include_pad)
74            .add("ceil_mode", &self.ceil_mode)
75            .optional()
76    }
77}
78
79impl AvgPool1dConfig {
80    /// Initialize a new [avg pool 1d](AvgPool1d) module.
81    pub fn init(&self) -> AvgPool1d {
82        AvgPool1d {
83            stride: self.stride,
84            kernel_size: self.kernel_size,
85            padding: self.padding.clone(),
86            count_include_pad: self.count_include_pad,
87            ceil_mode: self.ceil_mode,
88        }
89    }
90}
91
92impl AvgPool1d {
93    /// Applies the forward pass on the input tensor.
94    ///
95    /// See [avg_pool1d](burn::tensor::module::avg_pool1d) for more information.
96    ///
97    /// # Shapes
98    ///
99    /// - input: `[batch_size, channels, length_in]`
100    /// - output: `[batch_size, channels, length_out]`
101    pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
102        let [_batch_size, _channels, length] = input.dims();
103
104        // Calculate padding as pair - handles Same, Valid, and Explicit uniformly
105        let (left, right) =
106            self.padding
107                .calculate_padding_1d_pair(length, self.kernel_size, self.stride);
108
109        // TODO: Move asymmetric padding to functional level via PoolOptions
110        // See: https://github.com/tracel-ai/burn/issues/4362
111        // Handle asymmetric padding by applying explicit pad operation first
112        if left != right {
113            // Burn's pad takes (left, right, top, bottom) for the last two dimensions
114            // For 1D (NCL format), we only pad L (last dim), so top/bottom = 0
115            let padded = input.pad((left, right, 0, 0), PadMode::Constant(0.0));
116            // Use zero padding for the pool operation since we already padded
117            avg_pool1d(
118                padded,
119                self.kernel_size,
120                self.stride,
121                0,
122                self.count_include_pad,
123                self.ceil_mode,
124            )
125        } else {
126            // Symmetric padding
127            avg_pool1d(
128                input,
129                self.kernel_size,
130                self.stride,
131                left,
132                self.count_include_pad,
133                self.ceil_mode,
134            )
135        }
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use crate::TestBackend;
143    use rstest::rstest;
144
145    #[test]
146    fn same_with_even_kernel_uses_asymmetric_padding() {
147        let device = Default::default();
148        let config = AvgPool1dConfig::new(2)
149            .with_stride(1)
150            .with_padding(PaddingConfig1d::Same);
151        let pool = config.init();
152
153        // Input: [batch=1, channels=2, length=5]
154        let input = Tensor::<TestBackend, 3>::ones([1, 2, 5], &device);
155        let output = pool.forward(input);
156
157        // Same padding should preserve spatial dimensions
158        assert_eq!(output.dims(), [1, 2, 5]);
159    }
160
161    #[test]
162    fn display() {
163        let config = AvgPool1dConfig::new(3);
164        let layer = config.init();
165
166        assert_eq!(
167            alloc::format!("{layer}"),
168            "AvgPool1d {kernel_size: 3, stride: 3, padding: Valid, count_include_pad: true, ceil_mode: false}"
169        );
170    }
171
172    #[rstest]
173    #[case(1)]
174    #[case(2)]
175    fn default_strides_match_kernel_size(#[case] kernel_size: usize) {
176        let config = AvgPool1dConfig::new(kernel_size);
177
178        assert_eq!(
179            config.stride, kernel_size,
180            "Expected stride ({:?}) to match kernel size ({:?}) in default AvgPool1dConfig::new constructor",
181            config.stride, config.kernel_size
182        );
183    }
184
185    #[test]
186    fn asymmetric_padding_forward() {
187        let device = Default::default();
188        // Create avg pool with asymmetric padding: left=1, right=2
189        let config = AvgPool1dConfig::new(3)
190            .with_stride(1)
191            .with_padding(PaddingConfig1d::Explicit(1, 2));
192        let pool = config.init();
193
194        // Input: [batch=1, channels=2, length=4]
195        let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
196        let output = pool.forward(input);
197
198        // With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7
199        // Output length = (7 - 3) / 1 + 1 = 5
200        assert_eq!(output.dims(), [1, 2, 5]);
201    }
202
203    #[test]
204    fn symmetric_explicit_padding_forward() {
205        let device = Default::default();
206        // Create avg pool with symmetric explicit padding: left=2, right=2
207        let config = AvgPool1dConfig::new(3)
208            .with_stride(1)
209            .with_padding(PaddingConfig1d::Explicit(2, 2));
210        let pool = config.init();
211
212        // Input: [batch=1, channels=2, length=4]
213        let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
214        let output = pool.forward(input);
215
216        // With symmetric padding (2, 2), input length 4 becomes 4+2+2=8
217        // Output length = (8 - 3) / 1 + 1 = 6
218        assert_eq!(output.dims(), [1, 2, 6]);
219    }
220}