tract-core 0.2.0

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use crate::model::Node;
use crate::TractResult;
use bit_set;

pub fn eval_order(model: &super::Model) -> TractResult<Vec<usize>> {
    let inputs = model
        .inputs()?
        .iter()
        .map(|n| n.node)
        .collect::<Vec<usize>>();
    let targets = model
        .outputs()?
        .iter()
        .map(|n| n.node)
        .collect::<Vec<usize>>();
    eval_order_for_nodes(model.nodes(), &inputs, &targets)
}

pub fn eval_order_for_nodes(
    nodes: &[Node],
    inputs: &[usize],
    targets: &[usize],
) -> TractResult<Vec<usize>> {
    let mut done = bit_set::BitSet::with_capacity(nodes.len());
    let mut needed: Vec<usize> = vec![];
    let mut order: Vec<usize> = vec![];
    for &t in targets {
        needed.push(t);
    }
    while let Some(&node) = needed.last() {
        if done.contains(node) {
            needed.pop();
            continue;
        }
        if inputs.contains(&node) || nodes[node].inputs.iter().all(|i| done.contains(i.node)) {
            order.push(node);
            needed.pop();
            done.insert(node);
        } else {
            for input in nodes[node].inputs.iter().rev() {
                if !done.contains(input.node) {
                    needed.push(input.node);
                }
            }
        }
    }
    Ok(order)
}

#[cfg(test)]
mod tests {
    use crate::model::dsl::ModelDsl;
    use crate::model::*;
    use crate::ops::math::Add;
    use crate::*;

    #[test]
    fn test_simple() {
        let mut model = Model::default();
        model.add_source("a").unwrap();
        model.chain("add", Box::new(Add::default())).unwrap();
        model.add_const("b", Tensor::from(12.0f32).into()).unwrap();
        model
            .add_edge(OutletId::new(2, 0), InletId::new(1, 1))
            .unwrap();
        assert_eq!(model.eval_order().unwrap(), vec!(0, 2, 1));
    }

    #[test]
    fn test_diamond() {
        let mut model = Model::default();
        model.add_source("a").unwrap();
        model.chain("add", Box::new(Add::default())).unwrap();
        model
            .add_edge(OutletId::new(0, 0), InletId::new(1, 1))
            .unwrap();
        assert_eq!(model.eval_order().unwrap(), vec!(0, 1));
    }
}