use super::super::types::OptLevel;
use super::super::Session;
use crate::graph::{Attributes, Graph, Node, OpKind};
use crate::tensor::Tensor;
use std::collections::HashMap;
#[test]
fn test_inplace_relu() {
let node = Node {
op: OpKind::Relu,
name: "relu".to_string(),
inputs: vec!["x".to_string()],
outputs: vec!["y".to_string()],
attrs: Attributes::default(),
};
let graph = Graph {
nodes: vec![node],
input_names: vec!["x".to_string()],
output_names: vec!["y".to_string()],
..Default::default()
};
let session = Session::builder()
.with_optimization_level(OptLevel::None)
.build_from_graph(graph, HashMap::new())
.expect("build");
let input = Tensor::new(vec![-3.0, -1.0, 0.0, 1.0, 3.0], vec![5]);
let outputs = session.run_one("x", input).expect("run");
let y = outputs.get("y").expect("y");
assert_eq!(y.data, vec![0.0, 0.0, 0.0, 1.0, 3.0]);
}
#[test]
fn test_inplace_add_same_shape() {
let node = Node {
op: OpKind::Add,
name: "add".to_string(),
inputs: vec!["x".to_string(), "w".to_string()],
outputs: vec!["y".to_string()],
attrs: Attributes::default(),
};
let graph = Graph {
nodes: vec![node],
input_names: vec!["x".to_string()],
output_names: vec!["y".to_string()],
..Default::default()
};
let mut weights = HashMap::new();
weights.insert(
"w".to_string(),
Tensor::new(vec![10.0, 20.0, 30.0], vec![3]),
);
let session = Session::builder()
.with_optimization_level(OptLevel::None)
.build_from_graph(graph, weights)
.expect("build");
let input = Tensor::new(vec![1.0, 2.0, 3.0], vec![3]);
let outputs = session.run_one("x", input).expect("run");
let y = outputs.get("y").expect("y");
assert_eq!(y.data, vec![11.0, 22.0, 33.0]);
}
#[test]
fn test_inplace_fallback_broadcast() {
let node = Node {
op: OpKind::Add,
name: "add".to_string(),
inputs: vec!["x".to_string(), "w".to_string()],
outputs: vec!["y".to_string()],
attrs: Attributes::default(),
};
let graph = Graph {
nodes: vec![node],
input_names: vec!["x".to_string()],
output_names: vec!["y".to_string()],
..Default::default()
};
let mut weights = HashMap::new();
weights.insert(
"w".to_string(),
Tensor::new(vec![10.0, 20.0, 30.0], vec![3]),
);
let session = Session::builder()
.with_optimization_level(OptLevel::None)
.build_from_graph(graph, weights)
.expect("build");
let input = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let outputs = session.run_one("x", input).expect("run");
let y = outputs.get("y").expect("y");
assert_eq!(y.data, vec![11.0, 22.0, 33.0, 14.0, 25.0, 36.0]);
assert_eq!(y.shape, vec![2, 3]);
}
#[test]
fn test_inplace_respects_refcount() {
let node_a = Node {
op: OpKind::Relu,
name: "relu_a".to_string(),
inputs: vec!["input".to_string()],
outputs: vec!["out_a".to_string()],
attrs: Attributes::default(),
};
let node_b = Node {
op: OpKind::Relu,
name: "relu_b".to_string(),
inputs: vec!["input".to_string()],
outputs: vec!["out_b".to_string()],
attrs: Attributes::default(),
};
let graph = Graph {
nodes: vec![node_a, node_b],
input_names: vec!["input".to_string()],
output_names: vec!["out_a".to_string(), "out_b".to_string()],
..Default::default()
};
let session = Session::builder()
.with_optimization_level(OptLevel::None)
.build_from_graph(graph, HashMap::new())
.expect("build");
let input = Tensor::new(vec![-2.0, 3.0, -1.0, 5.0], vec![2, 2]);
let outputs = session.run_one("input", input).expect("run");
let expected = vec![0.0, 3.0, 0.0, 5.0];
let out_a = outputs.get("out_a").expect("out_a");
let out_b = outputs.get("out_b").expect("out_b");
assert_eq!(out_a.data, expected);
assert_eq!(out_b.data, expected);
}
#[test]
fn test_compute_node_depths() {
let node1 = Node {
op: OpKind::Relu,
name: "relu1".to_string(),
inputs: vec!["input".to_string()],
outputs: vec!["mid".to_string()],
attrs: Attributes::default(),
};
let node2 = Node {
op: OpKind::Relu,
name: "relu2".to_string(),
inputs: vec!["mid".to_string()],
outputs: vec!["output".to_string()],
attrs: Attributes::default(),
};
let nodes = vec![node1, node2];
let weights = HashMap::new();
let depths = Session::compute_node_depths(&nodes, &weights);
assert_eq!(depths, vec![0, 1]);
}
#[test]
fn test_compute_node_depths_parallel_branches() {
let node_a = Node {
op: OpKind::Relu,
name: "relu_a".to_string(),
inputs: vec!["input".to_string()],
outputs: vec!["out_a".to_string()],
attrs: Attributes::default(),
};
let node_b = Node {
op: OpKind::Relu,
name: "relu_b".to_string(),
inputs: vec!["input".to_string()],
outputs: vec!["out_b".to_string()],
attrs: Attributes::default(),
};
let nodes = vec![node_a, node_b];
let weights = HashMap::new();
let depths = Session::compute_node_depths(&nodes, &weights);
assert_eq!(depths, vec![0, 0]);
}
#[test]
fn test_group_by_depth() {
let depths = vec![0, 0, 1, 2, 1];
let groups = Session::group_by_depth(&depths);
assert_eq!(groups.len(), 3);
assert_eq!(groups[0], vec![0, 1]);
assert_eq!(groups[1], vec![2, 4]);
assert_eq!(groups[2], vec![3]);
}