use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex};
use anyhow::Result;
use tensorlogic_ir::{EinsumGraph, TLExpr};
use crate::{
compile_to_einsum_with_config,
config::CompilationConfig,
dead_code::{DceConfig, DeadCodeEliminator},
optimize::pipeline::{OptimizationPipeline, PipelineConfig},
};
#[derive(Debug)]
pub enum JitError {
CompilationFailed(anyhow::Error),
}
impl std::fmt::Display for JitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
JitError::CompilationFailed(e) => write!(f, "JIT compilation failed: {}", e),
}
}
}
impl std::error::Error for JitError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
JitError::CompilationFailed(e) => e.source(),
}
}
}
impl From<anyhow::Error> for JitError {
fn from(e: anyhow::Error) -> Self {
JitError::CompilationFailed(e)
}
}
#[derive(Debug, Clone, Default)]
pub struct JitStats {
pub hot_paths: usize,
pub cold_compilations: usize,
pub jit_hits: usize,
pub jit_upgrades: usize,
}
#[derive(Clone)]
struct JitEntry {
graph: Arc<EinsumGraph>,
hit_count: usize,
}
struct CallRecord {
count: usize,
expr: TLExpr,
}
struct JitCacheInner {
hot_paths: HashMap<u64, JitEntry>,
call_counts: HashMap<u64, CallRecord>,
stats: JitStats,
}
impl JitCacheInner {
fn new() -> Self {
Self {
hot_paths: HashMap::new(),
call_counts: HashMap::new(),
stats: JitStats::default(),
}
}
}
pub struct JitCompiler {
config: CompilationConfig,
pub hot_threshold: usize,
cache: Arc<Mutex<JitCacheInner>>,
}
fn expr_hash(expr: &TLExpr) -> u64 {
let repr = format!("{expr:?}");
let mut hasher = DefaultHasher::new();
repr.hash(&mut hasher);
hasher.finish()
}
impl JitCompiler {
pub fn new(hot_threshold: usize) -> Self {
Self::with_config(CompilationConfig::default(), hot_threshold)
}
pub fn with_config(config: CompilationConfig, hot_threshold: usize) -> Self {
Self {
config,
hot_threshold,
cache: Arc::new(Mutex::new(JitCacheInner::new())),
}
}
pub fn compile(&self, expr: &TLExpr) -> Result<Arc<EinsumGraph>, JitError> {
let key = expr_hash(expr);
{
let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
let record = guard.call_counts.entry(key).or_insert_with(|| CallRecord {
count: 0,
expr: expr.clone(),
});
record.count += 1;
if let Some(arc) = guard.hot_paths.get_mut(&key).map(|entry| {
entry.hit_count += 1;
Arc::clone(&entry.graph)
}) {
guard.stats.jit_hits += 1;
return Ok(arc);
}
}
let cold_graph = compile_to_einsum_with_config(expr, &self.config)?;
let current_count = {
let guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
guard.call_counts.get(&key).map(|r| r.count).unwrap_or(0)
};
if current_count >= self.hot_threshold {
let stored_expr = {
let guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
guard.call_counts.get(&key).map(|r| r.expr.clone())
};
if let Some(original_expr) = stored_expr {
let optimized_graph = self.apply_extra_optimization(&original_expr)?;
let arc = Arc::new(optimized_graph);
let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
if let std::collections::hash_map::Entry::Vacant(slot) = guard.hot_paths.entry(key)
{
slot.insert(JitEntry {
graph: Arc::clone(&arc),
hit_count: 0,
});
guard.stats.jit_upgrades += 1;
guard.stats.hot_paths += 1;
}
guard.stats.cold_compilations += 1;
return Ok(arc);
}
}
let mut guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
guard.stats.cold_compilations += 1;
Ok(Arc::new(cold_graph))
}
fn apply_extra_optimization(&self, expr: &TLExpr) -> Result<EinsumGraph, JitError> {
let aggressive_config = PipelineConfig {
enable_negation_opt: true,
enable_constant_folding: true,
enable_algebraic_simplification: true,
enable_strength_reduction: true,
enable_distributivity: true,
enable_quantifier_opt: true,
enable_dead_code_elimination: true,
max_iterations: 20,
stop_on_fixed_point: true,
};
let pipeline = OptimizationPipeline::with_config(aggressive_config);
let (after_pipeline, _pipeline_stats) = pipeline.optimize(expr);
let dce_config = DceConfig {
eliminate_constant_and: true,
eliminate_constant_or: true,
eliminate_constant_not: true,
eliminate_if_branches: true,
eliminate_unused_let: true,
max_passes: 20,
};
let eliminator = DeadCodeEliminator::new(dce_config);
let (fully_optimized, _dce_stats) = eliminator.run(after_pipeline);
let graph = compile_to_einsum_with_config(&fully_optimized, &self.config)?;
Ok(graph)
}
pub fn stats(&self) -> JitStats {
let guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
guard.stats.clone()
}
pub fn clear_cache(&mut self) {
if let Ok(mut guard) = self.cache.lock() {
*guard = JitCacheInner::new();
}
}
pub fn hot_path_count(&self) -> usize {
let guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
guard.hot_paths.len()
}
pub fn call_count(&self, expr: &TLExpr) -> usize {
let guard = self.cache.lock().unwrap_or_else(|e| e.into_inner());
guard
.call_counts
.get(&expr_hash(expr))
.map(|r| r.count)
.unwrap_or(0)
}
pub fn threshold(&self) -> usize {
self.hot_threshold
}
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::{TLExpr, Term};
fn simple_expr() -> TLExpr {
TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")])
}
fn different_expr() -> TLExpr {
TLExpr::pred("likes", vec![Term::var("a")])
}
#[test]
fn test_cold_path_returns_graph() {
let jit = JitCompiler::new(5);
let graph = jit.compile(&simple_expr()).expect("cold compile");
let _ = graph;
let stats = jit.stats();
assert_eq!(stats.cold_compilations, 1);
assert_eq!(stats.jit_hits, 0);
}
#[test]
fn test_hot_upgrade_at_threshold() {
let jit = JitCompiler::new(3);
let expr = simple_expr();
for _ in 0..3 {
jit.compile(&expr).expect("compile");
}
assert_eq!(jit.hot_path_count(), 1);
let stats = jit.stats();
assert!(stats.jit_upgrades >= 1);
}
#[test]
fn test_jit_hit_after_upgrade() {
let jit = JitCompiler::new(2);
let expr = simple_expr();
jit.compile(&expr).expect("call 1");
jit.compile(&expr).expect("call 2");
jit.compile(&expr).expect("call 3");
let stats = jit.stats();
assert!(
stats.jit_hits >= 1,
"expected at least 1 jit_hit, got {stats:?}"
);
}
#[test]
fn test_clear_cache_resets() {
let mut jit = JitCompiler::new(1);
let expr = simple_expr();
jit.compile(&expr).expect("compile once");
assert_eq!(jit.hot_path_count(), 1);
jit.clear_cache();
assert_eq!(jit.hot_path_count(), 0);
assert_eq!(jit.call_count(&expr), 0);
}
#[test]
fn test_different_exprs_tracked_separately() {
let jit = JitCompiler::new(10);
let e1 = simple_expr();
let e2 = different_expr();
for _ in 0..3 {
jit.compile(&e1).expect("e1");
}
jit.compile(&e2).expect("e2");
assert_eq!(jit.call_count(&e1), 3);
assert_eq!(jit.call_count(&e2), 1);
}
#[test]
fn test_threshold_one_upgrades_immediately() {
let jit = JitCompiler::new(1);
let expr = simple_expr();
jit.compile(&expr).expect("first call");
assert_eq!(jit.hot_path_count(), 1);
}
#[test]
fn test_stats_consistent() {
let jit = JitCompiler::new(3);
let expr = simple_expr();
let total = 5usize;
for _ in 0..total {
jit.compile(&expr).expect("compile");
}
let stats = jit.stats();
assert_eq!(
stats.cold_compilations + stats.jit_hits,
total,
"cold + hits must equal total calls; got {stats:?}"
);
}
#[test]
fn test_hot_graph_not_empty() {
let jit = JitCompiler::new(2);
let expr = simple_expr();
jit.compile(&expr).expect("call 1");
jit.compile(&expr).expect("call 2");
let graph = jit.compile(&expr).expect("call 3 (hot)");
let _ = graph;
}
#[test]
fn test_threshold_accessor() {
let jit = JitCompiler::new(7);
assert_eq!(jit.threshold(), 7);
}
}