use core::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::sync::Mutex;
use crate::cache::Cache;
use crate::error::CacheError;
use crate::util::MutexExt;
const SKETCH_DEPTH: usize = 4;
const MIN_SKETCH_WIDTH: usize = 64;
pub struct TinyLfuCache<K, V> {
capacity: NonZeroUsize,
inner: Mutex<Inner<K, V>>,
}
struct Entry<V> {
value: V,
last_access: u64,
}
struct Inner<K, V> {
map: HashMap<K, Entry<V>>,
sketch: CountMinSketch,
clock: u64,
}
impl<K, V> TinyLfuCache<K, V>
where
K: Eq + Hash + Clone,
V: Clone,
{
pub fn new(capacity: usize) -> Result<Self, CacheError> {
let cap = NonZeroUsize::new(capacity).ok_or(CacheError::InvalidCapacity)?;
Ok(Self::with_capacity(cap))
}
pub fn with_capacity(capacity: NonZeroUsize) -> Self {
let cap = capacity.get();
Self {
capacity,
inner: Mutex::new(Inner {
map: HashMap::with_capacity(cap),
sketch: CountMinSketch::new(cap),
clock: 0,
}),
}
}
}
impl<K, V> Cache<K, V> for TinyLfuCache<K, V>
where
K: Eq + Hash + Clone,
V: Clone,
{
fn get(&self, key: &K) -> Option<V> {
let mut inner = self.inner.lock_recover();
inner.sketch.increment(key);
inner.clock = inner.clock.wrapping_add(1);
let now = inner.clock;
let entry = inner.map.get_mut(key)?;
entry.last_access = now;
Some(entry.value.clone())
}
fn insert(&self, key: K, value: V) -> Option<V> {
let mut inner = self.inner.lock_recover();
inner.sketch.increment(&key);
inner.clock = inner.clock.wrapping_add(1);
let now = inner.clock;
if let Some(existing) = inner.map.get_mut(&key) {
let old = core::mem::replace(&mut existing.value, value);
existing.last_access = now;
return Some(old);
}
if inner.map.len() >= self.capacity.get() {
let candidate_freq = inner.sketch.estimate(&key);
let victim = find_lru_victim(&inner.map);
if let Some(victim_key) = victim {
let victim_freq = inner.sketch.estimate(&victim_key);
if candidate_freq <= victim_freq {
return None;
}
let _ = inner.map.remove(&victim_key);
}
}
let _ = inner.map.insert(
key,
Entry {
value,
last_access: now,
},
);
None
}
fn remove(&self, key: &K) -> Option<V> {
let mut inner = self.inner.lock_recover();
inner.map.remove(key).map(|e| e.value)
}
fn contains_key(&self, key: &K) -> bool {
self.inner.lock_recover().map.contains_key(key)
}
fn len(&self) -> usize {
self.inner.lock_recover().map.len()
}
fn clear(&self) {
let mut inner = self.inner.lock_recover();
inner.map.clear();
inner.sketch.reset();
inner.clock = 0;
}
fn capacity(&self) -> usize {
self.capacity.get()
}
}
fn find_lru_victim<K, V>(map: &HashMap<K, Entry<V>>) -> Option<K>
where
K: Clone,
{
map.iter()
.min_by_key(|(_, e)| e.last_access)
.map(|(k, _)| k.clone())
}
struct CountMinSketch {
counters: Vec<u8>,
width: usize,
width_u64: u64,
samples: u64,
sample_window: u64,
}
impl CountMinSketch {
fn new(capacity: usize) -> Self {
let mut width = capacity.saturating_mul(2).max(MIN_SKETCH_WIDTH);
width = width.next_power_of_two();
let sample_window = (capacity as u64).saturating_mul(10).max(64);
Self {
counters: vec![0; width.saturating_mul(SKETCH_DEPTH)],
width,
width_u64: width as u64,
samples: 0,
sample_window,
}
}
fn estimate<K: Hash>(&self, key: &K) -> u8 {
let mut min = u8::MAX;
for d in 0..SKETCH_DEPTH {
let idx = self.cell(d, key);
let observed = *self.counters.get(idx).unwrap_or(&0);
if observed < min {
min = observed;
}
}
min
}
fn increment<K: Hash>(&mut self, key: &K) {
for d in 0..SKETCH_DEPTH {
let idx = self.cell(d, key);
if let Some(slot) = self.counters.get_mut(idx) {
*slot = slot.saturating_add(1);
}
}
self.samples = self.samples.saturating_add(1);
if self.samples >= self.sample_window {
self.age();
self.samples = 0;
}
}
fn reset(&mut self) {
for c in self.counters.iter_mut() {
*c = 0;
}
self.samples = 0;
}
fn age(&mut self) {
for c in self.counters.iter_mut() {
*c >>= 1;
}
}
fn cell<K: Hash>(&self, d: usize, key: &K) -> usize {
let h = hash_with_seed(key, d as u64);
let col = (h % self.width_u64) as usize;
d.saturating_mul(self.width).saturating_add(col)
}
}
fn hash_with_seed<K: Hash>(key: &K, seed: u64) -> u64 {
let mut hasher = DefaultHasher::new();
seed.hash(&mut hasher);
key.hash(&mut hasher);
hasher.finish()
}