use std::error::Error;
use std::sync::atomic::{AtomicI32, Ordering};
use std::sync::{Arc, Mutex};
use rten_tensor::prelude::*;
use rten_tensor::test_util::{expect_equal, expect_equal_with_tolerance};
use rten_tensor::{Tensor, TensorView};
use smallvec::{SmallVec, smallvec};
use super::{CachedPlan, CaptureEnv, PlanOptions};
use crate::graph::{
Dimension, Graph, Node, NodeId, RunError, RunErrorKind, RunOptions, TypedConstant,
};
use crate::operator::{
IntoOpResult, OpError, OpRunContext, Operator, OutputList, OutputTypeList, OutputTypesContext,
PrepackedInput, SubgraphOperator,
};
use crate::ops::{Add, Concat, Conv, Identity, If, MatMul, Mul, Relu, Shape};
use crate::timing::Profiler;
use crate::value::{DataType, Value, ValueType, ValueView};
use crate::weight_cache::WeightCache;
#[derive(Clone, Debug, Default)]
struct Metrics {
run_count: u32,
run_in_place_count: u32,
}
#[derive(Debug)]
struct TrackUsage<Op: Operator> {
inner: Op,
metrics: Arc<Mutex<Metrics>>,
}
impl<Op: Operator> TrackUsage<Op> {
fn new(inner: Op) -> Self {
TrackUsage {
inner,
metrics: Default::default(),
}
}
fn metrics(&self) -> Arc<Mutex<Metrics>> {
self.metrics.clone()
}
}
impl<Op: Operator> Operator for TrackUsage<Op> {
fn name(&self) -> &str {
self.inner.name()
}
fn can_run_in_place(&self) -> bool {
self.inner.can_run_in_place()
}
fn is_commutative(&self) -> bool {
self.inner.is_commutative()
}
fn max_inputs(&self) -> Option<usize> {
self.inner.max_inputs()
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
self.inner.output_types(_ctx)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
{
let mut m = self.metrics.lock().unwrap();
m.run_count += 1;
}
self.inner.run(ctx)
}
fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
{
let mut m = self.metrics.lock().unwrap();
m.run_in_place_count += 1;
}
self.inner.run_in_place(input, ctx)
}
}
struct RunFn<V: Into<Value>, F: Fn(&OpRunContext) -> Result<V, OpError> + 'static> {
run: F,
}
impl<V: Into<Value>, F: Fn(&OpRunContext) -> Result<V, OpError>> RunFn<V, F> {
fn new(run: F) -> Self {
Self { run }
}
}
impl<V: Into<Value>, F: Fn(&OpRunContext) -> Result<V, OpError>> std::fmt::Debug for RunFn<V, F> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(fmt, "RunFn")
}
}
impl<V: Into<Value> + 'static, F: Fn(&OpRunContext) -> Result<V, OpError>> Operator
for RunFn<V, F>
{
fn name(&self) -> &str {
"RunFn"
}
fn max_inputs(&self) -> Option<usize> {
None
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
None
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
(self.run)(ctx).map(|v| [v.into()].into())
}
}
#[test]
fn test_graph_run() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();
let weights = Tensor::from_data(
&[1, 1, 3, 3],
vec![
0.3230, 0.7632, 0.4616, 0.8837, 0.5898, 0.3424, 0.2101, 0.7821, 0.6861,
],
)
.into_arc();
let weights_id = g.add_constant(Some("weight"), weights);
let input_id = g.add_value(Some("input"), None, None);
let (_, conv_out) = g.add_simple_op(
"conv",
Conv {
dilations: vec![1, 1],
groups: 1,
padding: [1, 1, 1, 1].into(),
strides: vec![1, 1],
},
&[input_id, weights_id],
);
let (_, relu_out) = g.add_simple_op("relu", Relu {}, &[conv_out]);
let input = Tensor::from_data(
&[1, 1, 3, 3],
vec![
0.5946, 0.8249, 0.0448, 0.9552, 0.2041, 0.2501, 0.2693, 0.1007, 0.8862,
],
);
let results = g
.run(vec![(input_id, input.into())], &[relu_out], None, None)
.unwrap();
let expected = Tensor::from_data(
&[1, 1, 3, 3],
vec![
1.5202, 1.5592, 0.9939, 1.7475, 2.6358, 1.3428, 1.0165, 1.1806, 0.8685,
],
);
assert_eq!(results.len(), 1);
expect_equal_with_tolerance(
&results[0].as_tensor_view().unwrap(),
&expected.view(),
1e-4,
0.,
)?;
Ok(())
}
#[test]
fn test_graph_node_debug_names() {
let mut g = Graph::new();
let weights = Tensor::from([0.3230]).into_arc();
let weights_id = g.add_constant(Some("weights"), weights.clone());
let input_id = g.add_value(Some("input"), None, None);
let relu_out_id = g.add_value(Some("relu_out"), None, None);
let relu_op_id = g.add_op(
Some("relu"),
Arc::new(Relu {}),
&[Some(input_id)],
&[Some(relu_out_id)],
);
assert_eq!(g.node_name(weights_id), "weights");
assert_eq!(g.node_name(input_id), "input");
assert_eq!(g.node_name(relu_op_id), "relu");
let anon_weights_id = g.add_constant(None, weights);
let anon_input_id = g.add_value(None, None, None);
let anon_out_id = g.add_value(None, None, None);
let anon_op_id = g.add_op(
None,
Arc::new(Relu {}),
&[Some(input_id)],
&[Some(anon_out_id)],
);
assert_eq!(
g.node_name(anon_weights_id),
format!("[ID: {}]", anon_weights_id)
);
assert_eq!(
g.node_name(anon_input_id),
format!("[ID: {}]", anon_input_id)
);
assert_eq!(g.node_name(anon_op_id), format!("[ID: {}]", anon_op_id));
}
#[test]
fn test_graph_node_shapes() {
let mut g = Graph::new();
let weights = Tensor::from_data(&[1, 1, 2], vec![0.3230, 0.5]).into_arc();
let weights_id = g.add_constant(Some("weights"), weights.clone());
let input_id = g.add_value(
Some("input"),
Some(
[
Dimension::Symbolic("batch".to_string()),
Dimension::Fixed(3),
Dimension::Fixed(5),
Dimension::Fixed(5),
]
.to_vec(),
),
None,
);
let (relu_op_id, _) = g.add_simple_op("relu", Relu {}, &[input_id]);
assert_eq!(
g.get_node(weights_id).and_then(|n| n.shape()),
Some([1, 1, 2].map(Dimension::Fixed).as_slice().into())
);
assert_eq!(
g.get_node(input_id).and_then(|n| n.shape()),
Some(
[
Dimension::Symbolic("batch".to_string()),
Dimension::Fixed(3),
Dimension::Fixed(5),
Dimension::Fixed(5),
]
.as_slice()
.into()
)
);
assert_eq!(g.get_node(relu_op_id).and_then(|n| n.shape()), None);
}
#[test]
fn test_graph_value_dtype() {
let mut g = Graph::new();
for dtype in [
DataType::Float,
DataType::Int32,
DataType::UInt8,
DataType::Int8,
] {
let input_id = g.add_value(None, None, Some(ValueType::Tensor(dtype)));
let input_dtype = g.get_node(input_id).and_then(|n| n.dtype());
assert_eq!(input_dtype, Some(ValueType::Tensor(dtype)));
}
}
#[derive(Debug)]
struct AddOne {}
impl Operator for AddOne {
fn name(&self) -> &str {
"AddOne"
}
fn max_inputs(&self) -> Option<usize> {
Some(1)
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
None
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let input: TensorView<f32> = ctx.inputs().require_as(0)?;
let output_data: Vec<f32> = input.iter().map(|x| x + 1.0).collect();
Tensor::<f32>::from_data(input.shape().into(), output_data).into_op_result()
}
}
#[test]
fn test_graph_planning_order() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let (_, op_a_out) = g.add_simple_op("op_a", AddOne {}, &[input_id]);
let (_, op_b_out) = g.add_simple_op("op_b", AddOne {}, &[op_a_out]);
let (_, op_c_out) = g.add_simple_op("op_c", Concat { axis: 0 }, &[op_a_out, op_b_out]);
let (_, op_d_out) = g.add_simple_op("op_d", Concat { axis: 0 }, &[op_b_out, op_a_out]);
let input = Tensor::from([1.]);
let results = g
.run(
vec![(input_id, input.view().into())],
&[op_c_out],
None,
None,
)
.unwrap();
let expected = Tensor::from([2., 3.]);
expect_equal(&results[0].as_tensor_view().unwrap(), &expected.view())?;
let results = g
.run(vec![(input_id, input.into())], &[op_d_out], None, None)
.unwrap();
let expected = Tensor::from([3., 2.]);
expect_equal(&results[0].as_tensor_view().unwrap(), &expected.view())?;
Ok(())
}
#[test]
fn test_runs_non_in_place_ops_first() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();
let input_a_id = g.add_value(Some("input_a"), None, None);
let input_b_id = g.add_value(Some("input_b"), None, None);
let (add_op, add_out) = g.add_simple_op("add", Add {}, &[input_a_id, input_b_id]);
let (shape_op, shape_out) = g.add_simple_op("shape", Shape::default(), &[input_a_id]);
let plan = g.execution_plan(
&[input_a_id, input_b_id],
&[add_out, shape_out],
PlanOptions::default(),
)?;
assert_eq!(plan, &[shape_op, add_op]);
let plan = g.execution_plan(
&[input_a_id, input_b_id],
&[shape_out, add_out],
PlanOptions::default(),
)?;
assert_eq!(plan, &[shape_op, add_op]);
Ok(())
}
#[test]
fn test_graph_intermediate_output() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let (_, op_a_out) = g.add_simple_op("op_a", AddOne {}, &[input_id]);
let (_, op_b_out) = g.add_simple_op("op_b", AddOne {}, &[op_a_out]);
let input = Tensor::from(0.);
let results = g
.run(
vec![(input_id, input.into())],
&[op_a_out, op_b_out],
None,
None,
)
.unwrap();
assert_eq!(
&results[0].as_tensor_view().unwrap(),
&Tensor::from(1.).view()
);
assert_eq!(
&results[1].as_tensor_view().unwrap(),
&Tensor::from(2.).view()
);
}
#[test]
fn test_graph_many_steps() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();
let input = Tensor::from([1., 2., 3., 4., 5.]);
let input_id = g.add_value(Some("input"), None, None);
let mut prev_output = input_id;
for _ in 0..100 {
let next_output = g.add_value(None, None, None);
g.add_op(
None,
Arc::new(AddOne {}),
&[Some(prev_output)],
&[Some(next_output)],
);
prev_output = next_output;
}
let results = g
.run(vec![(input_id, input.into())], &[prev_output], None, None)
.unwrap();
let expected = Tensor::from([101., 102., 103., 104., 105.]);
expect_equal(&results[0].as_tensor_view().unwrap(), &expected.view())?;
Ok(())
}
#[test]
fn test_noop_graph() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();
let input = Tensor::from([1., 2., 3., 4., 5.]);
let input_id = g.add_value(Some("input"), None, None);
let results = g
.run(
vec![(input_id, input.view().into())],
&[input_id],
None,
None,
)
.unwrap();
expect_equal(&results[0].as_tensor_view().unwrap(), &input.view())?;
Ok(())
}
#[test]
fn test_constant_graph() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();
let value = Tensor::from([1., 2., 3., 4., 5.]).into_arc();
let const_id = g.add_constant(Some("weight"), value.clone());
let results = g.run(vec![], &[const_id], None, None).unwrap();
expect_equal(&results[0].as_tensor_view().unwrap(), &value.view())?;
Ok(())
}
#[test]
fn test_typed_constant() {
let mut g = Graph::new();
let scalar_id = g.add_constant(None, Tensor::from(42.).into_arc());
let vec_id = g.add_constant(None, Tensor::from([1, 2, 3]).into_arc());
let scalar_node = match g.get_node(scalar_id) {
Some(Node::Constant(c)) => Some(c),
_ => None,
}
.unwrap();
let vec_node = match g.get_node(vec_id) {
Some(Node::Constant(c)) => Some(c),
_ => None,
}
.unwrap();
assert_eq!(scalar_node.as_scalar(), Some(42.0));
assert_ne!(scalar_node.as_scalar(), Some(42));
assert_eq!(vec_node.as_scalar(), None::<i32>);
assert_eq!(vec_node.as_vector(), Some([1, 2, 3].as_slice()));
assert_eq!(vec_node.as_scalar(), None::<f32>);
}
#[test]
fn test_total_params() {
let mut g = Graph::new();
g.add_constant(Some("floats"), Tensor::<f32>::zeros(&[10, 10]).into_arc());
g.add_constant(Some("ints"), Tensor::<i32>::zeros(&[10, 10]).into_arc());
let mut subgraph = Graph::new();
subgraph.add_constant(
Some("sg_floats"),
Tensor::<f32>::zeros(&[10, 10]).into_arc(),
);
g.add_simple_op("Subgraph", Subgraph { graph: subgraph }, &[]);
assert_eq!(g.total_params(), 300);
}
#[test]
fn test_no_outputs() {
let g = Graph::new();
let results = g.run(vec![], &[], None, None).unwrap();
assert_eq!(results.len(), 0);
}
#[test]
fn test_duplicate_inputs() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let input = Tensor::from([1.]);
let err = g
.run(
vec![
(input_id, input.view().into()),
(input_id, input.view().into()),
],
&[input_id],
None,
None,
)
.err()
.unwrap();
assert_eq!(err.kind(), RunErrorKind::PlanningError);
assert_eq!(
err.to_string(),
"planning error: Inputs are not unique. Input \"input\" is duplicated."
);
}
#[test]
fn test_duplicate_outputs() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let (_, op_a_out) = g.add_simple_op("op_a", AddOne {}, &[input_id]);
let input = Tensor::from([1.]);
let err = g
.run(
vec![(input_id, input.into())],
&[op_a_out, op_a_out],
None,
None,
)
.err()
.unwrap();
assert_eq!(err.kind(), RunErrorKind::PlanningError);
assert_eq!(
err.to_string(),
"planning error: Outputs are not unique. Output \"op_a_out\" is duplicated."
);
}
#[test]
fn test_no_source_for_output() {
let mut g = Graph::new();
let output_id = g.add_value(Some("output"), None, None);
let err = g.run(vec![], &[output_id], None, None).err().unwrap();
assert_eq!(err.kind(), RunErrorKind::PlanningError);
assert_eq!(
err.to_string(),
"planning error: Source node not found for output \"output\""
);
}
#[test]
fn test_invalid_input_id() {
let mut g = Graph::new();
let (op_id, op_out) = g.add_simple_op("op", AddOne {}, &[]);
let input = Tensor::from([1.]);
let invalid_id = NodeId::from_u32(1234);
for wrong_input_id in [op_id, invalid_id] {
let err = g
.run(
[(wrong_input_id, input.view().into())].into(),
&[op_out],
None,
None,
)
.err()
.unwrap();
let name = g.node_name(wrong_input_id);
assert_eq!(err.kind(), RunErrorKind::PlanningError);
assert_eq!(
err.to_string(),
format!(
"planning error: Input 0 (\"{}\") is not a value node in the graph.",
name
)
);
}
}
#[test]
fn test_invalid_output_id() {
let mut g = Graph::new();
let input_id = g.add_value(None, None, None);
let (op_id, _op_out) = g.add_simple_op("op", AddOne {}, &[input_id]);
let input = Tensor::from([1.]);
let invalid_id = NodeId::from_u32(1234);
for wrong_output_id in [op_id, invalid_id] {
let err = g
.run(
[(input_id, input.view().into())].into(),
&[wrong_output_id],
None,
None,
)
.err()
.unwrap();
let name = g.node_name(wrong_output_id);
assert_eq!(err.kind(), RunErrorKind::PlanningError);
assert_eq!(
err.to_string(),
format!(
"planning error: Output 0 (\"{}\") is not a value node in the graph.",
name
)
);
}
}
#[test]
fn test_cycle() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let output_id = g.add_value(Some("output"), None, None);
let _op_id = g.add_op(
Some("identity_0"),
Arc::new(Identity {}),
&[Some(output_id)],
&[Some(output_id)],
);
let input = Tensor::from([1.]);
let err = g
.run(
[(input_id, input.view().into())].into(),
&[output_id],
None,
None,
)
.err()
.unwrap();
assert_eq!(err.kind(), RunErrorKind::PlanningError);
assert_eq!(
err.to_string(),
format!(
"planning error: Encountered cycle visiting dependency \"output\" of operator \"identity_0\""
)
);
}
#[test]
fn test_call_op_with_missing_input() {
let mut g = Graph::new();
let output = g.add_value(None, None, None);
g.add_op(
Some("shape"),
Arc::new(Shape::default()),
&[None],
&[Some(output)],
);
let err = g.run(vec![], &[output], None, None).err().unwrap();
assert_eq!(err.kind(), RunErrorKind::OperatorError);
assert_eq!(err.node_path(), [Some("shape")]);
}
#[test]
fn test_err_if_missing_operator_input() {
let mut g = Graph::new();
let (_, output) = g.add_simple_op("op", Relu {}, &[NodeId::from_u32(42)]);
let err = g.run(vec![], &[output], None, None).err().unwrap();
assert_eq!(err.kind(), RunErrorKind::PlanningError);
assert_eq!(
err.to_string(),
"planning error: Missing input \"[ID: 42]\" for op \"op\""
);
}
#[derive(Debug)]
struct AddOneInPlace {}
impl Operator for AddOneInPlace {
fn name(&self) -> &str {
"AddOneInPlace"
}
fn max_inputs(&self) -> Option<usize> {
Some(1)
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
None
}
fn can_run_in_place(&self) -> bool {
true
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let input: TensorView<f32> = ctx.inputs().require_as(0)?;
input.to_tensor().into_op_result()
}
fn run_in_place(&self, input: Value, _ctx: &OpRunContext) -> Result<Value, OpError> {
let mut output = input.into_tensor::<f32>().unwrap();
for x in output.iter_mut() {
*x = *x + 1.0;
}
Ok(output.into())
}
}
#[test]
fn test_runs_op_in_place() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let (_, op1_out) = g.add_simple_op("op1", AddOneInPlace {}, &[input_id]);
let (_, op2_out) = g.add_simple_op("op2", AddOneInPlace {}, &[op1_out]);
let (_, op3_out) = g.add_simple_op("op3", AddOneInPlace {}, &[op2_out]);
let (_, op4_out) = g.add_simple_op("op4", AddOneInPlace {}, &[op2_out]);
let input = Tensor::<f32>::zeros(&[1, 1]);
let results = g
.run(
vec![(input_id, input.view().into())],
&[op1_out],
None,
None,
)
.unwrap();
assert_eq!(results[0].as_tensor_view::<f32>().unwrap()[[0, 0]], 0.0);
let results = g
.run(
vec![(input_id, input.view().into())],
&[op2_out],
None,
None,
)
.unwrap();
assert_eq!(results[0].as_tensor_view::<f32>().unwrap()[[0, 0]], 1.0);
let results = g
.run(
vec![(input_id, input.view().into())],
&[op3_out, op4_out],
None,
None,
)
.unwrap();
assert_eq!(results[0].as_tensor_view::<f32>().unwrap()[[0, 0]], 1.0);
assert_eq!(results[1].as_tensor_view::<f32>().unwrap()[[0, 0]], 2.0);
}
#[test]
fn test_runs_commutative_op_in_place() {
use crate::ops::Add;
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let bias_id = g.add_value(Some("bias"), None, None);
let op1 = TrackUsage::new(Add {});
let op1_metrics = op1.metrics();
let op2 = TrackUsage::new(Add {});
let op2_metrics = op2.metrics();
let (_, op1_out) = g.add_simple_op("op1", op1, &[input_id, bias_id]);
let (_, op2_out) = g.add_simple_op(
"op2",
op2,
&[bias_id, op1_out],
);
let input = Tensor::<f32>::zeros(&[2, 2]);
let bias = Tensor::from(1.5);
let results = g
.run(
vec![(input_id, input.view().into()), (bias_id, bias.into())],
&[op2_out],
None,
None,
)
.unwrap();
assert_eq!(
results[0]
.as_tensor_view::<f32>()
.unwrap()
.iter()
.copied()
.collect::<Vec<_>>(),
&[3., 3., 3., 3.]
);
let op1_metrics = op1_metrics.lock().unwrap();
assert_eq!(op1_metrics.run_count, 1);
assert_eq!(op1_metrics.run_in_place_count, 0);
let op2_metrics = op2_metrics.lock().unwrap();
assert_eq!(op2_metrics.run_count, 0);
assert_eq!(op2_metrics.run_in_place_count, 1);
}
#[derive(Debug)]
struct Split {
run_count: Arc<Mutex<u32>>,
}
impl Split {
fn new() -> Split {
Split {
run_count: Arc::new(Mutex::new(0)),
}
}
}
impl Operator for Split {
fn name(&self) -> &str {
"Split"
}
fn max_inputs(&self) -> Option<usize> {
Some(1)
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
None
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
{
let mut rc = self.run_count.lock().unwrap();
*rc += 1;
}
let input: TensorView<f32> = ctx.inputs().require_as(0)?;
let left_split_len = input.len() / 2;
let left_split = Tensor::from_vec(input.iter().take(left_split_len).copied().collect());
let right_split = Tensor::from_vec(input.iter().skip(left_split_len).copied().collect());
Ok(smallvec![left_split.into(), right_split.into()])
}
}
#[test]
fn test_multiple_outputs() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let left_split_out = g.add_value(Some("left_split"), None, None);
let right_split_out = g.add_value(Some("right_split"), None, None);
let split_op = Arc::new(Split::new());
let run_count = split_op.run_count.clone();
g.add_op(
Some("split"),
split_op,
&[Some(input_id)],
&[left_split_out, right_split_out].map(Some),
);
let input = Tensor::from([1.0, 2.0, 3.0, 4.0, 5.0]);
let mut results = g
.run(
vec![(input_id, input.into())],
&[left_split_out, right_split_out],
None,
None,
)
.unwrap();
assert_eq!(*run_count.lock().unwrap(), 1);
assert_eq!(results.len(), 2);
let left_split = results.remove(0).into_tensor::<f32>().unwrap();
let right_split = results.remove(0).into_tensor::<f32>().unwrap();
assert_eq!(left_split.to_vec(), &[1.0, 2.0]);
assert_eq!(right_split.to_vec(), &[3.0, 4.0, 5.0]);
}
#[test]
fn test_ignore_unused_outputs() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let left_split_out = g.add_value(Some("left_split"), None, None);
let split_op = Split::new();
g.add_op(
Some("split"),
Arc::new(split_op),
&[Some(input_id)],
&[Some(left_split_out)],
);
let input = Tensor::from([1.0, 2.0, 3.0, 4.0, 5.0]);
let mut results = g
.run(
vec![(input_id, input.into())],
&[left_split_out],
None,
None,
)
.unwrap();
assert_eq!(results.len(), 1);
let left_split = results.remove(0).into_tensor::<f32>().unwrap();
assert_eq!(left_split.to_vec(), &[1.0, 2.0]);
}
#[test]
fn test_not_enough_outputs() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let out_1 = g.add_value(Some("out-1"), None, None);
let out_2 = g.add_value(Some("out-2"), None, None);
let out_3 = g.add_value(Some("out-3"), None, None);
let split_op = Split::new();
g.add_op(
Some("split"),
Arc::new(split_op),
&[Some(input_id)],
&[out_1, out_2, out_3].map(Some),
);
let input = Tensor::from([1.0, 2.0, 3.0, 4.0, 5.0]);
let err = g
.run(
vec![(input_id, input.into())],
&[out_1, out_2, out_3],
None,
None,
)
.err()
.unwrap();
assert_eq!(err.kind(), RunErrorKind::OperatorError);
assert_eq!(err.node_path(), [Some("split")]);
assert!(
err.to_string()
.contains("operator returned 2 outputs but expected 3")
);
}
#[test]
fn test_partial_run() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();
let const_0 = g.add_constant(Some("c0"), Tensor::from(3.).into_arc());
let val_0 = g.add_value(Some("i0"), None, None);
let const_1 = g.add_constant(Some("c1"), Tensor::from(4.).into_arc());
let val_1 = g.add_value(Some("i1"), None, None);
let (_, op_0_out) = g.add_simple_op("Add_0", Add {}, &[const_0, val_0]);
let (_, op_1_out) = g.add_simple_op("Add_1", Add {}, &[const_1, val_1]);
let (_, op_2_out) = g.add_simple_op("Add_2", Add {}, &[op_0_out, op_1_out]);
let partial_outs = g.partial_run(vec![], &[op_2_out], None)?;
assert_eq!(partial_outs.len(), 0);
let input = Tensor::from(2.);
let partial_outs = g.partial_run(vec![(val_0, input.view().into())], &[op_2_out], None)?;
assert_eq!(partial_outs.len(), 1);
assert_eq!(partial_outs[0].0, op_0_out);
assert_eq!(partial_outs[0].1, Value::FloatTensor(Tensor::from(5.)));
let input = Tensor::from(2.);
let partial_outs = g.partial_run(vec![(val_1, input.view().into())], &[op_2_out], None)?;
assert_eq!(partial_outs.len(), 1);
assert_eq!(partial_outs[0].0, op_1_out);
assert_eq!(partial_outs[0].1, Value::FloatTensor(Tensor::from(6.)));
let partial_outs = g.partial_run(
vec![(val_1, input.view().into()), (val_0, input.view().into())],
&[op_2_out],
None,
)?;
assert_eq!(partial_outs.len(), 1);
assert_eq!(partial_outs[0].0, op_2_out);
assert_eq!(partial_outs[0].1, Value::FloatTensor(Tensor::from(11.)));
Ok(())
}
#[derive(Debug)]
struct Counter {
count: AtomicI32,
}
impl Operator for Counter {
fn name(&self) -> &str {
"Counter"
}
fn max_inputs(&self) -> Option<usize> {
Some(0)
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
None
}
fn is_deterministic(&self) -> bool {
false
}
fn run(&self, _ctx: &OpRunContext) -> Result<OutputList, OpError> {
let count = self.count.fetch_add(1, Ordering::SeqCst);
Ok([Tensor::from(count).into()].into())
}
}
#[test]
fn test_partial_run_non_deterministic_ops() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();
let const_val = g.add_constant(Some("c0"), Tensor::from(3).into_arc());
let (_, add_op_0_out) = g.add_simple_op("Add_0", Add {}, &[const_val, const_val]);
let (_, count_op_out) = g.add_simple_op(
"Count",
Counter {
count: AtomicI32::new(0),
},
&[],
);
let (_, add_op_1_out) = g.add_simple_op("Add_1", Add {}, &[add_op_0_out, count_op_out]);
let partial_outs = g.partial_run(vec![], &[add_op_1_out], None)?;
assert_eq!(partial_outs.len(), 1);
assert_eq!(partial_outs[0].0, add_op_0_out);
Ok(())
}
#[test]
fn test_cached_plan_matches() {
let input_ids = &[3, 1, 2].map(NodeId::from_u32);
let output_ids = &[6, 4, 5].map(NodeId::from_u32);
let op_ids = &[10, 11, 12].map(NodeId::from_u32);
let plan = CachedPlan::new(input_ids, output_ids, op_ids.to_vec());
assert!(plan.matches(input_ids, output_ids));
assert!(plan.matches(
&[1, 2, 3].map(NodeId::from_u32),
&[4, 5, 6].map(NodeId::from_u32)
));
assert!(plan.matches(
&[3, 2, 1].map(NodeId::from_u32),
&[6, 5, 4].map(NodeId::from_u32)
));
assert!(!plan.matches(&[20, 21, 22].map(NodeId::from_u32), output_ids));
assert!(!plan.matches(input_ids, &[20, 21, 22].map(NodeId::from_u32)));
}
struct Subgraph {
graph: Graph,
}
impl std::fmt::Debug for Subgraph {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "Subgraph {{ ... }}")
}
}
impl Operator for Subgraph {
fn name(&self) -> &str {
"Subgraph"
}
fn max_inputs(&self) -> Option<usize> {
None
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
None
}
fn run(&self, _ctx: &OpRunContext) -> Result<OutputList, OpError> {
Err(OpError::InvalidValue(
"operator must be run with `run_subgraph`",
))
}
fn as_subgraph_op(&self) -> Option<&dyn SubgraphOperator> {
Some(self as &dyn SubgraphOperator)
}
}
impl SubgraphOperator for Subgraph {
fn subgraphs(&self) -> SmallVec<[&Graph; 2]> {
SmallVec::from_slice(&[&self.graph])
}
fn run_subgraph<'a>(
&'a self,
ctx: &OpRunContext,
captures: CaptureEnv,
weight_caches: Option<&[WeightCache]>,
profiler: Option<&mut Profiler<'a>>,
options: Option<RunOptions>,
) -> Result<OutputList, RunError> {
let inputs = self
.graph
.input_ids()
.iter()
.copied()
.zip(ctx.inputs().iter().flatten().map(|i| i.into()))
.collect();
self.graph
.run_subgraph(
inputs,
self.graph.output_ids(),
captures,
ctx.pool(),
weight_caches.map(|wcs| &wcs[0]),
profiler,
options,
)
.map(|xs| xs.into_iter().collect())
}
}
#[test]
fn test_subgraph() {
let mut g = Graph::new();
let input = g.add_value(Some("input"), None, None);
let mut then_branch = Graph::new();
let tb_input = then_branch.add_value(Some("input"), None, None);
let two = then_branch.add_constant(None, Tensor::from(2.).into_arc());
let (_, tb_output) = then_branch.add_simple_op("Mul", Mul {}, &[tb_input, two]);
then_branch.set_captures(&[tb_input]);
then_branch.set_output_ids(&[tb_output]);
let mut else_branch = Graph::new();
let eb_input = else_branch.add_value(Some("input"), None, None);
let three = else_branch.add_constant(None, Tensor::from(3.).into_arc());
let (_, eb_output) = else_branch.add_simple_op("Mul", Mul {}, &[eb_input, three]);
else_branch.set_captures(&[eb_input]);
else_branch.set_output_ids(&[eb_output]);
let cond = g.add_value(Some("cond"), None, None);
let branch = If {
then_branch,
else_branch,
};
let (_, if_out) = g.add_simple_op("If", branch, &[cond]);
let mut result = g
.run(
vec![
(input, Tensor::from(2.).into()),
(cond, Tensor::from(1).into()),
],
&[if_out],
None,
None,
)
.unwrap();
let result: Tensor<f32> = result.remove(0).try_into().unwrap();
assert_eq!(result, Tensor::from(4.));
let mut result = g
.run(
vec![
(input, Tensor::from(2.).into()),
(cond, Tensor::from(0).into()),
],
&[if_out],
None,
None,
)
.unwrap();
let result: Tensor<f32> = result.remove(0).try_into().unwrap();
assert_eq!(result, Tensor::from(6.));
}
#[test]
fn test_nested_subgraph() {
let mut g = Graph::new();
let input = g.add_value(Some("input"), None, None);
let mut subgraph = Graph::new();
let mut nested_subgraph = Graph::new();
let ns_input = nested_subgraph.add_value(Some("input"), None, None);
nested_subgraph.set_captures(&[ns_input]);
nested_subgraph.set_output_ids(&[ns_input]);
let (_, ns_out) = subgraph.add_simple_op(
"Subgraph",
Subgraph {
graph: nested_subgraph,
},
&[],
);
subgraph.set_output_ids(&[ns_out]);
let (_, sg_out) = g.add_simple_op("Subgraph", Subgraph { graph: subgraph }, &[]);
let mut result = g
.run(
vec![(input, Tensor::from(2.).into())],
&[sg_out],
None,
None,
)
.unwrap();
let result: Tensor<f32> = result.remove(0).try_into().unwrap();
assert_eq!(result, Tensor::from(2.));
}
#[test]
fn test_captures_not_available_when_subgraph_is_run_directly() {
let mut subgraph = Graph::new();
let sg_input = subgraph.add_value(Some("input"), None, None);
subgraph.set_captures(&[sg_input]);
let (_, sg_add) = subgraph.add_simple_op("Id", Identity {}, &[sg_input]);
subgraph.set_output_ids(&[sg_add]);
let result = subgraph.partial_run(Vec::new(), &[sg_add], None).unwrap();
assert_eq!(result.len(), 0);
let err = subgraph
.run(Vec::new(), &[sg_add], None, None)
.err()
.unwrap();
assert_eq!(err.kind(), RunErrorKind::PlanningError);
assert_eq!(
err.to_string(),
"planning error: Missing input \"input\" for op \"Id\""
);
}
#[test]
fn test_partial_run_considers_subgraph_captures() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let mut subgraph = Graph::new();
let sg_input = subgraph.add_value(Some("input"), None, None);
subgraph.set_captures(&[sg_input]);
let (_, sg_add) = subgraph.add_simple_op("Id", Identity {}, &[sg_input]);
subgraph.set_output_ids(&[sg_add]);
let (_, out) = g.add_simple_op("Subgraph", Subgraph { graph: subgraph }, &[]);
let result = g.partial_run(Vec::new(), &[out], None).unwrap();
assert_eq!(result.len(), 0);
let result = g
.partial_run([(input_id, Tensor::from(4.).into())].into(), &[out], None)
.unwrap();
assert_eq!(result.len(), 1);
}
#[test]
fn test_plan_considers_capture_dependencies() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let (_, _) = g.add_simple_op("Add", Add {}, &[input_id, input_id]);
let mut subgraph = Graph::new();
let sg_input = subgraph.add_value(Some("Add_out"), None, None);
subgraph.set_captures(&[sg_input]);
let (_, sg_out) = subgraph.add_simple_op("Id", Identity {}, &[sg_input]);
subgraph.set_output_ids(&[sg_out]);
let (_, out) = g.add_simple_op("Subgraph", Subgraph { graph: subgraph }, &[]);
let input = Tensor::from(3.);
let mut result = g
.run(vec![(input_id, input.into())], &[out], None, None)
.unwrap();
let result: Tensor<f32> = result.remove(0).try_into().unwrap();
assert_eq!(result.item(), Some(&6.));
}
#[test]
fn test_plan_considers_transitive_capture_dependencies() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let (_, _) = g.add_simple_op("Add", Add {}, &[input_id, input_id]);
let mut subgraph = Graph::new();
let mut nested_subgraph = Graph::new();
let ns_input = nested_subgraph.add_value(Some("Add_out"), None, None);
nested_subgraph.set_captures(&[ns_input]);
let (_, ns_out) = nested_subgraph.add_simple_op("Id", Identity {}, &[ns_input]);
nested_subgraph.set_output_ids(&[ns_out]);
let (_, sg_out) = subgraph.add_simple_op(
"Subgraph",
Subgraph {
graph: nested_subgraph,
},
&[],
);
subgraph.set_output_ids(&[sg_out]);
let (_, out) = g.add_simple_op("Subgraph", Subgraph { graph: subgraph }, &[]);
let input = Tensor::from(3.);
let mut result = g
.run(vec![(input_id, input.into())], &[out], None, None)
.unwrap();
let result: Tensor<f32> = result.remove(0).try_into().unwrap();
assert_eq!(result.item(), Some(&6.));
}
#[test]
fn test_keeps_temp_value_needed_as_subgraph_capture() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let (_, id_out) = g.add_simple_op("Id", Identity {}, &[input_id]);
let (_, mul_out) = g.add_simple_op("Mul", Mul {}, &[id_out, id_out]);
let mut subgraph = Graph::new();
let sg_input = subgraph.add_value(Some("Id_out"), None, None);
subgraph.set_captures(&[sg_input]);
let (_, sg_out) = subgraph.add_simple_op("Id", Identity {}, &[sg_input]);
subgraph.set_output_ids(&[sg_out]);
let (_, out) = g.add_simple_op("Subgraph", Subgraph { graph: subgraph }, &[mul_out]);
let input = Tensor::from(3.);
let mut result = g
.run(vec![(input_id, input.into())], &[out], None, None)
.unwrap();
let result: Tensor<f32> = result.remove(0).try_into().unwrap();
assert_eq!(result.item(), Some(&3.));
}
#[test]
fn test_captures_by_value_if_possible() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let mut subgraph = Graph::new();
let sg_input = subgraph.add_value(Some("input"), None, None);
subgraph.set_captures(&[sg_input]);
let id_op = TrackUsage::new(Identity {});
let id_op_metrics = id_op.metrics();
let (_, id_out) = subgraph.add_simple_op("Id", id_op, &[sg_input]);
subgraph.set_output_ids(&[id_out]);
let (_, out) = g.add_simple_op("Subgraph", Subgraph { graph: subgraph }, &[]);
let input = Tensor::from(42.);
let mut result = g
.run(vec![(input_id, input.into())], &[out], None, None)
.unwrap();
let result: Tensor<f32> = result.remove(0).try_into().unwrap();
assert_eq!(result.item(), Some(&42.));
{
let id_op_metrics = id_op_metrics.lock().unwrap();
assert_eq!(id_op_metrics.run_count, 0);
assert_eq!(id_op_metrics.run_in_place_count, 1);
}
let input = Tensor::from(42.);
let mut result = g
.run(vec![(input_id, input.view().into())], &[out], None, None)
.unwrap();
let result: Tensor<f32> = result.remove(0).try_into().unwrap();
assert_eq!(result.item(), Some(&42.));
{
let id_op_metrics = id_op_metrics.lock().unwrap();
assert_eq!(id_op_metrics.run_count, 1);
assert_eq!(id_op_metrics.run_in_place_count, 1);
}
}
#[derive(Debug)]
struct MatMulExpectPacked {
inner: MatMul,
}
impl MatMulExpectPacked {
fn new() -> Self {
MatMulExpectPacked { inner: MatMul {} }
}
}
impl Operator for MatMulExpectPacked {
fn name(&self) -> &str {
"MatMulExpectPacked"
}
fn max_inputs(&self) -> Option<usize> {
self.inner.max_inputs()
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
None
}
fn prepack_inputs(&self) -> SmallVec<[usize; 1]> {
[1].into()
}
fn prepack(&self, index: usize, input: ValueView) -> Option<PrepackedInput> {
self.inner.prepack(index, input)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let prepacked = ctx.inputs().get_prepacked(1);
assert!(prepacked.is_some());
self.inner.run(ctx)
}
}
#[test]
fn test_prepack_weights() {
let mut graph = Graph::new();
let mut cache = WeightCache::new();
let input = graph.add_value(Some("input"), None, None);
let weights = graph.add_constant(None, Tensor::<f32>::zeros(&[10, 7]).into_arc());
let (_, matmul_out) =
graph.add_simple_op("MatMul", MatMulExpectPacked::new(), &[input, weights]);
let mut subgraph = Graph::new();
let sg_input = subgraph.add_value(Some("sg-input"), None, None);
let sg_weights = subgraph.add_constant(None, Tensor::<f32>::zeros(&[7, 5]).into_arc());
let (_, sg_matmul_out) = subgraph.add_simple_op(
"sg-MatMul",
MatMulExpectPacked::new(),
&[sg_input, sg_weights],
);
subgraph.set_input_ids(&[sg_input]);
subgraph.set_output_ids(&[sg_matmul_out]);
let (subgraph_op, subgraph_out) =
graph.add_simple_op("Subgraph", Subgraph { graph: subgraph }, &[matmul_out]);
graph.set_input_ids(&[input]);
graph.set_output_ids(&[subgraph_out]);
graph.prepack_weights(&mut cache);
assert_eq!(cache.len(), 2);
assert!(cache.get(weights).is_some());
let sg_cache = cache
.get_subgraph_caches(subgraph_op)
.map(|caches| &caches[0])
.unwrap();
assert!(sg_cache.get(sg_weights).is_some());
let input_value = Tensor::<f32>::zeros(&[3, 10]);
graph
.run(
[(input, input_value.into())].into(),
&[subgraph_out],
Some(&cache),
None,
)
.unwrap();
}
#[test]
fn test_run_context_num_outputs() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let (_, op_out) = g.add_simple_op(
"test_op",
RunFn::new(|ctx| {
assert_eq!(ctx.num_outputs(), Some(1));
Ok(Tensor::from_scalar(0.))
}),
&[input_id],
);
let input = Tensor::from([1, 2, 3]);
g.run(vec![(input_id, input.into())], &[op_out], None, None)
.unwrap();
}
#[test]
fn test_run_context_name() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None, None);
let (_, op_out) = g.add_simple_op(
"test_op",
RunFn::new(|ctx| {
assert_eq!(ctx.name(), Some("test_op"));
Ok(Tensor::from_scalar(0.))
}),
&[input_id],
);
let input = Tensor::from([1, 2, 3]);
g.run(vec![(input_id, input.into())], &[op_out], None, None)
.unwrap();
}
#[test]
fn test_remove_nodes() {
let mut g = Graph::new();
let val_id = g.add_value(Some("value"), None, None);
g.set_input_ids(&[val_id]);
g.set_output_ids(&[val_id]);
assert!(g.get_node(val_id).is_some());
assert!(g.get_node_id("value").is_some());
g.remove_nodes(&[val_id]);
assert!(g.get_node(val_id).is_none());
assert!(g.get_node_id("value").is_none());
assert!(g.input_ids().is_empty());
assert!(g.output_ids().is_empty());
let val_id = g.add_value(Some("value2"), None, None);
let (op_id, out_id) = g.add_simple_op("Mul", Mul {}, &[val_id, val_id]);
assert!(g.get_source_node(out_id).is_some());
g.remove_nodes(&[op_id]);
assert!(g.get_source_node(out_id).is_none());
}