burn_nn/modules/pool/
max_pool1d.rs1use burn_core as burn;
2
3use crate::PaddingConfig1d;
4use burn::config::Config;
5use burn::module::Module;
6use burn::module::{Content, DisplaySettings, ModuleDisplay};
7use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9use burn::tensor::ops::PadMode;
10
11use burn::tensor::module::max_pool1d;
12
13#[derive(Config, Debug)]
15pub struct MaxPool1dConfig {
16 pub kernel_size: usize,
18 #[config(default = "kernel_size")]
20 pub stride: usize,
21 #[config(default = "PaddingConfig1d::Valid")]
26 pub padding: PaddingConfig1d,
27 #[config(default = "1")]
29 pub dilation: usize,
30 #[config(default = "false")]
32 pub ceil_mode: bool,
33}
34
35#[derive(Module, Clone, Debug)]
39#[module(custom_display)]
40pub struct MaxPool1d {
41 pub stride: usize,
43 pub kernel_size: usize,
45 pub padding: PaddingConfig1d,
47 pub dilation: usize,
49 pub ceil_mode: bool,
51}
52
53impl ModuleDisplay for MaxPool1d {
54 fn custom_settings(&self) -> Option<DisplaySettings> {
55 DisplaySettings::new()
56 .with_new_line_after_attribute(false)
57 .optional()
58 }
59
60 fn custom_content(&self, content: Content) -> Option<Content> {
61 content
62 .add("kernel_size", &self.kernel_size)
63 .add("stride", &self.stride)
64 .add_debug_attribute("padding", &self.padding)
65 .add("dilation", &self.dilation)
66 .add("ceil_mode", &self.ceil_mode)
67 .optional()
68 }
69}
70
71impl MaxPool1dConfig {
72 pub fn init(&self) -> MaxPool1d {
74 MaxPool1d {
75 stride: self.stride,
76 kernel_size: self.kernel_size,
77 padding: self.padding.clone(),
78 dilation: self.dilation,
79 ceil_mode: self.ceil_mode,
80 }
81 }
82}
83
84impl MaxPool1d {
85 pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
94 let [_batch_size, _channels, length] = input.dims();
95
96 let (left, right) =
98 self.padding
99 .calculate_padding_1d_pair(length, self.kernel_size, self.stride);
100
101 if left != right {
105 let padded = input.pad((left, right, 0, 0), PadMode::Constant(f32::NEG_INFINITY));
109 max_pool1d(
111 padded,
112 self.kernel_size,
113 self.stride,
114 0,
115 self.dilation,
116 self.ceil_mode,
117 )
118 } else {
119 max_pool1d(
121 input,
122 self.kernel_size,
123 self.stride,
124 left,
125 self.dilation,
126 self.ceil_mode,
127 )
128 }
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use crate::TestBackend;
136 use rstest::rstest;
137
138 #[test]
139 fn same_with_even_kernel_uses_asymmetric_padding() {
140 let device = Default::default();
141 let config = MaxPool1dConfig::new(2)
142 .with_stride(1)
143 .with_padding(PaddingConfig1d::Same);
144 let pool = config.init();
145
146 let input = Tensor::<TestBackend, 3>::ones([1, 2, 5], &device);
148 let output = pool.forward(input);
149
150 assert_eq!(output.dims(), [1, 2, 5]);
152 }
153
154 #[test]
155 fn display() {
156 let config = MaxPool1dConfig::new(3);
157
158 let layer = config.init();
159
160 assert_eq!(
161 alloc::format!("{layer}"),
162 "MaxPool1d {kernel_size: 3, stride: 3, padding: Valid, dilation: 1, ceil_mode: false}"
163 );
164 }
165
166 #[rstest]
167 #[case(1)]
168 #[case(2)]
169 fn default_strides_match_kernel_size(#[case] kernel_size: usize) {
170 let config = MaxPool1dConfig::new(kernel_size);
171
172 assert_eq!(
173 config.stride, kernel_size,
174 "Expected stride ({:?}) to match kernel size ({:?}) in default MaxPool1dConfig::new constructor",
175 config.stride, config.kernel_size
176 );
177 }
178
179 #[test]
180 fn asymmetric_padding_forward() {
181 let device = Default::default();
182 let config = MaxPool1dConfig::new(3)
184 .with_stride(1)
185 .with_padding(PaddingConfig1d::Explicit(1, 2));
186 let pool = config.init();
187
188 let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
190 let output = pool.forward(input);
191
192 assert_eq!(output.dims(), [1, 2, 5]);
195 }
196
197 #[test]
198 fn symmetric_explicit_padding_forward() {
199 let device = Default::default();
200 let config = MaxPool1dConfig::new(3)
202 .with_stride(1)
203 .with_padding(PaddingConfig1d::Explicit(2, 2));
204 let pool = config.init();
205
206 let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);
208 let output = pool.forward(input);
209
210 assert_eq!(output.dims(), [1, 2, 6]);
213 }
214}