use crate::compilation::{CompilationConfig, CompiledGraph, GraphCompiler, OptimizationLevel};
use crate::error::ExecutorError;
use crate::shape::TensorShape;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tensorlogic_ir::EinsumGraph;
#[derive(Debug, Clone)]
pub struct JitConfig {
pub initial_optimization: OptimizationLevel,
pub hot_path_optimization: OptimizationLevel,
pub hot_path_threshold: usize,
pub enable_specialization: bool,
pub max_specializations: usize,
pub enable_adaptive_optimization: bool,
pub profiling_window: usize,
pub cache_size: usize,
pub enable_deoptimization: bool,
pub deoptimization_threshold: usize,
}
impl Default for JitConfig {
fn default() -> Self {
JitConfig {
initial_optimization: OptimizationLevel::Basic,
hot_path_optimization: OptimizationLevel::Aggressive,
hot_path_threshold: 10,
enable_specialization: true,
max_specializations: 5,
enable_adaptive_optimization: true,
profiling_window: 100,
cache_size: 1000,
enable_deoptimization: true,
deoptimization_threshold: 1,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct JitKey {
pub graph_hash: u64,
pub specialization: Option<SpecializationContext>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SpecializationContext {
pub input_shapes: Vec<Vec<usize>>,
pub device: Option<String>,
}
impl SpecializationContext {
pub fn from_shapes(shapes: &[TensorShape]) -> Self {
SpecializationContext {
input_shapes: shapes
.iter()
.map(|s| {
s.dims
.iter()
.filter_map(|d| d.as_static())
.collect::<Vec<_>>()
})
.collect(),
device: None,
}
}
pub fn with_device(mut self, device: String) -> Self {
self.device = Some(device);
self
}
}
#[derive(Debug, Clone)]
pub struct JitEntryStats {
pub execution_count: usize,
pub total_execution_time: Duration,
pub avg_execution_time: Duration,
pub optimization_level: OptimizationLevel,
pub last_executed: Instant,
pub compiled_at: Instant,
pub is_specialized: bool,
}
impl Default for JitEntryStats {
fn default() -> Self {
JitEntryStats {
execution_count: 0,
total_execution_time: Duration::from_secs(0),
avg_execution_time: Duration::from_secs(0),
optimization_level: OptimizationLevel::Basic,
last_executed: Instant::now(),
compiled_at: Instant::now(),
is_specialized: false,
}
}
}
impl JitEntryStats {
pub fn record_execution(&mut self, duration: Duration) {
self.execution_count += 1;
self.total_execution_time += duration;
self.avg_execution_time = self.total_execution_time / self.execution_count as u32;
self.last_executed = Instant::now();
}
pub fn is_hot(&self, threshold: usize) -> bool {
self.execution_count >= threshold
}
pub fn is_cold(&self, threshold: usize, window: Duration) -> bool {
let time_since_last = Instant::now().duration_since(self.last_executed);
time_since_last > window && self.execution_count < threshold
}
}
#[derive(Debug, Clone)]
pub struct JitCacheEntry {
pub compiled: CompiledGraph,
pub stats: JitEntryStats,
}
pub struct JitCache {
cache: Arc<RwLock<HashMap<JitKey, JitCacheEntry>>>,
config: JitConfig,
}
impl JitCache {
pub fn new(config: JitConfig) -> Self {
JitCache {
cache: Arc::new(RwLock::new(HashMap::new())),
config,
}
}
pub fn insert(&self, key: JitKey, compiled: CompiledGraph, is_specialized: bool) {
let mut cache = self.cache.write().expect("lock should not be poisoned");
if cache.len() >= self.config.cache_size {
self.evict_lru(&mut cache);
}
let stats = JitEntryStats {
optimization_level: compiled.config.optimization_level,
is_specialized,
..Default::default()
};
cache.insert(key, JitCacheEntry { compiled, stats });
}
pub fn get(&self, key: &JitKey) -> Option<CompiledGraph> {
let cache = self.cache.read().expect("lock should not be poisoned");
cache.get(key).map(|entry| entry.compiled.clone())
}
pub fn record_execution(&self, key: &JitKey, duration: Duration) {
let mut cache = self.cache.write().expect("lock should not be poisoned");
if let Some(entry) = cache.get_mut(key) {
entry.stats.record_execution(duration);
}
}
pub fn get_stats(&self, key: &JitKey) -> Option<JitEntryStats> {
let cache = self.cache.read().expect("lock should not be poisoned");
cache.get(key).map(|entry| entry.stats.clone())
}
pub fn get_hot_paths(&self) -> Vec<(JitKey, JitEntryStats)> {
let cache = self.cache.read().expect("lock should not be poisoned");
cache
.iter()
.filter(|(_, entry)| entry.stats.is_hot(self.config.hot_path_threshold))
.map(|(key, entry)| (key.clone(), entry.stats.clone()))
.collect()
}
pub fn get_cold_paths(&self) -> Vec<(JitKey, JitEntryStats)> {
let cache = self.cache.read().expect("lock should not be poisoned");
let window = Duration::from_secs(300); cache
.iter()
.filter(|(_, entry)| {
entry
.stats
.is_cold(self.config.deoptimization_threshold, window)
})
.map(|(key, entry)| (key.clone(), entry.stats.clone()))
.collect()
}
fn evict_lru(&self, cache: &mut HashMap<JitKey, JitCacheEntry>) {
if let Some((key, _)) = cache
.iter()
.min_by_key(|(_, entry)| entry.stats.last_executed)
{
let key = key.clone();
cache.remove(&key);
}
}
pub fn clear(&self) {
let mut cache = self.cache.write().expect("lock should not be poisoned");
cache.clear();
}
pub fn cache_stats(&self) -> JitCacheStats {
let cache = self.cache.read().expect("lock should not be poisoned");
let total_entries = cache.len();
let hot_entries = cache
.values()
.filter(|e| e.stats.is_hot(self.config.hot_path_threshold))
.count();
let specialized_entries = cache.values().filter(|e| e.stats.is_specialized).count();
let total_executions = cache.values().map(|e| e.stats.execution_count).sum();
JitCacheStats {
total_entries,
hot_entries,
specialized_entries,
total_executions,
cache_capacity: self.config.cache_size,
}
}
}
#[derive(Debug, Clone)]
pub struct JitCacheStats {
pub total_entries: usize,
pub hot_entries: usize,
pub specialized_entries: usize,
pub total_executions: usize,
pub cache_capacity: usize,
}
pub struct HotPathDetector {
config: JitConfig,
}
impl HotPathDetector {
pub fn new(config: JitConfig) -> Self {
HotPathDetector { config }
}
pub fn detect_hot_paths(&self, cache: &JitCache) -> Vec<JitKey> {
cache
.get_hot_paths()
.into_iter()
.map(|(key, _)| key)
.collect()
}
pub fn recommend_recompilation(&self, cache: &JitCache) -> Vec<(JitKey, OptimizationLevel)> {
cache
.get_hot_paths()
.into_iter()
.filter_map(|(key, stats)| {
if stats.optimization_level < self.config.hot_path_optimization {
Some((key, self.config.hot_path_optimization))
} else {
None
}
})
.collect()
}
pub fn recommend_deoptimization(&self, cache: &JitCache) -> Vec<JitKey> {
if !self.config.enable_deoptimization {
return Vec::new();
}
cache
.get_cold_paths()
.into_iter()
.map(|(key, _)| key)
.collect()
}
}
pub struct AdaptiveOptimizer {
config: JitConfig,
hot_path_detector: HotPathDetector,
}
impl AdaptiveOptimizer {
pub fn new(config: JitConfig) -> Self {
AdaptiveOptimizer {
hot_path_detector: HotPathDetector::new(config.clone()),
config,
}
}
pub fn analyze_and_recommend(&self, cache: &JitCache) -> AdaptiveOptimizationPlan {
let hot_paths = self.hot_path_detector.recommend_recompilation(cache);
let cold_paths = self.hot_path_detector.recommend_deoptimization(cache);
AdaptiveOptimizationPlan {
recompile: hot_paths,
deoptimize: cold_paths,
}
}
pub fn optimize(&self, cache: &JitCache) -> Result<usize, ExecutorError> {
let plan = self.analyze_and_recommend(cache);
let mut optimized_count = 0;
for (key, opt_level) in plan.recompile {
if let Some(entry) = cache
.cache
.read()
.expect("lock should not be poisoned")
.get(&key)
{
let graph = &entry.compiled.graph;
let mut config = entry.compiled.config.clone();
config.optimization_level = opt_level;
let mut new_compiler = GraphCompiler::new(config);
let recompiled = new_compiler.compile(graph)?;
cache
.cache
.write()
.expect("lock should not be poisoned")
.get_mut(&key)
.expect("key just retrieved from cache")
.compiled = recompiled;
optimized_count += 1;
}
}
for key in plan.deoptimize {
cache
.cache
.write()
.expect("lock should not be poisoned")
.remove(&key);
}
Ok(optimized_count)
}
pub fn config(&self) -> &JitConfig {
&self.config
}
pub fn hot_path_detector(&self) -> &HotPathDetector {
&self.hot_path_detector
}
}
#[derive(Debug, Clone)]
pub struct AdaptiveOptimizationPlan {
pub recompile: Vec<(JitKey, OptimizationLevel)>,
pub deoptimize: Vec<JitKey>,
}
pub struct JitCompiler {
config: JitConfig,
cache: JitCache,
adaptive_optimizer: AdaptiveOptimizer,
}
impl JitCompiler {
pub fn new(config: JitConfig) -> Self {
JitCompiler {
cache: JitCache::new(config.clone()),
adaptive_optimizer: AdaptiveOptimizer::new(config.clone()),
config,
}
}
pub fn with_default_config() -> Self {
Self::new(JitConfig::default())
}
pub fn compile_or_retrieve(
&mut self,
graph: &EinsumGraph,
input_shapes: &[TensorShape],
) -> Result<CompiledGraph, ExecutorError> {
let key = self.create_key(graph, input_shapes);
if let Some(compiled) = self.cache.get(&key) {
return Ok(compiled);
}
let config = CompilationConfig {
optimization_level: self.config.initial_optimization,
enable_shape_inference: true,
enable_memory_estimation: true,
enable_caching: true,
enable_parallelism: true,
..Default::default()
};
let mut compiler = GraphCompiler::new(config);
let compiled = compiler.compile(graph)?;
let is_specialized = self.config.enable_specialization && !input_shapes.is_empty();
self.cache.insert(key, compiled.clone(), is_specialized);
Ok(compiled)
}
pub fn record_execution(
&self,
graph: &EinsumGraph,
input_shapes: &[TensorShape],
duration: Duration,
) {
let key = self.create_key(graph, input_shapes);
self.cache.record_execution(&key, duration);
}
pub fn optimize_hot_paths(&mut self) -> Result<usize, ExecutorError> {
if !self.config.enable_adaptive_optimization {
return Ok(0);
}
self.adaptive_optimizer.optimize(&self.cache)
}
pub fn cache_stats(&self) -> JitCacheStats {
self.cache.cache_stats()
}
pub fn clear_cache(&self) {
self.cache.clear();
}
fn create_key(&self, graph: &EinsumGraph, input_shapes: &[TensorShape]) -> JitKey {
let graph_hash = self.hash_graph(graph);
let specialization = if self.config.enable_specialization && !input_shapes.is_empty() {
Some(SpecializationContext::from_shapes(input_shapes))
} else {
None
};
JitKey {
graph_hash,
specialization,
}
}
fn hash_graph(&self, graph: &EinsumGraph) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
graph.nodes.len().hash(&mut hasher);
hasher.finish()
}
}
pub trait TlJitExecutor {
fn jit_compiler(&mut self) -> &mut JitCompiler;
fn enable_jit(&mut self);
fn disable_jit(&mut self);
fn is_jit_enabled(&self) -> bool;
fn optimize_hot_paths(&mut self) -> Result<usize, ExecutorError> {
self.jit_compiler().optimize_hot_paths()
}
fn jit_stats(&self) -> JitCacheStats;
}
#[derive(Debug, Clone)]
pub struct JitStats {
pub total_compilations: usize,
pub cache_hits: usize,
pub cache_misses: usize,
pub recompilations: usize,
pub deoptimizations: usize,
pub avg_compilation_time: Duration,
pub total_time_saved: Duration,
}
impl Default for JitStats {
fn default() -> Self {
JitStats {
total_compilations: 0,
cache_hits: 0,
cache_misses: 0,
recompilations: 0,
deoptimizations: 0,
avg_compilation_time: Duration::from_secs(0),
total_time_saved: Duration::from_secs(0),
}
}
}
impl JitStats {
pub fn cache_hit_rate(&self) -> f64 {
if self.cache_hits + self.cache_misses == 0 {
return 0.0;
}
self.cache_hits as f64 / (self.cache_hits + self.cache_misses) as f64
}
pub fn summary(&self) -> String {
format!(
"JIT Stats: {} compilations, {:.1}% cache hit rate, {} recompilations, {:.2}ms avg compile time",
self.total_compilations,
self.cache_hit_rate() * 100.0,
self.recompilations,
self.avg_compilation_time.as_secs_f64() * 1000.0
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jit_config_default() {
let config = JitConfig::default();
assert_eq!(config.initial_optimization, OptimizationLevel::Basic);
assert_eq!(config.hot_path_optimization, OptimizationLevel::Aggressive);
assert_eq!(config.hot_path_threshold, 10);
assert!(config.enable_specialization);
assert!(config.enable_adaptive_optimization);
}
#[test]
fn test_specialization_context() {
let shapes = vec![
TensorShape::static_shape(vec![2, 3]),
TensorShape::static_shape(vec![3, 4]),
];
let ctx = SpecializationContext::from_shapes(&shapes);
assert_eq!(ctx.input_shapes.len(), 2);
assert_eq!(ctx.input_shapes[0], vec![2, 3]);
assert_eq!(ctx.input_shapes[1], vec![3, 4]);
}
#[test]
fn test_jit_entry_stats() {
let mut stats = JitEntryStats::default();
assert_eq!(stats.execution_count, 0);
assert!(!stats.is_hot(10));
for _ in 0..15 {
stats.record_execution(Duration::from_millis(10));
}
assert_eq!(stats.execution_count, 15);
assert!(stats.is_hot(10));
assert_eq!(stats.total_execution_time, Duration::from_millis(150));
}
#[test]
fn test_jit_cache_insert_retrieve() {
let config = JitConfig::default();
let cache = JitCache::new(config);
let graph = EinsumGraph::new();
let compiled = CompiledGraph {
graph: graph.clone(),
schedule: crate::scheduling::ExecutionSchedule {
execution_order: Vec::new(),
device_placement: HashMap::new(),
parallel_groups: Vec::new(),
estimated_cost: 0.0,
},
shapes: HashMap::new(),
memory_usage: HashMap::new(),
config: CompilationConfig::default(),
stats: crate::compilation::CompilationStats::default(),
compiled_at: std::time::SystemTime::now(),
};
let key = JitKey {
graph_hash: 12345,
specialization: None,
};
cache.insert(key.clone(), compiled.clone(), false);
let retrieved = cache.get(&key);
assert!(retrieved.is_some());
}
#[test]
fn test_jit_cache_eviction() {
let config = JitConfig {
cache_size: 2, ..Default::default()
};
let cache = JitCache::new(config);
let graph = EinsumGraph::new();
let compiled = CompiledGraph {
graph: graph.clone(),
schedule: crate::scheduling::ExecutionSchedule {
execution_order: Vec::new(),
device_placement: HashMap::new(),
parallel_groups: Vec::new(),
estimated_cost: 0.0,
},
shapes: HashMap::new(),
memory_usage: HashMap::new(),
config: CompilationConfig::default(),
stats: crate::compilation::CompilationStats::default(),
compiled_at: std::time::SystemTime::now(),
};
for i in 0..3 {
let key = JitKey {
graph_hash: i,
specialization: None,
};
cache.insert(key, compiled.clone(), false);
std::thread::sleep(Duration::from_millis(10)); }
let stats = cache.cache_stats();
assert_eq!(stats.total_entries, 2); }
#[test]
fn test_hot_path_detection() {
let config = JitConfig::default();
let cache = JitCache::new(config.clone());
let detector = HotPathDetector::new(config);
let graph = EinsumGraph::new();
let compiled = CompiledGraph {
graph: graph.clone(),
schedule: crate::scheduling::ExecutionSchedule {
execution_order: Vec::new(),
device_placement: HashMap::new(),
parallel_groups: Vec::new(),
estimated_cost: 0.0,
},
shapes: HashMap::new(),
memory_usage: HashMap::new(),
config: CompilationConfig::default(),
stats: crate::compilation::CompilationStats::default(),
compiled_at: std::time::SystemTime::now(),
};
let key = JitKey {
graph_hash: 123,
specialization: None,
};
cache.insert(key.clone(), compiled, false);
for _ in 0..15 {
cache.record_execution(&key, Duration::from_millis(10));
}
let hot_paths = detector.detect_hot_paths(&cache);
assert_eq!(hot_paths.len(), 1);
assert_eq!(hot_paths[0].graph_hash, 123);
}
#[test]
fn test_jit_compiler_basic() {
let mut jit = JitCompiler::with_default_config();
let graph = EinsumGraph::new();
let shapes = vec![];
let result = jit.compile_or_retrieve(&graph, &shapes);
assert!(result.is_ok());
let result2 = jit.compile_or_retrieve(&graph, &shapes);
assert!(result2.is_ok());
}
#[test]
fn test_jit_stats() {
let stats = JitStats::default();
assert_eq!(stats.cache_hit_rate(), 0.0);
let stats = JitStats {
cache_hits: 8,
cache_misses: 2,
..Default::default()
};
assert_eq!(stats.cache_hit_rate(), 0.8);
}
#[test]
fn test_adaptive_optimization_plan() {
let plan = AdaptiveOptimizationPlan {
recompile: vec![(
JitKey {
graph_hash: 123,
specialization: None,
},
OptimizationLevel::Aggressive,
)],
deoptimize: vec![],
};
assert_eq!(plan.recompile.len(), 1);
assert_eq!(plan.deoptimize.len(), 0);
}
#[test]
fn test_jit_cache_stats() {
let config = JitConfig::default();
let cache = JitCache::new(config);
let stats = cache.cache_stats();
assert_eq!(stats.total_entries, 0);
assert_eq!(stats.hot_entries, 0);
assert_eq!(stats.total_executions, 0);
}
#[test]
fn test_specialization_with_device() {
let shapes = vec![TensorShape::static_shape(vec![2, 3])];
let ctx = SpecializationContext::from_shapes(&shapes).with_device("cuda:0".to_string());
assert_eq!(ctx.device, Some("cuda:0".to_string()));
assert_eq!(ctx.input_shapes[0], vec![2, 3]);
}
#[test]
fn test_jit_entry_cold_detection() {
let mut stats = JitEntryStats::default();
stats.record_execution(Duration::from_millis(10));
assert!(!stats.is_cold(5, Duration::from_millis(100)));
std::thread::sleep(Duration::from_millis(150));
assert!(stats.is_cold(5, Duration::from_millis(100)));
}
#[test]
fn test_jit_cache_clear() {
let config = JitConfig::default();
let cache = JitCache::new(config);
let graph = EinsumGraph::new();
let compiled = CompiledGraph {
graph: graph.clone(),
schedule: crate::scheduling::ExecutionSchedule {
execution_order: Vec::new(),
device_placement: HashMap::new(),
parallel_groups: Vec::new(),
estimated_cost: 0.0,
},
shapes: HashMap::new(),
memory_usage: HashMap::new(),
config: CompilationConfig::default(),
stats: crate::compilation::CompilationStats::default(),
compiled_at: std::time::SystemTime::now(),
};
let key = JitKey {
graph_hash: 123,
specialization: None,
};
cache.insert(key.clone(), compiled, false);
assert_eq!(cache.cache_stats().total_entries, 1);
cache.clear();
assert_eq!(cache.cache_stats().total_entries, 0);
}
}