tract-core 0.2.0

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use crate::ops::prelude::*;
use ndarray::prelude::*;

use num_traits::AsPrimitive;
use num_traits::Float;

#[derive(Debug, Clone, new)]
pub struct Gemm {
    alpha: f32,
    beta: f32,
    trans_a: bool,
    trans_b: bool,
    have_c: bool,
}

impl Gemm {
    fn eval_t_3<T: Datum + Float>(
        &self,
        mut inputs: TVec<SharedTensor>,
    ) -> TractResult<TVec<SharedTensor>>
    where
        f32: AsPrimitive<T>,
    {
        let (a, b, c) = args_3!(inputs);
        let a = a.to_array_view::<T>()?.into_dimensionality()?;
        let at = if self.trans_a { a.t() } else { a };
        let b = b.to_array_view::<T>()?.into_dimensionality()?;
        let bt = if self.trans_b { b.t() } else { b };
        let c_shape = (at.rows(), bt.cols());
        let mut c = if c.shape() == &[c_shape.0, c_shape.1] {
            c.to_array::<T>()?.into_dimensionality::<Ix2>()?.to_owned()
        } else {
            c.to_array_view::<T>()?
                .broadcast(c_shape)
                .ok_or_else(|| format!("Incompatible broadcast: {:?} to {:?}", c.shape(), c_shape))?
                .to_owned()
        };
        ::ndarray::linalg::general_mat_mul(self.alpha.as_(), &at, &bt, self.beta.as_(), &mut c);
        Ok(tvec!(c.into()))
    }

    fn eval_t_2<T: Datum + Float>(
        &self,
        mut inputs: TVec<SharedTensor>,
    ) -> TractResult<TVec<SharedTensor>>
    where
        f32: AsPrimitive<T>,
    {
        let (a, b) = args_2!(inputs);
        let a = a.to_array_view::<T>()?.into_dimensionality()?;
        let at = if self.trans_a { a.t() } else { a };
        let b = b.to_array_view::<T>()?.into_dimensionality()?;
        let bt = if self.trans_b { b.t() } else { b };
        let c_shape = (at.rows(), bt.cols());
        let mut c = unsafe { Array::uninitialized((c_shape.0, c_shape.1)) };
        ::ndarray::linalg::general_mat_mul(self.alpha.as_(), &at, &bt, T::zero(), &mut c);
        Ok(tvec!(c.into()))
    }
}

impl Op for Gemm {
    fn name(&self) -> Cow<str> {
        "Gemm".into()
    }
}

impl StatelessOp for Gemm {
    fn eval(&self, inputs: TVec<SharedTensor>) -> TractResult<TVec<SharedTensor>> {
        if self.have_c {
            dispatch_floatlike!(Self::eval_t_3(inputs[0].datum_type())(self, inputs))
        } else {
            dispatch_floatlike!(Self::eval_t_2(inputs[0].datum_type())(self, inputs))
        }
    }
}

impl InferenceRulesOp for Gemm {
    fn rules<'r, 'p: 'r, 's: 'r>(
        &'s self,
        s: &mut Solver<'r>,
        inputs: &'p SharedTensorsProxy,
        outputs: &'p SharedTensorsProxy,
    ) -> InferenceResult {
        if self.have_c {
            s.equals(&inputs.len, 3)?;
            s.equals(&inputs[2].datum_type, &outputs[0].datum_type)?;
        } else {
            s.equals(&inputs.len, 2)?;
        };
        s.equals(&inputs[0].rank, 2)?;
        s.equals(&inputs[1].rank, 2)?;
        s.equals(&outputs.len, 1)?;
        s.equals(&outputs[0].rank, 2)?;
        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
        s.equals(&inputs[1].datum_type, &outputs[0].datum_type)?;
        let (ca, ra) = if self.trans_a { (0, 1) } else { (1, 0) };
        let (cb, rb) = if self.trans_b { (0, 1) } else { (1, 0) };
        s.equals(&inputs[0].shape[ra], &outputs[0].shape[0])?;
        s.equals(&inputs[0].shape[ca], &inputs[1].shape[rb])?;
        s.equals(&inputs[1].shape[cb], &outputs[0].shape[1])?;
        Ok(())
    }
}

#[derive(Debug, Clone, new)]
pub struct GemmUnaryA {
    alpha: f32,
    beta: f32,
    trans_a: bool,
    trans_b: bool,
    b: Tensor,
    c: Tensor,
}

impl GemmUnaryA {
    fn eval_t<T: Datum + Float>(
        &self,
        mut inputs: TVec<SharedTensor>,
    ) -> TractResult<TVec<SharedTensor>>
    where
        f32: AsPrimitive<T>,
    {
        let a = args_1!(inputs);
        let a = a.to_array_view::<T>()?.into_dimensionality()?;
        let at = if self.trans_a { a.t() } else { a };
        let b = self.b.to_array_view::<T>()?.into_dimensionality()?;
        let bt = if self.trans_b { b.t() } else { b };
        let mut c = self
            .c
            .to_array_view::<T>()?
            .into_dimensionality()?
            .to_owned();
        ::ndarray::linalg::general_mat_mul(self.alpha.as_(), &at, &bt, self.beta.as_(), &mut c);
        Ok(tvec!(c.into()))
    }
}

impl Op for GemmUnaryA {
    fn name(&self) -> Cow<str> {
        "GemmUnaryA".into()
    }
}

impl StatelessOp for GemmUnaryA {
    fn eval(&self, inputs: TVec<SharedTensor>) -> TractResult<TVec<SharedTensor>> {
        dispatch_floatlike!(Self::eval_t(inputs[0].datum_type())(self, inputs))
    }
}

impl InferenceRulesOp for GemmUnaryA {
    fn rules<'r, 'p: 'r, 's: 'r>(
        &'s self,
        s: &mut Solver<'r>,
        inputs: &'p SharedTensorsProxy,
        outputs: &'p SharedTensorsProxy,
    ) -> InferenceResult {
        s.equals(&inputs.len, 1)?;
        s.equals(&inputs[0].rank, 2)?;
        s.equals(&outputs.len, 1)?;
        s.equals(&outputs[0].rank, 2)?;
        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
        s.equals(&inputs[1].datum_type, &outputs[0].datum_type)?;
        s.equals(&inputs[2].datum_type, &outputs[0].datum_type)?;
        let (ca, ra) = if self.trans_a { (0, 1) } else { (1, 0) };
        let (cb, rb) = if self.trans_b { (0, 1) } else { (1, 0) };
        s.equals(&inputs[0].shape[ra], &outputs[0].shape[0])?;
        s.equals(&inputs[0].shape[ca], self.b.shape()[rb].to_dim())?;
        s.equals(self.b.shape()[cb].to_dim(), &outputs[0].shape[1])?;
        Ok(())
    }
}

#[derive(Debug, Clone, new)]
pub struct GemmUnaryB {
    alpha: f32,
    beta: f32,
    trans_a: bool,
    trans_b: bool,
    a: Tensor,
    c: Tensor,
}

impl GemmUnaryB {
    fn eval_t<T: Datum + Float>(
        &self,
        mut inputs: TVec<SharedTensor>,
    ) -> TractResult<TVec<SharedTensor>>
    where
        f32: AsPrimitive<T>,
    {
        let b = args_1!(inputs);
        let b = b.to_array_view::<T>()?.into_dimensionality()?;
        let a = self.a.to_array_view::<T>()?.into_dimensionality()?;
        let at = if self.trans_a { a.t() } else { a };
        let bt = if self.trans_b { b.t() } else { b };
        let mut c = self
            .c
            .to_array_view::<T>()?
            .into_dimensionality()?
            .to_owned();
        ::ndarray::linalg::general_mat_mul(self.alpha.as_(), &at, &bt, self.beta.as_(), &mut c);
        Ok(tvec!(c.into()))
    }
}

impl Op for GemmUnaryB {
    fn name(&self) -> Cow<str> {
        "GemmUnaryB".into()
    }
}

impl StatelessOp for GemmUnaryB {
    fn eval(&self, inputs: TVec<SharedTensor>) -> TractResult<TVec<SharedTensor>> {
        dispatch_floatlike!(Self::eval_t(inputs[0].datum_type())(self, inputs))
    }
}

impl InferenceRulesOp for GemmUnaryB {
    fn rules<'r, 'p: 'r, 's: 'r>(
        &'s self,
        s: &mut Solver<'r>,
        inputs: &'p SharedTensorsProxy,
        outputs: &'p SharedTensorsProxy,
    ) -> InferenceResult {
        s.equals(&inputs.len, 1)?;
        s.equals(&inputs[0].rank, 2)?;
        s.equals(&outputs.len, 1)?;
        s.equals(&outputs[0].rank, 2)?;
        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
        s.equals(&inputs[1].datum_type, &outputs[0].datum_type)?;
        s.equals(&inputs[2].datum_type, &outputs[0].datum_type)?;
        let (ca, ra) = if self.trans_a { (0, 1) } else { (1, 0) };
        let (cb, rb) = if self.trans_b { (0, 1) } else { (1, 0) };
        s.equals(self.a.shape()[ra].to_dim(), &outputs[0].shape[0])?;
        s.equals(self.a.shape()[ca].to_dim(), &inputs[0].shape[rb])?;
        s.equals(&inputs[0].shape[cb], &outputs[0].shape[1])?;
        Ok(())
    }
}