use core::hash::Hash;
use std::collections::HashMap;
use std::sync::Mutex;
use crate::cache::Cache;
use crate::error::CacheError;
use crate::util::MutexExt;
pub struct SizedCache<K, V> {
max_weight: usize,
weigher: fn(&V) -> usize,
inner: Mutex<Inner<K, V>>,
}
struct Node<K, V> {
key: K,
value: V,
weight: usize,
prev: Option<usize>,
next: Option<usize>,
}
struct Inner<K, V> {
nodes: Vec<Option<Node<K, V>>>,
free: Vec<usize>,
head: Option<usize>,
tail: Option<usize>,
map: HashMap<K, usize>,
total_weight: usize,
}
impl<K, V> Inner<K, V>
where
K: Eq + Hash + Clone,
{
fn new() -> Self {
Self {
nodes: Vec::new(),
free: Vec::new(),
head: None,
tail: None,
map: HashMap::new(),
total_weight: 0,
}
}
fn alloc(&mut self, node: Node<K, V>) -> usize {
if let Some(idx) = self.free.pop() {
self.nodes[idx] = Some(node);
idx
} else {
self.nodes.push(Some(node));
self.nodes.len() - 1
}
}
fn dealloc(&mut self, idx: usize) -> Node<K, V> {
let node = self.nodes[idx]
.take()
.unwrap_or_else(|| unreachable!("arena slot must be occupied"));
self.free.push(idx);
node
}
fn unlink(&mut self, idx: usize) {
let (prev, next) = {
let n = self.nodes[idx]
.as_ref()
.unwrap_or_else(|| unreachable!("unlink target must be occupied"));
(n.prev, n.next)
};
match prev {
Some(p) => {
self.nodes[p]
.as_mut()
.unwrap_or_else(|| unreachable!())
.next = next
}
None => self.head = next,
}
match next {
Some(n) => {
self.nodes[n]
.as_mut()
.unwrap_or_else(|| unreachable!())
.prev = prev
}
None => self.tail = prev,
}
if let Some(n) = self.nodes[idx].as_mut() {
n.prev = None;
n.next = None;
}
}
fn push_front(&mut self, idx: usize) {
let old_head = self.head;
if let Some(n) = self.nodes[idx].as_mut() {
n.prev = None;
n.next = old_head;
}
if let Some(h) = old_head {
if let Some(n) = self.nodes[h].as_mut() {
n.prev = Some(idx);
}
} else {
self.tail = Some(idx);
}
self.head = Some(idx);
}
fn promote(&mut self, idx: usize) {
if self.head == Some(idx) {
return;
}
self.unlink(idx);
self.push_front(idx);
}
fn evict_tail(&mut self) -> Option<(K, usize)> {
let tail_idx = self.tail?;
self.unlink(tail_idx);
let node = self.dealloc(tail_idx);
self.total_weight = self.total_weight.saturating_sub(node.weight);
Some((node.key, node.weight))
}
}
impl<K, V> SizedCache<K, V>
where
K: Eq + Hash + Clone,
V: Clone,
{
pub fn new(max_weight: usize, weigher: fn(&V) -> usize) -> Result<Self, CacheError> {
if max_weight == 0 {
return Err(CacheError::InvalidCapacity);
}
Ok(Self {
max_weight,
weigher,
inner: Mutex::new(Inner::new()),
})
}
pub fn max_weight(&self) -> usize {
self.max_weight
}
pub fn total_weight(&self) -> usize {
self.inner.lock_recover().total_weight
}
}
impl<K, V> Cache<K, V> for SizedCache<K, V>
where
K: Eq + Hash + Clone,
V: Clone,
{
fn get(&self, key: &K) -> Option<V> {
let mut inner = self.inner.lock_recover();
let idx = *inner.map.get(key)?;
inner.promote(idx);
inner.nodes[idx].as_ref().map(|n| n.value.clone())
}
fn insert(&self, key: K, value: V) -> Option<V> {
let new_weight = (self.weigher)(&value);
if new_weight > self.max_weight {
return None;
}
let mut inner = self.inner.lock_recover();
if let Some(&idx) = inner.map.get(&key) {
let (old_value, old_weight) = inner.nodes[idx]
.as_mut()
.map(|n| {
let ov = core::mem::replace(&mut n.value, value);
let ow = n.weight;
n.weight = new_weight;
(ov, ow)
})
.unwrap_or_else(|| unreachable!("mapped index must be occupied"));
inner.total_weight = inner
.total_weight
.saturating_add(new_weight)
.saturating_sub(old_weight);
inner.promote(idx);
while inner.total_weight > self.max_weight {
if inner.evict_tail().is_none() {
break;
}
}
return Some(old_value);
}
while inner.total_weight.saturating_add(new_weight) > self.max_weight {
match inner.evict_tail() {
Some((evicted_key, _)) => {
let _ = inner.map.remove(&evicted_key);
}
None => break,
}
}
let idx = inner.alloc(Node {
key: key.clone(),
value,
weight: new_weight,
prev: None,
next: None,
});
inner.push_front(idx);
let _ = inner.map.insert(key, idx);
inner.total_weight = inner.total_weight.saturating_add(new_weight);
None
}
fn remove(&self, key: &K) -> Option<V> {
let mut inner = self.inner.lock_recover();
let idx = inner.map.remove(key)?;
inner.unlink(idx);
let node = inner.dealloc(idx);
inner.total_weight = inner.total_weight.saturating_sub(node.weight);
Some(node.value)
}
fn contains_key(&self, key: &K) -> bool {
self.inner.lock_recover().map.contains_key(key)
}
fn len(&self) -> usize {
self.inner.lock_recover().map.len()
}
fn clear(&self) {
let mut inner = self.inner.lock_recover();
inner.nodes.clear();
inner.free.clear();
inner.head = None;
inner.tail = None;
inner.map.clear();
inner.total_weight = 0;
}
fn capacity(&self) -> usize {
self.max_weight
}
}