Skip to main content

burn_nn/modules/pool/
max_pool1d.rs

1use burn_core as burn;
2
3use crate::PaddingConfig1d;
4use burn::config::Config;
5use burn::module::Module;
6use burn::module::{Content, DisplaySettings, ModuleDisplay};
7use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9use burn::tensor::ops::PadMode;
10
11use burn::tensor::module::max_pool1d;
12
13/// Configuration to create a [1D max pooling](MaxPool1d) layer using the [init function](MaxPool1dConfig::init).
14#[derive(Config, Debug)]
15pub struct MaxPool1dConfig {
16    /// The size of the kernel.
17    pub kernel_size: usize,
18    /// The stride.
19    #[config(default = "kernel_size")]
20    pub stride: usize,
21    /// The padding configuration.
22    ///
23    /// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes
24    /// will automatically use asymmetric padding to preserve input dimensions.
25    #[config(default = "PaddingConfig1d::Valid")]
26    pub padding: PaddingConfig1d,
27    /// The dilation.
28    #[config(default = "1")]
29    pub dilation: usize,
30    /// If true, use ceiling instead of floor for output size calculation.
31    #[config(default = "false")]
32    pub ceil_mode: bool,
33}
34
35/// Applies a 1D max pooling over input tensors.
36///
37/// Should be created with [MaxPool1dConfig](MaxPool1dConfig).
38#[derive(Module, Clone, Debug)]
39#[module(custom_display)]
40pub struct MaxPool1d {
41    /// The stride.
42    pub stride: usize,
43    /// The size of the kernel.
44    pub kernel_size: usize,
45    /// The padding configuration.
46    pub padding: PaddingConfig1d,
47    /// The dilation.
48    pub dilation: usize,
49    /// If true, use ceiling instead of floor for output size calculation.
50    pub ceil_mode: bool,
51}
52
53impl ModuleDisplay for MaxPool1d {
54    fn custom_settings(&self) -> Option<DisplaySettings> {
55        DisplaySettings::new()
56            .with_new_line_after_attribute(false)
57            .optional()
58    }
59
60    fn custom_content(&self, content: Content) -> Option<Content> {
61        content
62            .add("kernel_size", &self.kernel_size)
63            .add("stride", &self.stride)
64            .add_debug_attribute("padding", &self.padding)
65            .add("dilation", &self.dilation)
66            .add("ceil_mode", &self.ceil_mode)
67            .optional()
68    }
69}
70
71impl MaxPool1dConfig {
72    /// Initialize a new [max pool 1d](MaxPool1d) module.
73    pub fn init(&self) -> MaxPool1d {
74        MaxPool1d {
75            stride: self.stride,
76            kernel_size: self.kernel_size,
77            padding: self.padding.clone(),
78            dilation: self.dilation,
79            ceil_mode: self.ceil_mode,
80        }
81    }
82}
83
84impl MaxPool1d {
85    /// Applies the forward pass on the input tensor.
86    ///
87    /// See [max_pool1d](burn::tensor::module::max_pool1d) for more information.
88    ///
89    /// # Shapes
90    ///
91    /// - input: `[batch_size, channels, length_in]`
92    /// - output: `[batch_size, channels, length_out]`
93    pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
94        let [_batch_size, _channels, length] = input.dims();
95
96        // Calculate padding as pair - handles Same, Valid, and Explicit uniformly
97        let (left, right) =
98            self.padding
99                .calculate_padding_1d_pair(length, self.kernel_size, self.stride);
100
101        // TODO: Move asymmetric padding to functional level via PoolOptions
102        // See: https://github.com/tracel-ai/burn/issues/4362
103        // Handle asymmetric padding by applying explicit pad operation first
104        if left != right {
105            // For 1D (NCL format), pad the length dimension with (left, right)
106            // and no padding for channel dimension (top=0, bottom=0)
107            // Use -inf for max pooling so padded values don't affect the max
108            let padded = input.pad((left, right, 0, 0), PadMode::Constant(f32::NEG_INFINITY));
109            // Use zero padding for the pool operation since we already padded
110            max_pool1d(
111                padded,
112                self.kernel_size,
113                self.stride,
114                0,
115                self.dilation,
116                self.ceil_mode,
117            )
118        } else {
119            // Symmetric padding
120            max_pool1d(
121                input,
122                self.kernel_size,
123                self.stride,
124                left,
125                self.dilation,
126                self.ceil_mode,
127            )
128        }
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use crate::TestBackend;
136    use rstest::rstest;
137
138    #[test]
139    fn same_with_even_kernel_uses_asymmetric_padding() {
140        let device = Default::default();
141        let config = MaxPool1dConfig::new(2)
142            .with_stride(1)
143            .with_padding(PaddingConfig1d::Same);
144        let pool = config.init();
145
146        // Input: [batch=1, channels=2, length=5]
147        let input = Tensor::<TestBackend, 3>::ones([1, 2, 5], &device);
148        let output = pool.forward(input);
149
150        // Same padding should preserve spatial dimensions
151        assert_eq!(output.dims(), [1, 2, 5]);
152    }
153
154    #[test]
155    fn display() {
156        let config = MaxPool1dConfig::new(3);
157
158        let layer = config.init();
159
160        assert_eq!(
161            alloc::format!("{layer}"),
162            "MaxPool1d {kernel_size: 3, stride: 3, padding: Valid, dilation: 1, ceil_mode: false}"
163        );
164    }
165
166    #[rstest]
167    #[case(1)]
168    #[case(2)]
169    fn default_strides_match_kernel_size(#[case] kernel_size: usize) {
170        let config = MaxPool1dConfig::new(kernel_size);
171
172        assert_eq!(
173            config.stride, kernel_size,
174            "Expected stride ({:?}) to match kernel size ({:?}) in default MaxPool1dConfig::new constructor",
175            config.stride, config.kernel_size
176        );
177    }
178
179    #[test]
180    fn asymmetric_padding_forward() {
181        let device = Default::default();
182        // Create max pool with asymmetric padding: left=1, right=2
183        let config = MaxPool1dConfig::new(3)
184            .with_stride(1)
185            .with_padding(PaddingConfig1d::Explicit(1, 2));
186        let pool = config.init();
187
188        // Input: [batch=1, channels=2, length=4]
189        let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
190        let output = pool.forward(input);
191
192        // With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7
193        // Output length = (7 - 3) / 1 + 1 = 5
194        assert_eq!(output.dims(), [1, 2, 5]);
195    }
196
197    #[test]
198    fn symmetric_explicit_padding_forward() {
199        let device = Default::default();
200        // Create max pool with symmetric explicit padding: left=2, right=2
201        let config = MaxPool1dConfig::new(3)
202            .with_stride(1)
203            .with_padding(PaddingConfig1d::Explicit(2, 2));
204        let pool = config.init();
205
206        // Input: [batch=1, channels=2, length=4]
207        let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
208        let output = pool.forward(input);
209
210        // With symmetric padding (2, 2), input length 4 becomes 4+2+2=8
211        // Output length = (8 - 3) / 1 + 1 = 6
212        assert_eq!(output.dims(), [1, 2, 6]);
213    }
214}