1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
use crate::model::ParsingContext; use crate::tfpb::tensorflow::NodeDef; use tract_hir::internal::*; use tract_hir::ops::cnn::*; pub fn avgpool(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> { let ksize: Vec<usize> = pb.get_attr_list_int("ksize")?; let data_format = super::data_format(pb)?; let kshape = data_format.shape(ksize)?; let strides = super::strides(pb)?; let padding = super::padding(pb)?; Ok(Box::new(AvgPool::new( PoolSpec::new( data_format, kshape.hw_dims().into(), padding, None, Some(strides[kshape.hw_axes()].into()), None, ), false, ))) } pub fn maxpool(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> { let ksize: Vec<usize> = pb.get_attr_list_int("ksize")?; let data_format = super::data_format(pb)?; let kshape = data_format.shape(ksize)?; let strides = super::strides(pb)?; let padding = super::padding(pb)?; Ok(Box::new(MaxPool::new( PoolSpec::new( data_format, kshape.hw_dims().into(), padding, None, Some(strides[kshape.hw_axes()].into()), None, ), None, ))) }