use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex};
use anyhow::Result;
use tensorlogic_ir::{EinsumGraph, TLExpr};
use crate::config::CompilationConfig;
use crate::CompilerContext;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ExprFingerprint {
pub(crate) data: String,
}
impl ExprFingerprint {
pub fn compute(expr_repr: &str) -> Self {
ExprFingerprint {
data: expr_repr.to_string(),
}
}
}
impl std::fmt::Display for ExprFingerprint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let preview_len = self.data.len().min(32);
write!(f, "fp:{}", &self.data[..preview_len])
}
}
#[derive(Debug, Clone)]
pub struct CachedResult {
pub graph: EinsumGraph,
pub hit_count: u64,
pub memory_bytes: usize,
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub current_entries: usize,
pub total_memory_bytes: usize,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn total_lookups(&self) -> u64 {
self.hits + self.misses
}
}
pub struct LruCompilationCache {
capacity: usize,
entries: HashMap<ExprFingerprint, CachedResult>,
lru_order: std::collections::VecDeque<ExprFingerprint>,
stats: CacheStats,
}
impl LruCompilationCache {
pub fn new(capacity: usize) -> Self {
LruCompilationCache {
capacity: capacity.max(1),
entries: HashMap::new(),
lru_order: std::collections::VecDeque::new(),
stats: CacheStats::default(),
}
}
pub fn insert(&mut self, fp: ExprFingerprint, graph: EinsumGraph) {
let memory_bytes = graph.nodes.len() * 256;
if self.entries.contains_key(&fp) {
if let Some(entry) = self.entries.get_mut(&fp) {
self.stats.total_memory_bytes = self
.stats
.total_memory_bytes
.saturating_sub(entry.memory_bytes);
entry.graph = graph;
entry.memory_bytes = memory_bytes;
self.stats.total_memory_bytes += memory_bytes;
}
if let Some(pos) = self.lru_order.iter().position(|x| x == &fp) {
self.lru_order.remove(pos);
}
self.lru_order.push_back(fp);
} else {
if self.entries.len() >= self.capacity {
if let Some(oldest) = self.lru_order.pop_front() {
if let Some(evicted) = self.entries.remove(&oldest) {
self.stats.total_memory_bytes = self
.stats
.total_memory_bytes
.saturating_sub(evicted.memory_bytes);
}
self.stats.evictions += 1;
}
}
self.stats.total_memory_bytes += memory_bytes;
self.lru_order.push_back(fp.clone());
self.entries.insert(
fp,
CachedResult {
graph,
hit_count: 0,
memory_bytes,
},
);
}
self.stats.current_entries = self.entries.len();
}
pub fn get(&mut self, fp: &ExprFingerprint) -> Option<&CachedResult> {
if self.entries.contains_key(fp) {
if let Some(pos) = self.lru_order.iter().position(|x| x == fp) {
self.lru_order.remove(pos);
}
self.lru_order.push_back(fp.clone());
if let Some(entry) = self.entries.get_mut(fp) {
entry.hit_count += 1;
}
self.stats.hits += 1;
self.entries.get(fp)
} else {
self.stats.misses += 1;
None
}
}
pub fn contains(&self, fp: &ExprFingerprint) -> bool {
self.entries.contains_key(fp)
}
pub fn invalidate(&mut self, fp: &ExprFingerprint) -> bool {
if let Some(evicted) = self.entries.remove(fp) {
self.stats.total_memory_bytes = self
.stats
.total_memory_bytes
.saturating_sub(evicted.memory_bytes);
if let Some(pos) = self.lru_order.iter().position(|x| x == fp) {
self.lru_order.remove(pos);
}
self.stats.current_entries = self.entries.len();
true
} else {
false
}
}
pub fn clear(&mut self) {
self.entries.clear();
self.lru_order.clear();
self.stats.current_entries = 0;
self.stats.total_memory_bytes = 0;
}
pub fn stats(&self) -> &CacheStats {
&self.stats
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn capacity(&self) -> usize {
self.capacity
}
}
impl Default for LruCompilationCache {
fn default() -> Self {
Self::new(256)
}
}
type CompileFn =
Box<dyn Fn(&TLExpr) -> std::result::Result<EinsumGraph, String> + Send + Sync + 'static>;
pub struct CachingCompiler {
cache: LruCompilationCache,
compile_fn: CompileFn,
}
impl CachingCompiler {
pub fn new<F>(capacity: usize, compile_fn: F) -> Self
where
F: Fn(&TLExpr) -> std::result::Result<EinsumGraph, String> + Send + Sync + 'static,
{
CachingCompiler {
cache: LruCompilationCache::new(capacity),
compile_fn: Box::new(compile_fn),
}
}
pub fn compile(&mut self, expr: &TLExpr) -> std::result::Result<EinsumGraph, String> {
let fp = Self::fingerprint(expr);
if let Some(cached) = self.cache.get(&fp) {
return Ok(cached.graph.clone());
}
let result = (self.compile_fn)(expr)?;
self.cache.insert(fp, result.clone());
Ok(result)
}
pub fn compile_batch(
&mut self,
exprs: &[TLExpr],
) -> Vec<std::result::Result<EinsumGraph, String>> {
exprs.iter().map(|e| self.compile(e)).collect()
}
pub fn cache_stats(&self) -> &CacheStats {
self.cache.stats()
}
pub fn invalidate(&mut self, expr: &TLExpr) -> bool {
let fp = Self::fingerprint(expr);
self.cache.invalidate(&fp)
}
pub fn fingerprint(expr: &TLExpr) -> ExprFingerprint {
ExprFingerprint::compute(&Self::structural_repr(expr))
}
fn structural_repr(expr: &TLExpr) -> String {
format!("{:?}", expr)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct CacheKey {
expr_hash: u64,
config_hash: u64,
domain_hash: u64,
}
impl CacheKey {
fn new(expr: &TLExpr, config: &CompilationConfig, ctx: &CompilerContext) -> Self {
use std::collections::hash_map::DefaultHasher;
let mut expr_hasher = DefaultHasher::new();
format!("{:?}", expr).hash(&mut expr_hasher);
let expr_hash = expr_hasher.finish();
let mut config_hasher = DefaultHasher::new();
format!("{:?}", config).hash(&mut config_hasher);
let config_hash = config_hasher.finish();
let mut domain_hasher = DefaultHasher::new();
for (name, domain) in &ctx.domains {
name.hash(&mut domain_hasher);
domain.cardinality.hash(&mut domain_hasher);
}
let domain_hash = domain_hasher.finish();
CacheKey {
expr_hash,
config_hash,
domain_hash,
}
}
}
#[derive(Clone)]
struct ThreadSafeCachedResult {
graph: EinsumGraph,
hit_count: usize,
}
pub struct CompilationCache {
cache: Arc<Mutex<HashMap<CacheKey, ThreadSafeCachedResult>>>,
max_size: usize,
stats: Arc<Mutex<CacheStats>>,
}
impl CompilationCache {
pub fn new(max_size: usize) -> Self {
Self {
cache: Arc::new(Mutex::new(HashMap::new())),
max_size,
stats: Arc::new(Mutex::new(CacheStats::default())),
}
}
pub fn default_size() -> Self {
Self::new(1000)
}
pub fn max_size(&self) -> usize {
self.max_size
}
pub fn get_or_compile<F>(
&self,
expr: &TLExpr,
ctx: &mut CompilerContext,
compile_fn: F,
) -> Result<EinsumGraph>
where
F: FnOnce(&TLExpr, &mut CompilerContext) -> Result<EinsumGraph>,
{
let key = CacheKey::new(expr, &ctx.config, ctx);
{
let mut cache = self
.cache
.lock()
.map_err(|e| anyhow::anyhow!("cache lock poisoned: {}", e))?;
if let Some(cached) = cache.get_mut(&key) {
cached.hit_count += 1;
let mut stats = self
.stats
.lock()
.map_err(|e| anyhow::anyhow!("stats lock poisoned: {}", e))?;
stats.hits += 1;
return Ok(cached.graph.clone());
}
}
{
let mut stats = self
.stats
.lock()
.map_err(|e| anyhow::anyhow!("stats lock poisoned: {}", e))?;
stats.misses += 1;
}
let graph = compile_fn(expr, ctx)?;
{
let mut cache = self
.cache
.lock()
.map_err(|e| anyhow::anyhow!("cache lock poisoned: {}", e))?;
if cache.len() >= self.max_size {
let min_key = cache
.iter()
.min_by_key(|(_, v)| v.hit_count)
.map(|(k, _)| k.clone());
if let Some(key_to_evict) = min_key {
cache.remove(&key_to_evict);
let mut stats = self
.stats
.lock()
.map_err(|e| anyhow::anyhow!("stats lock poisoned: {}", e))?;
stats.evictions += 1;
}
}
cache.insert(
key,
ThreadSafeCachedResult {
graph: graph.clone(),
hit_count: 0,
},
);
let mut stats = self
.stats
.lock()
.map_err(|e| anyhow::anyhow!("stats lock poisoned: {}", e))?;
stats.current_entries = cache.len();
}
Ok(graph)
}
pub fn stats(&self) -> CacheStats {
self.stats.lock().map(|g| g.clone()).unwrap_or_default()
}
pub fn clear(&self) {
if let Ok(mut cache) = self.cache.lock() {
cache.clear();
}
if let Ok(mut stats) = self.stats.lock() {
stats.current_entries = 0;
stats.total_memory_bytes = 0;
}
}
pub fn len(&self) -> usize {
self.cache.lock().map(|g| g.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for CompilationCache {
fn default() -> Self {
Self::default_size()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compile_to_einsum_with_context;
use tensorlogic_ir::Term;
fn make_graph(node_count: usize) -> EinsumGraph {
use tensorlogic_ir::EinsumNode;
let mut g = EinsumGraph::new();
for i in 0..node_count {
let a = g.add_tensor(format!("t{}", i));
let b = g.add_tensor(format!("u{}", i));
let c = g.add_tensor(format!("v{}", i));
g.add_node(EinsumNode::einsum("i,i->i", vec![a, b], vec![c]))
.ok();
}
g
}
fn simple_fp(s: &str) -> ExprFingerprint {
ExprFingerprint::compute(s)
}
#[test]
fn test_cache_basic_insert_get() {
let mut cache = LruCompilationCache::new(8);
let fp = simple_fp("pred(x)");
cache.insert(fp.clone(), EinsumGraph::new());
assert!(
cache.get(&fp).is_some(),
"entry should be present after insert"
);
}
#[test]
fn test_cache_miss() {
let mut cache = LruCompilationCache::new(8);
let fp = simple_fp("pred(x)");
assert!(cache.get(&fp).is_none(), "empty cache must return None");
}
#[test]
fn test_cache_hit_increments_hit_count() {
let mut cache = LruCompilationCache::new(8);
let fp = simple_fp("pred(x)");
cache.insert(fp.clone(), EinsumGraph::new());
cache.get(&fp);
cache.get(&fp);
assert!(cache.contains(&fp), "entry must still exist after reads");
let entry = cache.get(&fp).expect("entry must be present");
assert_eq!(entry.hit_count, 3, "hit_count should be 3 after three gets");
}
#[test]
fn test_cache_stats_hit_rate() {
let mut cache = LruCompilationCache::new(8);
let fp = simple_fp("pred(x)");
cache.insert(fp.clone(), EinsumGraph::new());
cache.get(&fp); cache.get(&simple_fp("missing"));
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert!(
(stats.hit_rate() - 0.5).abs() < f64::EPSILON,
"hit rate must be 0.5"
);
}
#[test]
fn test_cache_lru_eviction() {
let mut cache = LruCompilationCache::new(2);
let fp1 = simple_fp("a");
let fp2 = simple_fp("b");
let fp3 = simple_fp("c");
cache.insert(fp1.clone(), EinsumGraph::new());
cache.insert(fp2.clone(), EinsumGraph::new());
cache.insert(fp3.clone(), EinsumGraph::new());
assert!(
!cache.contains(&fp1),
"oldest entry (fp1) must have been evicted"
);
assert!(cache.contains(&fp2), "fp2 must still be present");
assert!(cache.contains(&fp3), "fp3 must be present");
assert_eq!(cache.len(), 2);
}
#[test]
fn test_cache_lru_access_updates_order() {
let mut cache = LruCompilationCache::new(2);
let fp1 = simple_fp("a");
let fp2 = simple_fp("b");
let fp3 = simple_fp("c");
cache.insert(fp1.clone(), EinsumGraph::new());
cache.insert(fp2.clone(), EinsumGraph::new());
cache.get(&fp1);
cache.insert(fp3.clone(), EinsumGraph::new());
assert!(cache.contains(&fp1), "fp1 was accessed so it must survive");
assert!(
!cache.contains(&fp2),
"fp2 is LRU after fp1 was accessed; it must be evicted"
);
assert!(cache.contains(&fp3), "fp3 must be present");
}
#[test]
fn test_cache_invalidate() {
let mut cache = LruCompilationCache::new(8);
let fp = simple_fp("pred(x)");
cache.insert(fp.clone(), EinsumGraph::new());
let removed = cache.invalidate(&fp);
assert!(removed, "invalidate must return true when entry existed");
assert!(
!cache.contains(&fp),
"entry must be gone after invalidation"
);
}
#[test]
fn test_cache_clear() {
let mut cache = LruCompilationCache::new(8);
cache.insert(simple_fp("a"), EinsumGraph::new());
cache.insert(simple_fp("b"), EinsumGraph::new());
cache.clear();
assert!(cache.is_empty(), "cache must be empty after clear");
assert_eq!(cache.len(), 0);
assert_eq!(cache.stats().total_memory_bytes, 0);
}
#[test]
fn test_cache_len_and_is_empty() {
let mut cache = LruCompilationCache::new(8);
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
cache.insert(simple_fp("x"), EinsumGraph::new());
assert!(!cache.is_empty());
assert_eq!(cache.len(), 1);
}
#[test]
fn test_cache_capacity() {
let cache = LruCompilationCache::new(42);
assert_eq!(cache.capacity(), 42);
}
#[test]
fn test_cache_eviction_stat() {
let mut cache = LruCompilationCache::new(2);
cache.insert(simple_fp("a"), EinsumGraph::new());
cache.insert(simple_fp("b"), EinsumGraph::new());
cache.insert(simple_fp("c"), EinsumGraph::new()); cache.insert(simple_fp("d"), EinsumGraph::new());
assert_eq!(
cache.stats().evictions,
2,
"two evictions must have occurred"
);
}
#[test]
fn test_cache_memory_estimate() {
let mut cache = LruCompilationCache::new(8);
let graph = make_graph(4);
cache.insert(simple_fp("g"), graph);
assert!(
cache.stats().total_memory_bytes > 0,
"memory estimate must be > 0 for a non-empty graph"
);
}
#[test]
fn test_fingerprint_same_for_same_expr() {
let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
let fp1 = CachingCompiler::fingerprint(&expr);
let fp2 = CachingCompiler::fingerprint(&expr);
assert_eq!(
fp1, fp2,
"identical expressions must produce identical fingerprints"
);
}
#[test]
fn test_fingerprint_display() {
let fp = ExprFingerprint::compute("pred(x, y)");
let display = format!("{}", fp);
assert!(display.starts_with("fp:"), "Display must start with 'fp:'");
}
fn make_caching_compiler(capacity: usize) -> CachingCompiler {
CachingCompiler::new(capacity, |expr| {
let mut ctx = CompilerContext::new();
compile_to_einsum_with_context(expr, &mut ctx).map_err(|e| e.to_string())
})
}
#[test]
fn test_caching_compiler_cache_hit() {
let mut cc = make_caching_compiler(32);
let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
cc.compile(&expr).expect("first compile");
cc.compile(&expr).expect("second compile");
assert_eq!(
cc.cache_stats().hits,
1,
"second compile must be a cache hit"
);
}
#[test]
fn test_caching_compiler_cache_miss_count() {
let mut cc = make_caching_compiler(32);
let expr = TLExpr::pred("likes", vec![Term::var("a"), Term::var("b")]);
cc.compile(&expr).expect("compile");
assert_eq!(
cc.cache_stats().misses,
1,
"first compile must be a cache miss"
);
assert_eq!(cc.cache_stats().hits, 0);
}
#[test]
fn test_caching_compiler_batch() {
let mut cc = make_caching_compiler(32);
let exprs = vec![
TLExpr::pred("p", vec![Term::var("x")]),
TLExpr::pred("q", vec![Term::var("y")]),
TLExpr::pred("r", vec![Term::var("z")]),
];
let results = cc.compile_batch(&exprs);
assert_eq!(results.len(), 3, "batch must return one result per input");
for (i, r) in results.iter().enumerate() {
assert!(r.is_ok(), "result[{}] must be Ok", i);
}
}
#[test]
fn test_caching_compiler_invalidate() {
let mut cc = make_caching_compiler(32);
let expr = TLExpr::pred("p", vec![Term::var("x")]);
cc.compile(&expr).expect("compile");
let removed = cc.invalidate(&expr);
assert!(removed, "invalidate must return true when entry existed");
cc.compile(&expr).expect("re-compile");
assert_eq!(
cc.cache_stats().misses,
2,
"re-compile after invalidation must be another miss"
);
}
#[test]
fn test_cache_default_capacity() {
let cache = LruCompilationCache::default();
assert_eq!(cache.capacity(), 256, "default capacity must be 256");
}
#[test]
fn test_expr_fingerprint_hash() {
let mut map: HashMap<ExprFingerprint, u32> = HashMap::new();
let fp = ExprFingerprint::compute("some_expr");
map.insert(fp.clone(), 42);
assert_eq!(
map.get(&fp),
Some(&42),
"fingerprint must work as HashMap key"
);
}
#[test]
fn test_ts_cache_new() {
let cache = CompilationCache::new(100);
assert_eq!(cache.max_size(), 100);
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_ts_cache_hit() {
let cache = CompilationCache::new(100);
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
let graph1 = cache
.get_or_compile(&expr, &mut ctx, compile_to_einsum_with_context)
.expect("compile");
let stats = cache.stats();
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 0);
let graph2 = cache
.get_or_compile(&expr, &mut ctx, compile_to_einsum_with_context)
.expect("compile");
let stats = cache.stats();
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 1);
assert!(
(stats.hit_rate() - 0.5).abs() < f64::EPSILON,
"hit rate must be 0.5"
);
assert_eq!(graph1, graph2);
}
#[test]
fn test_ts_cache_different_expressions() {
let cache = CompilationCache::new(100);
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let expr1 = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
let expr2 = TLExpr::pred("likes", vec![Term::var("x"), Term::var("y")]);
let _ = cache
.get_or_compile(&expr1, &mut ctx, compile_to_einsum_with_context)
.expect("compile");
let _ = cache
.get_or_compile(&expr2, &mut ctx, compile_to_einsum_with_context)
.expect("compile");
let stats = cache.stats();
assert_eq!(stats.misses, 2);
assert_eq!(stats.hits, 0);
assert_eq!(cache.len(), 2);
}
#[test]
fn test_ts_cache_eviction() {
let cache = CompilationCache::new(2);
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let _ = cache.get_or_compile(
&TLExpr::pred("p1", vec![Term::var("x")]),
&mut ctx,
compile_to_einsum_with_context,
);
let _ = cache.get_or_compile(
&TLExpr::pred("p2", vec![Term::var("x")]),
&mut ctx,
compile_to_einsum_with_context,
);
let _ = cache.get_or_compile(
&TLExpr::pred("p3", vec![Term::var("x")]),
&mut ctx,
compile_to_einsum_with_context,
);
let stats = cache.stats();
assert_eq!(stats.evictions, 1);
assert_eq!(cache.len(), 2);
}
#[test]
fn test_ts_cache_clear() {
let cache = CompilationCache::new(100);
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let _ = cache.get_or_compile(
&TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
&mut ctx,
compile_to_einsum_with_context,
);
assert_eq!(cache.len(), 1);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_ts_cache_stats() {
let cache = CompilationCache::new(100);
let stats = cache.stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
assert_eq!(stats.evictions, 0);
assert_eq!(stats.current_entries, 0);
assert_eq!(stats.hit_rate(), 0.0);
assert_eq!(stats.total_lookups(), 0);
}
}