1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

use super::*;

#[derive(Clone, Serialize, Deserialize)]
struct MaxPool {
    kernel: [usize; 2],
    stride: [usize; 2],
    padh: [usize; 2],
    padw: [usize; 2],
    indices: Vec<usize>,
}

impl Operator for MaxPool {
    fn forward(&mut self, node: &Node) -> Result<()> {
        let (y, _) = node.y();
        let (x, _) = node.x(1);

        bmls::max_pool(
            &x.read(), &mut y.write(), &mut self.indices,
            x.shape4(), self.kernel, self.stride, self.padh, self.padw
        )?;

        Ok(())
    }

    fn backward(&mut self, node: &Node) -> Result<()> {
        let (_, gy) = node.y();
        let (_, gx) = node.x(1);

        bmls::max_pool_wrt_a(
            &self.indices, &gy.read(), &mut gx.write()
        )?;

        Ok(())
    }

    fn reshape(&mut self, new: Shape) {
        self.indices.resize(new.len(), 0)
    }
}

impl Display for MaxPool {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "Max Pool Operator. Kernel: {:?}, Stride: {:?}, padh: {:?}, padw: {:?}", 
            self.kernel, self.stride, self.padh, self.padw)
    }
}

pub fn max_pool<'t>(x: Var<'t>, p: PoolParams) -> Var<'t> {
    let h = ((x.shape4()[2]-p.kernel[0]+(p.padh[0]+p.padh[1])) / p.stride[0]) + 1;
    let w = ((x.shape4()[3]-p.kernel[1]+(p.padw[0]+p.padw[1])) / p.stride[1]) + 1;
    let shape = [x.shape()[0], x.shape()[1], h, w].to_shape();

    x.extend(NodeBuilder {
        op: Box::new(MaxPool {
            kernel: p.kernel,
            stride: p.stride,
            padh: p.padh,
            padw: p.padw,
            indices: vec![0; shape.len()]
        }),
        deps: vec![x.index],
        shape: shape,
        skip: false,
        init: None,
        is_batched: x.is_batched,
    })
}