1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
use crate::model::ParsingContext;
use crate::tfpb::tensorflow::NodeDef;
use tract_hir::internal::*;
use tract_hir::ops::cnn::*;

pub fn avgpool(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
    let ksize: Vec<usize> = pb.get_attr_list_int("ksize")?;
    let data_format = super::data_format(pb)?;
    let kshape = data_format.shape(ksize)?;
    let strides = super::strides(pb)?;
    let padding = super::padding(pb)?;
    Ok(Box::new(AvgPool::new(
        PoolSpec::new(
            data_format,
            kshape.hw_dims().into(),
            padding,
            None,
            Some(strides[kshape.hw_axes()].into()),
            None,
        ),
        false,
    )))
}

pub fn maxpool(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
    let ksize: Vec<usize> = pb.get_attr_list_int("ksize")?;
    let data_format = super::data_format(pb)?;
    let kshape = data_format.shape(ksize)?;
    let strides = super::strides(pb)?;
    let padding = super::padding(pb)?;
    Ok(Box::new(MaxPool::new(
        PoolSpec::new(
            data_format,
            kshape.hw_dims().into(),
            padding,
            None,
            Some(strides[kshape.hw_axes()].into()),
            None,
        ),
        None,
    )))
}