use super::{compile_with, CompileConfig, CompiledFn, JitError};
use crate::kernel::{ExprId, ExprPool};
use std::collections::HashMap;
use std::sync::Arc;
type CacheKey = (ExprId, Vec<ExprId>, CompileConfig);
pub struct CompileCache {
store: HashMap<CacheKey, Arc<CompiledFn>>,
compiles: u64,
hits: u64,
}
impl CompileCache {
pub fn new() -> Self {
Self {
store: HashMap::new(),
compiles: 0,
hits: 0,
}
}
pub fn compile(
&mut self,
expr: ExprId,
inputs: &[ExprId],
pool: &ExprPool,
) -> Result<Arc<CompiledFn>, JitError> {
self.compile_with(expr, inputs, pool, CompileConfig::default())
}
pub fn compile_with(
&mut self,
expr: ExprId,
inputs: &[ExprId],
pool: &ExprPool,
config: CompileConfig,
) -> Result<Arc<CompiledFn>, JitError> {
let key: CacheKey = (expr, inputs.to_vec(), config);
if let Some(cached) = self.store.get(&key) {
self.hits += 1;
return Ok(Arc::clone(cached));
}
self.compiles += 1;
let compiled = Arc::new(compile_with(expr, inputs, pool, config)?);
self.store.insert(key, Arc::clone(&compiled));
Ok(compiled)
}
pub fn len(&self) -> usize {
self.store.len()
}
pub fn is_empty(&self) -> bool {
self.store.is_empty()
}
pub fn contains(&self, expr: ExprId, inputs: &[ExprId]) -> bool {
self.contains_with(expr, inputs, CompileConfig::default())
}
pub fn contains_with(&self, expr: ExprId, inputs: &[ExprId], config: CompileConfig) -> bool {
self.store.contains_key(&(expr, inputs.to_vec(), config))
}
pub fn compile_count(&self) -> u64 {
self.compiles
}
pub fn hit_count(&self) -> u64 {
self.hits
}
pub fn hit_rate(&self) -> f64 {
let total = self.compiles + self.hits;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn clear(&mut self) {
self.store.clear();
}
pub fn evict(&mut self, expr: ExprId, inputs: &[ExprId]) -> Option<Arc<CompiledFn>> {
self.evict_with(expr, inputs, CompileConfig::default())
}
pub fn evict_with(
&mut self,
expr: ExprId,
inputs: &[ExprId],
config: CompileConfig,
) -> Option<Arc<CompiledFn>> {
self.store.remove(&(expr, inputs.to_vec(), config))
}
}
impl Default for CompileCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::{Domain, ExprPool};
fn p() -> ExprPool {
ExprPool::new()
}
#[test]
fn cache_miss_then_hit() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.pow(x, pool.integer(2_i32));
let mut cache = CompileCache::new();
assert!(cache.is_empty());
assert_eq!(cache.compile_count(), 0);
assert_eq!(cache.hit_count(), 0);
let f1 = cache.compile(expr, &[x], &pool).unwrap();
assert_eq!(cache.len(), 1);
assert_eq!(cache.compile_count(), 1);
assert_eq!(cache.hit_count(), 0);
let f2 = cache.compile(expr, &[x], &pool).unwrap();
assert_eq!(cache.len(), 1); assert_eq!(cache.compile_count(), 1); assert_eq!(cache.hit_count(), 1);
assert!(Arc::ptr_eq(&f1, &f2));
}
#[test]
fn cache_correct_result() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.pow(x, pool.integer(2_i32));
let mut cache = CompileCache::new();
let f = cache.compile(expr, &[x], &pool).unwrap();
assert!((f.call(&[3.0]) - 9.0).abs() < 1e-10);
assert!((f.call(&[5.0]) - 25.0).abs() < 1e-10);
}
#[test]
fn different_var_order_different_entry() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let y = pool.symbol("y", Domain::Real);
let expr = pool.add(vec![x, y]);
let mut cache = CompileCache::new();
let f_xy = cache.compile(expr, &[x, y], &pool).unwrap();
let f_yx = cache.compile(expr, &[y, x], &pool).unwrap();
assert_eq!(cache.len(), 2);
assert!(!Arc::ptr_eq(&f_xy, &f_yx));
assert!((f_xy.call(&[1.0, 2.0]) - 3.0).abs() < 1e-10);
assert!((f_yx.call(&[1.0, 2.0]) - 3.0).abs() < 1e-10);
}
#[test]
fn different_exprs_different_entries() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let sq = pool.pow(x, pool.integer(2_i32));
let cube = pool.pow(x, pool.integer(3_i32));
let mut cache = CompileCache::new();
let f_sq = cache.compile(sq, &[x], &pool).unwrap();
let f_cu = cache.compile(cube, &[x], &pool).unwrap();
assert_eq!(cache.len(), 2);
assert!(!Arc::ptr_eq(&f_sq, &f_cu));
assert!((f_sq.call(&[3.0]) - 9.0).abs() < 1e-10);
assert!((f_cu.call(&[3.0]) - 27.0).abs() < 1e-10);
}
#[test]
fn arc_survives_cache_clear() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.pow(x, pool.integer(2_i32));
let mut cache = CompileCache::new();
let f = cache.compile(expr, &[x], &pool).unwrap();
cache.clear();
assert!(cache.is_empty());
assert!((f.call(&[4.0]) - 16.0).abs() < 1e-10);
}
#[test]
fn evict_removes_single_entry() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let sq = pool.pow(x, pool.integer(2_i32));
let cube = pool.pow(x, pool.integer(3_i32));
let mut cache = CompileCache::new();
cache.compile(sq, &[x], &pool).unwrap();
cache.compile(cube, &[x], &pool).unwrap();
assert_eq!(cache.len(), 2);
let evicted = cache.evict(sq, &[x]);
assert!(evicted.is_some());
assert_eq!(cache.len(), 1);
assert!(!cache.contains(sq, &[x]));
assert!(cache.contains(cube, &[x]));
}
#[test]
fn contains_checks_key() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let y = pool.symbol("y", Domain::Real);
let expr = pool.add(vec![x, y]);
let mut cache = CompileCache::new();
assert!(!cache.contains(expr, &[x, y]));
cache.compile(expr, &[x, y], &pool).unwrap();
assert!(cache.contains(expr, &[x, y]));
assert!(!cache.contains(expr, &[y, x])); }
#[test]
fn hit_rate_is_correct() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.pow(x, pool.integer(2_i32));
let mut cache = CompileCache::new();
assert_eq!(cache.hit_rate(), 0.0);
cache.compile(expr, &[x], &pool).unwrap(); assert_eq!(cache.hit_rate(), 0.0);
cache.compile(expr, &[x], &pool).unwrap(); cache.compile(expr, &[x], &pool).unwrap();
let rate = cache.hit_rate();
assert!((rate - 2.0 / 3.0).abs() < 1e-10);
}
}