1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use burn_tensor::module::adaptive_avg_pool2d;
/// Configuration to create a [2D adaptive avg pooling](AdaptiveAvgPool2d) layer.
#[derive(Config)]
pub struct AdaptiveAvgPool2dConfig {
/// The size of the output.
pub output_size: [usize; 2],
}
/// Applies a 2D adaptive avg pooling over input tensors.
#[derive(Module, Clone, Debug)]
pub struct AdaptiveAvgPool2d {
output_size: [usize; 2],
}
impl AdaptiveAvgPool2dConfig {
/// Initialize a new [adaptive avg pool 2d](AdaptiveAvgPool2d) module.
pub fn init(&self) -> AdaptiveAvgPool2d {
AdaptiveAvgPool2d {
output_size: self.output_size,
}
}
}
impl AdaptiveAvgPool2d {
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: [batch_size, channels, height_in, width_in],
/// - output: [batch_size, channels, height_out, width_out],
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
adaptive_avg_pool2d(input, self.output_size)
}
}