burn_core/nn/pool/
max_pool1d.rs1use crate as burn;
2use crate::nn::conv::checks::check_same_padding_support;
3
4use crate::config::Config;
5use crate::module::{Content, DisplaySettings, ModuleDisplay};
6use crate::module::{Ignored, Module};
7use crate::nn::PaddingConfig1d;
8use crate::tensor::Tensor;
9use crate::tensor::backend::Backend;
10
11use crate::tensor::module::max_pool1d;
12
13#[derive(Config, Debug)]
15pub struct MaxPool1dConfig {
16 pub kernel_size: usize,
18 #[config(default = "1")]
20 pub stride: usize,
21 #[config(default = "PaddingConfig1d::Valid")]
27 pub padding: PaddingConfig1d,
28 #[config(default = "1")]
30 pub dilation: usize,
31}
32
33#[derive(Module, Clone, Debug)]
37#[module(custom_display)]
38pub struct MaxPool1d {
39 pub stride: usize,
41 pub kernel_size: usize,
43 pub padding: Ignored<PaddingConfig1d>,
45 pub dilation: usize,
47}
48
49impl ModuleDisplay for MaxPool1d {
50 fn custom_settings(&self) -> Option<DisplaySettings> {
51 DisplaySettings::new()
52 .with_new_line_after_attribute(false)
53 .optional()
54 }
55
56 fn custom_content(&self, content: Content) -> Option<Content> {
57 content
58 .add("kernel_size", &self.kernel_size)
59 .add("stride", &self.stride)
60 .add("padding", &self.padding)
61 .add("dilation", &self.dilation)
62 .optional()
63 }
64}
65
66impl MaxPool1dConfig {
67 pub fn init(&self) -> MaxPool1d {
69 if self.padding == PaddingConfig1d::Same {
70 check_same_padding_support(&[self.kernel_size]);
71 }
72 MaxPool1d {
73 stride: self.stride,
74 kernel_size: self.kernel_size,
75 padding: Ignored(self.padding.clone()),
76 dilation: self.dilation,
77 }
78 }
79}
80
81impl MaxPool1d {
82 pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
91 let [_batch_size, _channels, length] = input.dims();
92 let padding = self
93 .padding
94 .calculate_padding_1d(length, self.kernel_size, self.stride);
95
96 max_pool1d(input, self.kernel_size, self.stride, padding, self.dilation)
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 #[should_panic = "Same padding with an even kernel size is not supported"]
106 fn same_with_even_kernel_is_invalid() {
107 let config = MaxPool1dConfig::new(2).with_padding(PaddingConfig1d::Same);
108 let _ = config.init();
109 }
110
111 #[test]
112 fn display() {
113 let config = MaxPool1dConfig::new(3);
114
115 let layer = config.init();
116
117 assert_eq!(
118 alloc::format!("{}", layer),
119 "MaxPool1d {kernel_size: 3, stride: 1, padding: Valid, dilation: 1}"
120 );
121 }
122}