use crate::device::context::{DeviceContext, DEVICE_MANAGER};
use crate::{DType, Device, Result, Tensor};
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
use super::config::{CachedOperation, EagerExecutionConfig, ExecutionMetrics, OpSignature};
use super::memory_pool::{MemoryGuard, MemoryPool};
use super::reporting::{CacheStatistics, EagerPerformanceReport};
pub struct EagerExecutionEngine {
pub(super) config: EagerExecutionConfig,
pub(super) op_cache: RwLock<HashMap<OpSignature, CachedOperation>>,
pub(super) memory_pool: MemoryPool,
pub(super) metrics: Mutex<Vec<ExecutionMetrics>>,
pub(super) active_contexts: RwLock<HashMap<Device, Arc<dyn DeviceContext>>>,
pub(super) fusion_opportunities: RwLock<Vec<FusionOpportunity>>,
}
#[derive(Debug)]
#[allow(dead_code)]
pub(super) struct FusionOpportunity {
pub(super) operations: Vec<String>,
pub(super) potential_speedup: f64,
pub(super) memory_savings: usize,
}
impl EagerExecutionEngine {
pub fn new(config: EagerExecutionConfig) -> Self {
Self {
memory_pool: MemoryPool::new(config.clone()),
config,
op_cache: RwLock::new(HashMap::new()),
metrics: Mutex::new(Vec::new()),
active_contexts: RwLock::new(HashMap::new()),
fusion_opportunities: RwLock::new(Vec::new()),
}
}
pub fn execute_operation<T, F>(
&self,
operation: &str,
inputs: &[&Tensor<T>],
params: &HashMap<String, String>,
executor: F,
) -> Result<(Tensor<T>, ExecutionMetrics)>
where
T: Clone + Send + Sync + 'static,
F: FnOnce(&[&Tensor<T>]) -> Result<Tensor<T>>,
{
let overall_start = Instant::now();
let signature = self.create_signature(operation, inputs, params)?;
let setup_start = Instant::now();
let cache_hit = self.check_cache(&signature);
let setup_time = setup_start.elapsed();
let exec_start = Instant::now();
let result = if cache_hit && self.config.enable_op_cache {
executor(inputs)?
} else {
let _memory_guard = if self.config.enable_memory_pool {
Some(self.prepare_memory_for_operation(&signature)?)
} else {
None
};
let result = if self.config.enable_context_optimization {
self.execute_with_context_optimization(inputs, executor)?
} else {
executor(inputs)?
};
if self.config.enable_op_cache {
self.cache_operation(&signature, &result, exec_start.elapsed())?;
}
result
};
let execution_time = exec_start.elapsed();
let teardown_start = Instant::now();
if self.config.enable_memory_pool {
self.cleanup_operation_memory(&signature)?;
}
let teardown_time = teardown_start.elapsed();
let total_time = overall_start.elapsed();
let total_overhead = total_time - execution_time;
let metrics = ExecutionMetrics {
operation: operation.to_string(),
device: *inputs[0].device(),
setup_time,
execution_time,
teardown_time,
total_overhead,
memory_allocation_time: Duration::ZERO, cache_hit,
meets_target: total_overhead.as_nanos() <= self.config.target_overhead_ns as u128,
};
self.metrics
.lock()
.expect("lock should not be poisoned")
.push(metrics.clone());
if self.config.enable_kernel_fusion {
self.analyze_fusion_opportunity(operation, &signature);
}
Ok((result, metrics))
}
fn create_signature<T: 'static>(
&self,
operation: &str,
inputs: &[&Tensor<T>],
params: &HashMap<String, String>,
) -> Result<OpSignature> {
let input_shapes: Vec<Vec<usize>> =
inputs.iter().map(|t| t.shape().dims().to_vec()).collect();
let device = *inputs[0].device();
let dtype = inputs[0].dtype();
let params: Vec<(String, String)> =
params.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
Ok(OpSignature {
operation: operation.to_string(),
input_shapes,
dtype,
device,
params,
})
}
fn check_cache(&self, signature: &OpSignature) -> bool {
let cache = self
.op_cache
.read()
.expect("read lock should not be poisoned");
cache.contains_key(signature)
}
fn cache_operation<T>(
&self,
signature: &OpSignature,
result: &Tensor<T>,
execution_time: Duration,
) -> Result<()> {
let mut cache = self
.op_cache
.write()
.expect("write lock should not be poisoned");
if cache.len() >= self.config.max_cache_size {
let oldest_key = cache
.iter()
.min_by_key(|(_, cached_op)| cached_op.last_used)
.map(|(k, _)| k.clone());
if let Some(key) = oldest_key {
cache.remove(&key);
}
}
let cached_op = CachedOperation {
signature: signature.clone(),
result_shape: result.shape().dims().to_vec(),
execution_time,
memory_usage: result.shape().size() * std::mem::size_of::<T>(),
created_at: Instant::now(),
last_used: Instant::now(),
use_count: 1,
};
cache.insert(signature.clone(), cached_op);
Ok(())
}
fn prepare_memory_for_operation(&self, signature: &OpSignature) -> Result<MemoryGuard> {
let output_memory_required = self.estimate_output_memory_requirements(signature)?;
let intermediate_memory_required =
self.estimate_intermediate_memory_requirements(signature)?;
if output_memory_required > 1024 * 1024 {
self.pre_warm_memory_pool(&signature.device, output_memory_required)?;
}
self.optimize_memory_layout_for_operation(signature)?;
Ok(MemoryGuard {
device: signature.device,
estimated_memory: output_memory_required + intermediate_memory_required,
operation: signature.operation.clone(),
})
}
fn estimate_output_memory_requirements(&self, signature: &OpSignature) -> Result<usize> {
let element_size = self.get_dtype_size(&signature.dtype);
let output_elements = match signature.operation.as_str() {
"add" | "sub" | "mul" | "div" | "relu" | "sigmoid" | "tanh" | "gelu" => signature
.input_shapes
.iter()
.map(|shape| shape.iter().product::<usize>())
.max()
.unwrap_or(0),
"matmul" => {
if signature.input_shapes.len() >= 2
&& signature.input_shapes[0].len() >= 2
&& signature.input_shapes[1].len() >= 2
{
let m = signature.input_shapes[0][signature.input_shapes[0].len() - 2];
let n = signature.input_shapes[1][signature.input_shapes[1].len() - 1];
let batch_size = signature.input_shapes[0]
.iter()
.take(signature.input_shapes[0].len() - 2)
.product::<usize>();
batch_size * m * n
} else {
0
}
}
"sum" | "mean" | "max" | "min" => {
signature
.input_shapes
.iter()
.map(|shape| shape.iter().product::<usize>() / shape.len().max(1))
.sum()
}
"conv2d" => {
if !signature.input_shapes.is_empty() && signature.input_shapes[0].len() >= 4 {
let batch = signature.input_shapes[0][0];
let height = signature.input_shapes[0][2];
let width = signature.input_shapes[0][3];
let output_channels = signature.input_shapes[0][1]; batch * output_channels * height * width
} else {
0
}
}
_ => {
signature
.input_shapes
.iter()
.map(|shape| shape.iter().product::<usize>())
.max()
.unwrap_or(0)
}
};
Ok(output_elements * element_size)
}
fn estimate_intermediate_memory_requirements(&self, signature: &OpSignature) -> Result<usize> {
let element_size = self.get_dtype_size(&signature.dtype);
let total_input_elements: usize = signature
.input_shapes
.iter()
.map(|shape| shape.iter().product::<usize>())
.sum();
let intermediate_factor = match signature.operation.as_str() {
"add" | "sub" | "mul" | "div" => 0.1,
"relu" | "sigmoid" | "tanh" | "gelu" => 0.2,
"matmul" => 0.5,
"batch_norm" | "layer_norm" | "group_norm" => 0.8,
"conv2d" | "conv3d" => 1.2,
"sum" | "mean" | "max" | "min" => 0.3,
_ => 0.5, };
Ok((total_input_elements as f64 * intermediate_factor * element_size as f64) as usize)
}
fn pre_warm_memory_pool(&self, device: &Device, required_memory: usize) -> Result<()> {
if self.config.enable_memory_pool {
let warmup_size = required_memory.next_power_of_two();
let num_blocks = if warmup_size > 1024 * 1024 { 2 } else { 3 };
self.memory_pool.pre_warm(device, warmup_size, num_blocks)?;
}
Ok(())
}
fn optimize_memory_layout_for_operation(&self, signature: &OpSignature) -> Result<()> {
match signature.operation.as_str() {
"matmul" | "conv2d" | "conv3d" => {
}
"add" | "sub" | "mul" | "div" => {
}
_ => {
}
}
Ok(())
}
fn get_dtype_size(&self, dtype: &DType) -> usize {
match dtype {
DType::Float16 => 2,
DType::BFloat16 => 2,
DType::Float32 => 4,
DType::Float64 => 8,
DType::Int8 => 1,
DType::Int16 => 2,
DType::Int32 => 4,
DType::Int64 => 8,
DType::Int4 => 1, DType::UInt8 => 1,
DType::UInt16 => 2,
DType::UInt32 => 4,
DType::UInt64 => 8,
DType::Bool => 1,
DType::Complex32 => 8,
DType::Complex64 => 16,
DType::String => 8, }
}
fn execute_with_context_optimization<T, F>(
&self,
inputs: &[&Tensor<T>],
executor: F,
) -> Result<Tensor<T>>
where
F: FnOnce(&[&Tensor<T>]) -> Result<Tensor<T>>,
{
let device = *inputs[0].device();
{
let mut contexts = self
.active_contexts
.write()
.expect("write lock should not be poisoned");
if let std::collections::hash_map::Entry::Vacant(e) = contexts.entry(device) {
let context = DEVICE_MANAGER.get_context(&device)?;
e.insert(context);
}
}
executor(inputs)
}
fn cleanup_operation_memory(&self, _signature: &OpSignature) -> Result<()> {
Ok(())
}
fn analyze_fusion_opportunity(&self, operation: &str, signature: &OpSignature) {
let mut opportunities = self
.fusion_opportunities
.write()
.expect("write lock should not be poisoned");
let fusion_speedup = match operation {
"add" | "sub" | "mul" | "div" => self.calculate_elementwise_fusion_benefit(signature),
"relu" | "sigmoid" | "tanh" | "gelu" => {
self.calculate_activation_fusion_benefit(signature)
}
"batch_norm" | "layer_norm" | "group_norm" => {
self.calculate_normalization_fusion_benefit(signature)
}
"matmul" | "conv2d" | "conv3d" => {
self.calculate_compute_intensive_fusion_benefit(signature)
}
"sum" | "mean" | "max" | "min" => self.calculate_reduction_fusion_benefit(signature),
_ => 1.0, };
if fusion_speedup > 1.1 && opportunities.len() < 50 {
let memory_savings = self.estimate_memory_savings(operation, signature);
if let Some(existing) = opportunities
.iter_mut()
.find(|opp| self.can_extend_fusion_chain(&opp.operations, operation))
{
existing.operations.push(operation.to_string());
existing.potential_speedup *= fusion_speedup.min(1.5); existing.memory_savings += memory_savings;
} else {
opportunities.push(FusionOpportunity {
operations: vec![operation.to_string()],
potential_speedup: fusion_speedup,
memory_savings,
});
}
}
}
fn calculate_elementwise_fusion_benefit(&self, signature: &OpSignature) -> f64 {
let total_elements: usize = signature
.input_shapes
.iter()
.map(|shape| shape.iter().product::<usize>())
.sum();
if total_elements > 10_000 {
1.8 } else if total_elements > 1_000 {
1.4 } else {
1.1 }
}
#[allow(unused_variables)] fn calculate_activation_fusion_benefit(&self, signature: &OpSignature) -> f64 {
let is_gpu = {
#[cfg(feature = "gpu")]
{
matches!(signature.device, Device::Gpu(_))
}
#[cfg(not(feature = "gpu"))]
{
false
}
};
if is_gpu {
1.6 } else {
1.3 }
}
fn calculate_normalization_fusion_benefit(&self, signature: &OpSignature) -> f64 {
let input_size: usize = signature
.input_shapes
.iter()
.map(|shape| shape.iter().product::<usize>())
.max()
.unwrap_or(0);
if input_size > 50_000 {
1.7 } else {
1.2 }
}
fn calculate_compute_intensive_fusion_benefit(&self, signature: &OpSignature) -> f64 {
let is_large_computation = signature
.input_shapes
.iter()
.any(|shape| shape.iter().product::<usize>() > 100_000);
if is_large_computation {
1.4 } else {
1.1 }
}
fn calculate_reduction_fusion_benefit(&self, signature: &OpSignature) -> f64 {
let input_size: usize = signature
.input_shapes
.iter()
.map(|shape| shape.iter().product::<usize>())
.max()
.unwrap_or(0);
if input_size > 20_000 {
1.5 } else {
1.2 }
}
fn estimate_memory_savings(&self, operation: &str, signature: &OpSignature) -> usize {
let element_size = match signature.dtype {
DType::Float16 => 2,
DType::BFloat16 => 2,
DType::Float32 => 4,
DType::Float64 => 8,
DType::Int8 => 1,
DType::Int16 => 2,
DType::Int32 => 4,
DType::Int64 => 8,
DType::Int4 => 1, DType::UInt8 => 1,
DType::UInt16 => 2,
DType::UInt32 => 4,
DType::UInt64 => 8,
DType::Bool => 1,
DType::Complex32 => 8,
DType::Complex64 => 16,
DType::String => 8, };
let total_elements: usize = signature
.input_shapes
.iter()
.map(|shape| shape.iter().product::<usize>())
.sum();
match operation {
"add" | "sub" | "mul" | "div" => total_elements * element_size, "relu" | "sigmoid" | "tanh" => total_elements * element_size / 2, "batch_norm" | "layer_norm" => total_elements * element_size * 2, _ => total_elements * element_size / 4, }
}
fn can_extend_fusion_chain(&self, existing_ops: &[String], new_op: &str) -> bool {
if existing_ops.is_empty() {
return false;
}
let last_op = &existing_ops[existing_ops.len() - 1];
match (last_op.as_str(), new_op) {
("add" | "sub" | "mul" | "div", "add" | "sub" | "mul" | "div") => true,
("matmul" | "conv2d" | "conv3d", "relu" | "sigmoid" | "tanh" | "gelu") => true,
("add" | "sub", "relu" | "sigmoid" | "tanh" | "gelu") => true,
("relu" | "sigmoid" | "tanh" | "gelu", "batch_norm" | "layer_norm") => true,
(_, "sum" | "mean" | "max" | "min") => existing_ops.len() < 3,
_ => false,
}
}
pub fn get_metrics(&self) -> Vec<ExecutionMetrics> {
self.metrics
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn get_cache_stats(&self) -> CacheStatistics {
let cache = self
.op_cache
.read()
.expect("read lock should not be poisoned");
let total_entries = cache.len();
let total_hits = cache.values().map(|op| op.use_count).sum();
let avg_execution_time = if total_entries > 0 {
cache
.values()
.map(|op| op.execution_time.as_nanos())
.sum::<u128>()
/ total_entries as u128
} else {
0
};
CacheStatistics {
total_entries,
total_hits,
hit_rate: if total_hits > 0 {
cache.len() as f64 / total_hits as f64
} else {
0.0
},
avg_execution_time: Duration::from_nanos(avg_execution_time as u64),
}
}
pub fn generate_performance_report(&self) -> EagerPerformanceReport {
let metrics = self.get_metrics();
let cache_stats = self.get_cache_stats();
if metrics.is_empty() {
return EagerPerformanceReport::default();
}
let total_operations = metrics.len();
let meets_target = metrics.iter().filter(|m| m.meets_target).count();
let success_rate = meets_target as f64 / total_operations as f64;
let avg_overhead = Duration::from_nanos(
(metrics
.iter()
.map(|m| m.total_overhead.as_nanos())
.sum::<u128>()
/ total_operations as u128) as u64,
);
let min_overhead = metrics
.iter()
.map(|m| m.total_overhead)
.min()
.unwrap_or(Duration::ZERO);
let max_overhead = metrics
.iter()
.map(|m| m.total_overhead)
.max()
.unwrap_or(Duration::ZERO);
let cache_hit_rate =
metrics.iter().filter(|m| m.cache_hit).count() as f64 / total_operations as f64;
EagerPerformanceReport {
total_operations,
operations_meeting_target: meets_target,
success_rate,
avg_overhead,
min_overhead,
max_overhead,
cache_statistics: cache_stats,
cache_hit_rate,
target_overhead: Duration::from_nanos(self.config.target_overhead_ns),
recommendations: self.generate_recommendations(&metrics),
}
}
fn generate_recommendations(&self, metrics: &[ExecutionMetrics]) -> Vec<String> {
let mut recommendations = Vec::new();
let avg_overhead = if !metrics.is_empty() {
metrics
.iter()
.map(|m| m.total_overhead.as_nanos())
.sum::<u128>()
/ metrics.len() as u128
} else {
0
};
if avg_overhead > self.config.target_overhead_ns as u128 {
recommendations
.push("Consider enabling operation caching to reduce setup overhead".to_string());
recommendations.push("Enable memory pooling to reduce allocation overhead".to_string());
}
let cache_hit_rate = if !metrics.is_empty() {
metrics.iter().filter(|m| m.cache_hit).count() as f64 / metrics.len() as f64
} else {
0.0
};
if cache_hit_rate < 0.3 {
recommendations.push("Increase cache size to improve hit rates".to_string());
}
let high_setup_ops = metrics
.iter()
.filter(|m| m.setup_time > Duration::from_micros(100))
.count();
if high_setup_ops > metrics.len() / 4 {
recommendations
.push("Enable context optimization to reduce setup overhead".to_string());
}
recommendations
}
pub fn cleanup(&self) {
self.memory_pool.cleanup_old_blocks();
let threshold = Duration::from_secs(300); let now = Instant::now();
let mut cache = self
.op_cache
.write()
.expect("write lock should not be poisoned");
cache.retain(|_, cached_op| now.duration_since(cached_op.last_used) <= threshold);
}
}