use std::collections::HashMap;
use super::types::{MemoConfig, MemoStats, WhnfEntry, WhnfKey, WhnfMemo};
pub fn hash_bytes(data: &[u8]) -> u64 {
const OFFSET_BASIS: u64 = 14_695_981_039_346_656_037;
const PRIME: u64 = 1_099_511_628_211;
data.iter()
.fold(OFFSET_BASIS, |h, &b| (h ^ b as u64).wrapping_mul(PRIME))
}
impl WhnfMemo {
pub fn new(config: MemoConfig) -> Self {
WhnfMemo {
entries: HashMap::new(),
hits: 0,
misses: 0,
evictions: 0,
env_version: 0,
config,
insert_order: Vec::new(),
}
}
pub fn lookup(&mut self, expr_hash: u64) -> Option<u64> {
let key = WhnfKey {
expr_hash,
env_version: self.env_version,
};
if let Some(entry) = self.entries.get_mut(&key) {
entry.access_count = entry.access_count.saturating_add(1);
self.hits += 1;
Some(entry.result_hash)
} else {
self.misses += 1;
None
}
}
pub fn insert(&mut self, expr_hash: u64, result_hash: u64, steps: u32) {
if steps < self.config.min_steps_to_cache {
return;
}
let key = WhnfKey {
expr_hash,
env_version: self.env_version,
};
if let Some(entry) = self.entries.get_mut(&key) {
entry.result_hash = result_hash;
entry.reduction_steps = steps;
return;
}
let max = self.config.max_entries;
if max > 0 && self.entries.len() >= max {
self.evict_cold();
}
self.entries.insert(
key,
WhnfEntry {
result_hash,
reduction_steps: steps,
access_count: 0,
},
);
self.insert_order.push(key);
}
pub fn invalidate_all(&mut self) {
self.entries.clear();
self.insert_order.clear();
self.env_version = self.env_version.wrapping_add(1);
}
pub fn evict_cold(&mut self) {
let threshold = (self.config.max_entries as f64 * self.config.eviction_threshold) as u32;
let cold: Vec<WhnfKey> = self
.entries
.iter()
.filter(|(_, e)| e.access_count <= threshold)
.map(|(k, _)| *k)
.collect();
if !cold.is_empty() {
let removed = cold.len() as u64;
for k in &cold {
self.entries.remove(k);
}
self.insert_order.retain(|k| !cold.contains(k));
self.evictions += removed;
} else {
if let Some(oldest) = self.insert_order.first().copied() {
self.insert_order.remove(0);
self.entries.remove(&oldest);
self.evictions += 1;
}
}
}
pub fn stats(&self) -> MemoStats {
let total = self.hits + self.misses;
let hit_rate = if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
};
MemoStats {
hits: self.hits,
misses: self.misses,
evictions: self.evictions,
hit_rate,
size: self.entries.len(),
env_version: self.env_version,
}
}
}
pub fn with_memo<F>(memo: &mut WhnfMemo, key: u64, min_steps: u32, compute: F) -> u64
where
F: FnOnce() -> (u64, u32),
{
if let Some(cached) = memo.lookup(key) {
return cached;
}
let (result_hash, steps) = compute();
let effective_min = memo.config.min_steps_to_cache.max(min_steps);
if steps >= effective_min {
memo.insert(key, result_hash, steps);
}
result_hash
}
#[cfg(test)]
mod tests {
use super::*;
fn default_memo() -> WhnfMemo {
WhnfMemo::new(MemoConfig::default())
}
fn memo_with_capacity(n: usize) -> WhnfMemo {
WhnfMemo::new(MemoConfig {
max_entries: n,
min_steps_to_cache: 1,
eviction_threshold: 0.0,
})
}
#[test]
fn test_hash_bytes_empty() {
let h = hash_bytes(&[]);
assert_eq!(h, 14_695_981_039_346_656_037);
}
#[test]
fn test_hash_bytes_deterministic() {
let h1 = hash_bytes(b"hello");
let h2 = hash_bytes(b"hello");
assert_eq!(h1, h2);
}
#[test]
fn test_hash_bytes_distinct() {
let ha = hash_bytes(b"Nat.add");
let hb = hash_bytes(b"List.map");
assert_ne!(ha, hb);
}
#[test]
fn test_new_memo_empty() {
let m = default_memo();
assert_eq!(m.hits, 0);
assert_eq!(m.misses, 0);
assert_eq!(m.evictions, 0);
assert_eq!(m.env_version, 0);
assert!(m.entries.is_empty());
}
#[test]
fn test_miss_on_empty() {
let mut m = default_memo();
assert_eq!(m.lookup(42), None);
assert_eq!(m.misses, 1);
}
#[test]
fn test_hit_after_insert() {
let mut m = default_memo();
m.insert(100, 200, 5);
let result = m.lookup(100);
assert_eq!(result, Some(200));
assert_eq!(m.hits, 1);
}
#[test]
fn test_insert_below_min_steps_not_cached() {
let mut m = WhnfMemo::new(MemoConfig {
max_entries: 64,
min_steps_to_cache: 3,
eviction_threshold: 0.1,
});
m.insert(1, 999, 2); assert_eq!(m.lookup(1), None);
}
#[test]
fn test_insert_at_min_steps_cached() {
let mut m = WhnfMemo::new(MemoConfig {
max_entries: 64,
min_steps_to_cache: 3,
eviction_threshold: 0.1,
});
m.insert(1, 999, 3); assert_eq!(m.lookup(1), Some(999));
}
#[test]
fn test_insert_overwrite() {
let mut m = default_memo();
m.insert(1, 100, 5);
m.insert(1, 200, 7); assert_eq!(m.lookup(1), Some(200));
}
#[test]
fn test_invalidate_all_clears_entries() {
let mut m = default_memo();
m.insert(1, 10, 5);
m.invalidate_all();
assert!(m.entries.is_empty());
}
#[test]
fn test_invalidate_all_bumps_version() {
let mut m = default_memo();
assert_eq!(m.env_version, 0);
m.invalidate_all();
assert_eq!(m.env_version, 1);
m.invalidate_all();
assert_eq!(m.env_version, 2);
}
#[test]
fn test_invalidate_all_prior_entries_miss() {
let mut m = default_memo();
m.insert(42, 99, 5);
m.invalidate_all();
assert_eq!(m.lookup(42), None);
}
#[test]
fn test_insert_after_invalidate_uses_new_version() {
let mut m = default_memo();
m.insert(42, 1, 5);
m.invalidate_all();
m.insert(42, 2, 5);
assert_eq!(m.lookup(42), Some(2));
}
#[test]
fn test_evict_cold_removes_unaccessed() {
let mut m = memo_with_capacity(4);
m.insert(1, 10, 5);
m.insert(2, 20, 5);
let _ = m.lookup(2);
let _ = m.lookup(2);
m.evict_cold();
assert_eq!(m.lookup(1), None);
assert_eq!(m.lookup(2), Some(20));
assert!(m.evictions > 0);
}
#[test]
fn test_evict_cold_fifo_fallback() {
let mut m = WhnfMemo::new(MemoConfig {
max_entries: 2,
min_steps_to_cache: 1,
eviction_threshold: 0.0, });
m.insert(1, 10, 5);
m.insert(2, 20, 5);
let _ = m.lookup(1);
let _ = m.lookup(2);
m.evict_cold();
assert!(m.evictions > 0);
}
#[test]
fn test_stats_zero() {
let m = default_memo();
let s = m.stats();
assert_eq!(s.hits, 0);
assert_eq!(s.misses, 0);
assert_eq!(s.hit_rate, 0.0);
assert_eq!(s.size, 0);
assert_eq!(s.env_version, 0);
}
#[test]
fn test_stats_hit_rate() {
let mut m = default_memo();
m.insert(1, 10, 5);
let _ = m.lookup(1); let _ = m.lookup(2); let s = m.stats();
assert_eq!(s.hits, 1);
assert_eq!(s.misses, 1);
assert!((s.hit_rate - 0.5).abs() < 1e-9);
}
#[test]
fn test_stats_display() {
let m = default_memo();
let s = m.stats();
let text = format!("{}", s);
assert!(text.contains("MemoStats"));
}
#[test]
fn test_with_memo_miss_calls_compute() {
let mut m = default_memo();
let mut called = false;
let result = with_memo(&mut m, 42, 0, || {
called = true;
(99, 5)
});
assert!(called);
assert_eq!(result, 99);
}
#[test]
fn test_with_memo_hit_skips_compute() {
let mut m = default_memo();
m.insert(42, 99, 5);
let mut called = false;
let result = with_memo(&mut m, 42, 0, || {
called = true;
(0, 10)
});
assert!(!called);
assert_eq!(result, 99);
}
#[test]
fn test_with_memo_stores_result() {
let mut m = default_memo();
let _ = with_memo(&mut m, 77, 0, || (55, 5));
let mut called = false;
let result = with_memo(&mut m, 77, 0, || {
called = true;
(0, 5)
});
assert!(!called);
assert_eq!(result, 55);
}
#[test]
fn test_with_memo_respects_min_steps() {
let mut m = WhnfMemo::new(MemoConfig {
max_entries: 64,
min_steps_to_cache: 10,
eviction_threshold: 0.1,
});
let _ = with_memo(&mut m, 5, 0, || (7, 3));
assert_eq!(m.lookup(5), None);
}
#[test]
fn test_with_memo_min_steps_override() {
let mut m = WhnfMemo::new(MemoConfig {
max_entries: 64,
min_steps_to_cache: 2,
eviction_threshold: 0.1,
});
let _ = with_memo(&mut m, 5, 10, || (7, 3));
assert_eq!(m.lookup(5), None);
}
#[test]
fn test_capacity_triggers_eviction() {
let mut m = memo_with_capacity(2);
m.insert(1, 10, 5);
m.insert(2, 20, 5);
m.insert(3, 30, 5);
assert!(m.entries.len() <= 2);
assert!(m.evictions > 0);
}
#[test]
fn test_zero_capacity_no_panic() {
let mut m = WhnfMemo::new(MemoConfig {
max_entries: 0,
min_steps_to_cache: 1,
eviction_threshold: 0.0,
});
m.insert(1, 99, 5);
}
#[test]
fn test_eviction_count_tracked() {
let mut m = memo_with_capacity(1);
m.insert(1, 10, 5);
m.insert(2, 20, 5); assert!(m.evictions >= 1);
}
}