burn_nn/modules/pool/
max_pool1d.rs1use crate::conv::checks::check_same_padding_support;
2use burn_core as burn;
3
4use crate::PaddingConfig1d;
5use burn::config::Config;
6use burn::module::{Content, DisplaySettings, ModuleDisplay};
7use burn::module::{Ignored, Module};
8use burn::tensor::Tensor;
9use burn::tensor::backend::Backend;
10
11use burn::tensor::module::max_pool1d;
12
13#[derive(Config, Debug)]
15pub struct MaxPool1dConfig {
16 pub kernel_size: usize,
18 #[config(default = "kernel_size")]
20 pub stride: usize,
21 #[config(default = "PaddingConfig1d::Valid")]
27 pub padding: PaddingConfig1d,
28 #[config(default = "1")]
30 pub dilation: usize,
31 #[config(default = "false")]
33 pub ceil_mode: bool,
34}
35
36#[derive(Module, Clone, Debug)]
40#[module(custom_display)]
41pub struct MaxPool1d {
42 pub stride: usize,
44 pub kernel_size: usize,
46 pub padding: Ignored<PaddingConfig1d>,
48 pub dilation: usize,
50 pub ceil_mode: bool,
52}
53
54impl ModuleDisplay for MaxPool1d {
55 fn custom_settings(&self) -> Option<DisplaySettings> {
56 DisplaySettings::new()
57 .with_new_line_after_attribute(false)
58 .optional()
59 }
60
61 fn custom_content(&self, content: Content) -> Option<Content> {
62 content
63 .add("kernel_size", &self.kernel_size)
64 .add("stride", &self.stride)
65 .add("padding", &self.padding)
66 .add("dilation", &self.dilation)
67 .add("ceil_mode", &self.ceil_mode)
68 .optional()
69 }
70}
71
72impl MaxPool1dConfig {
73 pub fn init(&self) -> MaxPool1d {
75 if self.padding == PaddingConfig1d::Same {
76 check_same_padding_support(&[self.kernel_size]);
77 }
78 MaxPool1d {
79 stride: self.stride,
80 kernel_size: self.kernel_size,
81 padding: Ignored(self.padding.clone()),
82 dilation: self.dilation,
83 ceil_mode: self.ceil_mode,
84 }
85 }
86}
87
88impl MaxPool1d {
89 pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
98 let [_batch_size, _channels, length] = input.dims();
99 let padding = self
100 .padding
101 .calculate_padding_1d(length, self.kernel_size, self.stride);
102
103 max_pool1d(
104 input,
105 self.kernel_size,
106 self.stride,
107 padding,
108 self.dilation,
109 self.ceil_mode,
110 )
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use rstest::rstest;
118
119 #[test]
120 #[should_panic = "Same padding with an even kernel size is not supported"]
121 fn same_with_even_kernel_is_invalid() {
122 let config = MaxPool1dConfig::new(2).with_padding(PaddingConfig1d::Same);
123 let _ = config.init();
124 }
125
126 #[test]
127 fn display() {
128 let config = MaxPool1dConfig::new(3);
129
130 let layer = config.init();
131
132 assert_eq!(
133 alloc::format!("{layer}"),
134 "MaxPool1d {kernel_size: 3, stride: 3, padding: Valid, dilation: 1, ceil_mode: false}"
135 );
136 }
137
138 #[rstest]
139 #[case(1)]
140 #[case(2)]
141 fn default_strides_match_kernel_size(#[case] kernel_size: usize) {
142 let config = MaxPool1dConfig::new(kernel_size);
143
144 assert_eq!(
145 config.stride, kernel_size,
146 "Expected stride ({:?}) to match kernel size ({:?}) in default MaxPool1dConfig::new constructor",
147 config.stride, config.kernel_size
148 );
149 }
150}