yarnn/layers/
maxpool2d.rs1use crate::tensor::TensorShape;
2use crate::layer::Layer;
3use crate::backend::{Backend, PaddingKind, BackendMaxPool2d, Conv2dInfo};
4use core::marker::PhantomData;
5
6pub struct MaxPool2dConfig {
7 pub pool: (u32, u32),
8 pub strides: Option<(u32, u32)>,
9}
10
11impl Default for MaxPool2dConfig {
12 fn default() -> Self {
13 Self {
14 pool: (2, 2),
15 strides: None,
16 }
17 }
18}
19
20pub struct MaxPool2d<N, B>
21 where B: Backend<N>
22{
23 input_shape: TensorShape,
24 conv_info: Conv2dInfo,
25 _m: PhantomData<fn(N, B)>
26}
27
28impl <N, B> Layer<N, B> for MaxPool2d<N, B>
29 where B: Backend<N> + BackendMaxPool2d<N>,
30{
31 type Config = MaxPool2dConfig;
32
33 fn name(&self) -> &str {
34 "MaxPool2d"
35 }
36
37 fn create(input_shape: TensorShape, config: Self::Config) -> Self {
38 assert!(input_shape.dims == 3);
39
40 MaxPool2d {
41 input_shape,
42 conv_info: Conv2dInfo {
43 kernel: config.pool,
44 strides: config.strides.unwrap_or(config.pool),
45 padding: PaddingKind::Valid,
46 },
47 _m: Default::default(),
48 }
49 }
50
51 #[inline]
52 fn input_shape(&self) -> TensorShape {
53 self.input_shape.clone()
54 }
55
56 #[inline]
57 fn output_shape(&self) -> TensorShape {
58 let is = self.input_shape.as_slice();
59
60 let rows = (is[1] - self.conv_info.kernel.0) / self.conv_info.strides.0 + 1;
63 let cols = (is[2] - self.conv_info.kernel.1) / self.conv_info.strides.1 + 1;
64
65 TensorShape::new3d(
66 is[0],
67 rows,
68 cols,
69 )
70 }
71
72 #[inline]
73 fn forward(&self, backend: &B, y: &mut B::Tensor, x: &B::Tensor) {
74 backend.max_pool2d(y, x, &self.conv_info)
75 }
76
77 #[inline]
78 fn backward(&self, backend: &B, dx: &mut B::Tensor, dy: &B::Tensor, x: &B::Tensor, _: &B::Tensor) {
79 backend.max_pool2d_backprop(dx, dy, x, &self.conv_info);
80 }
81}