burn_core/nn/pool/
adaptive_avg_pool2d.rs1use crate as burn;
2
3use crate::config::Config;
4use crate::module::Module;
5use crate::module::{Content, DisplaySettings, ModuleDisplay};
6use crate::tensor::Tensor;
7use crate::tensor::backend::Backend;
8
9use crate::tensor::module::adaptive_avg_pool2d;
10
11#[derive(Config)]
13pub struct AdaptiveAvgPool2dConfig {
14 pub output_size: [usize; 2],
16}
17
18#[derive(Module, Clone, Debug)]
22#[module(custom_display)]
23pub struct AdaptiveAvgPool2d {
24 pub output_size: [usize; 2],
26}
27
28impl ModuleDisplay for AdaptiveAvgPool2d {
29 fn custom_settings(&self) -> Option<DisplaySettings> {
30 DisplaySettings::new()
31 .with_new_line_after_attribute(false)
32 .optional()
33 }
34
35 fn custom_content(&self, content: Content) -> Option<Content> {
36 let output_size = alloc::format!("{:?}", self.output_size);
37
38 content.add("output_size", &output_size).optional()
39 }
40}
41
42impl AdaptiveAvgPool2dConfig {
43 pub fn init(&self) -> AdaptiveAvgPool2d {
45 AdaptiveAvgPool2d {
46 output_size: self.output_size,
47 }
48 }
49}
50
51impl AdaptiveAvgPool2d {
52 pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
61 adaptive_avg_pool2d(input, self.output_size)
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68
69 #[test]
70 fn display() {
71 let config = AdaptiveAvgPool2dConfig::new([3, 3]);
72 let layer = config.init();
73
74 assert_eq!(
75 alloc::format!("{}", layer),
76 "AdaptiveAvgPool2d {output_size: [3, 3]}"
77 );
78 }
79}