tract-core 0.23.0-dev.4

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use crate::ops::binary::{BinMiniOp, TypedBinOp};
use crate::ops::konst::Const;
use crate::prelude::*;
use tract_data::internal::Approximation;

pub trait TypedModelHelpers {
    fn next_node(&self, node: &TypedNode) -> Option<&TypedNode>;
    fn previous_node(&self, node: &TypedNode) -> Option<&TypedNode>;
    fn previous_nodes(&self, node: &TypedNode) -> TVec<&TypedNode>;
    fn collect_const_inputs<'a>(&'a self, node: &TypedNode) -> TVec<&'a Const>;
    fn single_prev_node_as<O: TypedOp>(&self, node: &TypedNode) -> Option<(usize, &TypedNode)>;
    fn matches_single_input_const(&self, node: &TypedNode, konst: f32) -> bool;
    fn find_succ_bin_with_const<B: BinMiniOp>(
        &self,
        node: &TypedNode,
        konst: f32,
    ) -> Option<&TypedNode>;
    fn find_succ_bin_with_outlet<B: BinMiniOp>(
        &self,
        node: &TypedNode,
        outlet_id: &OutletId,
    ) -> Option<&TypedNode>;
}

impl TypedModelHelpers for TypedModel {
    fn next_node(&self, node: &TypedNode) -> Option<&TypedNode> {
        if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return None;
        }
        let succ = node.outputs[0].successors[0];
        Some(&self.nodes()[succ.node])
    }

    fn previous_node(&self, node: &TypedNode) -> Option<&TypedNode> {
        if node.inputs.len() != 1 {
            return None;
        }
        Some(&self.nodes()[node.inputs[0].node])
    }

    fn previous_nodes(&self, node: &TypedNode) -> TVec<&TypedNode> {
        node.inputs.iter().map(|n| &self.nodes()[n.node]).collect()
    }

    fn collect_const_inputs<'a>(&'a self, node: &TypedNode) -> TVec<&'a Const> {
        node.inputs
            .iter()
            .filter_map(|i| {
                let prec = &self.nodes()[i.node];
                prec.op_as::<Const>()
            })
            .collect::<TVec<_>>()
    }

    fn single_prev_node_as<O: TypedOp>(&self, node: &TypedNode) -> Option<(usize, &TypedNode)> {
        let prev_nodes = node
            .inputs
            .iter()
            .enumerate()
            .filter_map(|(in_idx, i)| {
                let prec = &self.nodes()[i.node];
                prec.op_is::<O>().then_some((in_idx, prec))
            })
            .collect::<TVec<_>>();

        if prev_nodes.len() != 1 { None } else { Some(prev_nodes[0]) }
    }

    fn matches_single_input_const(&self, node: &TypedNode, konst: f32) -> bool {
        let consts = self.collect_const_inputs(node);
        if consts.len() != 1 {
            return false;
        }
        let Ok(in_const) = consts[0].val().cast_to_dt(DatumType::F32) else {
            return false;
        };
        let Ok(in_const) = in_const.to_scalar_tensor() else {
            return false;
        };
        in_const
            .close_enough(&tract_data::prelude::tensor0(konst), Approximation::Approximate)
            .is_ok()
    }

    fn find_succ_bin_with_const<B: BinMiniOp>(
        &self,
        node: &TypedNode,
        konst: f32,
    ) -> Option<&TypedNode> {
        let succ = self.single_succ(node.id).ok()??;
        let succ_op = succ.op_as::<TypedBinOp>()?;
        (succ_op.0.is::<B>() && self.matches_single_input_const(succ, konst)).then_some(succ)
    }

    fn find_succ_bin_with_outlet<B: BinMiniOp>(
        &self,
        node: &TypedNode,
        outlet_id: &OutletId,
    ) -> Option<&TypedNode> {
        let succ = self.single_succ(node.id).ok()??;
        let succ_op = succ.op_as::<TypedBinOp>()?;
        (succ_op.0.is::<B>() && succ.inputs.contains(outlet_id)).then_some(succ)
    }
}