burn_core/nn/pool/
adaptive_avg_pool1d.rs

1use crate as burn;
2
3use crate::config::Config;
4use crate::module::Module;
5use crate::module::{Content, DisplaySettings, ModuleDisplay};
6use crate::tensor::Tensor;
7use crate::tensor::backend::Backend;
8
9use crate::tensor::module::adaptive_avg_pool1d;
10
11/// Configuration to create a [1D adaptive avg pooling](AdaptiveAvgPool1d) layer using the [init function](AdaptiveAvgPool1dConfig::init).
12#[derive(Config)]
13pub struct AdaptiveAvgPool1dConfig {
14    /// The size of the output.
15    pub output_size: usize,
16}
17
18/// Applies a 1D adaptive avg pooling over input tensors.
19///
20/// Should be created with [AdaptiveAvgPool1dConfig].
21#[derive(Module, Clone, Debug)]
22#[module(custom_display)]
23pub struct AdaptiveAvgPool1d {
24    /// The size of the output.
25    pub output_size: usize,
26}
27
28impl ModuleDisplay for AdaptiveAvgPool1d {
29    fn custom_settings(&self) -> Option<DisplaySettings> {
30        DisplaySettings::new()
31            .with_new_line_after_attribute(false)
32            .optional()
33    }
34
35    fn custom_content(&self, content: Content) -> Option<Content> {
36        content.add("output_size", &self.output_size).optional()
37    }
38}
39
40impl AdaptiveAvgPool1dConfig {
41    /// Initialize a new [adaptive avg pool 1d](AdaptiveAvgPool1d) module.
42    pub fn init(&self) -> AdaptiveAvgPool1d {
43        AdaptiveAvgPool1d {
44            output_size: self.output_size,
45        }
46    }
47}
48
49impl AdaptiveAvgPool1d {
50    /// Applies the forward pass on the input tensor.
51    ///
52    /// See [adaptive_avg_pool1d](crate::tensor::module::adaptive_avg_pool1d) for more information.
53    ///
54    /// # Shapes
55    ///
56    /// - input: `[batch_size, channels, length]`
57    /// - output: `[batch_size, channels, length_out]`
58    pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
59        adaptive_avg_pool1d(input, self.output_size)
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66
67    #[test]
68    fn display() {
69        let config = AdaptiveAvgPool1dConfig::new(3);
70        let layer = config.init();
71
72        assert_eq!(
73            alloc::format!("{}", layer),
74            "AdaptiveAvgPool1d {output_size: 3}"
75        );
76    }
77}