tract-tensorflow 0.3.3

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use tract_core::internal::*;

use crate::model::TfOpRegister;
use crate::tfpb::node_def::NodeDef;

#[macro_use]
mod macros;

pub mod array;
pub mod logic;
pub mod math;
pub mod nn;
pub mod quant;
pub mod rec;
pub mod vars;

pub fn register_all_ops(reg: &mut TfOpRegister) {
    array::register_all_ops(reg);
    logic::register_all_ops(reg);
    math::register_all_ops(reg);
    nn::register_all_ops(reg);
    quant::register_all_ops(reg);
    rec::register_all_ops(reg);
    vars::register_all_ops(reg);
    reg.insert("Cast", cast);
    reg.insert("Const", konst);
    reg.insert("Identity", |_| Ok(Box::new(Identity::default())));
    reg.insert("NoOp", |_| Ok(Box::new(Noop)));
    reg.insert("Placeholder", placeholder);
}

pub fn cast(node: &NodeDef) -> TractResult<Box<Op>> {
    let dtype = node.get_attr_datum_type("DstT")?;
    Ok(Box::new(::tract_core::ops::cast::Cast::new(dtype)))
}

pub fn konst(node: &NodeDef) -> TractResult<Box<Op>> {
    let dtype = node.get_attr_datum_type("dtype")?;
    let mat = node.get_attr_tensor("value")?;

    if mat.datum_type() != dtype {
        bail!("Const node {:?} doesn't have the expected {:?} type.", mat, dtype);
    }

    Ok(Box::new(::tract_core::ops::konst::Const::for_tensor(mat)))
}

pub fn placeholder(_node: &NodeDef) -> TractResult<Box<Op>> {
    Ok(Box::new(::tract_core::ops::source::Source::new()))
}

#[derive(Clone, Debug, new)]
pub struct Noop;

impl Op for Noop {
    fn name(&self) -> Cow<str> {
        "tf.Noop".into()
    }
}

impl StatelessOp for Noop {
    fn eval(&self, _inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
        Ok(tvec!(Tensor::from(false).into()))
    }
}

impl InferenceRulesOp for Noop {
    fn rules<'r, 'p: 'r, 's: 'r>(
        &'s self,
        s: &mut Solver<'r>,
        _inputs: &'p [TensorProxy],
        outputs: &'p [TensorProxy],
    ) -> InferenceResult {
        s.equals(&outputs[0].datum_type, bool::datum_type())?;
        s.equals(&outputs[0].rank, 0)?;
        Ok(())
    }
}

#[derive(Clone, Debug, new, Default)]
struct Identity;

impl Op for Identity {
    fn name(&self) -> Cow<str> {
        "tf.Identity".into()
    }

    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        if node.inputs.len() == 1 {
            Ok(Some(TypedModelPatch::single_unary_op(
                model,
                node,
                tract_core::ops::identity::Identity::default(),
            )?))
        } else {
            Ok(None)
        }
    }
}

impl StatelessOp for Identity {
    fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
        Ok(tvec!(inputs.remove(0)))
    }
}

impl InferenceRulesOp for Identity {
    fn rules<'r, 'p: 'r, 's: 'r>(
        &'s self,
        s: &mut Solver<'r>,
        inputs: &'p [TensorProxy],
        outputs: &'p [TensorProxy],
    ) -> InferenceResult {
        check_output_arity(&outputs, 1)?;
        s.equals(&outputs[0].shape, &inputs[0].shape)?;
        s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
        Ok(())
    }
}