1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
use crate::{shapes::*, tensor::*, tensor_ops::*};
use super::{Module, NonMutableModule, ZeroSizedModule};
/// Applies average pooling over an entire image, fully reducing the height and width
/// dimensions:
/// - Reduces 3d (C, H, W) to 1d (C, )
/// - Reduces 4d (B, C, H, W) to 2d (B, C)
///
/// **Pytorch equivalent**: `torch.nn.AdaptiveAvgPool2d(1)` followed by a flatten.
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let m: AvgPoolGlobal = Default::default();
/// let _: Tensor<Rank1<5>, f32, _> = m.forward(dev.zeros::<Rank3<5, 16, 8>>());
/// let _: Tensor<Rank2<10, 5>, f32, _> = m.forward(dev.zeros::<Rank4<10, 5, 16, 8>>());
/// ```
#[derive(Clone, Copy, Default)]
pub struct AvgPoolGlobal;
/// Applies max pooling over an entire image, fully reducing the height and width
/// dimensions:
/// - Reduces 3d (C, H, W) to 1d (C, )
/// - Reduces 4d (B, C, H, W) to 2d (B, C)
///
/// **Pytorch equivalent**: `torch.nn.AdaptiveMaxPool2d(1)` followed by a flatten.
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let m: MaxPoolGlobal = Default::default();
/// let _: Tensor<Rank1<5>, f32, _> = m.forward(dev.zeros::<Rank3<5, 16, 8>>());
/// let _: Tensor<Rank2<10, 5>, f32, _> = m.forward(dev.zeros::<Rank4<10, 5, 16, 8>>());
/// ```
#[derive(Clone, Copy, Default)]
pub struct MaxPoolGlobal;
/// Applies min pooling over an entire image, fully reducing the height and width
/// dimensions:
/// - Reduces 3d (C, H, W) to 1d (C, )
/// - Reduces 4d (B, C, H, W) to 2d (B, C)
///
/// **Pytorch equivalent**: `torch.nn.AdaptiveMinPool2d(1)` followed by a flatten.
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let m: MinPoolGlobal = Default::default();
/// let _: Tensor<Rank1<5>, f32, _> = m.forward(dev.zeros::<Rank3<5, 16, 8>>());
/// let _: Tensor<Rank2<10, 5>, f32, _> = m.forward(dev.zeros::<Rank4<10, 5, 16, 8>>());
/// ```
#[derive(Clone, Copy, Default)]
pub struct MinPoolGlobal;
macro_rules! impl_pools {
($PoolTy:ty, $Method:ident) => {
impl ZeroSizedModule for $PoolTy {}
impl NonMutableModule for $PoolTy {}
impl<C: Dim, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>>
Module<Tensor<(C, H, W), E, D, T>> for $PoolTy
{
type Output = Tensor<(C,), E, D, T>;
type Error = D::Err;
fn try_forward(
&self,
input: Tensor<(C, H, W), E, D, T>,
) -> Result<Self::Output, D::Err> {
input.$Method()
}
}
impl<B: Dim, C: Dim, H: Dim, W: Dim, E: Dtype, D: Device<E>, T: Tape<E, D>>
Module<Tensor<(B, C, H, W), E, D, T>> for $PoolTy
{
type Output = Tensor<(B, C), E, D, T>;
type Error = D::Err;
fn try_forward(
&self,
input: Tensor<(B, C, H, W), E, D, T>,
) -> Result<Self::Output, D::Err> {
input.$Method()
}
}
};
}
impl_pools!(AvgPoolGlobal, try_mean);
impl_pools!(MaxPoolGlobal, try_max);
impl_pools!(MinPoolGlobal, try_min);