use std::collections::{HashMap, HashSet};
use std::error::Error;
use std::fmt;
use std::iter::zip;
use rten_tensor::prelude::*;
use rten_tensor::Tensor;
use crate::ops::{Input, InputList, OpError, Operator, Output};
use crate::timer::Timer;
use crate::timing::{InputShape, RunTiming, TimingRecord, TimingSort};
#[derive(Clone, Debug, PartialEq)]
pub enum Dimension {
Fixed(usize),
Symbolic(String),
}
pub struct OperatorNode {
name: Option<String>,
inputs: Vec<Option<NodeId>>,
outputs: Vec<Option<NodeId>>,
operator: Box<dyn Operator + Sync>,
}
pub struct ValueNode {
name: Option<String>,
shape: Option<Vec<Dimension>>,
}
pub struct ConstantNode<T> {
name: Option<String>,
data: Tensor<T>,
}
pub enum Constant {
Float(ConstantNode<f32>),
Int(ConstantNode<i32>),
}
impl Constant {
fn len(&self) -> usize {
match self {
Constant::Float(f) => f.data.len(),
Constant::Int(i) => i.data.len(),
}
}
}
impl From<ConstantNode<f32>> for Constant {
fn from(node: ConstantNode<f32>) -> Constant {
Constant::Float(node)
}
}
impl From<ConstantNode<i32>> for Constant {
fn from(node: ConstantNode<i32>) -> Constant {
Constant::Int(node)
}
}
pub enum Node {
Operator(OperatorNode),
Constant(Constant),
Value(ValueNode),
}
impl Node {
pub fn name(&self) -> Option<&str> {
let maybe_name = match self {
Node::Operator(node) => &node.name,
Node::Constant(constant) => match constant {
Constant::Float(node) => &node.name,
Constant::Int(node) => &node.name,
},
Node::Value(node) => &node.name,
};
maybe_name.as_ref().map(|s| s.as_str())
}
pub fn shape(&self) -> Option<Vec<Dimension>> {
let dims_from_fixed_shape =
|shape: &[usize]| shape.iter().copied().map(Dimension::Fixed).collect();
match self {
Node::Operator(_) => None,
Node::Constant(constant) => match constant {
Constant::Float(node) => Some(dims_from_fixed_shape(node.data.shape())),
Constant::Int(node) => Some(dims_from_fixed_shape(node.data.shape())),
},
Node::Value(node) => node.shape.clone(),
}
}
}
pub type NodeId = usize;
pub struct Graph {
nodes: Vec<Node>,
}
#[derive(Eq, PartialEq, Debug)]
pub enum RunError {
InvalidNodeId,
InvalidNodeName(String),
PlanningError(String),
OperatorError { name: String, error: OpError },
OutputMismatch(&'static str),
}
impl fmt::Display for RunError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RunError::InvalidNodeId => write!(f, "node ID is invalid"),
RunError::InvalidNodeName(ref name) => write!(f, "no node found with name {}", name),
RunError::PlanningError(ref err) => write!(f, "planning error {:?}", err),
RunError::OperatorError {
name,
error: ref err,
} => write!(f, "operator \"{}\" failed: {:?}", name, err),
RunError::OutputMismatch(err) => write!(f, "output mismatch {:?}", err),
}
}
}
fn all_unique<T, F: Fn(&T, &T) -> bool>(xs: &[T], eq: F) -> bool {
xs.iter()
.all(|x| xs.iter().filter(|y| eq(x, y)).count() == 1)
}
struct NodeRefCount {
rc: HashMap<NodeId, usize>,
}
impl NodeRefCount {
fn new() -> NodeRefCount {
NodeRefCount { rc: HashMap::new() }
}
fn inc(&mut self, id: NodeId) {
self.rc
.entry(id)
.and_modify(|count| *count += 1)
.or_insert(1);
}
fn dec(&mut self, id: NodeId) -> usize {
let Some(rc) = self.rc.get_mut(&id) else {
return 0;
};
*rc = rc.saturating_sub(1);
if *rc == 0 {
self.rc.remove(&id);
0
} else {
*rc
}
}
fn count(&self, id: NodeId) -> usize {
*self.rc.get(&id).unwrap_or(&0)
}
}
impl Error for RunError {}
#[derive(Default)]
pub struct RunOptions {
pub timing: bool,
pub timing_sort: TimingSort,
pub timing_by_shape: bool,
pub verbose: bool,
}
impl Graph {
pub fn new() -> Graph {
Graph { nodes: Vec::new() }
}
pub fn add_op(
&mut self,
name: Option<&str>,
op: Box<dyn Operator + Sync>,
inputs: &[Option<NodeId>],
outputs: &[Option<NodeId>],
) -> NodeId {
self.nodes.push(Node::Operator(OperatorNode {
name: name.map(|s| s.to_owned()),
inputs: Vec::from(inputs),
outputs: Vec::from(outputs),
operator: op,
}));
self.nodes.len() - 1
}
pub fn add_constant<T>(&mut self, name: Option<&str>, value: Tensor<T>) -> NodeId
where
ConstantNode<T>: Into<Constant>,
{
let node = ConstantNode {
name: name.map(|s| s.to_owned()),
data: value,
};
self.nodes.push(Node::Constant(node.into()));
self.nodes.len() - 1
}
pub fn add_value(&mut self, name: Option<&str>, shape: Option<Vec<Dimension>>) -> NodeId {
self.nodes.push(Node::Value(ValueNode {
name: name.map(|s| s.to_owned()),
shape,
}));
self.nodes.len() - 1
}
pub fn node_name(&self, id: NodeId) -> String {
self.get_node(id)
.and_then(|node| node.name())
.map(|s| s.to_string())
.unwrap_or_else(|| format!("[ID: {}]", id))
}
pub fn get_node(&self, id: NodeId) -> Option<&Node> {
self.nodes.get(id)
}
pub fn total_params(&self) -> usize {
self.nodes
.iter()
.map(|node| match node {
Node::Operator(_) => 0,
Node::Value(_) => 0,
Node::Constant(constant) => constant.len(),
})
.sum()
}
pub fn run(
&self,
inputs: &[(NodeId, Input)],
outputs: &[NodeId],
opts: Option<RunOptions>,
) -> Result<Vec<Output>, RunError> {
let plan = self.create_plan(inputs, outputs)?;
let opts = opts.unwrap_or_default();
let mut run_timer = Timer::new();
if opts.timing {
run_timer.start();
}
let mut values: HashMap<NodeId, Input> = inputs.iter().cloned().collect();
for (node_id, node) in self.nodes.iter().enumerate() {
if let Node::Constant(constant) = node {
let input = match constant {
Constant::Float(node) => Input::FloatTensor(node.data.view()),
Constant::Int(node) => Input::IntTensor(node.data.view()),
};
values.insert(node_id, input);
}
}
let mut temp_value_refcount = NodeRefCount::new();
for (_, op_node) in plan.iter() {
for node_id in op_node.inputs.iter().filter_map(|node| *node) {
temp_value_refcount.inc(node_id);
}
}
for node_id in outputs {
temp_value_refcount.inc(*node_id);
}
let mut temp_values: HashMap<NodeId, Output> = HashMap::new();
let mut op_elapsed: Vec<TimingRecord> = Vec::new();
let record_timing = opts.timing || opts.verbose;
let mut alloc_timer = Timer::new();
for (step, (op_node_id, op_node)) in plan.iter().enumerate() {
let mut op_timer = Timer::new();
if record_timing {
op_timer.start();
}
let in_place_input_id = if op_node.operator.can_run_in_place() {
if op_node.operator.is_commutative() {
op_node
.inputs
.iter()
.max_by_key(|input_id| {
input_id
.and_then(|id| temp_values.get(&id))
.map(|val| val.len())
.unwrap_or(0)
})
.copied()
.flatten()
} else {
op_node.inputs.first().copied().flatten()
}
} else {
None
};
let in_place_input = in_place_input_id.and_then(|first_input| {
if temp_values.contains_key(&first_input)
&& temp_value_refcount.count(first_input) == 1
{
temp_value_refcount.dec(first_input);
Some(temp_values.remove(&first_input).unwrap())
} else {
None
}
});
let mut op_inputs: Vec<Option<Input>> = Vec::new();
for node_id in op_node.inputs.iter() {
if in_place_input.is_some() && *node_id == in_place_input_id {
continue;
}
if let Some(node_id) = node_id {
if let Some(value) = values.get(node_id) {
op_inputs.push(Some(value.clone()));
} else if let Some(value) = temp_values.get(node_id) {
let input = match value {
Output::IntTensor(t) => Input::IntTensor(t.view()),
Output::FloatTensor(t) => Input::FloatTensor(t.view()),
};
op_inputs.push(Some(input));
} else {
panic!(
"Invalid plan did not produce input value {} for operator {}",
self.node_name(*node_id),
self.node_name(*op_node_id),
);
}
} else {
op_inputs.push(None);
}
}
let input_shapes = if opts.timing_by_shape || opts.verbose {
let mut shapes: Vec<InputShape> = Vec::new();
if let Some(ref input) = in_place_input {
shapes.push(Some(input.shape().into()));
}
for input in &op_inputs {
shapes.push(input.as_ref().map(|i| i.shape().into()))
}
shapes
} else {
Vec::new()
};
let op_result = if let Some(input) = in_place_input {
op_node
.operator
.run_in_place(input, InputList::from_optional(&op_inputs))
.map(|out| [out].into())
} else {
op_node
.operator
.run(InputList::from_optional(&op_inputs[..]))
};
if record_timing {
op_timer.end();
op_elapsed.push(TimingRecord {
name: op_node.operator.name().to_string(),
input_shapes: input_shapes.clone(),
elapsed_micros: op_timer.elapsed_micros(),
});
}
if opts.verbose {
println!(
"#{} {} ({})",
step,
op_node.operator.name(),
op_node.name.as_ref().unwrap_or(&String::new())
);
for (index, (id, shape)) in
zip(op_node.inputs.iter(), input_shapes.iter()).enumerate()
{
if let (Some(id), Some(shape)) = (id, shape) {
let name = self.node_name(*id);
println!(" input {}: {} ({:?})", index, name, shape);
}
}
if let Ok(outputs) = op_result.as_ref() {
for (index, (id, output)) in
zip(op_node.outputs.iter(), outputs.iter()).enumerate()
{
let name = id.map(|id| self.node_name(id)).unwrap_or(String::new());
println!(" output {}: {} ({:?})", index, name, output.shape());
}
}
println!(" time: {}ms", op_timer.elapsed_ms());
}
let outputs = match op_result {
Ok(outputs) => outputs,
Err(op_error) => {
let err = RunError::OperatorError {
name: op_node.name.as_deref().unwrap_or("").to_string(),
error: op_error,
};
return Err(err);
}
};
if op_node.outputs.len() != outputs.len() {
return Err(RunError::OutputMismatch(
"operator output count did not match expected count",
));
}
for (&output_id, output) in zip(op_node.outputs.iter(), outputs.into_iter()) {
if let Some(output_id) = output_id {
temp_values.insert(output_id, output);
}
}
record_timing.then(|| alloc_timer.start());
for node_id in op_node.inputs.iter().filter_map(|node| *node) {
let rc = temp_value_refcount.dec(node_id);
if rc == 0 {
temp_values.remove(&node_id);
}
}
record_timing.then(|| alloc_timer.end());
}
if opts.timing {
run_timer.end();
println!(
"Graph run of {} ops finished in {}ms",
plan.len(),
run_timer.elapsed_ms()
);
let timing = RunTiming {
records: &op_elapsed,
alloc_time: alloc_timer.elapsed_ms(),
total_time: run_timer.elapsed_ms(),
};
print!("{}", timing.display(opts.timing_sort, opts.timing_by_shape));
}
let result = outputs
.iter()
.map(|output_id| {
if let Some(value) = values.get(output_id) {
match value {
Input::IntTensor(t) => Output::IntTensor(t.to_tensor()),
Input::FloatTensor(t) => Output::FloatTensor(t.to_tensor()),
}
} else {
temp_values.remove(output_id).expect("missing output value")
}
})
.collect();
Ok(result)
}
fn create_plan(
&self,
inputs: &[(NodeId, Input)],
outputs: &[NodeId],
) -> Result<Vec<(NodeId, &OperatorNode)>, RunError> {
if !all_unique(outputs, |x, y| x == y) {
return Err(RunError::PlanningError("output IDs are not unique".into()));
}
if !all_unique(inputs, |(x_id, _), (y_id, _)| x_id == y_id) {
return Err(RunError::PlanningError("input IDs are not unique".into()));
}
let mut operator_nodes = HashMap::new();
for (node_id, node) in self.nodes.iter().enumerate() {
if let Node::Operator(op_node) = node {
for output_id in op_node.outputs.iter().filter_map(|node| *node) {
operator_nodes.insert(output_id, (node_id, op_node));
}
}
}
let mut resolved_values: HashSet<NodeId> =
inputs.iter().map(|(node_id, _)| *node_id).collect();
for (node_id, node) in self.nodes.iter().enumerate() {
if let Node::Constant(_) = node {
resolved_values.insert(node_id);
}
}
struct PlanBuilder<'a> {
graph: &'a Graph,
resolved_values: HashSet<NodeId>,
plan: Vec<(NodeId, &'a OperatorNode)>,
operator_nodes: HashMap<NodeId, (NodeId, &'a OperatorNode)>,
}
impl<'a> PlanBuilder<'a> {
fn visit(
&mut self,
op_node_id: NodeId,
op_node: &'a OperatorNode,
) -> Result<(), RunError> {
for input in op_node.inputs.iter().filter_map(|node| *node) {
if self.resolved_values.contains(&input) {
continue;
}
if let Some((input_op_id, input_op_node)) =
self.operator_nodes.get(&input).copied()
{
self.visit(input_op_id, input_op_node)?;
} else {
let msg = format!(
"Missing input \"{}\" for op \"{}\"",
self.graph.node_name(input),
self.graph.node_name(op_node_id)
);
return Err(RunError::PlanningError(msg));
}
}
for output_id in op_node.outputs.iter().filter_map(|node| *node) {
self.resolved_values.insert(output_id);
}
self.plan.push((op_node_id, op_node));
Ok(())
}
fn plan(
mut self,
outputs: &[NodeId],
) -> Result<Vec<(NodeId, &'a OperatorNode)>, RunError> {
for output_id in outputs.iter() {
if self.resolved_values.contains(output_id) {
continue;
}
if let Some((op_node_id, op_node)) = self.operator_nodes.get(output_id).copied()
{
self.visit(op_node_id, op_node)?;
} else {
let msg = format!("Missing output {}", output_id);
return Err(RunError::PlanningError(msg));
}
}
Ok(self.plan)
}
}
let builder = PlanBuilder {
graph: self,
resolved_values,
plan: Vec::new(),
operator_nodes,
};
builder.plan(outputs)
}
}
#[cfg(test)]
mod tests {
use std::error::Error;
use std::sync::{Arc, Mutex};
use rten_tensor::prelude::*;
use rten_tensor::test_util::{expect_equal, expect_equal_with_tolerance};
use rten_tensor::{tensor, Tensor, TensorView};
use crate::graph::{Dimension, Graph, RunError};
use crate::ops::{
Concat, Conv, InputList, IntoOpResult, OpError, Operator, Output, Relu, Shape,
};
#[derive(Clone, Debug, Default)]
struct Metrics {
run_count: u32,
run_in_place_count: u32,
}
#[derive(Debug)]
struct TrackUsage<Op: Operator> {
inner: Op,
metrics: Arc<Mutex<Metrics>>,
}
impl<Op: Operator> TrackUsage<Op> {
fn new(inner: Op) -> Self {
TrackUsage {
inner,
metrics: Default::default(),
}
}
fn metrics(&self) -> Arc<Mutex<Metrics>> {
self.metrics.clone()
}
}
impl<Op: Operator> Operator for TrackUsage<Op> {
fn name(&self) -> &str {
self.inner.name()
}
fn can_run_in_place(&self) -> bool {
self.inner.can_run_in_place()
}
fn is_commutative(&self) -> bool {
self.inner.is_commutative()
}
fn run(&self, inputs: InputList) -> Result<Vec<Output>, OpError> {
{
let mut m = self.metrics.lock().unwrap();
m.run_count += 1;
}
self.inner.run(inputs)
}
fn run_in_place(&self, output: Output, inputs: InputList) -> Result<Output, OpError> {
{
let mut m = self.metrics.lock().unwrap();
m.run_in_place_count += 1;
}
self.inner.run_in_place(output, inputs)
}
}
#[test]
fn test_graph_run() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();
let weights = Tensor::from_data(
&[1, 1, 3, 3],
vec![
0.3230, 0.7632, 0.4616, 0.8837, 0.5898, 0.3424, 0.2101, 0.7821, 0.6861,
],
);
let weights_id = g.add_constant(Some("weight"), weights);
let input_id = g.add_value(Some("input"), None);
let conv_out = g.add_value(Some("conv_out"), None);
g.add_op(
Some("conv"),
Box::new(Conv {
dilations: vec![1, 1],
groups: 1,
padding: [1, 1, 1, 1].into(),
strides: vec![1, 1],
}),
&[input_id, weights_id].map(Some),
&[conv_out].map(Some),
);
let relu_out = g.add_value(Some("relu_out"), None);
g.add_op(
Some("relu"),
Box::new(Relu {}),
&[conv_out].map(Some),
&[relu_out].map(Some),
);
let input = Tensor::from_data(
&[1, 1, 3, 3],
vec![
0.5946, 0.8249, 0.0448, 0.9552, 0.2041, 0.2501, 0.2693, 0.1007, 0.8862,
],
);
let results = g
.run(&[(input_id, (&input).into())], &[relu_out], None)
.unwrap();
let expected = Tensor::from_data(
&[1, 1, 3, 3],
vec![
1.5202, 1.5592, 0.9939, 1.7475, 2.6358, 1.3428, 1.0165, 1.1806, 0.8685,
],
);
assert_eq!(results.len(), 1);
expect_equal_with_tolerance(results[0].as_float_ref().unwrap(), &expected, 1e-4, 0.)?;
Ok(())
}
#[test]
fn test_graph_node_debug_names() {
let mut g = Graph::new();
let weights = Tensor::from_data(&[1], vec![0.3230]);
let weights_id = g.add_constant(Some("weights"), weights.clone());
let input_id = g.add_value(Some("input"), None);
let relu_out_id = g.add_value(Some("relu_out"), None);
let relu_op_id = g.add_op(
Some("relu"),
Box::new(Relu {}),
&[Some(input_id)],
&[Some(relu_out_id)],
);
assert_eq!(g.node_name(weights_id), "weights");
assert_eq!(g.node_name(input_id), "input");
assert_eq!(g.node_name(relu_op_id), "relu");
let anon_weights_id = g.add_constant(None, weights);
let anon_input_id = g.add_value(None, None);
let anon_out_id = g.add_value(None, None);
let anon_op_id = g.add_op(
None,
Box::new(Relu {}),
&[Some(input_id)],
&[Some(anon_out_id)],
);
assert_eq!(
g.node_name(anon_weights_id),
format!("[ID: {}]", anon_weights_id)
);
assert_eq!(
g.node_name(anon_input_id),
format!("[ID: {}]", anon_input_id)
);
assert_eq!(g.node_name(anon_op_id), format!("[ID: {}]", anon_op_id));
}
#[test]
fn test_graph_node_shapes() {
let mut g = Graph::new();
let weights = Tensor::from_data(&[1, 1, 2], vec![0.3230, 0.5]);
let weights_id = g.add_constant(Some("weights"), weights.clone());
let input_id = g.add_value(
Some("input"),
Some(
[
Dimension::Symbolic("batch".to_string()),
Dimension::Fixed(3),
Dimension::Fixed(5),
Dimension::Fixed(5),
]
.to_vec(),
),
);
let relu_out_id = g.add_value(Some("relu_out"), None);
let relu_op_id = g.add_op(
Some("relu"),
Box::new(Relu {}),
&[Some(input_id)],
&[Some(relu_out_id)],
);
assert_eq!(
g.get_node(weights_id).and_then(|n| n.shape()),
Some([1, 1, 2].map(Dimension::Fixed).to_vec())
);
assert_eq!(
g.get_node(input_id).and_then(|n| n.shape()),
Some(
[
Dimension::Symbolic("batch".to_string()),
Dimension::Fixed(3),
Dimension::Fixed(5),
Dimension::Fixed(5),
]
.to_vec()
)
);
assert_eq!(g.get_node(relu_op_id).and_then(|n| n.shape()), None);
}
#[derive(Debug)]
struct AddOne {}
impl Operator for AddOne {
fn name(&self) -> &str {
"AddOne"
}
fn run(&self, inputs: InputList) -> Result<Vec<Output>, OpError> {
let input: TensorView<f32> = inputs.require_as(0)?;
let output_data: Vec<f32> = input.iter().map(|x| x + 1.0).collect();
Tensor::<f32>::from_data(input.shape().into(), output_data).into_op_result()
}
}
#[test]
fn test_graph_planning_order() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None);
let op_a_out = g.add_value(Some("op_a_out"), None);
g.add_op(
Some("op_a"),
Box::new(AddOne {}),
&[Some(input_id)],
&[Some(op_a_out)],
);
let op_b_out = g.add_value(Some("op_b_out"), None);
g.add_op(
Some("op_b"),
Box::new(AddOne {}),
&[Some(op_a_out)],
&[Some(op_b_out)],
);
let op_c_out = g.add_value(Some("op_c_out"), None);
g.add_op(
Some("op_c"),
Box::new(Concat { axis: 0 }),
&[op_a_out, op_b_out].map(Some),
&[Some(op_c_out)],
);
let op_d_out = g.add_value(Some("op_d_out"), None);
g.add_op(
Some("op_d"),
Box::new(Concat { axis: 0 }),
&[op_b_out, op_a_out].map(Some),
&[Some(op_d_out)],
);
let input = Tensor::from_data(&[1], vec![1.]);
let results = g
.run(&[(input_id, (&input).into())], &[op_c_out], None)
.unwrap();
let expected = Tensor::from_data(&[2], vec![2., 3.]);
expect_equal(results[0].as_float_ref().unwrap(), &expected)?;
let results = g
.run(&[(input_id, (&input).into())], &[op_d_out], None)
.unwrap();
let expected = Tensor::from_data(&[2], vec![3., 2.]);
expect_equal(results[0].as_float_ref().unwrap(), &expected)?;
Ok(())
}
#[test]
fn test_graph_intermediate_output() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None);
let op_a_out = g.add_value(Some("op_a_out"), None);
g.add_op(
Some("op_a"),
Box::new(AddOne {}),
&[Some(input_id)],
&[Some(op_a_out)],
);
let op_b_out = g.add_value(Some("op_b_out"), None);
g.add_op(
Some("op_b"),
Box::new(AddOne {}),
&[Some(op_a_out)],
&[Some(op_b_out)],
);
let input = tensor!(0.);
let results = g
.run(&[(input_id, (&input).into())], &[op_a_out, op_b_out], None)
.unwrap();
assert_eq!(results[0].as_float_ref().unwrap(), &tensor!(1.));
assert_eq!(results[1].as_float_ref().unwrap(), &tensor!(2.));
}
#[test]
fn test_graph_many_steps() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();
let input = Tensor::from_data(&[5], vec![1., 2., 3., 4., 5.]);
let input_id = g.add_value(Some("input"), None);
let mut prev_output = input_id;
for _ in 0..100 {
let next_output = g.add_value(None, None);
g.add_op(
None,
Box::new(AddOne {}),
&[Some(prev_output)],
&[Some(next_output)],
);
prev_output = next_output;
}
let results = g
.run(&[(input_id, (&input).into())], &[prev_output], None)
.unwrap();
let expected = Tensor::from_data(&[5], vec![101., 102., 103., 104., 105.]);
expect_equal(results[0].as_float_ref().unwrap(), &expected)?;
Ok(())
}
#[test]
fn test_noop_graph() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();
let input = Tensor::from_data(&[5], vec![1., 2., 3., 4., 5.]);
let input_id = g.add_value(Some("input"), None);
let results = g
.run(&[(input_id, (&input).into())], &[input_id], None)
.unwrap();
expect_equal(results[0].as_float_ref().unwrap(), &input)?;
Ok(())
}
#[test]
fn test_constant_graph() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();
let value = Tensor::from_data(&[5], vec![1., 2., 3., 4., 5.]);
let const_id = g.add_constant(Some("weight"), value.clone());
let results = g.run(&[], &[const_id], None).unwrap();
expect_equal(results[0].as_float_ref().unwrap(), &value)?;
Ok(())
}
#[test]
fn test_total_params() {
let mut g = Graph::new();
g.add_constant(Some("floats"), Tensor::<f32>::zeros(&[10, 10]));
g.add_constant(Some("ints"), Tensor::<i32>::zeros(&[10, 10]));
assert_eq!(g.total_params(), 200);
}
#[test]
fn test_no_outputs() {
let g = Graph::new();
let results = g.run(&[], &[], None).unwrap();
assert_eq!(results.len(), 0);
}
#[test]
fn test_duplicate_inputs() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None);
let input = tensor!([1.]);
let result = g.run(
&[(input_id, (&input).into()), (input_id, (&input).into())],
&[input_id],
None,
);
assert_eq!(
result,
Err(RunError::PlanningError("input IDs are not unique".into()))
);
}
#[test]
fn test_duplicate_outputs() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None);
let op_a_out = g.add_value(Some("op_a_out"), None);
g.add_op(
Some("op_a"),
Box::new(AddOne {}),
&[Some(input_id)],
&[Some(op_a_out)],
);
let input = tensor!([1.]);
let result = g.run(&[(input_id, (&input).into())], &[op_a_out, op_a_out], None);
assert_eq!(
result,
Err(RunError::PlanningError("output IDs are not unique".into()))
);
}
#[test]
fn test_call_op_with_missing_input() {
let mut g = Graph::new();
let output = g.add_value(None, None);
g.add_op(Some("shape"), Box::new(Shape {}), &[None], &[Some(output)]);
let results = g.run(&[], &[output], None);
assert_eq!(
results.err(),
Some(RunError::OperatorError {
name: "shape".to_string(),
error: OpError::MissingInputs
})
);
}
#[test]
fn test_err_if_invalid_output() {
let g = Graph::new();
let result = g.run(&[], &[123], None);
assert_eq!(
result.err(),
Some(RunError::PlanningError("Missing output 123".to_string()))
);
}
#[test]
fn test_err_if_missing_operator_input() {
let mut g = Graph::new();
let output = g.add_value(None, None);
g.add_op(Some("op"), Box::new(Relu {}), &[Some(42)], &[Some(output)]);
let result = g.run(&[], &[output], None);
assert_eq!(
result.err(),
Some(RunError::PlanningError(
"Missing input \"[ID: 42]\" for op \"op\"".to_string()
))
);
}
#[derive(Debug)]
struct AddOneInPlace {}
impl Operator for AddOneInPlace {
fn name(&self) -> &str {
"AddOneInPlace"
}
fn can_run_in_place(&self) -> bool {
true
}
fn run(&self, inputs: InputList) -> Result<Vec<Output>, OpError> {
let input: TensorView<f32> = inputs.require_as(0)?;
input.to_tensor().into_op_result()
}
fn run_in_place(&self, input: Output, _other: InputList) -> Result<Output, OpError> {
let mut output = input.into_float().unwrap();
for x in output.iter_mut() {
*x = *x + 1.0;
}
Ok(output.into())
}
}
#[test]
fn test_runs_op_in_place() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None);
let op1_out = g.add_value(Some("op1_out"), None);
g.add_op(
Some("op1"),
Box::new(AddOneInPlace {}),
&[Some(input_id)],
&[Some(op1_out)],
);
let op2_out = g.add_value(Some("op2_out"), None);
g.add_op(
Some("op2"),
Box::new(AddOneInPlace {}),
&[Some(op1_out)],
&[Some(op2_out)],
);
let op3_out = g.add_value(Some("op3_out"), None);
g.add_op(
Some("op3"),
Box::new(AddOneInPlace {}),
&[Some(op2_out)],
&[Some(op3_out)],
);
let op4_out = g.add_value(Some("op4_out"), None);
g.add_op(
Some("op4"),
Box::new(AddOneInPlace {}),
&[Some(op2_out)],
&[Some(op4_out)],
);
let input = Tensor::<f32>::zeros(&[1, 1]);
let results = g
.run(&[(input_id, (&input).into())], &[op1_out], None)
.unwrap();
assert_eq!(results[0].as_float_ref().unwrap()[[0, 0]], 0.0);
let results = g
.run(&[(input_id, (&input).into())], &[op2_out], None)
.unwrap();
assert_eq!(results[0].as_float_ref().unwrap()[[0, 0]], 1.0);
let results = g
.run(&[(input_id, (&input).into())], &[op3_out, op4_out], None)
.unwrap();
assert_eq!(results[0].as_float_ref().unwrap()[[0, 0]], 1.0);
assert_eq!(results[1].as_float_ref().unwrap()[[0, 0]], 2.0);
}
#[test]
fn test_runs_commutative_op_in_place() {
use crate::ops::Add; let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None);
let bias_id = g.add_value(Some("bias"), None);
let op1 = TrackUsage::new(Add {});
let op1_metrics = op1.metrics();
let op2 = TrackUsage::new(Add {});
let op2_metrics = op2.metrics();
let op1_out = g.add_value(Some("op1_out"), None);
g.add_op(
Some("op1"),
Box::new(op1),
&[Some(input_id), Some(bias_id)],
&[Some(op1_out)],
);
let op2_out = g.add_value(Some("op2_out"), None);
g.add_op(
Some("op2"),
Box::new(op2),
&[Some(bias_id), Some(op1_out)],
&[Some(op2_out)],
);
let input = Tensor::<f32>::zeros(&[2, 2]);
let bias = tensor!(1.5);
let results = g
.run(
&[(input_id, (&input).into()), (bias_id, (&bias).into())],
&[op2_out],
None,
)
.unwrap();
assert_eq!(
results[0]
.as_float_ref()
.unwrap()
.iter()
.copied()
.collect::<Vec<_>>(),
&[3., 3., 3., 3.]
);
let op1_metrics = op1_metrics.lock().unwrap();
assert_eq!(op1_metrics.run_count, 1);
assert_eq!(op1_metrics.run_in_place_count, 0);
let op2_metrics = op2_metrics.lock().unwrap();
assert_eq!(op2_metrics.run_count, 0);
assert_eq!(op2_metrics.run_in_place_count, 1);
}
#[derive(Debug)]
struct Split {
run_count: Arc<Mutex<u32>>,
}
impl Split {
fn new() -> Split {
Split {
run_count: Arc::new(Mutex::new(0)),
}
}
}
impl Operator for Split {
fn name(&self) -> &str {
"Split"
}
fn run(&self, inputs: InputList) -> Result<Vec<Output>, OpError> {
{
let mut rc = self.run_count.lock().unwrap();
*rc += 1;
}
let input: TensorView<f32> = inputs.require_as(0)?;
let left_split_len = input.len() / 2;
let left_split = Tensor::from_vec(input.iter().take(left_split_len).copied().collect());
let right_split =
Tensor::from_vec(input.iter().skip(left_split_len).copied().collect());
Ok([left_split.into(), right_split.into()].into())
}
}
#[test]
fn test_multiple_outputs() {
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None);
let left_split_out = g.add_value(Some("left_split"), None);
let right_split_out = g.add_value(Some("right_split"), None);
let split_op = Box::new(Split::new());
let run_count = split_op.run_count.clone();
g.add_op(
Some("split"),
split_op,
&[Some(input_id)],
&[left_split_out, right_split_out].map(Some),
);
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let mut results = g
.run(
&[(input_id, (&input).into())],
&[left_split_out, right_split_out],
None,
)
.unwrap();
assert_eq!(*run_count.lock().unwrap(), 1);
assert_eq!(results.len(), 2);
let left_split = results.remove(0).into_float().unwrap();
let right_split = results.remove(0).into_float().unwrap();
assert_eq!(left_split.to_vec(), &[1.0, 2.0]);
assert_eq!(right_split.to_vec(), &[3.0, 4.0, 5.0]);
}
}