use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
struct CachedMask {
mask: Arc<[bool]>,
}
pub struct AllowedTokensCache {
capacity: usize,
inner: HashMap<u64, CachedMask>,
lru: VecDeque<u64>,
hits: u64,
misses: u64,
}
impl AllowedTokensCache {
pub fn with_capacity(capacity: usize) -> Self {
let capacity = capacity.max(1);
Self {
capacity,
inner: HashMap::with_capacity(capacity),
lru: VecDeque::with_capacity(capacity),
hits: 0,
misses: 0,
}
}
pub fn get(&mut self, state_hash: u64) -> Option<Arc<[bool]>> {
if let Some(entry) = self.inner.get(&state_hash) {
if let Some(pos) = self.lru.iter().position(|&k| k == state_hash) {
self.lru.remove(pos);
}
self.lru.push_back(state_hash);
self.hits += 1;
Some(Arc::clone(&entry.mask))
} else {
self.misses += 1;
None
}
}
pub fn insert(&mut self, state_hash: u64, mask: Vec<bool>) {
if self.inner.contains_key(&state_hash) {
return;
}
if self.inner.len() >= self.capacity {
if let Some(oldest) = self.lru.pop_front() {
self.inner.remove(&oldest);
}
}
let mask: Arc<[bool]> = Arc::from(mask.into_boxed_slice());
self.inner.insert(state_hash, CachedMask { mask });
self.lru.push_back(state_hash);
}
pub fn hits(&self) -> u64 {
self.hits
}
pub fn misses(&self) -> u64 {
self.misses
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_mask(v: &[bool]) -> Vec<bool> {
v.to_vec()
}
#[test]
fn cache_empty_initially() {
let cache = AllowedTokensCache::with_capacity(4);
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
}
#[test]
fn cache_miss_on_empty() {
let mut cache = AllowedTokensCache::with_capacity(4);
assert!(cache.get(42).is_none());
assert_eq!(cache.misses(), 1);
assert_eq!(cache.hits(), 0);
}
#[test]
fn cache_insert_and_hit() {
let mut cache = AllowedTokensCache::with_capacity(4);
cache.insert(1, make_mask(&[true, false, true]));
let result = cache.get(1).expect("should be present");
assert_eq!(&*result, &[true, false, true]);
assert_eq!(cache.hits(), 1);
assert_eq!(cache.misses(), 0);
}
#[test]
fn cache_duplicate_insert_is_noop() {
let mut cache = AllowedTokensCache::with_capacity(4);
cache.insert(7, make_mask(&[true]));
cache.insert(7, make_mask(&[false])); let result = cache.get(7).expect("present");
assert_eq!(&*result, &[true]);
}
#[test]
fn cache_evicts_lru_at_capacity() {
let mut cache = AllowedTokensCache::with_capacity(2);
cache.insert(10, make_mask(&[true]));
cache.insert(20, make_mask(&[true]));
cache.get(20);
cache.insert(30, make_mask(&[true]));
assert_eq!(cache.len(), 2);
assert!(cache.get(10).is_none(), "10 should have been evicted");
assert!(cache.get(20).is_some(), "20 should still be present");
assert!(cache.get(30).is_some(), "30 should be present");
}
#[test]
fn cache_capacity_one_always_evicts() {
let mut cache = AllowedTokensCache::with_capacity(1);
cache.insert(1, make_mask(&[true]));
cache.insert(2, make_mask(&[false]));
assert_eq!(cache.len(), 1);
assert!(cache.get(1).is_none());
assert!(cache.get(2).is_some());
}
#[test]
fn cache_stats_track_correctly() {
let mut cache = AllowedTokensCache::with_capacity(8);
cache.get(99); cache.get(99); cache.insert(99, make_mask(&[true, true]));
cache.get(99); cache.get(99); assert_eq!(cache.misses(), 2);
assert_eq!(cache.hits(), 2);
}
#[test]
fn cache_lru_promotes_on_hit() {
let mut cache = AllowedTokensCache::with_capacity(2);
cache.insert(1, make_mask(&[true]));
cache.insert(2, make_mask(&[true]));
cache.get(1); cache.insert(3, make_mask(&[true])); assert!(cache.get(1).is_some(), "1 was promoted, should survive");
assert!(cache.get(2).is_none(), "2 was LRU, should be evicted");
assert!(cache.get(3).is_some(), "3 was just inserted");
}
}