use core::hash::Hash;
use std::collections::{BTreeMap, HashMap};
use std::num::NonZeroUsize;
use crate::cache::Cache;
use crate::error::CacheError;
use crate::sharding::{self, Sharded};
use crate::util::MutexExt;
pub struct LfuCache<K, V> {
capacity: NonZeroUsize,
sharded: Sharded<Inner<K, V>>,
}
struct Entry<V> {
value: V,
count: u64,
age: u64,
}
struct Inner<K, V> {
capacity: NonZeroUsize,
map: HashMap<K, Entry<V>>,
by_priority: BTreeMap<(u64, u64), K>,
clock: u64,
}
impl<K, V> Inner<K, V>
where
K: Eq + Hash + Clone,
{
fn with_capacity(capacity: NonZeroUsize) -> Self {
let cap = capacity.get();
Self {
capacity,
map: HashMap::with_capacity(cap),
by_priority: BTreeMap::new(),
clock: 0,
}
}
fn tick(&mut self) -> u64 {
self.clock = self.clock.wrapping_add(1);
self.clock
}
}
impl<K, V> LfuCache<K, V>
where
K: Eq + Hash + Clone,
V: Clone,
{
pub fn new(capacity: usize) -> Result<Self, CacheError> {
let cap = NonZeroUsize::new(capacity).ok_or(CacheError::InvalidCapacity)?;
Ok(Self::with_capacity(cap))
}
pub fn with_capacity(capacity: NonZeroUsize) -> Self {
let num_shards = sharding::shard_count(capacity);
let per_shard = sharding::per_shard_capacity(capacity, num_shards);
let sharded = Sharded::from_factory(num_shards, |_| Inner::with_capacity(per_shard));
Self { capacity, sharded }
}
}
impl<K, V> Cache<K, V> for LfuCache<K, V>
where
K: Eq + Hash + Clone,
V: Clone,
{
fn get(&self, key: &K) -> Option<V> {
let mut inner = self.sharded.shard_for(key).lock_recover();
let new_age = inner.tick();
let (old_priority, new_priority, value) = {
let entry = inner.map.get_mut(key)?;
let old = (entry.count, entry.age);
entry.count = entry.count.saturating_add(1);
entry.age = new_age;
let new = (entry.count, entry.age);
(old, new, entry.value.clone())
};
let _ = inner.by_priority.remove(&old_priority);
let _ = inner.by_priority.insert(new_priority, key.clone());
Some(value)
}
fn insert(&self, key: K, value: V) -> Option<V> {
let mut inner = self.sharded.shard_for(&key).lock_recover();
let new_age = inner.tick();
if let Some(entry) = inner.map.get_mut(&key) {
let old_priority = (entry.count, entry.age);
entry.count = entry.count.saturating_add(1);
entry.age = new_age;
let new_priority = (entry.count, entry.age);
let old_value = core::mem::replace(&mut entry.value, value);
let _ = inner.by_priority.remove(&old_priority);
let _ = inner.by_priority.insert(new_priority, key);
return Some(old_value);
}
if inner.map.len() >= inner.capacity.get() {
if let Some((_, victim_key)) = inner.by_priority.pop_first() {
let _ = inner.map.remove(&victim_key);
}
}
let entry = Entry {
value,
count: 1,
age: new_age,
};
let priority = (entry.count, entry.age);
let _ = inner.map.insert(key.clone(), entry);
let _ = inner.by_priority.insert(priority, key);
None
}
fn remove(&self, key: &K) -> Option<V> {
let mut inner = self.sharded.shard_for(key).lock_recover();
let entry = inner.map.remove(key)?;
let _ = inner.by_priority.remove(&(entry.count, entry.age));
Some(entry.value)
}
fn contains_key(&self, key: &K) -> bool {
self.sharded
.shard_for(key)
.lock_recover()
.map
.contains_key(key)
}
fn len(&self) -> usize {
self.sharded
.iter()
.map(|m| m.lock_recover().map.len())
.sum()
}
fn clear(&self) {
for mutex in self.sharded.iter() {
let mut inner = mutex.lock_recover();
inner.map.clear();
inner.by_priority.clear();
inner.clock = 0;
}
}
fn capacity(&self) -> usize {
self.capacity.get()
}
}