burn_nn/modules/pool/
avg_pool1d.rs1use crate::conv::checks::check_same_padding_support;
2use burn_core as burn;
3
4use crate::PaddingConfig1d;
5use burn::config::Config;
6use burn::module::{Content, DisplaySettings, ModuleDisplay};
7use burn::module::{Ignored, Module};
8use burn::tensor::Tensor;
9use burn::tensor::backend::Backend;
10
11use burn::tensor::module::avg_pool1d;
12
13#[derive(Config, Debug)]
15pub struct AvgPool1dConfig {
16 pub kernel_size: usize,
18 #[config(default = "kernel_size")]
20 pub stride: usize,
21 #[config(default = "PaddingConfig1d::Valid")]
27 pub padding: PaddingConfig1d,
28 #[config(default = "true")]
30 pub count_include_pad: bool,
31 #[config(default = "false")]
33 pub ceil_mode: bool,
34}
35
36#[derive(Module, Clone, Debug)]
48#[module(custom_display)]
49pub struct AvgPool1d {
50 pub stride: usize,
52 pub kernel_size: usize,
54 pub padding: Ignored<PaddingConfig1d>,
56 pub count_include_pad: bool,
58 pub ceil_mode: bool,
60}
61
62impl ModuleDisplay for AvgPool1d {
63 fn custom_settings(&self) -> Option<DisplaySettings> {
64 DisplaySettings::new()
65 .with_new_line_after_attribute(false)
66 .optional()
67 }
68
69 fn custom_content(&self, content: Content) -> Option<Content> {
70 content
71 .add("kernel_size", &self.kernel_size)
72 .add("stride", &self.stride)
73 .add("padding", &self.padding)
74 .add("count_include_pad", &self.count_include_pad)
75 .add("ceil_mode", &self.ceil_mode)
76 .optional()
77 }
78}
79
80impl AvgPool1dConfig {
81 pub fn init(&self) -> AvgPool1d {
83 if self.padding == PaddingConfig1d::Same {
84 check_same_padding_support(&[self.kernel_size]);
85 }
86 AvgPool1d {
87 stride: self.stride,
88 kernel_size: self.kernel_size,
89 padding: Ignored(self.padding.clone()),
90 count_include_pad: self.count_include_pad,
91 ceil_mode: self.ceil_mode,
92 }
93 }
94}
95
96impl AvgPool1d {
97 pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
106 let [_batch_size, _channels, length] = input.dims();
107 let padding = self
108 .padding
109 .calculate_padding_1d(length, self.kernel_size, self.stride);
110
111 avg_pool1d(
112 input,
113 self.kernel_size,
114 self.stride,
115 padding,
116 self.count_include_pad,
117 self.ceil_mode,
118 )
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use rstest::rstest;
126
127 #[test]
128 #[should_panic = "Same padding with an even kernel size is not supported"]
129 fn same_with_even_kernel_is_invalid() {
130 let config = AvgPool1dConfig::new(2).with_padding(PaddingConfig1d::Same);
131 let _ = config.init();
132 }
133
134 #[test]
135 fn display() {
136 let config = AvgPool1dConfig::new(3);
137 let layer = config.init();
138
139 assert_eq!(
140 alloc::format!("{layer}"),
141 "AvgPool1d {kernel_size: 3, stride: 3, padding: Valid, count_include_pad: true, ceil_mode: false}"
142 );
143 }
144
145 #[rstest]
146 #[case(1)]
147 #[case(2)]
148 fn default_strides_match_kernel_size(#[case] kernel_size: usize) {
149 let config = AvgPool1dConfig::new(kernel_size);
150
151 assert_eq!(
152 config.stride, kernel_size,
153 "Expected stride ({:?}) to match kernel size ({:?}) in default AvgPool1dConfig::new constructor",
154 config.stride, config.kernel_size
155 );
156 }
157}