use crate::error::ExecutorError;
use crate::memory::MemoryEstimator;
use crate::optimization::{GraphOptimizer, OptimizationResult};
use crate::scheduling::{ExecutionSchedule, Scheduler, SchedulingStrategy};
use crate::shape::ShapeInferenceContext;
use crate::validation::GraphValidator;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, SystemTime};
use tensorlogic_ir::EinsumGraph;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
pub enum OptimizationLevel {
None,
Basic,
#[default]
Moderate,
Aggressive,
}
#[derive(Debug, Clone)]
pub struct CompilationConfig {
pub optimization_level: OptimizationLevel,
pub enable_shape_inference: bool,
pub enable_memory_estimation: bool,
pub target_device: Option<String>,
pub memory_budget: Option<usize>,
pub enable_caching: bool,
pub enable_parallelism: bool,
}
impl Default for CompilationConfig {
fn default() -> Self {
CompilationConfig {
optimization_level: OptimizationLevel::default(),
enable_shape_inference: true,
enable_memory_estimation: true,
target_device: None,
memory_budget: None,
enable_caching: true,
enable_parallelism: true,
}
}
}
#[derive(Debug, Clone)]
pub struct CompilationStats {
pub compilation_time: Duration,
pub original_nodes: usize,
pub optimized_nodes: usize,
pub fusions_applied: usize,
pub dead_nodes_eliminated: usize,
pub estimated_memory_bytes: usize,
pub execution_steps: usize,
}
impl Default for CompilationStats {
fn default() -> Self {
CompilationStats {
compilation_time: Duration::from_secs(0),
original_nodes: 0,
optimized_nodes: 0,
fusions_applied: 0,
dead_nodes_eliminated: 0,
estimated_memory_bytes: 0,
execution_steps: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct CompiledGraph {
pub graph: EinsumGraph,
pub schedule: ExecutionSchedule,
pub shapes: HashMap<usize, Vec<usize>>,
pub memory_usage: HashMap<usize, usize>,
pub config: CompilationConfig,
pub stats: CompilationStats,
pub compiled_at: SystemTime,
}
impl CompiledGraph {
pub fn node_count(&self) -> usize {
self.graph.nodes.len()
}
pub fn total_memory(&self) -> usize {
self.memory_usage.values().sum()
}
pub fn is_valid(&self) -> bool {
if self.graph.nodes.is_empty() {
return false;
}
if self.schedule.execution_order.len() != self.graph.nodes.len() {
return false;
}
true
}
pub fn summary(&self) -> String {
format!(
"CompiledGraph: {} nodes, {} steps, {:.2}MB memory, compiled in {:.2}ms",
self.node_count(),
self.stats.execution_steps,
self.total_memory() as f64 / 1_000_000.0,
self.stats.compilation_time.as_secs_f64() * 1000.0
)
}
}
pub struct GraphCompiler {
config: CompilationConfig,
optimizer: GraphOptimizer,
validator: GraphValidator,
scheduler: Scheduler,
}
impl GraphCompiler {
pub fn new(config: CompilationConfig) -> Self {
GraphCompiler {
config,
optimizer: GraphOptimizer::new(),
validator: GraphValidator::new(),
scheduler: Scheduler::new(SchedulingStrategy::Balanced),
}
}
pub fn with_default_config() -> Self {
Self::new(CompilationConfig::default())
}
pub fn compile(&mut self, graph: &EinsumGraph) -> Result<CompiledGraph, ExecutorError> {
let start_time = SystemTime::now();
let original_nodes = graph.nodes.len();
let validation_result = self.validator.validate(graph);
if !validation_result.is_valid {
return Err(ExecutorError::GraphValidationError(format!(
"Graph validation failed: {}",
validation_result
.errors
.first()
.map(|e| e.as_str())
.unwrap_or("unknown error")
)));
}
let optimized_graph = graph.clone();
let opt_result = match self.config.optimization_level {
OptimizationLevel::None => OptimizationResult {
fusion_opportunities: vec![],
dead_nodes: vec![],
redundant_computations: vec![],
estimated_improvement: 0.0,
},
OptimizationLevel::Basic
| OptimizationLevel::Moderate
| OptimizationLevel::Aggressive => {
self.optimizer.analyze(&optimized_graph)
}
};
let schedule = self.scheduler.schedule(&optimized_graph);
let shapes = if self.config.enable_shape_inference {
let _shape_ctx = ShapeInferenceContext::new();
HashMap::new()
} else {
HashMap::new()
};
let memory_usage = if self.config.enable_memory_estimation {
use crate::capabilities::DType;
let estimator = MemoryEstimator::new(DType::F32);
let estimate = estimator.estimate(&optimized_graph);
let mut per_node: HashMap<usize, usize> = HashMap::new();
for (idx, mem) in estimate.intermediate_memory.iter().enumerate() {
per_node.insert(idx, mem.bytes);
}
per_node
} else {
HashMap::new()
};
let compilation_time = start_time.elapsed().unwrap_or(Duration::from_secs(0));
let stats = CompilationStats {
compilation_time,
original_nodes,
optimized_nodes: optimized_graph.nodes.len(),
fusions_applied: opt_result.fusion_opportunities.len(),
dead_nodes_eliminated: opt_result.dead_nodes.len(),
estimated_memory_bytes: memory_usage.values().sum(),
execution_steps: schedule.execution_order.len(),
};
Ok(CompiledGraph {
graph: optimized_graph,
schedule,
shapes,
memory_usage,
config: self.config.clone(),
stats,
compiled_at: SystemTime::now(),
})
}
pub fn set_config(&mut self, config: CompilationConfig) {
self.config = config;
}
pub fn config(&self) -> &CompilationConfig {
&self.config
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CompilationKey {
pub graph_hash: u64,
pub optimization_level: OptimizationLevel,
pub target_device: Option<String>,
}
impl CompilationKey {
pub fn new(graph: &EinsumGraph, config: &CompilationConfig) -> Self {
CompilationKey {
graph_hash: Self::hash_graph(graph),
optimization_level: config.optimization_level,
target_device: config.target_device.clone(),
}
}
fn hash_graph(graph: &EinsumGraph) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
graph.nodes.len().hash(&mut hasher);
for node in &graph.nodes {
match &node.op {
tensorlogic_ir::OpType::Einsum { spec } => {
"einsum".hash(&mut hasher);
spec.hash(&mut hasher);
}
tensorlogic_ir::OpType::Reduce { op, axes } => {
"reduce".hash(&mut hasher);
op.hash(&mut hasher);
axes.hash(&mut hasher);
}
tensorlogic_ir::OpType::ElemUnary { op } => {
"elemunary".hash(&mut hasher);
op.hash(&mut hasher);
}
tensorlogic_ir::OpType::ElemBinary { op } => {
"elembinary".hash(&mut hasher);
op.hash(&mut hasher);
}
}
node.inputs.hash(&mut hasher);
node.outputs.hash(&mut hasher);
}
hasher.finish()
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: usize,
pub misses: usize,
pub size: usize,
pub time_saved: Duration,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
pub struct CompilationCache {
cache: Arc<RwLock<HashMap<CompilationKey, Arc<CompiledGraph>>>>,
stats: Arc<RwLock<CacheStats>>,
max_size: usize,
}
impl CompilationCache {
pub fn new(max_size: usize) -> Self {
CompilationCache {
cache: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(CacheStats::default())),
max_size,
}
}
pub fn with_default_size() -> Self {
Self::new(100)
}
pub fn get(&self, key: &CompilationKey) -> Option<Arc<CompiledGraph>> {
let cache = self.cache.read().expect("lock should not be poisoned");
let result = cache.get(key).cloned();
let mut stats = self.stats.write().expect("lock should not be poisoned");
if let Some(ref compiled) = result {
stats.hits += 1;
stats.time_saved += compiled.stats.compilation_time;
} else {
stats.misses += 1;
}
result
}
pub fn insert(&self, key: CompilationKey, compiled: CompiledGraph) {
let mut cache = self.cache.write().expect("lock should not be poisoned");
if cache.len() >= self.max_size && !cache.contains_key(&key) {
if let Some(oldest_key) = cache.keys().next().cloned() {
cache.remove(&oldest_key);
}
}
cache.insert(key, Arc::new(compiled));
let mut stats = self.stats.write().expect("lock should not be poisoned");
stats.size = cache.len();
}
pub fn clear(&self) {
let mut cache = self.cache.write().expect("lock should not be poisoned");
cache.clear();
let mut stats = self.stats.write().expect("lock should not be poisoned");
stats.size = 0;
}
pub fn stats(&self) -> CacheStats {
self.stats
.read()
.expect("lock should not be poisoned")
.clone()
}
pub fn len(&self) -> usize {
self.cache
.read()
.expect("lock should not be poisoned")
.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub trait TlCompilableExecutor {
fn compile_graph(
&mut self,
graph: &EinsumGraph,
config: &CompilationConfig,
) -> Result<CompiledGraph, ExecutorError>;
fn execute_compiled(
&mut self,
compiled: &CompiledGraph,
inputs: &HashMap<usize, Box<dyn std::any::Any>>,
) -> Result<HashMap<usize, Box<dyn std::any::Any>>, ExecutorError>;
fn supports_compilation(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::EinsumNode;
fn create_test_graph() -> EinsumGraph {
let mut graph = EinsumGraph::new();
graph.tensors.push("input".to_string());
graph.inputs.push(0);
graph
.nodes
.push(EinsumNode::new("ij->ij", vec![0], vec![1]));
graph
.nodes
.push(EinsumNode::new("ij,jk->ik", vec![1], vec![2]));
graph
.nodes
.push(EinsumNode::new("ik->ik", vec![2], vec![3]));
graph.outputs.push(3);
graph
}
#[test]
fn test_compilation_key_equality() {
let graph1 = create_test_graph();
let graph2 = create_test_graph();
let config = CompilationConfig::default();
let key1 = CompilationKey::new(&graph1, &config);
let key2 = CompilationKey::new(&graph2, &config);
assert_eq!(key1, key2);
}
#[test]
fn test_compilation_key_different_graphs() {
let graph1 = create_test_graph();
let mut graph2 = create_test_graph();
graph2.nodes.push(EinsumNode::new("i->i", vec![3], vec![4]));
let config = CompilationConfig::default();
let key1 = CompilationKey::new(&graph1, &config);
let key2 = CompilationKey::new(&graph2, &config);
assert_ne!(key1, key2);
}
#[test]
fn test_compilation_key_different_config() {
let graph = create_test_graph();
let config1 = CompilationConfig {
optimization_level: OptimizationLevel::Basic,
..Default::default()
};
let config2 = CompilationConfig {
optimization_level: OptimizationLevel::Aggressive,
..Default::default()
};
let key1 = CompilationKey::new(&graph, &config1);
let key2 = CompilationKey::new(&graph, &config2);
assert_ne!(key1, key2);
}
#[test]
fn test_graph_compiler_basic() {
let graph = create_test_graph();
let mut compiler = GraphCompiler::new(CompilationConfig {
optimization_level: OptimizationLevel::Basic,
..Default::default()
});
let result = compiler.compile(&graph);
assert!(result.is_ok());
let compiled = result.expect("unwrap");
assert!(compiled.is_valid());
assert_eq!(compiled.stats.original_nodes, 3);
}
#[test]
fn test_graph_compiler_moderate() {
let graph = create_test_graph();
let mut compiler = GraphCompiler::new(CompilationConfig {
optimization_level: OptimizationLevel::Moderate,
..Default::default()
});
let result = compiler.compile(&graph);
assert!(result.is_ok());
let compiled = result.expect("unwrap");
assert!(compiled.is_valid());
assert!(compiled.stats.compilation_time > Duration::from_secs(0));
}
#[test]
fn test_graph_compiler_aggressive() {
let graph = create_test_graph();
let mut compiler = GraphCompiler::new(CompilationConfig {
optimization_level: OptimizationLevel::Aggressive,
..Default::default()
});
let result = compiler.compile(&graph);
assert!(result.is_ok());
let compiled = result.expect("unwrap");
assert!(compiled.is_valid());
assert_eq!(compiled.node_count(), compiled.stats.optimized_nodes);
}
#[test]
fn test_compiled_graph_summary() {
let graph = create_test_graph();
let mut compiler = GraphCompiler::with_default_config();
let compiled = compiler.compile(&graph).expect("unwrap");
let summary = compiled.summary();
assert!(summary.contains("CompiledGraph"));
assert!(summary.contains("nodes"));
assert!(summary.contains("MB"));
}
#[test]
fn test_compilation_cache_basic() {
let cache = CompilationCache::new(10);
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
let graph = create_test_graph();
let config = CompilationConfig::default();
let key = CompilationKey::new(&graph, &config);
assert!(cache.get(&key).is_none());
let mut compiler = GraphCompiler::with_default_config();
let compiled = compiler.compile(&graph).expect("unwrap");
cache.insert(key.clone(), compiled);
assert_eq!(cache.len(), 1);
assert!(!cache.is_empty());
let cached = cache.get(&key);
assert!(cached.is_some());
}
#[test]
fn test_compilation_cache_eviction() {
let cache = CompilationCache::new(2);
let graph1 = create_test_graph();
let mut graph2 = create_test_graph();
graph2.nodes.push(EinsumNode::new("i->i", vec![3], vec![4]));
let mut graph3 = create_test_graph();
graph3
.nodes
.push(EinsumNode::new("ij->ji", vec![3], vec![5]));
let config = CompilationConfig::default();
let mut compiler = GraphCompiler::with_default_config();
let key1 = CompilationKey::new(&graph1, &config);
let key2 = CompilationKey::new(&graph2, &config);
let key3 = CompilationKey::new(&graph3, &config);
cache.insert(key1.clone(), compiler.compile(&graph1).expect("unwrap"));
cache.insert(key2.clone(), compiler.compile(&graph2).expect("unwrap"));
assert_eq!(cache.len(), 2);
cache.insert(key3.clone(), compiler.compile(&graph3).expect("unwrap"));
assert_eq!(cache.len(), 2);
}
#[test]
fn test_compilation_cache_stats() {
let cache = CompilationCache::new(10);
let graph = create_test_graph();
let config = CompilationConfig::default();
let key = CompilationKey::new(&graph, &config);
let stats = cache.stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
assert_eq!(stats.hit_rate(), 0.0);
cache.get(&key);
let stats = cache.stats();
assert_eq!(stats.misses, 1);
let mut compiler = GraphCompiler::with_default_config();
let compiled = compiler.compile(&graph).expect("unwrap");
cache.insert(key.clone(), compiled);
cache.get(&key);
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert_eq!(stats.hit_rate(), 0.5);
}
#[test]
fn test_compilation_cache_clear() {
let cache = CompilationCache::new(10);
let graph = create_test_graph();
let config = CompilationConfig::default();
let key = CompilationKey::new(&graph, &config);
let mut compiler = GraphCompiler::with_default_config();
let compiled = compiler.compile(&graph).expect("unwrap");
cache.insert(key.clone(), compiled);
assert_eq!(cache.len(), 1);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_optimization_levels() {
let graph = create_test_graph();
let levels = vec![
OptimizationLevel::None,
OptimizationLevel::Basic,
OptimizationLevel::Moderate,
OptimizationLevel::Aggressive,
];
for level in levels {
let mut compiler = GraphCompiler::new(CompilationConfig {
optimization_level: level,
..Default::default()
});
let result = compiler.compile(&graph);
assert!(result.is_ok(), "Compilation failed for level {:?}", level);
let compiled = result.expect("unwrap");
assert!(compiled.is_valid());
}
}
#[test]
fn test_compiled_graph_memory_estimation() {
let graph = create_test_graph();
let mut compiler = GraphCompiler::new(CompilationConfig {
enable_memory_estimation: true,
..Default::default()
});
let compiled = compiler.compile(&graph).expect("unwrap");
let _memory = compiled.total_memory();
}
#[test]
fn test_config_update() {
let mut compiler = GraphCompiler::with_default_config();
let new_config = CompilationConfig {
optimization_level: OptimizationLevel::Aggressive,
enable_parallelism: false,
..Default::default()
};
compiler.set_config(new_config.clone());
let config = compiler.config();
assert_eq!(config.optimization_level, OptimizationLevel::Aggressive);
assert!(!config.enable_parallelism);
}
}