use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex};
use anyhow::Result;
use tensorlogic_ir::{EinsumGraph, TLExpr};
use crate::config::CompilationConfig;
use crate::CompilerContext;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct CacheKey {
expr_hash: u64,
config_hash: u64,
domain_hash: u64,
}
impl CacheKey {
fn new(expr: &TLExpr, config: &CompilationConfig, ctx: &CompilerContext) -> Self {
use std::collections::hash_map::DefaultHasher;
let mut expr_hasher = DefaultHasher::new();
format!("{:?}", expr).hash(&mut expr_hasher);
let expr_hash = expr_hasher.finish();
let mut config_hasher = DefaultHasher::new();
format!("{:?}", config).hash(&mut config_hasher);
let config_hash = config_hasher.finish();
let mut domain_hasher = DefaultHasher::new();
for (name, domain) in &ctx.domains {
name.hash(&mut domain_hasher);
domain.cardinality.hash(&mut domain_hasher);
}
let domain_hash = domain_hasher.finish();
CacheKey {
expr_hash,
config_hash,
domain_hash,
}
}
}
#[derive(Clone)]
struct CachedResult {
graph: EinsumGraph,
hit_count: usize,
}
pub struct CompilationCache {
cache: Arc<Mutex<HashMap<CacheKey, CachedResult>>>,
max_size: usize,
stats: Arc<Mutex<CacheStats>>,
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub current_size: usize,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn total_lookups(&self) -> u64 {
self.hits + self.misses
}
}
impl CompilationCache {
pub fn new(max_size: usize) -> Self {
Self {
cache: Arc::new(Mutex::new(HashMap::new())),
max_size,
stats: Arc::new(Mutex::new(CacheStats::default())),
}
}
pub fn default_size() -> Self {
Self::new(1000)
}
pub fn max_size(&self) -> usize {
self.max_size
}
pub fn get_or_compile<F>(
&self,
expr: &TLExpr,
ctx: &mut CompilerContext,
compile_fn: F,
) -> Result<EinsumGraph>
where
F: FnOnce(&TLExpr, &mut CompilerContext) -> Result<EinsumGraph>,
{
let key = CacheKey::new(expr, &ctx.config, ctx);
{
let mut cache = self.cache.lock().unwrap();
if let Some(cached) = cache.get_mut(&key) {
cached.hit_count += 1;
let mut stats = self.stats.lock().unwrap();
stats.hits += 1;
return Ok(cached.graph.clone());
}
}
let mut stats = self.stats.lock().unwrap();
stats.misses += 1;
drop(stats);
let graph = compile_fn(expr, ctx)?;
{
let mut cache = self.cache.lock().unwrap();
if cache.len() >= self.max_size {
let min_key = cache
.iter()
.min_by_key(|(_, v)| v.hit_count)
.map(|(k, _)| k.clone());
if let Some(key_to_evict) = min_key {
cache.remove(&key_to_evict);
let mut stats = self.stats.lock().unwrap();
stats.evictions += 1;
}
}
cache.insert(
key,
CachedResult {
graph: graph.clone(),
hit_count: 0,
},
);
let mut stats = self.stats.lock().unwrap();
stats.current_size = cache.len();
}
Ok(graph)
}
pub fn stats(&self) -> CacheStats {
self.stats.lock().unwrap().clone()
}
pub fn clear(&self) {
let mut cache = self.cache.lock().unwrap();
cache.clear();
let mut stats = self.stats.lock().unwrap();
stats.current_size = 0;
}
pub fn len(&self) -> usize {
self.cache.lock().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for CompilationCache {
fn default() -> Self {
Self::default_size()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compile_to_einsum_with_context;
use tensorlogic_ir::Term;
#[test]
fn test_cache_new() {
let cache = CompilationCache::new(100);
assert_eq!(cache.max_size(), 100);
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_cache_hit() {
let cache = CompilationCache::new(100);
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
let graph1 = cache
.get_or_compile(&expr, &mut ctx, compile_to_einsum_with_context)
.unwrap();
let stats = cache.stats();
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 0);
let graph2 = cache
.get_or_compile(&expr, &mut ctx, compile_to_einsum_with_context)
.unwrap();
let stats = cache.stats();
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 1);
assert_eq!(stats.hit_rate(), 0.5);
assert_eq!(graph1, graph2);
}
#[test]
fn test_cache_different_expressions() {
let cache = CompilationCache::new(100);
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let expr1 = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
let expr2 = TLExpr::pred("likes", vec![Term::var("x"), Term::var("y")]);
let _graph1 = cache
.get_or_compile(&expr1, &mut ctx, |e, c| {
compile_to_einsum_with_context(e, c)
})
.unwrap();
let _graph2 = cache
.get_or_compile(&expr2, &mut ctx, |e, c| {
compile_to_einsum_with_context(e, c)
})
.unwrap();
let stats = cache.stats();
assert_eq!(stats.misses, 2);
assert_eq!(stats.hits, 0);
assert_eq!(cache.len(), 2);
}
#[test]
fn test_cache_eviction() {
let cache = CompilationCache::new(2); let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let expr1 = TLExpr::pred("p1", vec![Term::var("x")]);
let expr2 = TLExpr::pred("p2", vec![Term::var("x")]);
let expr3 = TLExpr::pred("p3", vec![Term::var("x")]);
let _ = cache.get_or_compile(&expr1, &mut ctx, |e, c| {
compile_to_einsum_with_context(e, c)
});
let _ = cache.get_or_compile(&expr2, &mut ctx, |e, c| {
compile_to_einsum_with_context(e, c)
});
let _ = cache.get_or_compile(&expr3, &mut ctx, |e, c| {
compile_to_einsum_with_context(e, c)
});
let stats = cache.stats();
assert_eq!(stats.evictions, 1);
assert_eq!(cache.len(), 2);
}
#[test]
fn test_cache_clear() {
let cache = CompilationCache::new(100);
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
let _ = cache.get_or_compile(&expr, &mut ctx, compile_to_einsum_with_context);
assert_eq!(cache.len(), 1);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_cache_stats() {
let cache = CompilationCache::new(100);
let stats = cache.stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
assert_eq!(stats.evictions, 0);
assert_eq!(stats.current_size, 0);
assert_eq!(stats.hit_rate(), 0.0);
assert_eq!(stats.total_lookups(), 0);
}
}