tract-core 0.23.0-dev.4

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
mod block_quant;
#[allow(clippy::module_inception)]
mod conv;
mod depth_wise;
mod im2col;
mod lazy_im2col;
mod q_sum_b;

use crate::internal::*;
use crate::ops::cnn::Deconv;

pub use self::conv::Conv;
pub use self::im2col::Im2Col;
pub(crate) use self::q_sum_b::QSumB;

#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
pub enum KernelFormat {
    #[default]
    OIHW,
    HWIO,
    OHWI,
}

impl KernelFormat {
    pub fn h_axis(&self) -> usize {
        match self {
            KernelFormat::OIHW => 2,
            KernelFormat::HWIO => 0,
            KernelFormat::OHWI => 1,
        }
    }

    pub fn spatial_shape<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
        &full_shape[self.h_axis()..][..full_shape.len() - 2]
    }

    pub fn hw<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
        self.spatial_shape(full_shape)
    }

    pub fn i<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
        match self {
            KernelFormat::OIHW => &full_shape[1],
            KernelFormat::HWIO => &full_shape[full_shape.len() - 2],
            KernelFormat::OHWI => &full_shape[full_shape.len() - 1],
        }
    }

    pub fn o_axis<D>(&self, full_shape: &[D]) -> usize {
        match self {
            KernelFormat::OIHW | KernelFormat::OHWI => 0,
            KernelFormat::HWIO => full_shape.len() - 1,
        }
    }

    pub fn i_axis<D>(&self, full_shape: &[D]) -> usize {
        match self {
            KernelFormat::OIHW => 1,
            KernelFormat::OHWI => full_shape.len() - 1,
            KernelFormat::HWIO => full_shape.len() - 2,
        }
    }

    pub fn o<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
        &full_shape[self.o_axis(full_shape)]
    }

    pub fn input_channels<'s, D: DimLike>(
        &self,
        full_kernel_shape: &'s [D],
        group: usize,
    ) -> Cow<'s, D> {
        match self {
            KernelFormat::OIHW => Cow::Owned(self.i(full_kernel_shape).clone() * group),
            KernelFormat::HWIO | KernelFormat::OHWI => Cow::Borrowed(self.i(full_kernel_shape)),
        }
    }

    pub fn output_channels<'s, D: DimLike>(
        &self,
        full_kernel_shape: &'s [D],
        group: usize,
    ) -> Cow<'s, D> {
        match self {
            KernelFormat::OIHW => Cow::Borrowed(self.o(full_kernel_shape)),
            KernelFormat::HWIO | KernelFormat::OHWI => {
                Cow::Owned(self.o(full_kernel_shape).clone() * group)
            }
        }
    }

    pub fn kernel_as_group_o_i_h_w_ops(
        &self,
        full_shape: &[impl DimLike],
        group: usize,
    ) -> TVec<AxisOp> {
        let geo_rank = full_shape.len() - 2;
        match self {
            // g is on i
            KernelFormat::HWIO => {
                tvec!(
                    AxisOp::Reshape(
                        geo_rank,
                        tvec!(self.i(full_shape).to_dim()),
                        tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
                    ), // h w g i o
                    AxisOp::Move(geo_rank, 0),     // g h w i o
                    AxisOp::Move(geo_rank + 2, 1), // g o h w i
                    AxisOp::Move(geo_rank + 2, 2)
                ) // g o i h w
            }
            // g is on o
            KernelFormat::OIHW => {
                tvec!(AxisOp::Reshape(
                    0,
                    tvec!(self.o(full_shape).to_dim()),
                    tvec!(group.to_dim(), self.o(full_shape).to_dim() / group),
                ))
            }
            // g is on i
            KernelFormat::OHWI => {
                tvec!(
                    AxisOp::Reshape(
                        geo_rank + 1,
                        tvec!(self.i(full_shape).to_dim()),
                        tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
                    ), // o h w g i
                    AxisOp::Move(geo_rank + 1, 0), // g o h w i
                    AxisOp::Move(geo_rank + 2, 2)
                )
            }
        }
    }

    pub fn kernel_as_group_o_i_hw_ops(
        &self,
        full_shape: &[impl DimLike],
        group: usize,
    ) -> TVec<AxisOp> {
        let mut ops = self.kernel_as_group_o_i_h_w_ops(full_shape, group);
        if self.hw(full_shape).len() > 1 {
            ops.push(AxisOp::Reshape(
                3,
                self.hw(full_shape).iter().map(|t| t.to_dim()).collect(),
                tvec!(self.hw(full_shape).iter().map(|t| t.to_dim()).product()),
            ));
        }
        ops
    }

    pub fn kernel_as_group_o_ihw_ops(
        &self,
        full_shape: &[impl DimLike],
        group: usize,
    ) -> TVec<AxisOp> {
        let i = (self.input_channels(full_shape, group).into_owned() / group).to_dim();
        let hw = self.hw(full_shape).iter().map(|t| t.to_dim()).product::<TDim>();
        let mut ops = self.kernel_as_group_o_i_hw_ops(full_shape, group);
        ops.push(AxisOp::Reshape(2, tvec!(i.clone(), hw.clone()), tvec!(i * hw)));
        ops
    }

    pub fn kernel_as_group_o_i_hw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
        let mut kernel = kernel.clone();
        let ops = self.kernel_as_group_o_i_hw_ops(kernel.shape(), group);
        for op in &ops {
            op.change_tensor(&mut kernel, false)?;
        }
        Ok(kernel)
    }

    pub fn kernel_as_group_o_ihw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
        let group_o_i_hw = self.kernel_as_group_o_i_hw(kernel, group)?;
        Ok(group_o_i_hw.collapse_axis_with_next(2))
    }
}

pub fn rewrite_kernel_conv_in_oihw(
    _ctx: &(),
    model: &TypedModel,
    node: &TypedNode,
    name: &str,
    conv: &Conv,
) -> TractResult<Option<TypedModelPatch>> {
    rewrite_kernel_in_oihw(
        model,
        node,
        name,
        conv.kernel_fmt,
        conv.group,
        Box::new(Conv { kernel_fmt: KernelFormat::OIHW, ..conv.clone() }),
    )
}

pub fn rewrite_kernel_deconv_in_oihw(
    _ctx: &(),
    model: &TypedModel,
    node: &TypedNode,
    name: &str,
    conv: &Deconv,
) -> TractResult<Option<TypedModelPatch>> {
    rewrite_kernel_in_oihw(
        model,
        node,
        name,
        conv.kernel_format,
        conv.group,
        Box::new(Deconv { kernel_format: KernelFormat::OIHW, ..conv.clone() }),
    )
}

fn rewrite_kernel_in_oihw(
    model: &TypedModel,
    node: &TypedNode,
    name: &str,
    fmt: KernelFormat,
    group: usize,
    new: Box<dyn TypedOp>,
) -> TractResult<Option<TypedModelPatch>> {
    rule_if!(fmt != KernelFormat::OIHW);
    let mut patch = TypedModelPatch::default();
    let mut wire = patch.taps(model, &node.inputs)?;
    let prefix = format!("{name}.kernel_reorg");
    for (ix, op) in fmt
        .kernel_as_group_o_i_h_w_ops(&patch.outlet_fact(wire[1])?.shape, group)
        .into_iter()
        .enumerate()
    {
        wire[1] = patch.wire_node(format!("{prefix}.{ix}"), op, &[wire[1]])?[0];
    }
    wire[1] =
        AxisOp::wire_collapse_axis(&mut patch, format!("{name}.kernel_reorg_go"), wire[1], 0)?[0];
    wire = patch.wire_node(name, new, &wire)?;
    patch.shunt_outside(model, node.id.into(), wire[0])?;
    Ok(Some(patch))
}