use std::collections::{HashMap, VecDeque};
use super::ir::*;
use super::ops::OpRegistry;
pub struct GraphExecutor {
graph: Graph,
execution_order: Vec<usize>,
}
impl GraphExecutor {
pub fn new(graph: Graph) -> OnnxResult<Self> {
let execution_order = topological_sort(&graph)?;
Ok(Self {
graph,
execution_order,
})
}
pub fn execution_order(&self) -> &[usize] {
&self.execution_order
}
pub fn graph(&self) -> &Graph {
&self.graph
}
pub fn run(
&self,
inputs: HashMap<String, OnnxTensor>,
) -> OnnxResult<HashMap<String, OnnxTensor>> {
let mut tensors: HashMap<String, OnnxTensor> = HashMap::new();
for (name, tensor) in &self.graph.initializers {
tensors.insert(name.clone(), tensor.clone());
}
for (name, tensor) in inputs {
tensors.insert(name, tensor);
}
let registry = OpRegistry::new();
for &idx in &self.execution_order {
let node = &self.graph.nodes[idx];
let input_refs: Vec<Option<&OnnxTensor>> = node
.inputs
.iter()
.map(|name| {
if name.is_empty() {
None
} else {
tensors.get(name.as_str())
}
})
.collect();
let outputs = registry.execute(&node.op_type, &input_refs, &node.attributes)?;
for (output_name, output_tensor) in node.outputs.iter().zip(outputs) {
if !output_name.is_empty() {
tensors.insert(output_name.clone(), output_tensor);
}
}
}
let mut result = HashMap::new();
for output_info in &self.graph.outputs {
if let Some(tensor) = tensors.remove(&output_info.name) {
result.insert(output_info.name.clone(), tensor);
}
}
Ok(result)
}
}
fn topological_sort(graph: &Graph) -> OnnxResult<Vec<usize>> {
let n = graph.nodes.len();
if n == 0 {
return Ok(vec![]);
}
let mut producer: HashMap<&str, usize> = HashMap::new();
for (i, node) in graph.nodes.iter().enumerate() {
for output in &node.outputs {
if !output.is_empty() {
producer.insert(output.as_str(), i);
}
}
}
let mut adj: Vec<Vec<usize>> = vec![vec![]; n];
let mut in_degree = vec![0u32; n];
for (i, node) in graph.nodes.iter().enumerate() {
for input in &node.inputs {
if let Some(&j) = producer.get(input.as_str()) {
if j != i {
adj[j].push(i);
in_degree[i] += 1;
}
}
}
}
let mut queue: VecDeque<usize> = VecDeque::new();
for (i, °) in in_degree.iter().enumerate() {
if deg == 0 {
queue.push_back(i);
}
}
let mut order = Vec::with_capacity(n);
while let Some(node_idx) = queue.pop_front() {
order.push(node_idx);
for &next in &adj[node_idx] {
in_degree[next] -= 1;
if in_degree[next] == 0 {
queue.push_back(next);
}
}
}
if order.len() != n {
return Err(OnnxError::InvalidGraph(format!(
"cycle detected: sorted {} of {} nodes",
order.len(),
n
)));
}
Ok(order)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_node(op: &str, name: &str, inputs: &[&str], outputs: &[&str]) -> Node {
Node {
op_type: op.into(),
name: name.into(),
inputs: inputs.iter().map(|s| s.to_string()).collect(),
outputs: outputs.iter().map(|s| s.to_string()).collect(),
attributes: HashMap::new(),
}
}
fn make_info(name: &str) -> TensorInfo {
TensorInfo {
name: name.into(),
dtype: DataType::Float32,
shape: TensorShape::fixed(vec![2]),
}
}
#[test]
fn test_topological_sort_linear() {
let graph = Graph {
name: "linear".into(),
nodes: vec![
make_node("Relu", "relu0", &["A"], &["B"]),
make_node("Relu", "relu1", &["B"], &["C"]),
],
inputs: vec![make_info("A")],
outputs: vec![make_info("C")],
initializers: HashMap::new(),
};
let order = topological_sort(&graph).unwrap();
assert_eq!(order, vec![0, 1]);
}
#[test]
fn test_topological_sort_diamond() {
let graph = Graph {
name: "diamond".into(),
nodes: vec![
make_node("Relu", "left", &["X"], &["L"]),
make_node("Neg", "right", &["X"], &["R"]),
make_node("Add", "merge", &["L", "R"], &["Y"]),
],
inputs: vec![make_info("X")],
outputs: vec![make_info("Y")],
initializers: HashMap::new(),
};
let order = topological_sort(&graph).unwrap();
assert_eq!(*order.last().unwrap_or(&999), 2);
assert_eq!(order.len(), 3);
}
#[test]
fn test_topological_sort_cycle() {
let graph = Graph {
name: "cycle".into(),
nodes: vec![
make_node("Relu", "a", &["B"], &["A"]),
make_node("Relu", "b", &["A"], &["B"]),
],
inputs: vec![],
outputs: vec![],
initializers: HashMap::new(),
};
assert!(topological_sort(&graph).is_err());
}
#[test]
fn test_executor_linear_chain() {
let graph = Graph {
name: "test".into(),
nodes: vec![make_node("Relu", "relu0", &["X"], &["Y"])],
inputs: vec![make_info("X")],
outputs: vec![make_info("Y")],
initializers: HashMap::new(),
};
let exec = GraphExecutor::new(graph).unwrap();
let mut inputs = HashMap::new();
inputs.insert("X".into(), OnnxTensor::from_f32(&[-1.0, 2.0], vec![2]));
let result = exec.run(inputs).unwrap();
let y = result.get("Y").unwrap();
assert_eq!(y.as_f32().unwrap(), vec![0.0, 2.0]);
}
#[test]
fn test_executor_diamond() {
let graph = Graph {
name: "diamond".into(),
nodes: vec![
make_node("Relu", "relu", &["X"], &["L"]),
make_node("Neg", "neg", &["X"], &["R"]),
make_node("Add", "add", &["L", "R"], &["Y"]),
],
inputs: vec![TensorInfo {
name: "X".into(),
dtype: DataType::Float32,
shape: TensorShape::fixed(vec![3]),
}],
outputs: vec![TensorInfo {
name: "Y".into(),
dtype: DataType::Float32,
shape: TensorShape::fixed(vec![3]),
}],
initializers: HashMap::new(),
};
let exec = GraphExecutor::new(graph).unwrap();
let mut inputs = HashMap::new();
inputs.insert("X".into(), OnnxTensor::from_f32(&[-1.0, 2.0, 0.0], vec![3]));
let result = exec.run(inputs).unwrap();
let y = result.get("Y").unwrap();
assert_eq!(y.as_f32().unwrap(), vec![1.0, 0.0, 0.0]);
}
#[test]
fn test_executor_skip_connection() {
let graph = Graph {
name: "skip".into(),
nodes: vec![
make_node("Relu", "relu", &["X"], &["T1"]),
make_node("Add", "add", &["T1", "X"], &["Y"]),
],
inputs: vec![TensorInfo {
name: "X".into(),
dtype: DataType::Float32,
shape: TensorShape::fixed(vec![3]),
}],
outputs: vec![TensorInfo {
name: "Y".into(),
dtype: DataType::Float32,
shape: TensorShape::fixed(vec![3]),
}],
initializers: HashMap::new(),
};
let exec = GraphExecutor::new(graph).unwrap();
let mut inputs = HashMap::new();
inputs.insert("X".into(), OnnxTensor::from_f32(&[-2.0, 3.0, 1.0], vec![3]));
let result = exec.run(inputs).unwrap();
let y = result.get("Y").unwrap();
assert_eq!(y.as_f32().unwrap(), vec![-2.0, 6.0, 2.0]);
}
#[test]
fn test_executor_with_initializer() {
let mut initializers = HashMap::new();
initializers.insert("W".into(), OnnxTensor::from_f32(&[10.0, 20.0], vec![2]));
let graph = Graph {
name: "with_init".into(),
nodes: vec![make_node("Add", "add", &["X", "W"], &["Y"])],
inputs: vec![make_info("X")],
outputs: vec![make_info("Y")],
initializers,
};
let exec = GraphExecutor::new(graph).unwrap();
let mut inputs = HashMap::new();
inputs.insert("X".into(), OnnxTensor::from_f32(&[1.0, 2.0], vec![2]));
let result = exec.run(inputs).unwrap();
let y = result.get("Y").unwrap();
assert_eq!(y.as_f32().unwrap(), vec![11.0, 22.0]);
}
#[test]
fn test_empty_graph() {
let graph = Graph {
name: "empty".into(),
nodes: vec![],
inputs: vec![],
outputs: vec![],
initializers: HashMap::new(),
};
let exec = GraphExecutor::new(graph).unwrap();
let result = exec.run(HashMap::new()).unwrap();
assert!(result.is_empty());
}
}