burn_core/nn/pool/
max_pool1d.rs

1use crate as burn;
2use crate::nn::conv::checks::check_same_padding_support;
3
4use crate::config::Config;
5use crate::module::{Content, DisplaySettings, ModuleDisplay};
6use crate::module::{Ignored, Module};
7use crate::nn::PaddingConfig1d;
8use crate::tensor::Tensor;
9use crate::tensor::backend::Backend;
10
11use crate::tensor::module::max_pool1d;
12
13/// Configuration to create a [1D max pooling](MaxPool1d) layer using the [init function](MaxPool1dConfig::init).
14#[derive(Config, Debug)]
15pub struct MaxPool1dConfig {
16    /// The size of the kernel.
17    pub kernel_size: usize,
18    /// The stride.
19    #[config(default = "1")]
20    pub stride: usize,
21    /// The padding configuration.
22    ///
23    /// ### Warning
24    /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel
25    /// size is not supported as it will not produce the same output size.
26    #[config(default = "PaddingConfig1d::Valid")]
27    pub padding: PaddingConfig1d,
28    /// The dilation.
29    #[config(default = "1")]
30    pub dilation: usize,
31}
32
33/// Applies a 1D max pooling over input tensors.
34///
35/// Should be created with [MaxPool1dConfig](MaxPool1dConfig).
36#[derive(Module, Clone, Debug)]
37#[module(custom_display)]
38pub struct MaxPool1d {
39    /// The stride.
40    pub stride: usize,
41    /// The size of the kernel.
42    pub kernel_size: usize,
43    /// The padding configuration.
44    pub padding: Ignored<PaddingConfig1d>,
45    /// The dilation.
46    pub dilation: usize,
47}
48
49impl ModuleDisplay for MaxPool1d {
50    fn custom_settings(&self) -> Option<DisplaySettings> {
51        DisplaySettings::new()
52            .with_new_line_after_attribute(false)
53            .optional()
54    }
55
56    fn custom_content(&self, content: Content) -> Option<Content> {
57        content
58            .add("kernel_size", &self.kernel_size)
59            .add("stride", &self.stride)
60            .add("padding", &self.padding)
61            .add("dilation", &self.dilation)
62            .optional()
63    }
64}
65
66impl MaxPool1dConfig {
67    /// Initialize a new [max pool 1d](MaxPool1d) module.
68    pub fn init(&self) -> MaxPool1d {
69        if self.padding == PaddingConfig1d::Same {
70            check_same_padding_support(&[self.kernel_size]);
71        }
72        MaxPool1d {
73            stride: self.stride,
74            kernel_size: self.kernel_size,
75            padding: Ignored(self.padding.clone()),
76            dilation: self.dilation,
77        }
78    }
79}
80
81impl MaxPool1d {
82    /// Applies the forward pass on the input tensor.
83    ///
84    /// See [max_pool1d](crate::tensor::module::max_pool1d) for more information.
85    ///
86    /// # Shapes
87    ///
88    /// - input: `[batch_size, channels, length_in]`
89    /// - output: `[batch_size, channels, length_out]`
90    pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
91        let [_batch_size, _channels, length] = input.dims();
92        let padding = self
93            .padding
94            .calculate_padding_1d(length, self.kernel_size, self.stride);
95
96        max_pool1d(input, self.kernel_size, self.stride, padding, self.dilation)
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    #[should_panic = "Same padding with an even kernel size is not supported"]
106    fn same_with_even_kernel_is_invalid() {
107        let config = MaxPool1dConfig::new(2).with_padding(PaddingConfig1d::Same);
108        let _ = config.init();
109    }
110
111    #[test]
112    fn display() {
113        let config = MaxPool1dConfig::new(3);
114
115        let layer = config.init();
116
117        assert_eq!(
118            alloc::format!("{}", layer),
119            "MaxPool1d {kernel_size: 3, stride: 1, padding: Valid, dilation: 1}"
120        );
121    }
122}