tract-core 0.23.0-dev.4

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use super::AxesMapping;
use crate::internal::*;
use ndarray::{ArrayViewD, Zip};
use tract_data::itertools::Itertools;
use tract_linalg::block_quant::{BlockQuantStorage, block_quant_slice};
use tract_ndarray::{Axis, Dimension};
use tract_num_traits::{One, Zero};

pub fn output_shape<D: DimLike>(
    expr: &AxesMapping,
    inputs: &[impl AsRef<[D]>],
) -> TractResult<TVec<D>> {
    Ok(expr
        .iter_all_axes()
        .filter(|a| a.outputs[0].len() > 0)
        .sorted_by_key(|axis| axis.outputs[0][0])
        .map(|axis| {
            axis.inputs[0..inputs.len()]
                .iter()
                .enumerate()
                .flat_map(|(input_id, positions)| {
                    positions.iter().map(move |p| inputs[input_id].as_ref()[*p].clone())
                })
                .find(|x| x != &1.into())
                .unwrap_or_else(|| 1.into())
        })
        .collect())
}

pub fn dequant_inputs(acc: DatumType, input: TVec<TValue>) -> TractResult<TVec<TValue>> {
    input
        .into_iter()
        .map(|i| {
            if i.is_plain() && i.datum_type().is_number() {
                Ok(i)
            } else {
                let s = i.shape();
                let k = *s.last().unwrap();
                // Leading dims are group/batch dims; last two are [M, K]
                let num_groups: usize =
                    if s.len() > 2 { s[..s.len() - 2].iter().product() } else { 1 };
                let m_per_group: usize = if s.len() >= 2 { s[s.len() - 2] } else { 1 };
                let bqs = i.try_storage_as::<BlockQuantStorage>()?;
                let mut unpacked: Vec<Tensor> = if acc.is::<f16>() {
                    (0..num_groups)
                        .map(|g| {
                            let slice =
                                block_quant_slice(bqs.value(), bqs.format(), m_per_group, k, g);
                            bqs.format().dequant_f16(slice)
                        })
                        .collect::<TractResult<_>>()?
                } else if acc.is::<f32>() {
                    (0..num_groups)
                        .map(|g| {
                            let slice =
                                block_quant_slice(bqs.value(), bqs.format(), m_per_group, k, g);
                            bqs.format().dequant_f32(slice)
                        })
                        .collect::<TractResult<_>>()?
                } else {
                    bail!(
                        "Only f32 and f16 accumulators are compatible with BlockQuantValue inputs"
                    );
                };
                unpacked.iter_mut().try_for_each(|t| t.insert_axis(0))?;
                let stacked = if unpacked.len() > 1 {
                    Tensor::stack_tensors(0, &unpacked)?
                } else {
                    unpacked.into_iter().next().unwrap()
                };
                Ok(stacked.into_shape(s)?.into_tvalue())
            }
        })
        .collect::<TractResult<TVec<TValue>>>()
}

pub fn eval_t<Acc: Datum + Zero + One>(
    expr: &AxesMapping,
    inputs: TVec<TValue>,
) -> TractResult<Tensor> {
    let inputs = dequant_inputs(Acc::datum_type(), inputs)?;
    let shapes: TVec<_> = inputs.iter().map(|t| t.shape()).collect();
    let output_shape = output_shape(expr, &shapes)?;
    let inputs: TVec<Cow<Tensor>> =
        inputs.iter().map(|t| t.cast_to::<Acc>()).collect::<TractResult<_>>()?;
    let inputs: TVec<tract_ndarray::ArrayViewD<Acc>> =
        inputs.iter().map(|t| t.to_plain_array_view::<Acc>()).collect::<TractResult<_>>()?;
    let summing_axes: TVec<_> = expr
        .iter_all_axes()
        .filter(|a| {
            a.outputs[0].len() == 0 && a.inputs[0..inputs.len()].iter().any(|i| i.len() > 0)
        })
        .collect();
    let summing_shape: TVec<usize> = summing_axes
        .iter()
        .map(|axis| {
            axis.inputs
                .iter()
                .take(inputs.len())
                .enumerate()
                .find_map(|(input_id, positions)| {
                    if positions.len() > 0 {
                        Some(inputs[input_id].shape()[positions[0]])
                    } else {
                        None
                    }
                })
                .unwrap()
        })
        .collect();
    let output = tract_ndarray::ArrayD::<Acc>::from_shape_fn(&*output_shape, |coords| {
        let coords = coords.as_array_view();
        let mut views = inputs.clone();
        for (axis, x) in expr
            .iter_all_axes()
            .filter(|a| a.outputs[0].len() > 0)
            .sorted_by_key(|axis| axis.outputs[0][0])
            .zip(coords)
        {
            for (input_id, input_axis_positions) in axis.inputs[0..inputs.len()].iter().enumerate()
            {
                for position in input_axis_positions {
                    let x = if views[input_id].shape()[*position] == 1 { 0 } else { *x };
                    views[input_id]
                        .slice_axis_inplace(tract_ndarray::Axis(*position), (x..=x).into());
                }
            }
        }
        let mut sum: Acc = Acc::zero();
        for sum_coords in tract_ndarray::indices(&*summing_shape) {
            let mut views = views.clone();
            let sum_coords = sum_coords.as_array_view();
            for (axis, x) in summing_axes.iter().zip(sum_coords) {
                for (input_id, input_axis_positions) in
                    axis.inputs.iter().take(inputs.len()).enumerate()
                {
                    for position in input_axis_positions {
                        views[input_id].slice_axis_inplace(Axis(*position), (*x..=*x).into())
                    }
                }
            }
            let mut product = Acc::one();
            for v in &views {
                debug_assert_eq!(v.len(), 1);
                product = product * v.iter().next().unwrap().clone();
            }
            sum = sum + product;
        }
        sum
    });
    Ok(output.into_tensor())
}

pub fn eval_q(expr: &AxesMapping, qp: DatumType, inputs: TVec<TValue>) -> TractResult<Tensor> {
    fn reshape_param<'a>(
        expr: &AxesMapping,
        data_slot: InOut,
        qp: &'a Tensor,
        qp_slot: InOut,
    ) -> TractResult<ArrayViewD<'a, f32>> {
        if qp.rank() == 0 {
            qp.try_as_plain()?.to_array_view()
        } else {
            let data_rank = expr.rank(data_slot);

            // Handle case where axis is not present in input (qp.len is necessarily 1)
            let pos_in_input =
                expr.axis((qp_slot, 0))?.interface(data_slot).first().cloned().unwrap_or(0);

            let mut shape = vec![1; data_rank];
            shape[pos_in_input] = qp.len();
            Ok(qp.try_as_plain()?.to_array_view()?.into_shape_with_order(shape)?)
        }
    }
    let [a, b, bias, a0, a_scale, b0, b_scale, c0, c_scale] = &*inputs else {
        bail!("Expect exactly 9 inputs")
    };

    let mut a = a.cast_to::<i32>()?.cast_to::<f32>()?.into_owned();
    let mut b = b.cast_to::<i32>()?.cast_to::<f32>()?.into_owned();

    let a0 = a0.cast_to::<f32>()?;
    let b0 = b0.cast_to::<f32>()?;
    let c0 = c0.cast_to::<f32>()?;
    let a_scale = a_scale.cast_to::<f32>()?;
    let b_scale = b_scale.cast_to::<f32>()?;
    let c_scale = c_scale.cast_to::<f32>()?;
    let bias = bias.cast_to::<f32>()?;
    ensure!(a0.rank() < 2);
    ensure!(b0.rank() < 2);
    ensure!(c0.rank() < 2);
    ensure!(a_scale.rank() < 2);
    ensure!(b_scale.rank() < 2);
    ensure!(c_scale.rank() < 2);
    ensure!(bias.rank() < 2);

    Zip::from(a.to_plain_array_view_mut::<f32>()?)
        .and_broadcast(reshape_param(expr, InOut::In(0), &a0, InOut::In(3))?)
        .and_broadcast(reshape_param(expr, InOut::In(0), &a_scale, InOut::In(4))?)
        .for_each(|a, a0, a_scale| *a = a_scale * (*a - a0));

    Zip::from(b.to_plain_array_view_mut::<f32>()?)
        .and_broadcast(reshape_param(expr, InOut::In(1), &b0, InOut::In(5))?)
        .and_broadcast(reshape_param(expr, InOut::In(1), &b_scale, InOut::In(6))?)
        .for_each(|b, b0, b_scale| *b = b_scale * (*b - b0));

    let mut output =
        eval_t::<f32>(expr, tvec!(a.into_tvalue(), b.into_tvalue()))?.into_plain_array::<f32>()?;

    Zip::from(&mut output)
        .and_broadcast(reshape_param(expr, InOut::Out(0), &bias, InOut::In(2))?)
        .and_broadcast(reshape_param(expr, InOut::Out(0), &c0, InOut::In(7))?)
        .and_broadcast(reshape_param(expr, InOut::Out(0), &c_scale, InOut::In(8))?)
        .and_broadcast(reshape_param(expr, InOut::Out(0), &a_scale, InOut::In(4))?)
        .and_broadcast(reshape_param(expr, InOut::Out(0), &b_scale, InOut::In(6))?)
        .for_each(|c, bias, c0, c_scale, a_scale, b_scale| {
            *c = ((*c + bias * a_scale * b_scale) / c_scale + c0).round()
        });

    if qp.unquantized() == i8::datum_type() {
        output.mapv_inplace(|x| x.clamp(i8::MIN as _, i8::MAX as _))
    } else if qp.unquantized() == u8::datum_type() {
        output.mapv_inplace(|x| x.clamp(u8::MIN as _, u8::MAX as _))
    }
    Ok(output.into_tensor().cast_to::<i32>()?.cast_to_dt(qp)?.into_owned())
}