burn_core/nn/pool/
max_pool1d.rs1use crate as burn;
2use crate::nn::conv::checks::check_same_padding_support;
3
4use crate::config::Config;
5use crate::module::{Content, DisplaySettings, ModuleDisplay};
6use crate::module::{Ignored, Module};
7use crate::nn::PaddingConfig1d;
8use crate::tensor::Tensor;
9use crate::tensor::backend::Backend;
10
11use crate::tensor::module::max_pool1d;
12
13#[derive(Config, Debug)]
15pub struct MaxPool1dConfig {
16 pub kernel_size: usize,
18 #[config(default = "kernel_size")]
20 pub stride: usize,
21 #[config(default = "PaddingConfig1d::Valid")]
27 pub padding: PaddingConfig1d,
28 #[config(default = "1")]
30 pub dilation: usize,
31}
32
33#[derive(Module, Clone, Debug)]
37#[module(custom_display)]
38pub struct MaxPool1d {
39 pub stride: usize,
41 pub kernel_size: usize,
43 pub padding: Ignored<PaddingConfig1d>,
45 pub dilation: usize,
47}
48
49impl ModuleDisplay for MaxPool1d {
50 fn custom_settings(&self) -> Option<DisplaySettings> {
51 DisplaySettings::new()
52 .with_new_line_after_attribute(false)
53 .optional()
54 }
55
56 fn custom_content(&self, content: Content) -> Option<Content> {
57 content
58 .add("kernel_size", &self.kernel_size)
59 .add("stride", &self.stride)
60 .add("padding", &self.padding)
61 .add("dilation", &self.dilation)
62 .optional()
63 }
64}
65
66impl MaxPool1dConfig {
67 pub fn init(&self) -> MaxPool1d {
69 if self.padding == PaddingConfig1d::Same {
70 check_same_padding_support(&[self.kernel_size]);
71 }
72 MaxPool1d {
73 stride: self.stride,
74 kernel_size: self.kernel_size,
75 padding: Ignored(self.padding.clone()),
76 dilation: self.dilation,
77 }
78 }
79}
80
81impl MaxPool1d {
82 pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
91 let [_batch_size, _channels, length] = input.dims();
92 let padding = self
93 .padding
94 .calculate_padding_1d(length, self.kernel_size, self.stride);
95
96 max_pool1d(input, self.kernel_size, self.stride, padding, self.dilation)
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use rstest::rstest;
104
105 #[test]
106 #[should_panic = "Same padding with an even kernel size is not supported"]
107 fn same_with_even_kernel_is_invalid() {
108 let config = MaxPool1dConfig::new(2).with_padding(PaddingConfig1d::Same);
109 let _ = config.init();
110 }
111
112 #[test]
113 fn display() {
114 let config = MaxPool1dConfig::new(3);
115
116 let layer = config.init();
117
118 assert_eq!(
119 alloc::format!("{layer}"),
120 "MaxPool1d {kernel_size: 3, stride: 3, padding: Valid, dilation: 1}"
121 );
122 }
123
124 #[rstest]
125 #[case(1)]
126 #[case(2)]
127 fn default_strides_match_kernel_size(#[case] kernel_size: usize) {
128 let config = MaxPool1dConfig::new(kernel_size);
129
130 assert_eq!(
131 config.stride, kernel_size,
132 "Expected stride ({:?}) to match kernel size ({:?}) in default MaxPool1dConfig::new constructor",
133 config.stride, config.kernel_size
134 );
135 }
136}