use rustorch::execution::{DynamicOp, RuntimeConfig, RuntimeEngine};
use rustorch::tensor::Tensor;
use std::time::Duration;
#[test]
fn test_basic_dynamic_execution() {
let config = RuntimeConfig::default();
let mut engine = RuntimeEngine::<f32>::new(config);
let result = engine.execute_graph(|builder| {
let input = builder.add_input(Tensor::ones(&[2, 3]))?;
let weight = builder.add_parameter(Tensor::ones(&[4, 3]))?;
let output = builder.linear(input, weight, None)?;
Ok(output)
});
match result {
Ok(output) => {
assert_eq!(output.shape(), &[2, 4]);
}
Err(e) => panic!("Basic dynamic execution test failed with error: {:?}", e),
}
}
#[test]
fn test_complex_neural_network() {
let config = RuntimeConfig::default();
let mut engine = RuntimeEngine::<f32>::new(config);
let result = engine.execute_graph(|builder| {
let input = builder.add_input(Tensor::ones(&[16, 784]))?;
let w1 = builder.add_parameter(Tensor::ones(&[512, 784]))?;
let b1 = builder.add_parameter(Tensor::ones(&[512]))?;
let h1 = builder.linear(input, w1, Some(b1))?;
let a1 = builder.relu(h1)?;
let w2 = builder.add_parameter(Tensor::ones(&[256, 512]))?;
let b2 = builder.add_parameter(Tensor::ones(&[256]))?;
let h2 = builder.linear(a1, w2, Some(b2))?;
let a2 = builder.relu(h2)?;
let w3 = builder.add_parameter(Tensor::ones(&[10, 256]))?;
let b3 = builder.add_parameter(Tensor::ones(&[10]))?;
let output = builder.linear(a2, w3, Some(b3))?;
Ok(output)
});
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.shape(), &[16, 10]);
let metrics = engine.get_metrics();
assert_eq!(metrics.total_executions, 1);
assert!(metrics.avg_execution_time > Duration::from_nanos(0));
}
#[test]
fn test_convolutional_network() {
let config = RuntimeConfig::default();
let mut engine = RuntimeEngine::<f32>::new(config);
let result = engine.execute_graph(|builder| {
let input = builder.add_input(Tensor::ones(&[2, 3, 32, 32]))?;
let conv1_weight = builder.add_parameter(Tensor::ones(&[16, 3, 3, 3]))?;
let conv1 = builder.conv2d(input, conv1_weight, (3, 3), (1, 1), (1, 1))?;
let conv1_relu = builder.relu(conv1)?;
let conv2_weight = builder.add_parameter(Tensor::ones(&[32, 16, 3, 3]))?;
let conv2 = builder.conv2d(conv1_relu, conv2_weight, (3, 3), (1, 1), (1, 1))?;
let conv2_relu = builder.relu(conv2)?;
let flattened = builder.reshape(conv2_relu, vec![2, 32 * 32 * 32])?;
let linear_weight = builder.add_parameter(Tensor::ones(&[10, 32 * 32 * 32]))?;
let linear_bias = builder.add_parameter(Tensor::ones(&[10]))?;
let output = builder.linear(flattened, linear_weight, Some(linear_bias))?;
Ok(output)
});
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.shape(), &[2, 10]);
}
#[test]
fn test_jit_compilation() {
let config = RuntimeConfig {
enable_jit: true,
jit_threshold: 3,
..Default::default()
};
let mut engine = RuntimeEngine::<f32>::new(config);
for i in 0..5 {
let result = engine.execute_graph(|builder| {
let input = builder.add_input(Tensor::ones(&[4, 4]))?;
let relu1 = builder.relu(input)?;
let relu2 = builder.relu(relu1)?;
let relu3 = builder.relu(relu2)?;
Ok(relu3)
});
assert!(result.is_ok(), "Execution {} failed", i);
}
let metrics = engine.get_metrics();
assert!(metrics.jit_stats.total_compilations > 0);
assert_eq!(metrics.total_executions, 5);
}
#[test]
fn test_memory_optimization() {
let config = RuntimeConfig {
enable_memory_opt: true,
..Default::default()
};
let mut engine = RuntimeEngine::<f32>::new(config);
for size in [10, 50, 100, 200].iter() {
let result = engine.execute_graph(|builder| {
let input = builder.add_input(Tensor::ones(&[*size, *size]))?;
let weight = builder.add_parameter(Tensor::ones(&[*size, *size]))?;
let matmul = builder.matmul(input, weight)?;
let output = builder.relu(matmul)?;
Ok(output)
});
assert!(result.is_ok());
}
let metrics = engine.get_metrics();
assert!(metrics.memory_stats.allocations > 0);
assert!(metrics.total_executions == 4);
}
#[test]
fn test_parallel_execution() {
let config = RuntimeConfig {
enable_parallel: true,
..Default::default()
};
let mut engine = RuntimeEngine::<f32>::new(config);
let result = engine.execute_graph(|builder| {
let input1 = builder.add_input(Tensor::ones(&[8, 8]))?;
let input2 = builder.add_input(Tensor::ones(&[8, 8]))?;
let input3 = builder.add_input(Tensor::ones(&[8, 8]))?;
let relu1 = builder.relu(input1)?;
let relu2 = builder.relu(input2)?;
let relu3 = builder.relu(input3)?;
let add1 = builder.add(relu1, relu2)?;
let final_output = builder.add(add1, relu3)?;
Ok(final_output)
});
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.shape(), &[8, 8]);
let _metrics = engine.get_metrics();
}
#[test]
fn test_operation_fusion() {
let config = RuntimeConfig {
enable_fusion: true,
..Default::default()
};
let mut engine = RuntimeEngine::<f32>::new(config);
let result = engine.execute_graph(|builder| {
let input = builder.add_input(Tensor::ones(&[4, 4]))?;
let relu1 = builder.relu(input)?;
let relu2 = builder.relu(relu1)?;
Ok(relu2)
});
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.shape(), &[4, 4]);
if let Some(slice) = output.as_slice() {
for &value in slice {
assert!((value - 1.0).abs() < 1e-6);
}
}
}
#[test]
fn test_caching_effectiveness() {
let config = RuntimeConfig {
max_cache_size: 100,
..Default::default()
};
let mut engine = RuntimeEngine::<f32>::new(config);
let pattern_executions = 5;
for _ in 0..pattern_executions {
let result = engine.execute_graph(|builder| {
let input = builder.add_input(Tensor::ones(&[3, 3]))?;
let weight = builder.add_parameter(Tensor::ones(&[3, 3]))?; let output = builder.matmul(input, weight)?;
Ok(output)
});
assert!(result.is_ok());
}
let metrics = engine.get_metrics();
assert_eq!(metrics.total_executions, pattern_executions);
assert!(metrics.cache_hit_rate >= 0.0);
}
#[test]
fn test_error_handling() {
let config = RuntimeConfig::default();
let mut engine = RuntimeEngine::<f32>::new(config);
let result = engine.execute_graph(|builder| {
let input = builder.add_input(Tensor::ones(&[2, 3]))?;
let weight = builder.add_parameter(Tensor::ones(&[5, 6]))?; let output = builder.matmul(input, weight)?; Ok(output)
});
assert!(result.is_err());
}
#[test]
fn test_profiling_functionality() {
let config = RuntimeConfig::default();
let mut engine = RuntimeEngine::<f32>::new(config);
let profile_result = engine.profile_execution(3);
assert!(profile_result.is_ok());
let result = profile_result.unwrap();
let summary = result.summary();
assert!(summary.contains("Executions: 3"));
assert!(summary.contains("Average time:"));
}
#[test]
fn test_warmup_functionality() {
let config = RuntimeConfig::default();
let mut engine = RuntimeEngine::<f32>::new(config);
let warmup_result = engine.warmup();
assert!(warmup_result.is_ok());
let metrics = engine.get_metrics();
assert!(metrics.jit_stats.total_compilations > 0);
}
#[test]
fn test_cache_cleanup() {
let config = RuntimeConfig {
max_cache_size: 2,
..Default::default()
};
let mut engine = RuntimeEngine::<f32>::new(config);
for i in 0..5 {
let _result = engine
.execute_graph(|builder| {
let input = builder.add_input(Tensor::ones(&[i + 2, i + 2]))?;
let output = builder.relu(input)?;
Ok(output)
})
.unwrap();
}
engine.cleanup_cache();
assert!(engine.execution_cache.len() <= 2);
}
#[test]
fn test_multi_input_operations() {
let config = RuntimeConfig::default();
let mut engine = RuntimeEngine::<f32>::new(config);
let result = engine.execute_graph(|builder| {
let input1 = builder.add_input(Tensor::ones(&[3, 3]))?;
let input2 = builder.add_input(Tensor::from_vec(vec![2.0; 9], vec![3, 3]))?;
let add_result = builder.add(input1, input2)?;
let mult_input = builder.add_input(Tensor::from_vec(vec![0.5; 9], vec![3, 3]))?;
let mult_result = builder.add_operation(DynamicOp::Mul, vec![add_result, mult_input])?;
Ok(mult_result)
});
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.shape(), &[3, 3]);
if let Some(slice) = output.as_slice() {
for &value in slice {
assert!((value - 1.5).abs() < 1e-6);
}
}
}
#[test]
fn test_activation_functions() {
let config = RuntimeConfig::default();
let mut engine = RuntimeEngine::<f32>::new(config);
let relu_result = engine
.execute_graph(|builder| {
let input = builder.add_input(Tensor::from_vec(vec![-1.0, 0.0, 1.0, 2.0], vec![4]))?;
let output = builder.relu(input)?;
Ok(output)
})
.unwrap();
let expected_relu = [0.0, 0.0, 1.0, 2.0];
if let Some(slice) = relu_result.as_slice() {
for (actual, expected) in slice.iter().zip(expected_relu.iter()) {
assert!((actual - expected).abs() < 1e-6);
}
}
let sigmoid_result = engine
.execute_graph(|builder| {
let input = builder.add_input(Tensor::from_vec(vec![0.0], vec![1]))?;
let output = builder.sigmoid(input)?;
Ok(output)
})
.unwrap();
if let Some(slice) = sigmoid_result.as_slice() {
assert!((slice[0] - 0.5).abs() < 1e-6);
}
}
#[test]
fn test_reshape_operations() {
let config = RuntimeConfig::default();
let mut engine = RuntimeEngine::<f32>::new(config);
let result = engine.execute_graph(|builder| {
let input = builder.add_input(Tensor::ones(&[2, 3, 4]))?; let reshaped = builder.reshape(input, vec![4, 6])?; Ok(reshaped)
});
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.shape(), &[4, 6]);
if let Some(slice) = output.as_slice() {
assert_eq!(slice.len(), 24);
}
}
#[test]
fn test_performance_optimization() {
let config = RuntimeConfig {
enable_jit: true,
enable_fusion: true,
enable_parallel: true,
jit_threshold: 2,
..Default::default()
};
let mut engine = RuntimeEngine::<f32>::new(config);
engine.warmup().unwrap();
let mut times = Vec::new();
for _ in 0..10 {
let start = std::time::Instant::now();
let _result = engine
.execute_graph(|builder| {
let input = builder.add_input(Tensor::ones(&[32, 64]))?;
let w1 = builder.add_parameter(Tensor::ones(&[128, 64]))?;
let h1 = builder.linear(input, w1, None)?;
let a1 = builder.relu(h1)?;
let w2 = builder.add_parameter(Tensor::ones(&[32, 128]))?;
let output = builder.linear(a1, w2, None)?;
Ok(output)
})
.unwrap();
times.push(start.elapsed());
}
let early_avg = times[0..3].iter().sum::<Duration>() / 3;
let later_avg = times[7..10].iter().sum::<Duration>() / 3;
println!(
"Early average: {:?}, Later average: {:?}",
early_avg, later_avg
);
assert!(later_avg <= early_avg * 3);
let metrics = engine.get_metrics();
assert!(metrics.jit_stats.total_compilations > 0);
assert!(metrics.total_executions == 10);
}
#[test]
fn test_memory_efficiency() {
let config = RuntimeConfig::default();
let mut engine = RuntimeEngine::<f32>::new(config);
let _initial_memory = std::mem::size_of::<RuntimeEngine<f32>>();
for i in 1..=5 {
let _result = engine
.execute_graph(|builder| {
let size = i * 10;
let input = builder.add_input(Tensor::ones(&[size, size]))?;
let weight = builder.add_parameter(Tensor::ones(&[size, size]))?;
let output = builder.matmul(input, weight)?;
Ok(output)
})
.unwrap();
}
let metrics = engine.get_metrics();
assert!(metrics.memory_stats.allocations > 0);
assert!(metrics.memory_stats.memory_efficiency >= 0.0);
assert!(metrics.memory_stats.peak_memory > 0);
}
#[test]
fn test_execution_plan_optimization() {
let config = RuntimeConfig::default();
let mut engine = RuntimeEngine::<f32>::new(config);
let result = engine.execute_graph(|builder| {
let input1 = builder.add_input(Tensor::ones(&[5, 5]))?;
let input2 = builder.add_input(Tensor::ones(&[5, 5]))?;
let input3 = builder.add_input(Tensor::ones(&[5, 5]))?;
let relu1 = builder.relu(input1)?;
let relu2 = builder.relu(input2)?;
let sigmoid1 = builder.sigmoid(input3)?;
let add1 = builder.add(relu1, relu2)?;
let final_result = builder.add(add1, sigmoid1)?;
Ok(final_result)
});
assert!(result.is_ok());
let _metrics = engine.get_metrics();
if engine.config.enable_parallel {
}
}
#[test]
fn test_error_recovery() {
let config = RuntimeConfig::default();
let mut engine = RuntimeEngine::<f32>::new(config);
let invalid_result = engine.execute_graph(|builder| {
let input = builder.add_input(Tensor::ones(&[2, 3]))?;
let incompatible = builder.add_input(Tensor::ones(&[4, 5]))?;
let _output = builder.add(input, incompatible)?; Ok(0) });
assert!(invalid_result.is_err());
let valid_result = engine.execute_graph(|builder| {
let input = builder.add_input(Tensor::ones(&[3, 3]))?;
let output = builder.relu(input)?;
Ok(output)
});
assert!(valid_result.is_ok());
}
#[test]
fn test_large_scale_execution() {
let config = RuntimeConfig::default();
let mut engine = RuntimeEngine::<f32>::new(config);
let result = engine.execute_graph(|builder| {
let input = builder.add_input(Tensor::ones(&[128, 1024]))?;
let mut current = input;
let layer_sizes = [1024, 512, 256, 128, 64, 32, 16, 8];
for i in 0..layer_sizes.len() - 1 {
let weight =
builder.add_parameter(Tensor::ones(&[layer_sizes[i + 1], layer_sizes[i]]))?;
let bias = builder.add_parameter(Tensor::ones(&[layer_sizes[i + 1]]))?;
let linear = builder.linear(current, weight, Some(bias))?;
current = builder.relu(linear)?;
}
Ok(current)
});
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.shape(), &[128, 8]);
let metrics = engine.get_metrics();
assert!(metrics.avg_execution_time > Duration::from_nanos(0));
}
#[test]
fn test_mixed_operations() {
let config = RuntimeConfig::default();
let mut engine = RuntimeEngine::<f32>::new(config);
let result = engine.execute_graph(|builder| {
let input = builder.add_input(Tensor::ones(&[4, 16, 16]))?;
let flattened = builder.reshape(input, vec![4, 16 * 16])?;
let weight = builder.add_parameter(Tensor::ones(&[64, 16 * 16]))?;
let linear = builder.linear(flattened, weight, None)?;
let activated = builder.relu(linear)?;
let reshaped = builder.reshape(activated, vec![4, 8, 8])?;
let final_flat = builder.reshape(reshaped, vec![4, 64])?;
let final_weight = builder.add_parameter(Tensor::ones(&[10, 64]))?;
let output = builder.linear(final_flat, final_weight, None)?;
Ok(output)
});
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.shape(), &[4, 10]);
}
#[test]
fn test_metrics_accuracy() {
let config = RuntimeConfig::default();
let mut engine = RuntimeEngine::<f32>::new(config);
engine.reset_metrics();
let executions = 3;
for _ in 0..executions {
let _result = engine
.execute_graph(|builder| {
let input = builder.add_input(Tensor::ones(&[2, 2]))?;
let output = builder.relu(input)?;
Ok(output)
})
.unwrap();
}
let metrics = engine.get_metrics();
assert_eq!(metrics.total_executions, executions);
assert!(metrics.avg_execution_time > Duration::from_nanos(0));
assert!(metrics.cache_hit_rate >= 0.0 && metrics.cache_hit_rate <= 1.0);
}