tract-core 0.19.2

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use ndarray::*;
use tract_data::itertools::izip;

use crate::internal::*;
use crate::model::*;
use crate::ops;
use crate::ops::array::Pad;
use crate::ops::array::PadMode;
use crate::ops::cnn::PaddingSpec;
use crate::ops::matmul::mir_quant::wire_offset_u8_as_i8;
use crate::ops::matmul::mir_quant::QParamKind;
use crate::ops::matmul::MatMulAxes;

use super::depth_wise::DepthWise;
use super::im2col::Im2Col;
use crate::ops::cnn::conv::KernelFormat;
use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry, PoolSpec};
use crate::ops::matmul::lir_unary::{
    ConcreteMatMulGeometry, LirMatMulUnary, MatMulGeometry, ProtoFusedSpec, SymbolicMatMulGeometry,
};
use crate::ops::matmul::MatMulQParams;
use crate::ops::nn::{BaseDataShape, DataFormat, DataShape};

use tract_linalg::frame::Packer;
use tract_linalg::mmm::MatMatMul;

use std::iter::Sum;

#[derive(Debug, Clone, new, Hash)]
pub struct ConvUnary {
    pub pool_spec: PoolSpec,
    pub kernel_fmt: KernelFormat,
    pub kernel: Arc<Tensor>,

    pub group: usize,

    pub bias: Option<Arc<Tensor>>,

    pub q_params: Option<(DatumType, MatMulQParams)>,
}

impl_dyn_hash!(ConvUnary);

impl ConvUnary {
    fn input_channels(&self) -> usize {
        match self.kernel_fmt {
            KernelFormat::OIHW => self.kernel.shape()[1] * self.group,
            KernelFormat::HWIO => self.kernel.shape()[self.kernel.shape().len() - 2],
        }
    }

    fn output_channels(&self) -> usize {
        let kshape = self.kernel.shape();
        match self.kernel_fmt {
            KernelFormat::OIHW => kshape[0],
            KernelFormat::HWIO => kshape[kshape.len() - 1] * self.group,
        }
    }

    pub fn kernel_as_group_o_ihw(&self) -> TractResult<Arc<Tensor>> {
        self.kernel_fmt.kernel_as_group_o_ihw(
            &self.kernel,
            self.group,
            self.input_channels(),
            self.output_channels(),
        )
    }

    fn kernel_as_packed_as(
        &self,
        packer: &Packer,
        k: usize,
        m: usize,
    ) -> TractResult<ArrayD<Arc<Tensor>>> {
        let kernel = self.kernel_as_group_o_ihw()?;
        unsafe {
            let mut packed_as = Array1::from(
                (0..self.group)
                    .map(|g| {
                        let packed = Tensor::uninitialized_aligned_dt(
                            kernel.datum_type(),
                            &[packer.len(k, m)],
                            packer.alignment(),
                        )?;
                        packer.pack(
                            &mut TensorView::at_prefix(&packed, &[])?,
                            &kernel.view_at_prefix(&[g])?,
                            1,
                            0,
                        );
                        Ok(packed.into_arc_tensor())
                    })
                    .collect::<TractResult<Vec<_>>>()?,
            )
            .into_dyn();
            if self.group == 1 {
                packed_as.index_axis_inplace(Axis(0), 0);
            }
            if self.pool_spec.data_format.has_n() {
                packed_as.insert_axis_inplace(Axis(0));
            }
            packed_as.insert_axis_inplace(Axis(packed_as.ndim()));
            packed_as.insert_axis_inplace(Axis(packed_as.ndim()));
            Ok(packed_as)
        }
    }

    fn kernel_offset_u8_as_i8(
        &self,
        inputs: &mut [OutletId],
        model: &mut TypedModel,
    ) -> TractResult<Option<Self>> {
        if let DatumType::U8 = self.kernel.datum_type().unquantized() {
            let new_op = Self {
                kernel: self.kernel.offset_u8_as_i8(),
                q_params: self
                    .q_params
                    .as_ref()
                    .map(|(dt, qp)| -> TractResult<_> {
                        let a0 = match &qp.a0 {
                            QParamKind::Attr(_) | QParamKind::FromQType => {
                                qp.a0.offset_u8_as_i8(model, &[])?
                            }
                            QParamKind::FromInput(i) => {
                                match model.outlet_fact(inputs[*i])?.datum_type.unquantized() {
                                    DatumType::U8 => {
                                        inputs[*i] = model.wire_node(
                                            format!(
                                                "{}.offset_{}_as_i8",
                                                model.node(inputs[*i].node).name,
                                                "a0"
                                            ),
                                            ops::quant::offset_u8_as_i8(),
                                            &[inputs[*i]],
                                        )?[0];
                                    }
                                    DatumType::I32 => {
                                        let cst = model.add_const(
                                            format!(
                                                "{}.offset_{}_as_i8.cst",
                                                &model.node(inputs[*i].node).name,
                                                "a0"
                                            ),
                                            rctensor0(-128i32),
                                        )?;
                                        inputs[*i] = model.wire_node(
                                            format!(
                                                "{}.offset_{}_as_i8",
                                                model.node(inputs[*i].node).name,
                                                "a0"
                                            ),
                                            ops::math::add(),
                                            &[inputs[*i], cst],
                                        )?[0];
                                    }
                                    _ => (),
                                }
                                QParamKind::FromInput(*i)
                            }
                        };
                        Ok((*dt, MatMulQParams { a0, ..qp.clone() }))
                    })
                    .transpose()?,
                ..self.clone()
            };
            Ok(Some(new_op))
        } else {
            Ok(None)
        }
    }

    fn bias_as_non_linear<T>(&self) -> TractResult<ArrayD<Vec<ProtoFusedSpec>>>
    where
        T: Datum + Copy,
    {
        let mut ops = Array1::from_elem(self.group, vec![]);

        if let Some(bias) = &self.bias {
            let bias = bias.cast_to::<T>()?;
            let bias = bias.as_slice::<T>()?;
            ops.iter_mut().zip(bias.chunks(self.output_channels() / self.group)).for_each(
                |(ops, bias)| {
                    ops.push(ProtoFusedSpec::BinPerRow(
                        rctensor1(bias).into(),
                        tract_linalg::mmm::BinOp::Add,
                    ));
                },
            )
        }
        let mut ops = ops.into_dyn();

        if self.group == 1 {
            ops.index_axis_inplace(Axis(0), 0);
        }
        if self.pool_spec.data_format.has_n() {
            ops.insert_axis_inplace(Axis(0));
        }
        Ok(ops)
    }

    pub unsafe fn wire_as_quant_im2col(
        &self,
        model: &mut TypedModel,
        name: &str,
        b_dt: DatumType,
        wires: &[OutletId],
    ) -> TractResult<OutletId> {
        use crate::ops::matmul::mir_quant as qmm;

        let c_dt = self.q_params.as_ref().unwrap().0;

        let params = self.q_params.as_ref().unwrap().1.as_outlet_ids(
            model,
            name,
            wires,
            self.kernel.datum_type(),
            b_dt,
            c_dt,
        )?;

        let a0 = params[0];
        let a_scale = params[1];
        let mut b0 = params[2];
        let b_scale = params[3];
        let c0 = params[4];
        let c_scale = params[5];

        let b = wire_offset_u8_as_i8(model, name, wires[0], "b", &mut b0, "b0")?;
        let b_fact = model.outlet_fact(b)?.clone();
        let (_, m, k, n, mmm) = self.compute_geo(&b_fact)?;
        let output_shape = self.pool_spec.output_shape(&b_fact.shape)?;

        let abc_scale = qmm::combine_scales(model, name, a_scale, b_scale, c_scale)?;

        let im2col = model.wire_node(
            format!("{name}.im2col"),
            Im2Col::new(self.pool_spec.clone(), self.group, k, &b_fact.shape, mmm.clone())?,
            &[b, b0],
        )?[0];

        let a = self.kernel_as_group_o_ihw()?.into_tensor();
        let a = a.cast_to_dt(i32::datum_type())?;
        let a = a.to_array_view::<i32>()?;
        let mut sum_a = a.sum_axis(Axis(a.ndim() - 1));
        if self.group == 1 {
            sum_a.index_axis_inplace(Axis(0), 0);
        }

        if self.pool_spec.data_format.has_n() {
            sum_a.insert_axis_inplace(Axis(0));
        }
        let sum_a = model.add_const(format!("{name}.sum_a"), sum_a)?;

        let mut sum_b = model.wire_node(
            format!("{name}.sum_b"),
            super::QSumB { n: n.clone(), r: mmm.b_pack().panel_width(), k },
            &[im2col],
        )?[0];

        if self.group > 1 && self.pool_spec.data_format.c_is_last() {
            let has_n = self.pool_spec.data_format.has_n() as usize;
            sum_b = model.wire_node(
                format!("{name}.transpose_sum_b"),
                AxisOp::Move(has_n, 1 + has_n),
                &[sum_b],
            )?[0];
        }

        let b_dt = model.outlet_fact(b)?.datum_type;
        let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?;
        let mut geometry = MatMulGeometry::from(SymbolicMatMulGeometry {
            b_datum_type: b_dt,
            m: m.to_dim(),
            k: k.to_dim(),
            n: n.clone(),
            mmm: mmm.clone(),
        });
        if n.to_usize().is_ok() {
            geometry = geometry.optimize_if(Some(&SymbolValues::default()))?;
        }
        let wire = self.wire_lir_matmatmul(
            model,
            name,
            im2col,
            mmm,
            i32::datum_type(),
            mmm_output_shape.clone().into(),
            m,
            k,
            geometry,
            c_axis,
            h_axis,
        )?;
        let has_n = self.pool_spec.data_format.has_n() as usize;
        let has_group = (self.group > 1) as usize;
        let (m_axis, n_axis) = if self.pool_spec.data_format.c_is_last() {
            (1 + has_group + has_n, has_n)
        } else {
            (has_group + has_n, 1 + has_n + has_group)
        };
        let wire = qmm::compensate_zero_points(
            model,
            name,
            wire,
            k.to_dim(),
            a0,
            b0,
            sum_a,
            sum_b,
            m_axis,
            n_axis,
        )?;

        let mut wire = qmm::requant(model, name, wire, c_dt, abc_scale, c0)?;
        if self.group > 1 {
            wire = model.wire_node(
                format!("{name}.reshape_group"),
                AxisOp::Reshape(
                    c_axis - 1,
                    mmm_output_shape[c_axis - 1..][..2].iter().map(|d| d.to_dim()).collect(),
                    tvec!((m * self.group).to_dim()),
                ),
                &[wire],
            )?[0];
        }
        let wire = Self::wire_geo_reshape(model, name, wire, &output_shape)?;
        Ok(wire)
    }

    pub unsafe fn wire_as_im2col_pair(
        &self,
        model: &mut TypedModel,
        name: &str,
        mut wire: OutletId,
    ) -> TractResult<OutletId> {
        let b_fact = model.outlet_fact(wire)?.clone();
        let b_dt = b_fact.datum_type;
        let c_dt = crate::ops::matmul::output_type(b_fact.datum_type);

        let output_shape = self.pool_spec.output_shape(&b_fact.shape)?;
        let (_, m, k, n, mmm) = self.compute_geo(model.outlet_fact(wire)?)?;
        let padding = model.add_const(format!("{name}.b0"), Tensor::zero_dt(b_dt, &[])?)?;

        wire = model.wire_node(
            format!("{name}.im2col"),
            Im2Col::new(self.pool_spec.clone(), self.group, k, &b_fact.shape, mmm.clone())?,
            &[wire, padding],
        )?[0];

        let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?;
        let mut geometry = MatMulGeometry::from(SymbolicMatMulGeometry {
            b_datum_type: b_dt,
            m: m.to_dim(),
            k: k.to_dim(),
            n: n.clone(),
            mmm: mmm.clone(),
        });
        if n.to_usize().is_ok() {
            geometry = geometry.optimize_if(Some(&SymbolValues::default()))?;
        }
        let mut wire = self.wire_lir_matmatmul(
            model,
            name,
            wire,
            mmm,
            c_dt,
            mmm_output_shape.clone().into(),
            m.to_usize().unwrap(),
            k.to_usize().unwrap(),
            geometry,
            c_axis,
            h_axis,
        )?;

        if self.group > 1 {
            wire = model.wire_node(
                format!("{name}.reshape_group"),
                AxisOp::Reshape(
                    c_axis - 1,
                    mmm_output_shape[c_axis - 1..][..2].iter().map(|d| d.to_dim()).collect(),
                    tvec!((m * self.group).to_dim()),
                ),
                &[wire],
            )?[0];
        }
        let wire = Self::wire_geo_reshape(model, name, wire, &output_shape)?;
        Ok(wire)
    }

    fn mmm_output_shape<D: DimLike>(
        &self,
        output_shape: &BaseDataShape<D, TVec<D>>,
    ) -> TractResult<(TVec<D>, usize, usize)> {
        let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
        let shape: BaseDataShape<D, TVec<D>> = output_shape.fmt.from_n_c_hw(
            output_shape.n().cloned().unwrap_or_else(|| 1.into()),
            output_shape.c().clone(),
            tvec!(geo_collapsed_out),
        )?;
        let mut mmm_output_shape: TVec<D> = shape.shape.clone();
        let mut c_axis = shape.c_axis();
        let mut h_axis = shape.h_axis();
        if self.group > 1 {
            mmm_output_shape[shape.c_axis()] =
                mmm_output_shape[shape.c_axis()].clone() / self.group;
            mmm_output_shape.insert(shape.c_axis(), self.group.into());
            if self.group > 1 {
                if h_axis > c_axis {
                    h_axis += 1;
                }
                c_axis += 1;
            }
        }
        Ok((mmm_output_shape, c_axis, h_axis))
    }

    fn wire_geo_reshape<D: DimLike>(
        model: &mut TypedModel,
        name: &str,
        wire: OutletId,
        output_shape: &BaseDataShape<D, TVec<D>>,
    ) -> TractResult<OutletId> {
        let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
        let wire = model.wire_node(
            name,
            AxisOp::Reshape(
                output_shape.h_axis(),
                tvec!(geo_collapsed_out.to_dim()),
                output_shape.hw_dims().iter().map(|d| d.to_dim()).collect(),
            ),
            &[wire],
        )?;
        Ok(wire[0])
    }

    pub unsafe fn wire_as_lazy_im2col(
        &self,
        model: &mut TypedModel,
        name: &str,
        mut wire: OutletId,
    ) -> TractResult<OutletId> {
        let mut b_fact = model.outlet_fact(wire)?.clone();
        let (geo, m, k, n, mmm) = self.compute_geo(&b_fact)?;
        let input_shape = b_fact.shape.as_concrete().unwrap().to_vec();
        let mut geo = geo.to_concrete(&input_shape)?.into_owned();
        let mut input_shape: DataShape = self.pool_spec.data_format.shape(input_shape.into())?;
        let padding = self.pool_spec.computed_padding(input_shape.hw_dims());
        if padding.iter().any(|axis| axis.pad_before != 0 || axis.pad_after != 0) {
            let mut pads = vec![(0, 0); b_fact.rank()];
            for (ix, ax) in padding.iter().enumerate() {
                pads[input_shape.h_axis() + ix] = (ax.pad_before, ax.pad_after);
            }
            let op = crate::ops::array::Pad {
                mode: crate::ops::array::PadMode::Constant(
                    Tensor::zero_scalar_dt(b_fact.datum_type)?.into_arc_tensor(),
                ),
                pads,
            };
            wire = model.wire_node(format!("{name}.pad"), op, &[wire])?[0];
            let valid_pool_spec =
                PoolSpec { padding: ops::cnn::PaddingSpec::Valid, ..self.pool_spec.clone() };
            b_fact = model.outlet_fact(wire)?.clone();
            let concrete_shape = b_fact.shape.as_concrete().unwrap();
            input_shape = valid_pool_spec.data_format.shape(concrete_shape.into())?;
            geo = valid_pool_spec
                .compute_geo(&b_fact.shape)?
                .to_concrete(concrete_shape)?
                .into_owned();
        }
        let c_dt = crate::ops::matmul::output_type(b_fact.datum_type);
        let c_stride = input_shape.c_stride();
        let size_of_b = b_fact.datum_type.size_of() as isize;
        let n_bytes_offsets: Vec<isize> =
            geo.patch.centers_offsets().into_iter().map(|x| x * size_of_b).collect();
        let k_bytes_offsets: Vec<isize> = (0..self.input_channels())
            .flat_map(|ici| {
                geo.patch
                    .standard_layout_data_field
                    .iter()
                    .map(move |x| (x + (ici * c_stride) as isize) * size_of_b)
            })
            .collect();
        let virtual_input = super::lazy_im2col::LazyIm2colSpec { n_bytes_offsets, k_bytes_offsets };
        let b_storage = mmm.b_virtual_input(Box::new(virtual_input), k);
        let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo.output_shape)?;

        let geometry = MatMulGeometry::Concrete(ConcreteMatMulGeometry {
            m,
            k,
            n: n.to_usize().unwrap(),
            b_storage,
        });
        let wire = self.wire_lir_matmatmul(
            model,
            name,
            wire,
            mmm,
            c_dt,
            mmm_output_shape.into(),
            m.to_usize().unwrap(),
            k,
            geometry,
            c_axis,
            h_axis,
        )?;

        let wire = Self::wire_geo_reshape(model, name, wire, &geo.output_shape)?;
        Ok(wire)
    }

    #[allow(clippy::type_complexity)]
    fn compute_geo(
        &self,
        input_fact: &TypedFact,
    ) -> TractResult<(PoolGeometry, usize, usize, TDim, Box<dyn MatMatMul>)> {
        let a_dt = self.kernel.datum_type();
        let b_dt = input_fact.datum_type;
        let c_dt = crate::ops::matmul::output_type(b_dt);

        let geo = self.pool_spec.compute_geo(&input_fact.shape)?;

        trace!("output channels: {:?}", self.output_channels());
        let m = self.output_channels() / self.group;
        let k = self.kernel.len() / self.output_channels();
        let n: TDim =
            self.pool_spec.output_shape(&input_fact.shape)?.hw_dims().iter().cloned().product();

        let mmm = tract_linalg::ops()
            .mmm(a_dt, b_dt, c_dt, Some(m), Some(k), n.to_usize().ok())
            .with_context(|| format!("No multiplier for {a_dt:?}x{b_dt:?} to {c_dt:?}",))?;

        Ok((geo, m, k, n, mmm))
    }

    #[allow(clippy::too_many_arguments)]
    fn wire_lir_matmatmul(
        &self,
        model: &mut TypedModel,
        name: &str,
        wire: OutletId,
        mmm: Box<dyn MatMatMul>,
        c_datum_type: DatumType,
        mmm_output_shape: ShapeFact,
        m: usize,
        k: usize,
        geometry: MatMulGeometry,
        c_m_axis: usize,
        c_n_axis: usize,
    ) -> TractResult<OutletId> {
        let kernels = self.kernel_as_packed_as(&mmm.a_pack(), k, m)?;
        let shape = kernels.shape();
        let mut fused_ops = dispatch_copy!(Self::bias_as_non_linear(mmm.internal_type())(self))?;
        for fo in &mut fused_ops {
            fo.push(ProtoFusedSpec::Store);
        }
        let mut iter = kernels.iter().cloned().zip(fused_ops.iter().cloned());
        let micro_ops = ArrayD::from_shape_fn(shape, |_| iter.next().unwrap());

        let wire = model.wire_node(
            format!("{name}.matmatmul"),
            LirMatMulUnary {
                c_fact: c_datum_type.fact(mmm_output_shape.clone()),
                micro_ops,
                c_m_axis,
                c_n_axis,
                c_final_shape: mmm_output_shape,
                reshape_post: vec![],
                geometry,
                mmm,
            },
            &[wire],
        )?[0];
        Ok(wire)
    }

    pub fn to_depth_wise<T>(&self, input: &TypedFact) -> TractResult<Box<dyn TypedOp>>
    where
        T: Datum + Clone + ::ndarray::LinalgScalar + PartialEq + Sum,
    {
        let input_shape = input.shape.as_concrete().unwrap();
        let ConcretePoolGeometry { input_shape, patch, output_shape } =
            self.pool_spec.compute_geo(&input.shape)?.to_concrete(input_shape)?.into_owned();
        let bias = if let Some(b) = &self.bias {
            b.clone()
        } else {
            Tensor::zero::<T>(&[*input_shape.c()])?.into_arc_tensor()
        };
        let op = DepthWise::new(
            patch,
            input_shape,
            output_shape,
            self.kernel_as_group_o_ihw().context("in kernel_as_group_o_ihw")?,
            bias,
        );
        Ok(Box::new(op))
    }

    fn declutter_stride_slice_to_downsample(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let input_fact = model.outlet_fact(node.inputs[0])?;
        let spatial_rank = self.kernel.rank() - 2;
        if let Some(axis) = (0..spatial_rank).find(|&ax| {
            self.pool_spec.stride(ax) > 1
                && (self.pool_spec.kernel_shape[ax] == 1
                    || (self.pool_spec.padding.valid_dim(ax, self.pool_spec.stride(ax) == 1)
                        && self.pool_spec.dilation(ax) % self.pool_spec.stride(ax) == 0))
        }) {
            let downsample_factor = self.pool_spec.stride(axis);
            let mut new_op = self.clone();
            if new_op.pool_spec.dilation(axis) > 1 {
                new_op.pool_spec.dilations.as_mut().unwrap()[axis] /= downsample_factor;
            }
            new_op.pool_spec.strides.as_mut().unwrap()[axis] /= downsample_factor;
            let mut patch = TypedModelPatch::default();
            let tap = patch.tap_model(model, node.inputs[0])?;
            let shape = self
                .pool_spec
                .data_format
                .shape(input_fact.shape.iter().collect::<TVec<TDim>>())?;
            let down = patch.wire_node(
                format!("{}.downsample.{}", node.name, axis),
                crate::ops::Downsample::new(axis + shape.h_axis(), downsample_factor as isize, 0),
                &[tap],
            )?;
            let id = patch.wire_node(&*node.name, new_op, &down)?[0];
            patch.shunt_outside(model, OutletId::new(node.id, 0), id)?;
            return Ok(Some(patch));
        }
        Ok(None)
    }

    fn declutter_as_matmul(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        use crate::ops::matmul::*;
        let input_fact = model.outlet_fact(node.inputs[0])?;
        let full_input_shape = input_fact.shape.to_tvec();
        let input_shape = self.pool_spec.data_format.shape(&full_input_shape)?;
        if input_shape.hw_rank() == 1
            && self.group == 1
            && self.pool_spec.stride(0) == 1
            && self.kernel.len() == self.input_channels() * self.output_channels()
        {
            let ci = self.input_channels();
            let co = self.output_channels();
            let ker = self.kernel.clone().into_tensor();
            let (a_shape, a_trans) = if self.kernel_fmt == KernelFormat::HWIO {
                ([ci, co], true)
            } else {
                ([co, ci], false)
            };
            let a = ker
                .into_shape(&a_shape)?
                .broadcast_into_rank(full_input_shape.len())?
                .into_arc_tensor();
            let trans_data = self.pool_spec.data_format == DataFormat::HWC
                || self.pool_spec.data_format == DataFormat::NHWC;
            let mut patch = TypedModelPatch::new("declutter_as_matmul");
            let a = patch.add_const(format!("{}.filters", &node.name), a)?;
            let mut inputs = node
                .inputs
                .iter()
                .map(|i| patch.tap_model(model, *i))
                .collect::<TractResult<TVec<_>>>()?;
            inputs.insert(0, a);
            let axes = MatMulAxes::default_for_rank(full_input_shape.len())
                .transposing(a_trans, trans_data, trans_data);
            // in Q case, the bias has to be injected inside the QMatMul (as it
            // must be added before requantization)
            let wire = if let Some(q_params) = &self.q_params {
                let mut params = q_params.1.clone();
                params.insert_input(0); // kernel as input
                params.insert_input(2); // bias as input
                let bias = self.bias.clone().unwrap_or_else(|| rctensor0(0i32));
                anyhow::ensure!(bias.rank() == 0 || bias.rank() == 1);
                let bias = patch.add_const(format!("{}.bias", &node.name), bias)?;
                inputs.insert(2, bias);
                let op = QMatMul { axes, output_type: q_params.0, params: q_params.1.clone() };
                patch.wire_node(&*node.name, op, &inputs)?[0]
            } else {
                let op = MatMul { axes };
                let mut wire = patch.wire_node(format!("{}.matmul", node.name), op, &inputs)?[0];
                if let Some(b) = self.bias.as_ref().filter(|_| self.q_params.is_none()) {
                    anyhow::ensure!(b.rank() == 0 || b.rank() == 1);
                    let mut bias_shape = tvec!(1; input_shape.rank());
                    bias_shape[input_shape.c_axis()] = co;
                    let b = b.clone().into_tensor().into_shape(&bias_shape)?;
                    let b =
                        patch.add_const(format!("{}.bias.cst", node.name), b.into_arc_tensor())?;
                    wire = patch.wire_node(
                        format!("{}.bias", node.name),
                        crate::ops::math::add(),
                        &[wire, b],
                    )?[0];
                }
                wire
            };
            patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
            return Ok(Some(patch));
        }
        Ok(None)
    }

    fn declutter_precursor_padding(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        if self.pool_spec.padding != PaddingSpec::Valid
            && !matches!(self.pool_spec.padding, PaddingSpec::Explicit(_, _, _))
        {
            return Ok(None);
        }
        let prec = model.node(node.inputs[0].node);
        let pad = if let Some(pad) = prec.op_as::<Pad>() { pad } else { return Ok(None) };
        let value = if let PadMode::Constant(c) = &pad.mode {
            c
        } else {
            return Ok(None);
        };
        let shape = self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
        if value.cast_to_scalar::<i64>()? != 0
            || (self.pool_spec.data_format.has_n() && pad.pads[0] != (0, 0))
            || pad.pads[shape.c_axis()] != (0, 0)
        {
            return Ok(None);
        }
        let mut before: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.0).collect();
        let mut after: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.1).collect();
        if let PaddingSpec::Explicit(bef, aft, false) = &self.pool_spec.padding {
            izip!(&mut before, bef).for_each(|(pad, cv)| *pad += cv);
            izip!(&mut after, aft).for_each(|(pad, cv)| *pad += cv);
        }
        let padding = PaddingSpec::Explicit(before, after, false);
        let mut new = self.clone();
        new.pool_spec.padding = padding;
        let mut patch = TypedModelPatch::default();
        let wire = patch.tap_model(model, prec.inputs[0])?;
        let wire = patch.wire_node(&node.name, new, &[wire])?;
        patch.shunt_outside(model, node.id.into(), wire[0])?;
        Ok(Some(patch))
    }
}

impl Op for ConvUnary {
    fn name(&self) -> Cow<str> {
        "ConvUnary".into()
    }

    fn info(&self) -> TractResult<Vec<String>> {
        let mut info = self.pool_spec.info();
        info.push(format!(
            "Kernel {:?} (groups:{}), {:?}",
            self.kernel_fmt, self.group, self.kernel
        ));
        if let Some(b) = &self.bias {
            info.push(format!("Bias: {b:?}"))
        }
        Ok(info)
    }

    fn validation(&self) -> Validation {
        Validation::Rounding
    }

    op_as_typed_op!();
}

impl EvalOp for ConvUnary {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let mut model = TypedModel::default();

        let mut wires: TVec<OutletId> = inputs
            .iter()
            .enumerate()
            .map(|(ix, v)| {
                model.add_source(format!("source.{ix}"), v.datum_type().fact(v.shape()))
            })
            .collect::<TractResult<_>>()?;
        let new_op = self.kernel_offset_u8_as_i8(&mut wires, &mut model)?;
        let wire = unsafe {
            if self.q_params.is_some() {
                let op_ref = if let Some(op) = new_op.as_ref() { op } else { self };
                op_ref.wire_as_quant_im2col(
                    &mut model,
                    "im2col-adhoc",
                    inputs[0].datum_type(),
                    &wires,
                )?
            } else {
                self.wire_as_im2col_pair(&mut model, "im2col-adhoc", wires[0])?
            }
        };
        model.set_output_outlets(&[wire])?;
        model.into_runnable()?.run(inputs)
    }
}

impl TypedOp for ConvUnary {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let q_inputs = self.q_params.as_ref().map(|(_, qp)| qp.input_count()).unwrap_or(0);
        if inputs.len() != 1 + q_inputs {
            bail!("Wrong number of inputs: expected {} got {}", 1 + q_inputs, inputs.len());
        }
        if self.pool_spec.data_format.shape(&*inputs[0].shape)?.c()
            != &self.input_channels().to_dim()
        {
            bail!(
                "Inconsistent convolution: input is {:?}, kernel expects {} input channels, {:?}",
                inputs[0],
                self.input_channels(),
                self
            );
        }
        if self.pool_spec.output_channel_override != Some(self.output_channels()) {
            bail!(
                "Inconsistent convolution: output channels from pool spec is {:?}, kernel expects {} output channels, {:?}",
                self.pool_spec.output_channel_override,
                self.output_channels(),
                self
                );
        }
        if let Some(bias) = &self.bias {
            ensure!(
                bias.rank() == 0 || (bias.rank() == 1 && bias.len() == self.output_channels()),
                "Bias should be scalar or a vector with one value per output channel, got:{:?}",
                bias
            );
        }

        let mut fact = self.pool_spec.output_facts(inputs)?.remove(0);
        if let Some((dt, _qp)) = self.q_params.as_ref() {
            fact.datum_type = *dt;
        } else {
            ensure!(
                inputs[0].datum_type == self.kernel.datum_type(),
                "Convolution input and weights must have the same type. (resp {:?} and {:?})",
                inputs[0].datum_type,
                self.kernel.datum_type(),
            )
        }
        Ok(tvec!(fact))
    }

    fn invariants(
        &self,
        inputs: &[&TypedFact],
        _outputs: &[&TypedFact],
    ) -> TractResult<Invariants> {
        let fact = &inputs[0];
        let shape = self.pool_spec.data_format.shape(fact.shape.iter().collect::<Vec<TDim>>())?;
        let mut axes = vec![];
        if let Some(n_axis) = shape.n_axis() {
            let mut info = AxisInfo::simple(n_axis).disposable(true);
            info.inputs.extend(std::iter::repeat(None).take(inputs.len() - 1));
            axes.push(info);
        }
        let kernel_spatial_shape =
            &self.kernel.shape()[self.kernel_fmt.h_axis()..][..shape.hw_rank()];
        let h_axis = shape.h_axis();
        for (ix, &dim) in kernel_spatial_shape.iter().enumerate() {
            if dim == 1 && self.pool_spec.stride(ix) == 1 {
                let mut info = AxisInfo::simple(ix + h_axis).disposable(kernel_spatial_shape.len() > 1);
                info.inputs.extend(std::iter::repeat(None).take(inputs.len() - 1));
                axes.push(info)
            }
        }
        Ok(axes.into_iter().collect())
    }

    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        if let Some((_, qp)) = self.q_params.as_ref() {
            if let Some((inputs, qp)) = qp.inline_static(model, node)? {
                let mut op = self.clone();
                op.q_params.as_mut().unwrap().1 = qp;
                let patch = TypedModelPatch::replace_single_op(model, node, &inputs, op)?
                    .with_context("inlining quantized conv params");
                return Ok(Some(patch));
            }
        }
        for d in &[Self::declutter_stride_slice_to_downsample, Self::declutter_as_matmul] {
            if let Some(p) = d(self, model, node)? {
                return Ok(Some(p));
            }
        }
        if let Some(p) = self.declutter_precursor_padding(model, node)? {
            return Ok(Some(p));
        }
        Ok(None)
    }

    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
        let shape = self.pool_spec.data_format.shape(inputs[0].shape.to_tvec())?;
        let kernel_spatial_shape =
            &self.kernel.shape()[self.kernel_fmt.h_axis()..][..shape.hw_rank()];
        let output_dims = self.pool_spec.padding.compute(
            shape.hw_dims(),
            kernel_spatial_shape,
            &self
                .pool_spec
                .dilations
                .clone()
                .unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
            &self.pool_spec.strides.clone().unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
        );
        let n_output_points: TDim =
            output_dims.iter().map(|d| d.convoluted.clone()).product::<TDim>();
        let n_output_channels = self.output_channels().to_dim();
        let kernel_surface = kernel_spatial_shape.iter().product::<usize>().to_dim();
        let one = 1.to_dim();
        Ok(tvec!(
            (
                Cost::Params(inputs[0].datum_type.unquantized()),
                (self.kernel.len() + self.bias.as_ref().map(|b| b.len()).unwrap_or(0)).to_dim()
            ),
            (
                Cost::FMA(inputs[0].datum_type),
                shape.n().cloned().unwrap_or(one)
                    * shape.c()
                    * n_output_channels
                    * n_output_points
                    * kernel_surface
                    / self.group
            )
        ))
    }

    fn change_axes(
        &self,
        model: &TypedModel,
        node: &TypedNode,
        _io: InOut,
        change: &AxisOp,
    ) -> TractResult<Option<AxisChangeConsequence>> {
        let full_input_shape = model.outlet_fact(node.inputs[0])?.shape.to_tvec();
        let shape = self.pool_spec.data_format.shape(full_input_shape.clone())?;
        // remove n
        if let Some(n) = shape.n_axis() {
            assert_eq!(n, 0);
            if change == &AxisOp::Rm(n) {
                let op = ConvUnary { pool_spec: self.pool_spec.dispose_n_axis(), ..self.clone() };
                return Ok(Some(AxisChangeConsequence::new(
                    model,
                    node,
                    Some(Box::new(op)),
                    change,
                )));
            }
            if change.transform_axis(n).map(|axis| axis > 0).unwrap_or(true) {
                return Ok(None);
            }
        }
        // format swap: chw <-> hwc
        let (new_format, axis_move) = match self.pool_spec.data_format {
            DataFormat::NCHW => {
                (DataFormat::NHWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
            }
            DataFormat::CHW => {
                (DataFormat::HWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
            }
            DataFormat::NHWC => (DataFormat::NCHW, AxisOp::Move(shape.c_axis(), 1)),
            DataFormat::HWC => (DataFormat::CHW, AxisOp::Move(shape.c_axis(), 0)),
        };
        if *change == axis_move {
            let mut new_op = self.clone();
            new_op.pool_spec.data_format = new_format;
            return Ok(Some(AxisChangeConsequence {
                substitute_op: Some(Box::new(new_op)),
                wire_changes: tvec!(
                    (InOut::In(0), change.clone()),
                    (InOut::Out(0), change.clone())
                ),
            }));
        }
        // geo axis manips
        use AxisOp::*;
        let h_axis = shape.h_axis();
        let hw_axes = shape.hw_axes();
        let kh_axis = if self.kernel_fmt == KernelFormat::OIHW { 2 } else { 0 };
        let (geo_adjusted, kernel_adjusted) = match change {
            Rm(a)
                if hw_axes.contains(a)
                    && hw_axes.len() > 1
                    && self.pool_spec.dilation(a - h_axis) == 1
                    && self.pool_spec.stride(a - h_axis) == 1
                    && self.pool_spec.kernel_shape[a - h_axis] == 1 =>
            {
                (Rm(a - h_axis), Rm(a - h_axis + kh_axis))
            }
            Add(a) if hw_axes.contains(a) => (Add(a - h_axis), Add(a - h_axis + kh_axis)),
            Move(f, t) if hw_axes.contains(f) && hw_axes.contains(t) => {
                (Move(f - h_axis, t - h_axis), Move(f - h_axis + kh_axis, t - h_axis + kh_axis))
            }
            _ => return Ok(None),
        };
        let mut kernel = self.kernel.clone().into_tensor();
        kernel_adjusted.change_tensor(&mut kernel, false)?;
        let mut dilations = self.pool_spec.dilations().into_owned().into();
        geo_adjusted.change_shape_array(&mut dilations, false)?;
        let mut kernel_shape = self.pool_spec.kernel_shape.clone();
        geo_adjusted.change_shape_array(&mut kernel_shape, false)?;
        let mut strides = self.pool_spec.strides().into_owned().into();
        geo_adjusted.change_shape_array(&mut strides, false)?;
        let new_op = ConvUnary {
            pool_spec: PoolSpec {
                data_format: self.pool_spec.data_format,
                padding: self.pool_spec.padding.clone(), // fixme (explicit padding)
                dilations: Some(dilations),
                kernel_shape,
                strides: Some(strides),
                output_channel_override: self.pool_spec.output_channel_override,
            },
            kernel_fmt: self.kernel_fmt,
            kernel: kernel.into_arc_tensor(),
            group: self.group,
            bias: self.bias.clone(),
            q_params: self.q_params.clone(),
        };
        Ok(Some(AxisChangeConsequence {
            substitute_op: Some(Box::new(new_op)),
            wire_changes: tvec!((InOut::In(0), change.clone()), (InOut::Out(0), change.clone())),
        }))
    }

    fn codegen(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        if let DatumType::U8 = self.kernel.datum_type().unquantized() {
            let mut patch = TypedModelPatch::default();
            let mut inputs = node
                .inputs
                .iter()
                .map(|w| patch.tap_model(model, *w))
                .collect::<TractResult<TVec<_>>>()?;
            let new_op = self.kernel_offset_u8_as_i8(&mut inputs, &mut patch)?.unwrap();
            let wire = patch.wire_node(&node.name, new_op, &inputs)?;
            patch.shunt_outside(model, node.id.into(), wire[0])?;
            patch.obliterate(node.id)?;
            return Ok(Some(patch.with_context("kernel-u8-to-i8")));
        }

        let full_input_shape = model.outlet_fact(node.inputs[0])?.shape.to_tvec();
        let input_fact = model.outlet_fact(node.inputs[0])?;
        let input_shape = self.pool_spec.data_format.shape(&full_input_shape)?;
        let spatial_rank = input_shape.hw_rank();
        let kernel_spatial_shape = &self.kernel.shape()[self.kernel_fmt.h_axis()..][..spatial_rank];
        unsafe {
            let dt = input_fact.datum_type;
            if self.q_params.is_some() {
                let mut patch = TypedModelPatch::default();
                let inputs = node
                    .inputs
                    .iter()
                    .map(|w| patch.tap_model(model, *w))
                    .collect::<TractResult<TVec<_>>>()?;
                let wire = self.wire_as_quant_im2col(
                    &mut patch,
                    &node.name,
                    model.node_input_facts(node.id)?[0].datum_type,
                    &inputs,
                )?;
                patch.shunt_outside(model, node.id.into(), wire)?;
                patch.obliterate(node.id)?;
                Ok(Some(patch.with_context("quantized-codegen")))
            } else if kernel_spatial_shape.iter().product::<usize>() == 1
                && (0..spatial_rank)
                    .all(|i| self.pool_spec.stride(i) == 1 && self.pool_spec.dilation(i) == 1)
                && self.group == 1
            {
                use crate::ops::matmul::MatMulUnary;
                let mut patch = TypedModelPatch::default();
                let mut wire = patch.tap_model(model, node.inputs[0])?;
                let input_c_is_last = input_shape.c_axis() == input_shape.rank() - 1;
                let geo_dim: TDim = input_shape.hw_dims().iter().product();
                wire = patch.wire_node(
                    format!("{}.reshape_input", &*node.name),
                    AxisOp::Reshape(
                        input_shape.h_axis(),
                        input_shape.hw_dims().into(),
                        tvec!(geo_dim.clone()),
                    ),
                    &[wire],
                )?[0];
                let kernel_shape = match self.kernel_fmt {
                    KernelFormat::HWIO => &self.kernel.shape()[spatial_rank..],
                    KernelFormat::OIHW => &self.kernel.shape()[..2],
                };
                let operating_rank = input_fact.rank() + 1 - kernel_spatial_shape.len();
                let kernel = self
                    .kernel
                    .as_ref()
                    .clone()
                    .into_shape(kernel_shape)?
                    .broadcast_into_rank(operating_rank)?;
                wire = patch.wire_node(
                    &format!("{}.matmul", &node.name),
                    MatMulUnary::new(
                        kernel.into_arc_tensor(),
                        MatMulAxes::default_for_rank(operating_rank).transposing(
                            self.kernel_fmt == KernelFormat::HWIO,
                            input_c_is_last,
                            input_c_is_last,
                        ),
                    ),
                    &[wire],
                )?[0];
                if let Some(ref bias) = self.bias {
                    let bias_shape =
                        if input_c_is_last { [1, bias.len()] } else { [bias.len(), 1] };
                    let bias = bias
                        .clone()
                        .into_tensor()
                        .into_shape(&bias_shape)?
                        .broadcast_into_rank(operating_rank)?
                        .into_arc_tensor();
                    let bias = patch.add_const(format!("{}.bias.cst", node.name), bias)?;
                    wire = patch.wire_node(
                        format!("{}.bias", node.name),
                        crate::ops::math::add(),
                        &[wire, bias],
                    )?[0];
                }
                wire = patch.wire_node(
                    &*node.name,
                    AxisOp::Reshape(
                        input_shape.h_axis(),
                        tvec!(geo_dim),
                        input_shape.hw_dims().into(),
                    ),
                    &[wire],
                )?[0];
                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
                patch.obliterate(node.id)?;
                Ok(Some(patch))
            } else if input_fact
                .shape
                .as_concrete()
                .map(|s| {
                    should_use_lazy(
                        &self.pool_spec.data_format.shape(s.into()).unwrap(),
                        &self.pool_spec,
                        self.group,
                    )
                })
                .unwrap_or(false)
            {
                let mut patch = TypedModelPatch::new("wire_as_lazy_im2col");
                let mut wire = patch.tap_model(model, node.inputs[0])?;
                wire = self.wire_as_lazy_im2col(&mut patch, &node.name, wire)?;
                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
                patch.obliterate(node.id)?;
                Ok(Some(patch))
            } else if self.group != 1
                && self.group == self.output_channels()
                && self.group == self.input_channels()
                && input_fact.shape.as_concrete().is_some()
            {
                let op = dispatch_floatlike!(Self::to_depth_wise(dt)(self, input_fact))
                    .context("in to_depth_wise")?;
                Ok(Some(TypedModelPatch::single_unary_op(model, node, op)?))
            } else {
                let mut patch = TypedModelPatch::default();
                let wire = patch.tap_model(model, node.inputs[0])?;
                let wire = self
                    .wire_as_im2col_pair(&mut patch, &node.name, wire)
                    .context("in wire_as_im2col_pair")?;
                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
                patch.obliterate(node.id)?;
                Ok(Some(patch))
            }
        }
    }

    as_op!();
}

fn should_use_lazy(_input_shape: &DataShape, pool_spec: &PoolSpec, group: usize) -> bool {
    group == 1 && pool_spec.kernel_shape.iter().product::<usize>() > 5
}

#[allow(non_snake_case)]
#[cfg(test)]
mod test {
    use super::*;
    use crate::ops::array::Pad;
    use crate::ops::cnn::PaddingSpec;
    use DataFormat::*;

    #[test]
    fn onnx_basic_convinteger() {
        let op = ConvUnary {
            pool_spec: PoolSpec {
                data_format: NCHW,
                kernel_shape: tvec!(2, 2),
                padding: PaddingSpec::Valid,
                dilations: None,
                strides: None,
                output_channel_override: Some(1),
            },
            kernel_fmt: KernelFormat::OIHW,
            kernel: rctensor4(&[[[[1u8, 1], [1, 1]]]]),
            group: 1,
            bias: None,
            q_params: Some((i32::datum_type(), MatMulQParams::all_dynamic(1))),
        };
        let input = tvec!(
            rctensor4(&[[[[1u8, 2, 3], [4, 5, 6], [7, 8, 9]]]]),
            rctensor0(0u8),
            rctensor0(1.0f32),
            rctensor0(1u8),
            rctensor0(1.0f32),
            rctensor0(0i32),
            rctensor0(1.0f32),
        );
        let input = input.into_iter().map(IntoTValue::into_tvalue).collect::<TVec<_>>();
        let output = op.eval(input).unwrap();
        assert_eq!(*output[0], tensor4(&[[[[8i32, 12], [20, 24]]]]));
    }

    #[test]
    fn valid_conv_absorbs_precursor_pad() -> TractResult<()> {
        let mut model = TypedModel::default();
        let wire = tvec!(model.add_source("source", f32::fact(dims!(1, 10)))?);
        let wire = model.wire_node(
            "pad",
            Pad {
                pads: vec![(0, 0), (1, 0)],
                mode: ops::array::PadMode::Constant(rctensor0(0f32)),
            },
            &wire,
        )?;
        let wire = model.wire_node(
            "conv",
            ConvUnary {
                pool_spec: PoolSpec {
                    data_format: crate::ops::nn::DataFormat::CHW,
                    dilations: None,
                    strides: None,
                    kernel_shape: tvec![2],
                    padding: crate::ops::cnn::PaddingSpec::Explicit(tvec![0], tvec![0], false),
                    output_channel_override: Some(1),
                },
                kernel_fmt: crate::ops::cnn::KernelFormat::OIHW,
                kernel: rctensor3(&[[[1f32, 2f32]]]),
                group: 1,
                bias: None,
                q_params: None,
            },
            &wire,
        )?;
        model.set_output_outlets(&wire)?;
        model.declutter()?;
        assert_eq!(model.nodes().len(), 2); // source + conv
        let cv = model.nodes()[1].op_as::<ConvUnary>().unwrap();
        assert_eq!(
            cv.pool_spec.padding,
            crate::ops::cnn::PaddingSpec::Explicit(tvec![1], tvec![0], false)
        ); // source + conv
        Ok(())
    }
}