onednnl 0.0.1

high-level bindings to oneDNN Deep Learning library
Documentation
use onednnl_sys::dnnl_graph_op_attr_t::{
    self, dnnl_graph_op_attr_auto_pad, dnnl_graph_op_attr_data_format,
    dnnl_graph_op_attr_exclude_pad, dnnl_graph_op_attr_kernel, dnnl_graph_op_attr_pads_begin,
    dnnl_graph_op_attr_pads_end, dnnl_graph_op_attr_rounding_type, dnnl_graph_op_attr_strides,
};

use crate::graph::{
    op::{OneDNNGraphOp, OneDNNGraphOpType},
    spec::{OpSpec, RequiredAttrs},
};

pub struct AvgPoolSpec;

impl OpSpec for AvgPoolSpec {
    const KIND: OneDNNGraphOpType = OneDNNGraphOp::AVG_POOL;
}

impl AvgPoolSpec {
    pub const ROUNDING_TYPE: dnnl_graph_op_attr_t::Type = dnnl_graph_op_attr_rounding_type;
    pub const AUTO_PAD: dnnl_graph_op_attr_t::Type = dnnl_graph_op_attr_auto_pad;
    pub const DATA_FORMAT: dnnl_graph_op_attr_t::Type = dnnl_graph_op_attr_data_format;
}

#[derive(Debug, Clone)]
pub struct AvgPoolAttrs {
    pub strides: Vec<i64>,
    pub pads_begin: Vec<i64>,
    pub pads_end: Vec<i64>,
    pub exclude_pad: bool,
    pub kernel: Vec<i64>,
}

impl From<AvgPoolAttrs> for RequiredAttrs {
    fn from(attrs: AvgPoolAttrs) -> Self {
        RequiredAttrs::Some(vec![
            (dnnl_graph_op_attr_strides, attrs.strides.into()),
            (dnnl_graph_op_attr_pads_begin, attrs.pads_begin.into()),
            (dnnl_graph_op_attr_pads_end, attrs.pads_end.into()),
            (
                dnnl_graph_op_attr_exclude_pad,
                vec![attrs.exclude_pad as u8].into(),
            ),
            (dnnl_graph_op_attr_kernel, attrs.kernel.into()),
        ])
    }
}