tract-core 0.23.0-dev.4

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use crate::internal::*;
use tract_linalg::mmm::{MMMInputValue, PackedMatrixStorage};
use tract_ndarray::prelude::*;

#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct QSumB {
    pub dt: DatumType,
    pub r: usize,
    pub n: TDim,
    pub k: usize,
}

impl Op for QSumB {
    fn name(&self) -> StaticName {
        "QSumB".into()
    }

    fn info(&self) -> TractResult<Vec<String>> {
        Ok(vec![format!("r:{}, n:{}, k:{}", self.r, self.n, self.k)])
    }

    op_as_typed_op!();
}

impl EvalOp for QSumB {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval_with_session(
        &self,
        _node_id: usize,
        session: &TurnState,
        inputs: TVec<TValue>,
    ) -> TractResult<TVec<TValue>> {
        let n = self.n.eval_to_i64(&session.resolved_symbols)? as usize;
        self.eval(inputs, n)
    }
}

impl TypedOp for QSumB {
    as_op!();

    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let mut shape: TVec<TDim> = inputs[0].shape.to_tvec();
        shape.push(self.n.to_dim());
        Ok(tvec!(i32::fact(shape)))
    }
}

impl QSumB {
    fn eval(&self, inputs: TVec<TValue>, n: usize) -> TractResult<TVec<TValue>> {
        let input = args_1!(inputs);
        let storage = input
            .try_storage_as::<PackedMatrixStorage>()
            .context("Expected PackedMatrixStorage")?;
        let batch_shape = storage.batch_shape();
        let mut shape: TVec<usize> = batch_shape.into();
        shape.push(n);
        let mut output = ArrayD::<i32>::zeros(&*shape);
        for b in 0..batch_shape[0] {
            let mut output_view = output.index_axis_mut(Axis(0), b);
            for g in 0..batch_shape[1] {
                let mut output_view = output_view.index_axis_mut(Axis(0), g);
                let output_slice = output_view.as_slice_mut().unwrap();
                let payload = storage.value_at(&[b, g]);
                match self.dt.unquantized() {
                    DatumType::I8 => self.eval_t::<i8>(payload, output_slice)?,
                    DatumType::U8 => self.eval_t::<u8>(payload, output_slice)?,
                    dt => bail!("Unsupported input type in quantized operation ({:?})", dt),
                }
            }
        }
        Ok(tvec!(output.into_tvalue()))
    }

    fn eval_t<T: Datum + tract_num_traits::AsPrimitive<i32>>(
        &self,
        input: &dyn MMMInputValue,
        output: &mut [i32],
    ) -> TractResult<()> {
        let (r, k, n) = (input.format().r(), input.k(), input.mn());
        let panels = n.divceil(r);
        for ipanel in 0..panels {
            let panel = input.panel_bytes(ipanel, None)?;
            let panel: &[T] = unsafe { std::slice::from_raw_parts(panel as *const T, r * k) };
            let mut vec = vec![0i32; r];
            for ik in 0..k {
                for ir in 0..r {
                    vec[ir] += panel[ik * r + ir].as_();
                }
            }
            let len = r.min(n - r * ipanel);
            output[r * ipanel..][..len].copy_from_slice(&vec[..len]);
        }
        Ok(())
    }
}