use std::collections::{HashMap, VecDeque};
use std::marker::PhantomData;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct MemoKey {
pub expr_fingerprint: u64,
pub input_hash: u64,
}
impl MemoKey {
pub fn new(expr_fingerprint: u64, input_hash: u64) -> Self {
Self {
expr_fingerprint,
input_hash,
}
}
pub fn from_expr(expr: &tensorlogic_ir::TLExpr) -> Self {
let fp = tensorlogic_ir::expr_fingerprint(expr);
Self::new(fp, 0)
}
pub fn from_expr_and_hash(expr: &tensorlogic_ir::TLExpr, input_hash: u64) -> Self {
let fp = tensorlogic_ir::expr_fingerprint(expr);
Self::new(fp, input_hash)
}
pub fn hash_inputs(inputs: &[f64]) -> u64 {
let mut state: u64 = 14_695_981_039_346_656_037;
for &v in inputs {
let bits = v.to_bits();
for byte_idx in 0..8u64 {
let byte = (bits >> (byte_idx * 8)) & 0xFF;
state ^= byte;
state = state.wrapping_mul(1_099_511_628_211);
}
}
state
}
}
#[derive(Debug, Clone)]
pub enum MemoEvictionPolicy {
Lru,
Fifo,
Ttl(Duration),
}
#[derive(Debug, Clone)]
pub struct MemoConfig {
pub max_entries: usize,
pub ttl: Option<Duration>,
pub eviction: MemoEvictionPolicy,
}
impl Default for MemoConfig {
fn default() -> Self {
Self {
max_entries: 1024,
ttl: None,
eviction: MemoEvictionPolicy::Lru,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct MemoStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub expired_on_access: u64,
pub current_entries: usize,
}
impl MemoStats {
pub fn hit_rate(&self) -> f64 {
let total = self.total_lookups();
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn total_lookups(&self) -> u64 {
self.hits + self.misses + self.expired_on_access
}
pub fn summary(&self) -> String {
format!(
"MemoCache: entries={} hits={} misses={} expired={} evictions={} hit_rate={:.1}%",
self.current_entries,
self.hits,
self.misses,
self.expired_on_access,
self.evictions,
self.hit_rate() * 100.0,
)
}
}
#[derive(Debug, Clone)]
pub enum MemoLookupResult<V> {
Hit(V),
Miss,
Expired,
}
#[derive(Debug, Clone)]
struct MemoEntry<V> {
value: V,
inserted_at: Instant,
last_accessed: Instant,
access_count: u64,
}
pub struct MemoCache<V: Clone> {
entries: HashMap<MemoKey, MemoEntry<V>>,
insertion_order: VecDeque<MemoKey>,
config: MemoConfig,
stats: MemoStats,
}
impl<V: Clone + std::fmt::Debug> MemoCache<V> {
pub fn new(config: MemoConfig) -> Self {
let max = config.max_entries;
Self {
entries: HashMap::with_capacity(max.min(1024)),
insertion_order: VecDeque::with_capacity(max.min(1024)),
config,
stats: MemoStats::default(),
}
}
pub fn with_default() -> Self {
Self::new(MemoConfig::default())
}
pub fn with_max_entries(max: usize) -> Self {
Self::new(MemoConfig {
max_entries: max,
..MemoConfig::default()
})
}
pub fn get(&mut self, key: &MemoKey) -> MemoLookupResult<V> {
if !self.entries.contains_key(key) {
self.stats.misses += 1;
return MemoLookupResult::Miss;
}
if self.is_expired_by_key(key) {
self.entries.remove(key);
self.insertion_order.retain(|k| k != key);
self.stats.current_entries = self.entries.len();
self.stats.expired_on_access += 1;
return MemoLookupResult::Expired;
}
if let Some(entry) = self.entries.get_mut(key) {
entry.last_accessed = Instant::now();
entry.access_count += 1;
let value = entry.value.clone();
self.update_lru(key);
self.stats.hits += 1;
MemoLookupResult::Hit(value)
} else {
self.stats.misses += 1;
MemoLookupResult::Miss
}
}
pub fn insert(&mut self, key: MemoKey, value: V) {
if self.entries.contains_key(&key) {
if let Some(entry) = self.entries.get_mut(&key) {
entry.value = value;
entry.last_accessed = Instant::now();
entry.access_count += 1;
}
return;
}
if self.entries.len() >= self.config.max_entries {
self.evict_one();
}
let now = Instant::now();
let entry = MemoEntry {
value,
inserted_at: now,
last_accessed: now,
access_count: 1,
};
self.entries.insert(key.clone(), entry);
self.insertion_order.push_back(key);
self.stats.current_entries = self.entries.len();
}
pub fn invalidate(&mut self, key: &MemoKey) -> bool {
let removed = self.entries.remove(key).is_some();
if removed {
self.insertion_order.retain(|k| k != key);
self.stats.current_entries = self.entries.len();
}
removed
}
pub fn clear(&mut self) {
self.entries.clear();
self.insertion_order.clear();
self.stats.current_entries = 0;
}
pub fn stats(&self) -> &MemoStats {
&self.stats
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
fn is_expired_by_key(&self, key: &MemoKey) -> bool {
let ttl = match (&self.config.ttl, &self.config.eviction) {
(Some(d), _) => Some(*d),
(None, MemoEvictionPolicy::Ttl(d)) => Some(*d),
_ => None,
};
if let Some(duration) = ttl {
if let Some(entry) = self.entries.get(key) {
return entry.inserted_at.elapsed() > duration;
}
}
false
}
fn is_expired(&self, entry: &MemoEntry<V>) -> bool {
let ttl = match (&self.config.ttl, &self.config.eviction) {
(Some(d), _) => Some(*d),
(None, MemoEvictionPolicy::Ttl(d)) => Some(*d),
_ => None,
};
ttl.map(|d| entry.inserted_at.elapsed() > d)
.unwrap_or(false)
}
fn evict_one(&mut self) {
let key_to_remove = match &self.config.eviction {
MemoEvictionPolicy::Lru => self.find_lru_key(),
MemoEvictionPolicy::Fifo => self.find_fifo_key(),
MemoEvictionPolicy::Ttl(_) => {
self.find_expired_key().or_else(|| self.find_fifo_key())
}
};
if let Some(key) = key_to_remove {
self.entries.remove(&key);
self.insertion_order.retain(|k| k != &key);
self.stats.evictions += 1;
self.stats.current_entries = self.entries.len();
}
}
fn find_lru_key(&self) -> Option<MemoKey> {
self.insertion_order.front().cloned()
}
fn find_fifo_key(&self) -> Option<MemoKey> {
self.insertion_order.front().cloned()
}
fn find_expired_key(&self) -> Option<MemoKey> {
self.entries
.iter()
.find(|(_, e)| self.is_expired(e))
.map(|(k, _)| k.clone())
}
fn update_lru(&mut self, key: &MemoKey) {
if matches!(self.config.eviction, MemoEvictionPolicy::Lru) {
if let Some(pos) = self.insertion_order.iter().position(|k| k == key) {
self.insertion_order.remove(pos);
self.insertion_order.push_back(key.clone());
}
}
}
}
pub type ExprMemoCache = MemoCache<ndarray::ArrayD<f64>>;
pub struct MemoCacheBuilder<V: Clone + std::fmt::Debug> {
config: MemoConfig,
_phantom: PhantomData<V>,
}
impl<V: Clone + std::fmt::Debug> MemoCacheBuilder<V> {
pub fn new() -> Self {
Self {
config: MemoConfig::default(),
_phantom: PhantomData,
}
}
pub fn max_entries(mut self, max: usize) -> Self {
self.config.max_entries = max;
self
}
pub fn ttl(mut self, duration: Duration) -> Self {
self.config.ttl = Some(duration);
self
}
pub fn eviction(mut self, policy: MemoEvictionPolicy) -> Self {
self.config.eviction = policy;
self
}
pub fn build(self) -> MemoCache<V> {
MemoCache::new(self.config)
}
}
impl<V: Clone + std::fmt::Debug> Default for MemoCacheBuilder<V> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use tensorlogic_ir::{TLExpr, Term};
fn make_expr_a() -> TLExpr {
TLExpr::pred("foo", vec![Term::var("x")])
}
fn make_expr_b() -> TLExpr {
TLExpr::pred("bar", vec![Term::var("y")])
}
#[test]
fn test_memo_key_equality() {
let k1 = MemoKey::new(42, 99);
let k2 = MemoKey::new(42, 99);
let k3 = MemoKey::new(42, 100);
assert_eq!(k1, k2);
assert_ne!(k1, k3);
}
#[test]
fn test_memo_key_from_expr() {
let expr = make_expr_a();
let key = MemoKey::from_expr(&expr);
assert_eq!(key.input_hash, 0);
let key2 = MemoKey::from_expr(&expr);
assert_eq!(key.expr_fingerprint, key2.expr_fingerprint);
}
#[test]
fn test_memo_key_hash_inputs_consistent() {
let inputs = vec![1.0_f64, 2.0, 3.0];
let h1 = MemoKey::hash_inputs(&inputs);
let h2 = MemoKey::hash_inputs(&inputs);
assert_eq!(h1, h2);
}
#[test]
fn test_memo_key_hash_inputs_different() {
let h1 = MemoKey::hash_inputs(&[1.0, 2.0, 3.0]);
let h2 = MemoKey::hash_inputs(&[1.0, 2.0, 4.0]);
assert_ne!(h1, h2);
}
#[test]
fn test_memo_cache_miss_on_empty() {
let mut cache: MemoCache<i32> = MemoCache::with_default();
let key = MemoKey::new(1, 0);
assert!(matches!(cache.get(&key), MemoLookupResult::Miss));
assert_eq!(cache.stats().misses, 1);
}
#[test]
fn test_memo_cache_hit_after_insert() {
let mut cache: MemoCache<i32> = MemoCache::with_default();
let key = MemoKey::new(7, 0);
cache.insert(key.clone(), 42);
assert!(matches!(cache.get(&key), MemoLookupResult::Hit(42)));
assert_eq!(cache.stats().hits, 1);
}
#[test]
fn test_memo_cache_hit_rate_zero_initially() {
let cache: MemoCache<i32> = MemoCache::with_default();
assert_eq!(cache.stats().hit_rate(), 0.0);
}
#[test]
fn test_memo_cache_hit_rate_after_hit() {
let mut cache: MemoCache<i32> = MemoCache::with_default();
let key = MemoKey::new(1, 0);
cache.insert(key.clone(), 10);
cache.get(&key); cache.get(&MemoKey::new(2, 0)); let rate = cache.stats().hit_rate();
assert!((rate - 0.5).abs() < 1e-9, "expected 0.5, got {rate}");
}
#[test]
fn test_memo_cache_lru_evicts_oldest_access() {
let mut cache: MemoCache<i32> = MemoCache::new(MemoConfig {
max_entries: 2,
ttl: None,
eviction: MemoEvictionPolicy::Lru,
});
let k1 = MemoKey::new(1, 0);
let k2 = MemoKey::new(2, 0);
let k3 = MemoKey::new(3, 0);
cache.insert(k1.clone(), 1);
cache.insert(k2.clone(), 2);
cache.get(&k1);
cache.insert(k3.clone(), 3);
assert!(matches!(cache.get(&k1), MemoLookupResult::Hit(1)));
assert!(matches!(cache.get(&k2), MemoLookupResult::Miss));
assert!(matches!(cache.get(&k3), MemoLookupResult::Hit(3)));
assert!(cache.stats().evictions >= 1);
}
#[test]
fn test_memo_cache_fifo_evicts_first_inserted() {
let mut cache: MemoCache<i32> = MemoCache::new(MemoConfig {
max_entries: 2,
ttl: None,
eviction: MemoEvictionPolicy::Fifo,
});
let k1 = MemoKey::new(1, 0);
let k2 = MemoKey::new(2, 0);
let k3 = MemoKey::new(3, 0);
cache.insert(k1.clone(), 10);
cache.insert(k2.clone(), 20);
cache.get(&k1);
cache.insert(k3.clone(), 30);
assert!(matches!(cache.get(&k1), MemoLookupResult::Miss));
assert!(matches!(cache.get(&k2), MemoLookupResult::Hit(20)));
assert!(matches!(cache.get(&k3), MemoLookupResult::Hit(30)));
}
#[test]
fn test_memo_cache_ttl_expires_entry() {
let ttl = Duration::from_millis(10);
let mut cache: MemoCache<i32> = MemoCache::new(MemoConfig {
max_entries: 16,
ttl: Some(ttl),
eviction: MemoEvictionPolicy::Ttl(ttl),
});
let key = MemoKey::new(99, 0);
cache.insert(key.clone(), 55);
assert!(matches!(cache.get(&key), MemoLookupResult::Hit(55)));
thread::sleep(Duration::from_millis(20));
assert!(matches!(cache.get(&key), MemoLookupResult::Expired));
assert_eq!(cache.stats().expired_on_access, 1);
}
#[test]
fn test_memo_cache_invalidate_key() {
let mut cache: MemoCache<i32> = MemoCache::with_default();
let key = MemoKey::new(5, 0);
cache.insert(key.clone(), 77);
assert!(cache.invalidate(&key));
assert!(!cache.invalidate(&key)); assert!(matches!(cache.get(&key), MemoLookupResult::Miss));
}
#[test]
fn test_memo_cache_clear() {
let mut cache: MemoCache<i32> = MemoCache::with_default();
cache.insert(MemoKey::new(1, 0), 1);
cache.insert(MemoKey::new(2, 0), 2);
assert_eq!(cache.len(), 2);
cache.clear();
assert!(cache.is_empty());
assert_eq!(cache.stats().current_entries, 0);
}
#[test]
fn test_memo_cache_len() {
let mut cache: MemoCache<i32> = MemoCache::with_default();
assert_eq!(cache.len(), 0);
cache.insert(MemoKey::new(1, 0), 10);
assert_eq!(cache.len(), 1);
cache.insert(MemoKey::new(2, 0), 20);
assert_eq!(cache.len(), 2);
}
#[test]
fn test_memo_stats_total_lookups() {
let mut cache: MemoCache<i32> = MemoCache::with_default();
let key = MemoKey::new(1, 0);
cache.insert(key.clone(), 1);
cache.get(&key); cache.get(&MemoKey::new(99, 0)); assert_eq!(cache.stats().total_lookups(), 2);
}
#[test]
fn test_memo_stats_summary_nonempty() {
let mut cache: MemoCache<i32> = MemoCache::with_default();
cache.insert(MemoKey::new(1, 0), 1);
cache.get(&MemoKey::new(1, 0));
let summary = cache.stats().summary();
assert!(summary.contains("MemoCache"));
assert!(summary.contains("hits=1"));
}
#[test]
fn test_memo_lookup_result_variants() {
let hit: MemoLookupResult<i32> = MemoLookupResult::Hit(42);
let miss: MemoLookupResult<i32> = MemoLookupResult::Miss;
let expired: MemoLookupResult<i32> = MemoLookupResult::Expired;
assert!(matches!(hit, MemoLookupResult::Hit(42)));
assert!(matches!(miss, MemoLookupResult::Miss));
assert!(matches!(expired, MemoLookupResult::Expired));
}
#[test]
fn test_memo_cache_builder_default() {
let cache: MemoCache<i32> = MemoCacheBuilder::new().build();
assert!(cache.is_empty());
}
#[test]
fn test_memo_cache_builder_custom_config() {
let cache: MemoCache<i32> = MemoCacheBuilder::new()
.max_entries(8)
.ttl(Duration::from_secs(60))
.eviction(MemoEvictionPolicy::Fifo)
.build();
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
}
#[test]
fn test_expr_memo_cache_type_alias() {
use ndarray::ArrayD;
let mut cache: ExprMemoCache = MemoCache::with_default();
let key = MemoKey::from_expr(&make_expr_a());
let arr = ArrayD::<f64>::zeros(ndarray::IxDyn(&[2, 3]));
cache.insert(key.clone(), arr.clone());
assert!(matches!(cache.get(&key), MemoLookupResult::Hit(_)));
}
#[test]
fn test_memo_key_from_expr_different_exprs() {
let ka = MemoKey::from_expr(&make_expr_a());
let kb = MemoKey::from_expr(&make_expr_b());
assert_ne!(ka.expr_fingerprint, kb.expr_fingerprint);
}
#[test]
fn test_memo_key_from_expr_and_hash() {
let expr = make_expr_a();
let h = MemoKey::hash_inputs(&[1.0, 2.0]);
let key = MemoKey::from_expr_and_hash(&expr, h);
assert_eq!(key.input_hash, h);
assert_ne!(key.input_hash, 0);
}
}