burn_nn/modules/pool/
avg_pool1d.rs

1use crate::conv::checks::check_same_padding_support;
2use burn_core as burn;
3
4use crate::PaddingConfig1d;
5use burn::config::Config;
6use burn::module::{Content, DisplaySettings, ModuleDisplay};
7use burn::module::{Ignored, Module};
8use burn::tensor::Tensor;
9use burn::tensor::backend::Backend;
10
11use burn::tensor::module::avg_pool1d;
12
13/// Configuration to create a [1D avg pooling](AvgPool1d) layer using the [init function](AvgPool1dConfig::init).
14#[derive(Config, Debug)]
15pub struct AvgPool1dConfig {
16    /// The size of the kernel.
17    pub kernel_size: usize,
18    /// The stride.
19    #[config(default = "kernel_size")]
20    pub stride: usize,
21    /// The padding configuration.
22    ///
23    /// ### Warning
24    /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel
25    /// size is not supported as it will not produce the same output size.
26    #[config(default = "PaddingConfig1d::Valid")]
27    pub padding: PaddingConfig1d,
28    /// If the padding is counted in the denominator when computing the average.
29    #[config(default = "true")]
30    pub count_include_pad: bool,
31    /// If true, use ceiling instead of floor for output size calculation.
32    #[config(default = "false")]
33    pub ceil_mode: bool,
34}
35
36/// Applies a 1D avg pooling over input tensors.
37///
38/// Should be created with [AvgPool1dConfig](AvgPool1dConfig).
39///
40/// # Remarks
41///
42/// The zero-padding values will be included in the calculation
43/// of the average. This means that the zeros are counted as
44/// legitimate values, and they contribute to the denominator
45/// when calculating the average. This is equivalent to
46/// `torch.nn.AvgPool2d` with `count_include_pad=True`.
47#[derive(Module, Clone, Debug)]
48#[module(custom_display)]
49pub struct AvgPool1d {
50    /// The stride.
51    pub stride: usize,
52    /// The size of the kernel.
53    pub kernel_size: usize,
54    /// The padding configuration.
55    pub padding: Ignored<PaddingConfig1d>,
56    /// If the padding is counted in the denominator when computing the average.
57    pub count_include_pad: bool,
58    /// If true, use ceiling instead of floor for output size calculation.
59    pub ceil_mode: bool,
60}
61
62impl ModuleDisplay for AvgPool1d {
63    fn custom_settings(&self) -> Option<DisplaySettings> {
64        DisplaySettings::new()
65            .with_new_line_after_attribute(false)
66            .optional()
67    }
68
69    fn custom_content(&self, content: Content) -> Option<Content> {
70        content
71            .add("kernel_size", &self.kernel_size)
72            .add("stride", &self.stride)
73            .add("padding", &self.padding)
74            .add("count_include_pad", &self.count_include_pad)
75            .add("ceil_mode", &self.ceil_mode)
76            .optional()
77    }
78}
79
80impl AvgPool1dConfig {
81    /// Initialize a new [avg pool 1d](AvgPool1d) module.
82    pub fn init(&self) -> AvgPool1d {
83        if self.padding == PaddingConfig1d::Same {
84            check_same_padding_support(&[self.kernel_size]);
85        }
86        AvgPool1d {
87            stride: self.stride,
88            kernel_size: self.kernel_size,
89            padding: Ignored(self.padding.clone()),
90            count_include_pad: self.count_include_pad,
91            ceil_mode: self.ceil_mode,
92        }
93    }
94}
95
96impl AvgPool1d {
97    /// Applies the forward pass on the input tensor.
98    ///
99    /// See [avg_pool1d](burn::tensor::module::avg_pool1d) for more information.
100    ///
101    /// # Shapes
102    ///
103    /// - input: `[batch_size, channels, length_in]`
104    /// - output: `[batch_size, channels, length_out]`
105    pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
106        let [_batch_size, _channels, length] = input.dims();
107        let padding = self
108            .padding
109            .calculate_padding_1d(length, self.kernel_size, self.stride);
110
111        avg_pool1d(
112            input,
113            self.kernel_size,
114            self.stride,
115            padding,
116            self.count_include_pad,
117            self.ceil_mode,
118        )
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use rstest::rstest;
126
127    #[test]
128    #[should_panic = "Same padding with an even kernel size is not supported"]
129    fn same_with_even_kernel_is_invalid() {
130        let config = AvgPool1dConfig::new(2).with_padding(PaddingConfig1d::Same);
131        let _ = config.init();
132    }
133
134    #[test]
135    fn display() {
136        let config = AvgPool1dConfig::new(3);
137        let layer = config.init();
138
139        assert_eq!(
140            alloc::format!("{layer}"),
141            "AvgPool1d {kernel_size: 3, stride: 3, padding: Valid, count_include_pad: true, ceil_mode: false}"
142        );
143    }
144
145    #[rstest]
146    #[case(1)]
147    #[case(2)]
148    fn default_strides_match_kernel_size(#[case] kernel_size: usize) {
149        let config = AvgPool1dConfig::new(kernel_size);
150
151        assert_eq!(
152            config.stride, kernel_size,
153            "Expected stride ({:?}) to match kernel size ({:?}) in default AvgPool1dConfig::new constructor",
154            config.stride, config.kernel_size
155        );
156    }
157}