use core::hash::Hash;
use std::collections::{HashMap, VecDeque};
use std::sync::Mutex;
use crate::cache::Cache;
use crate::error::CacheError;
use crate::util::MutexExt;
pub struct SizedCache<K, V> {
max_weight: usize,
weigher: fn(&V) -> usize,
inner: Mutex<Inner<K, V>>,
}
struct Entry<V> {
value: V,
weight: usize,
}
struct Inner<K, V> {
map: HashMap<K, Entry<V>>,
order: VecDeque<K>,
total_weight: usize,
}
impl<K, V> SizedCache<K, V>
where
K: Eq + Hash + Clone,
V: Clone,
{
pub fn new(max_weight: usize, weigher: fn(&V) -> usize) -> Result<Self, CacheError> {
if max_weight == 0 {
return Err(CacheError::InvalidCapacity);
}
Ok(Self {
max_weight,
weigher,
inner: Mutex::new(Inner {
map: HashMap::new(),
order: VecDeque::new(),
total_weight: 0,
}),
})
}
pub fn max_weight(&self) -> usize {
self.max_weight
}
pub fn total_weight(&self) -> usize {
self.inner.lock_recover().total_weight
}
}
impl<K, V> Cache<K, V> for SizedCache<K, V>
where
K: Eq + Hash + Clone,
V: Clone,
{
fn get(&self, key: &K) -> Option<V> {
let mut inner = self.inner.lock_recover();
let value = inner.map.get(key)?.value.clone();
promote(&mut inner.order, key);
Some(value)
}
fn insert(&self, key: K, value: V) -> Option<V> {
let new_weight = (self.weigher)(&value);
if new_weight > self.max_weight {
return None;
}
let mut inner = self.inner.lock_recover();
if let Some(existing) = inner.map.get_mut(&key) {
let old_value = core::mem::replace(&mut existing.value, value);
let old_weight = existing.weight;
existing.weight = new_weight;
inner.total_weight = inner
.total_weight
.saturating_add(new_weight)
.saturating_sub(old_weight);
promote(&mut inner.order, &key);
evict_until_fits(&mut inner, self.max_weight);
return Some(old_value);
}
let projected_total = inner.total_weight.saturating_add(new_weight);
if projected_total > self.max_weight {
evict_until_fits_for_new(&mut inner, self.max_weight, new_weight);
}
inner.order.push_front(key.clone());
let _ = inner.map.insert(
key,
Entry {
value,
weight: new_weight,
},
);
inner.total_weight = inner.total_weight.saturating_add(new_weight);
None
}
fn remove(&self, key: &K) -> Option<V> {
let mut inner = self.inner.lock_recover();
let entry = inner.map.remove(key)?;
inner.total_weight = inner.total_weight.saturating_sub(entry.weight);
if let Some(pos) = inner.order.iter().position(|k| k == key) {
let _ = inner.order.remove(pos);
}
Some(entry.value)
}
fn contains_key(&self, key: &K) -> bool {
self.inner.lock_recover().map.contains_key(key)
}
fn len(&self) -> usize {
self.inner.lock_recover().map.len()
}
fn clear(&self) {
let mut inner = self.inner.lock_recover();
inner.map.clear();
inner.order.clear();
inner.total_weight = 0;
}
fn capacity(&self) -> usize {
self.max_weight
}
}
fn promote<K: Eq>(order: &mut VecDeque<K>, key: &K) {
if let Some(pos) = order.iter().position(|k| k == key) {
if let Some(k) = order.remove(pos) {
order.push_front(k);
}
}
}
fn evict_until_fits<K, V>(inner: &mut Inner<K, V>, max_weight: usize)
where
K: Eq + Hash,
{
while inner.total_weight > max_weight {
let Some(victim_key) = inner.order.pop_back() else {
break;
};
if let Some(victim) = inner.map.remove(&victim_key) {
inner.total_weight = inner.total_weight.saturating_sub(victim.weight);
}
}
}
fn evict_until_fits_for_new<K, V>(inner: &mut Inner<K, V>, max_weight: usize, incoming: usize)
where
K: Eq + Hash,
{
while inner.total_weight.saturating_add(incoming) > max_weight {
let Some(victim_key) = inner.order.pop_back() else {
break;
};
if let Some(victim) = inner.map.remove(&victim_key) {
inner.total_weight = inner.total_weight.saturating_sub(victim.weight);
}
}
}