use crate::{
error::{OnnxError, Result},
graph::Graph,
operators,
tensor::Tensor,
};
use std::collections::HashMap;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
pub struct Runtime {
debug: bool,
max_concurrency: usize,
}
pub struct ExecutionContext {
tensors: HashMap<String, Tensor>,
stats: ExecutionStats,
}
#[derive(Debug, Default)]
pub struct ExecutionStats {
pub total_time_ms: f64,
pub ops_executed: usize,
pub memory_usage_bytes: usize,
pub op_times: HashMap<String, f64>,
}
impl Runtime {
pub fn new() -> Self {
Self {
debug: false,
max_concurrency: 1,
}
}
pub fn with_debug() -> Self {
Self {
debug: true,
max_concurrency: 1,
}
}
pub fn with_max_concurrency(mut self, max_concurrency: usize) -> Self {
self.max_concurrency = max_concurrency;
self
}
pub fn execute(
&self,
graph: &Graph,
inputs: HashMap<String, Tensor>,
) -> Result<HashMap<String, Tensor>> {
let start_time = std::time::Instant::now();
if self.debug {
log::debug!("Starting execution of graph '{}'", graph.name);
}
graph.validate()?;
self.validate_inputs(graph, &inputs)?;
let mut context = ExecutionContext::new();
context.add_tensors(inputs);
for (name, tensor) in &graph.initializers {
context.add_tensor(name.clone(), tensor.clone());
}
let levels = graph.topological_levels()?;
let debug = self.debug;
for level_nodes in &levels {
let work: Vec<(usize, Vec<Tensor>)> = level_nodes
.iter()
.map(|&node_idx| {
let node = &graph.nodes[node_idx];
let inputs = node
.inputs
.iter()
.map(|name| {
context.get_tensor(name).cloned().ok_or_else(|| {
OnnxError::runtime_error(format!(
"Node '{}' references unknown tensor '{}'",
node.name, name
))
})
})
.collect::<Result<Vec<_>>>()?;
Ok((node_idx, inputs))
})
.collect::<Result<Vec<_>>>()?;
let run = |(node_idx, inputs): (usize, Vec<Tensor>)| -> (usize, Result<Vec<Tensor>>) {
let node = &graph.nodes[node_idx];
if debug {
log::debug!("Executing node '{}' ({})", node.name, node.op_type);
for (i, t) in inputs.iter().enumerate() {
log::debug!(" Input {}: shape {:?}", i, t.shape());
}
}
let result = node.get_operator_type().and_then(|op_type| {
operators::execute_operator(&op_type, &inputs, &node.attributes).map_err(|e| {
OnnxError::runtime_error(format!(
"Failed to execute {:?} ({}): {}",
op_type, node.name, e
))
})
});
(node_idx, result)
};
#[cfg(feature = "parallel")]
let results: Vec<(usize, Result<Vec<Tensor>>)> =
work.into_par_iter().map(run).collect();
#[cfg(not(feature = "parallel"))]
let results: Vec<(usize, Result<Vec<Tensor>>)> = work.into_iter().map(run).collect();
for (node_idx, outputs_result) in results {
let node = &graph.nodes[node_idx];
let output_tensors = outputs_result?;
if output_tensors.len() != node.outputs.len() {
return Err(OnnxError::runtime_error(format!(
"Node '{}' produced {} outputs but expected {}",
node.name,
output_tensors.len(),
node.outputs.len()
)));
}
if debug {
for (i, t) in output_tensors.iter().enumerate() {
log::debug!(" Output {}: shape {:?}", i, t.shape());
}
}
for (name, tensor) in node.outputs.iter().zip(output_tensors) {
context.add_tensor(name.clone(), tensor);
}
context.stats.ops_executed += 1;
}
}
let outputs = self.extract_outputs(graph, &context)?;
context.stats.total_time_ms = start_time.elapsed().as_millis() as f64;
if self.debug {
log::debug!(
"Execution completed in {:.2}ms",
context.stats.total_time_ms
);
log::debug!("Operations executed: {}", context.stats.ops_executed);
}
Ok(outputs)
}
#[cfg(feature = "async")]
pub async fn execute_async(
self,
graph: Graph,
inputs: HashMap<String, Tensor>,
) -> Result<HashMap<String, Tensor>> {
tokio::task::spawn_blocking(move || self.execute(&graph, inputs))
.await
.map_err(|e| OnnxError::runtime_error(e.to_string()))?
}
fn validate_inputs(&self, graph: &Graph, inputs: &HashMap<String, Tensor>) -> Result<()> {
for input_spec in &graph.inputs {
let tensor = inputs.get(&input_spec.name).ok_or_else(|| {
OnnxError::runtime_error(format!("Missing required input: {}", input_spec.name))
})?;
if !input_spec.matches_tensor(tensor) {
return Err(OnnxError::shape_mismatch(
&input_spec
.dimensions
.iter()
.map(|dim| dim.unwrap_or(0))
.collect::<Vec<_>>(),
tensor.shape(),
));
}
}
Ok(())
}
fn extract_outputs(
&self,
graph: &Graph,
context: &ExecutionContext,
) -> Result<HashMap<String, Tensor>> {
let mut outputs = HashMap::new();
for output_spec in &graph.outputs {
let tensor = context.get_tensor(&output_spec.name).ok_or_else(|| {
OnnxError::runtime_error(format!(
"Graph output '{}' not found in execution context",
output_spec.name
))
})?;
outputs.insert(output_spec.name.clone(), tensor.clone());
}
Ok(outputs)
}
}
impl Default for Runtime {
fn default() -> Self {
Self::new()
}
}
impl ExecutionContext {
pub fn new() -> Self {
Self {
tensors: HashMap::new(),
stats: ExecutionStats::default(),
}
}
pub fn add_tensor(&mut self, name: String, tensor: Tensor) {
if let Some(existing) = self.tensors.get(&name) {
let freed = existing.len() * std::mem::size_of::<f32>();
self.stats.memory_usage_bytes = self.stats.memory_usage_bytes.saturating_sub(freed);
}
self.stats.memory_usage_bytes += tensor.len() * std::mem::size_of::<f32>();
self.tensors.insert(name, tensor);
}
pub fn add_tensors(&mut self, tensors: HashMap<String, Tensor>) {
for (name, tensor) in tensors {
self.add_tensor(name, tensor);
}
}
pub fn get_tensor(&self, name: &str) -> Option<&Tensor> {
self.tensors.get(name)
}
pub fn tensor_names(&self) -> Vec<&str> {
self.tensors.keys().map(|s| s.as_str()).collect()
}
pub fn stats(&self) -> &ExecutionStats {
&self.stats
}
}
impl Default for ExecutionContext {
fn default() -> Self {
Self::new()
}
}
impl ExecutionStats {
pub fn avg_op_time(&self) -> f64 {
if self.ops_executed > 0 {
self.total_time_ms / self.ops_executed as f64
} else {
0.0
}
}
pub fn memory_usage_mb(&self) -> f64 {
self.memory_usage_bytes as f64 / 1024.0 / 1024.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Graph, Tensor};
use ndarray::Array1;
#[cfg(feature = "async")]
use ndarray::Array2;
#[test]
fn test_runtime_creation() {
let runtime = Runtime::new();
assert!(!runtime.debug);
assert_eq!(runtime.max_concurrency, 1);
let debug_runtime = Runtime::with_debug();
assert!(debug_runtime.debug);
}
#[test]
fn test_runtime_with_custom_config() {
let runtime = Runtime::with_debug();
assert!(runtime.debug);
assert_eq!(runtime.max_concurrency, 1);
}
#[test]
fn test_execution_context() {
let mut context = ExecutionContext::new();
let tensor = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
context.add_tensor("test".to_string(), tensor.clone());
assert!(context.get_tensor("test").is_some());
assert!(context.get_tensor("missing").is_none());
assert_eq!(context.tensor_names(), vec!["test"]);
let retrieved = context.get_tensor("test").unwrap();
assert_eq!(retrieved.shape(), tensor.shape());
}
#[test]
fn test_execution_context_multiple_tensors() {
let mut context = ExecutionContext::new();
let tensor1 = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let tensor2 = Tensor::from_array(Array1::from_vec(vec![4.0, 5.0, 6.0]));
context.add_tensor("tensor1".to_string(), tensor1);
context.add_tensor("tensor2".to_string(), tensor2);
let mut names = context.tensor_names();
names.sort(); assert_eq!(names, vec!["tensor1", "tensor2"]);
assert!(context.get_tensor("tensor1").is_some());
assert!(context.get_tensor("tensor2").is_some());
}
#[test]
fn test_simple_execution() {
env_logger::try_init().ok();
let runtime = Runtime::with_debug();
let graph = Graph::create_simple_linear();
let mut inputs = HashMap::new();
inputs.insert(
"input".to_string(),
Tensor::from_shape_vec(&[1, 3], vec![1.0, 2.0, 3.0]).unwrap(),
);
let outputs = runtime.execute(&graph, inputs).unwrap();
assert!(outputs.contains_key("output"));
let output = outputs.get("output").unwrap();
assert_eq!(output.shape(), &[1, 2]);
let data = output.data();
let expected = [1.3, 3.1];
for (actual, &expected) in data.iter().zip(expected.iter()) {
assert!(
(actual - expected).abs() < 1e-6,
"Expected {expected}, got {actual}"
);
}
}
#[test]
fn test_runtime_non_debug_execution() {
let runtime = Runtime::new(); let graph = Graph::create_simple_linear();
let mut inputs = HashMap::new();
inputs.insert(
"input".to_string(),
Tensor::from_shape_vec(&[1, 3], vec![1.0, 2.0, 3.0]).unwrap(),
);
let outputs = runtime.execute(&graph, inputs).unwrap();
assert!(outputs.contains_key("output"));
}
#[test]
fn test_missing_input() {
let runtime = Runtime::new();
let graph = Graph::create_simple_linear();
let inputs = HashMap::new();
let result = runtime.execute(&graph, inputs);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Missing required input"));
}
#[test]
fn test_input_shape_validation_error() {
let runtime = Runtime::new();
let graph = Graph::create_simple_linear();
let mut inputs = HashMap::new();
inputs.insert(
"input".to_string(),
Tensor::from_shape_vec(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]).unwrap(),
);
let result = runtime.execute(&graph, inputs);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("Shape mismatch") || error_msg.contains("shape"));
}
#[test]
fn test_unknown_tensor_reference_error() {
let runtime = Runtime::new();
let mut graph = Graph::new("invalid_graph".to_string());
let node = crate::graph::Node::new(
"invalid_node".to_string(),
"Add".to_string(),
vec![
"nonexistent_tensor".to_string(),
"another_nonexistent".to_string(),
],
vec!["output".to_string()],
);
graph.add_node(node);
let input_spec = crate::graph::TensorSpec {
name: "input".to_string(),
dtype: "float32".to_string(),
dimensions: vec![Some(1), Some(3)],
};
graph.add_input(input_spec);
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), Tensor::zeros(&[1, 3]));
let result = runtime.execute(&graph, inputs);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("unknown tensor") || error_msg.contains("nonexistent"));
}
#[test]
fn test_execution_with_intermediate_tensors() {
let runtime = Runtime::with_debug();
let mut graph = Graph::new("complex_graph".to_string());
let input_spec = crate::graph::TensorSpec {
name: "input".to_string(),
dtype: "float32".to_string(),
dimensions: vec![Some(1), Some(2)],
};
graph.add_input(input_spec);
let output_spec = crate::graph::TensorSpec {
name: "output".to_string(),
dtype: "float32".to_string(),
dimensions: vec![Some(1), Some(2)],
};
graph.add_output(output_spec);
let relu_node = crate::graph::Node::new(
"relu".to_string(),
"Relu".to_string(),
vec!["input".to_string()],
vec!["intermediate".to_string()],
);
graph.add_node(relu_node);
let add_node = crate::graph::Node::new(
"add".to_string(),
"Add".to_string(),
vec!["intermediate".to_string(), "intermediate".to_string()],
vec!["output".to_string()],
);
graph.add_node(add_node);
let mut inputs = HashMap::new();
inputs.insert(
"input".to_string(),
Tensor::from_shape_vec(&[1, 2], vec![-1.0, 2.0]).unwrap(),
);
let outputs = runtime.execute(&graph, inputs).unwrap();
assert!(outputs.contains_key("output"));
let output = outputs.get("output").unwrap();
let data = output.data();
assert!((data[[0, 0]] - 0.0).abs() < 1e-6);
assert!((data[[0, 1]] - 4.0).abs() < 1e-6);
}
#[tokio::test]
#[cfg(feature = "async")]
async fn test_async_execution() {
let runtime = Runtime::new();
let graph = Graph::create_simple_linear();
let mut inputs = HashMap::new();
inputs.insert(
"input".to_string(),
Tensor::from_array(
Array2::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0])
.unwrap()
.into_dyn(),
),
);
let outputs = runtime.execute_async(graph, inputs).await.unwrap();
assert!(outputs.contains_key("output"));
}
#[tokio::test]
#[cfg(feature = "async")]
async fn test_async_execution_error() {
let runtime = Runtime::new();
let graph = Graph::create_simple_linear();
let inputs = HashMap::new();
let result = runtime.execute_async(graph, inputs).await;
assert!(result.is_err());
}
#[test]
fn test_execution_stats() {
let stats = ExecutionStats {
total_time_ms: 100.0,
ops_executed: 5,
memory_usage_bytes: 1024 * 1024, ..Default::default()
};
assert_eq!(stats.avg_op_time(), 20.0);
assert_eq!(stats.memory_usage_mb(), 1.0);
}
#[test]
fn test_execution_stats_zero_ops() {
let stats = ExecutionStats {
total_time_ms: 100.0,
ops_executed: 0,
memory_usage_bytes: 0,
..Default::default()
};
assert_eq!(stats.avg_op_time(), 0.0);
assert_eq!(stats.memory_usage_mb(), 0.0);
}
#[test]
fn test_execution_stats_default() {
let stats = ExecutionStats::default();
assert_eq!(stats.total_time_ms, 0.0);
assert_eq!(stats.ops_executed, 0);
assert_eq!(stats.memory_usage_bytes, 0);
assert_eq!(stats.avg_op_time(), 0.0);
assert_eq!(stats.memory_usage_mb(), 0.0);
}
#[test]
fn test_runtime_builder_pattern() {
let runtime = Runtime::new();
assert!(!runtime.debug);
let runtime2 = Runtime::with_debug();
assert!(runtime2.debug);
}
#[test]
fn test_large_batch_execution() {
let runtime = Runtime::new();
let graph = Graph::create_simple_linear();
let mut inputs = HashMap::new();
inputs.insert(
"input".to_string(),
Tensor::from_shape_vec(&[1, 3], vec![1.0, 2.0, 3.0]).unwrap(),
);
let outputs = runtime.execute(&graph, inputs).unwrap();
assert!(outputs.contains_key("output"));
let output = outputs.get("output").unwrap();
assert_eq!(output.shape(), &[1, 2]); }
#[test]
fn test_execution_context_tensor_overwrite() {
let mut context = ExecutionContext::new();
let tensor1 = Tensor::from_array(Array1::from_vec(vec![1.0, 2.0, 3.0]));
let tensor2 = Tensor::from_array(Array1::from_vec(vec![4.0, 5.0, 6.0]));
context.add_tensor("test".to_string(), tensor1);
context.add_tensor("test".to_string(), tensor2);
let retrieved = context.get_tensor("test").unwrap();
let data = retrieved.data();
assert!((data[0] - 4.0).abs() < 1e-6);
assert!((data[1] - 5.0).abs() < 1e-6);
assert!((data[2] - 6.0).abs() < 1e-6);
}
}