yarnn/layers/
maxpool2d.rs

1use crate::tensor::TensorShape;
2use crate::layer::Layer;
3use crate::backend::{Backend, PaddingKind, BackendMaxPool2d, Conv2dInfo};
4use core::marker::PhantomData;
5
6pub struct MaxPool2dConfig {
7    pub pool: (u32, u32),
8    pub strides: Option<(u32, u32)>,
9}
10
11impl Default for MaxPool2dConfig {
12    fn default() -> Self {
13        Self {
14            pool: (2, 2),
15            strides: None,
16        }
17    }
18}
19
20pub struct MaxPool2d<N, B> 
21    where B: Backend<N>
22{
23    input_shape: TensorShape,
24    conv_info: Conv2dInfo,
25    _m: PhantomData<fn(N, B)>
26}
27
28impl <N, B> Layer<N, B> for MaxPool2d<N, B> 
29    where B: Backend<N> + BackendMaxPool2d<N>,
30{
31    type Config = MaxPool2dConfig;
32    
33    fn name(&self) -> &str {
34        "MaxPool2d"
35    }
36    
37    fn create(input_shape: TensorShape, config: Self::Config) -> Self {
38        assert!(input_shape.dims == 3);
39
40        MaxPool2d {
41            input_shape,
42            conv_info: Conv2dInfo {
43                kernel: config.pool,
44                strides: config.strides.unwrap_or(config.pool),
45                padding: PaddingKind::Valid,
46            },
47            _m: Default::default(),
48        }
49    }
50
51    #[inline]
52    fn input_shape(&self) -> TensorShape {
53        self.input_shape.clone()
54    }
55
56    #[inline]
57    fn output_shape(&self) -> TensorShape {
58        let is = self.input_shape.as_slice();
59        
60        // O = (W - K + 2P) / S + 1
61
62        let rows = (is[1] - self.conv_info.kernel.0) / self.conv_info.strides.0 + 1;
63        let cols = (is[2] - self.conv_info.kernel.1) / self.conv_info.strides.1 + 1;
64
65        TensorShape::new3d(
66            is[0],
67            rows,
68            cols,
69        )
70    }
71    
72    #[inline]
73    fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
74        backend.max_pool2d(y, x, &self.conv_info)
75    }
76    
77    #[inline]
78    fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, x: &B::Tensor, _: &B::Tensor) {
79        backend.max_pool2d_backprop(dx, dy, x, &self.conv_info);
80    }
81}