tract-tensorflow 0.3.3

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use tract_core::internal::*;
use tract_core::ops as tractops;

use crate::model::TfOpRegister;

pub fn register_all_ops(reg: &mut TfOpRegister) {
    reg.insert("Less", with_T!(tractops::logic::Lesser::Bin));
    reg.insert("Merge", merge);
    reg.insert("Switch", |_| Ok(Box::new(Switch)));
}

#[derive(Debug, Clone)]
pub struct Switch;

impl Op for Switch {
    fn name(&self) -> Cow<str> {
        "tf.Switch".into()
    }
}

impl StatelessOp for Switch {
    fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
        let (input, pred) = args_2!(inputs);
        let null = unsafe { Tensor::null_dt(input.datum_type(), input.shape())? };
        if *pred.to_scalar::<bool>()? {
            Ok(tvec!(null.into(), input))
        } else {
            Ok(tvec!(input, null.into()))
        }
    }
}

impl InferenceRulesOp for Switch {
    fn rules<'r, 'p: 'r, 's: 'r>(
        &'s self,
        s: &mut Solver<'r>,
        inputs: &'p [TensorProxy],
        outputs: &'p [TensorProxy],
    ) -> InferenceResult {
        check_input_arity(&inputs, 2)?;
        s.equals(&inputs[1].datum_type, DatumType::Bool)?;
        s.equals(&inputs[1].shape, shapefact!())?;
        for i in 0..outputs.len() {
            s.equals(&inputs[0].datum_type, &outputs[i].datum_type)?;
            s.equals(&inputs[0].shape, &outputs[i].shape)?;
        }
        Ok(())
    }
}

fn merge(pb: &crate::tfpb::node_def::NodeDef) -> TractResult<Box<Op>> {
    let inputs = pb.get_attr_int::<i32>("N")?;
    Ok(Box::new(Merge::new(inputs as usize)))
}

#[derive(Debug, Clone, new)]
pub struct Merge {
    n: usize,
}

impl Op for Merge {
    fn name(&self) -> Cow<str> {
        "tf.Merge".into()
    }
}

impl StatelessOp for Merge {
    fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
        let index =
            inputs.iter().position(|t| !t.is_null()).ok_or("No tensor received in merge")?;
        Ok(tvec!(inputs.remove(index), Tensor::from(index as i32).into()))
    }
}

impl InferenceRulesOp for Merge {
    fn rules<'r, 'p: 'r, 's: 'r>(
        &'s self,
        s: &mut Solver<'r>,
        inputs: &'p [TensorProxy],
        outputs: &'p [TensorProxy],
    ) -> InferenceResult {
        check_input_arity(&inputs, self.n)?;
        check_output_arity(&outputs, 1)?;
        for i in 1..self.n {
            s.equals(&inputs[0].datum_type, &inputs[i].datum_type)?;
            s.equals(&inputs[0].shape, &inputs[i].shape)?;
        }
        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
        s.equals(&inputs[0].shape, &outputs[0].shape)?;
        Ok(())
    }
}