burn_core/nn/pool/
adaptive_avg_pool2d.rs

1use crate as burn;
2
3use crate::config::Config;
4use crate::module::Module;
5use crate::module::{Content, DisplaySettings, ModuleDisplay};
6use crate::tensor::Tensor;
7use crate::tensor::backend::Backend;
8
9use crate::tensor::module::adaptive_avg_pool2d;
10
11/// Configuration to create a [2D adaptive avg pooling](AdaptiveAvgPool2d) layer using the [init function](AdaptiveAvgPool2dConfig::init).
12#[derive(Config)]
13pub struct AdaptiveAvgPool2dConfig {
14    /// The size of the output.
15    pub output_size: [usize; 2],
16}
17
18/// Applies a 2D adaptive avg pooling over input tensors.
19///
20/// Should be created with [AdaptiveAvgPool2dConfig].
21#[derive(Module, Clone, Debug)]
22#[module(custom_display)]
23pub struct AdaptiveAvgPool2d {
24    /// The size of the output.
25    pub output_size: [usize; 2],
26}
27
28impl ModuleDisplay for AdaptiveAvgPool2d {
29    fn custom_settings(&self) -> Option<DisplaySettings> {
30        DisplaySettings::new()
31            .with_new_line_after_attribute(false)
32            .optional()
33    }
34
35    fn custom_content(&self, content: Content) -> Option<Content> {
36        let output_size = alloc::format!("{:?}", self.output_size);
37
38        content.add("output_size", &output_size).optional()
39    }
40}
41
42impl AdaptiveAvgPool2dConfig {
43    /// Initialize a new [adaptive avg pool 2d](AdaptiveAvgPool2d) module.
44    pub fn init(&self) -> AdaptiveAvgPool2d {
45        AdaptiveAvgPool2d {
46            output_size: self.output_size,
47        }
48    }
49}
50
51impl AdaptiveAvgPool2d {
52    /// Applies the forward pass on the input tensor.
53    ///
54    /// See [adaptive_avg_pool2d](crate::tensor::module::adaptive_avg_pool2d) for more information.
55    ///
56    /// # Shapes
57    ///
58    /// - input: `[batch_size, channels, height_in, width_in]`
59    /// - output: `[batch_size, channels, height_out, width_out]`
60    pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
61        adaptive_avg_pool2d(input, self.output_size)
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68
69    #[test]
70    fn display() {
71        let config = AdaptiveAvgPool2dConfig::new([3, 3]);
72        let layer = config.init();
73
74        assert_eq!(
75            alloc::format!("{}", layer),
76            "AdaptiveAvgPool2d {output_size: [3, 3]}"
77        );
78    }
79}