use std::borrow::Borrow;
use std::collections::HashMap;
use std::fmt;
use std::hash::Hash;
use std::mem;
use std::ops::Deref;
use async_trait::async_trait;
use futures::join;
use uplock::RwLock;
pub trait Entry: Clone {
fn weight(&self) -> u64;
}
#[async_trait]
pub trait Policy<K, V>: Sized + Send {
fn can_evict(&self, value: &V) -> bool;
async fn evict(&self, key: K, value: &V);
}
struct Item<K, V> {
key: K,
value: V,
prev: Option<RwLock<Self>>,
next: Option<RwLock<Self>>,
}
impl<K, V> Deref for Item<K, V> {
type Target = V;
fn deref(&self) -> &V {
&self.value
}
}
pub struct LFUCache<K, V, P> {
cache: HashMap<K, RwLock<Item<K, V>>>,
first: Option<RwLock<Item<K, V>>>,
last: Option<RwLock<Item<K, V>>>,
occupied: i64,
capacity: i64,
policy: P,
}
impl<K: Clone + Eq + Hash, V: Entry, P: Policy<K, V>> LFUCache<K, V, P> {
pub fn new(capacity: u64, policy: P) -> Self {
Self {
cache: HashMap::new(),
first: None,
last: None,
occupied: 0,
capacity: capacity as i64,
policy,
}
}
pub fn contains_key<Q: ?Sized>(&mut self, key: &Q) -> bool
where
K: Borrow<Q>,
Q: Hash + Eq,
{
self.cache.contains_key(key)
}
pub async fn get<Q: ?Sized>(
&mut self,
key: &Q,
) -> Option<impl Deref<Target = impl Deref<Target = V>>>
where
K: Borrow<Q>,
Q: Hash + Eq,
{
if let Some(item) = self.cache.get(key) {
let (last, first) = bump(item).await;
if last.is_some() {
self.last = last;
}
if first.is_some() {
self.first = first;
}
Some(item.read().await)
} else {
None
}
}
pub async fn insert(&mut self, key: K, value: V) -> bool {
if let Some(item) = self.cache.get(&key) {
let (last, first) = bump(item).await;
if last.is_some() {
self.last = last;
}
if first.is_some() {
self.first = first;
}
let mut lock = item.write().await;
lock.value = value;
true
} else {
let mut last = None;
mem::swap(&mut self.last, &mut last);
self.occupied += value.weight() as i64;
let item = RwLock::new(Item {
key: key.clone(),
value,
prev: None,
next: last,
});
if let Some(next) = &item.write().await.next {
next.write().await.prev = Some(item.clone());
}
self.cache.insert(key, item.clone());
if self.first.is_none() {
self.first = Some(item.clone());
}
self.last = Some(item);
false
}
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn is_full(&self) -> bool {
self.occupied >= self.capacity
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub async fn remove<Q>(&mut self, key: &Q) -> Option<V>
where
K: Borrow<Q>,
Q: Hash + Eq,
{
if let Some(item) = self.cache.remove(key) {
let mut item_lock = item.write().await;
if item_lock.prev.is_none() && item_lock.next.is_none() {
self.last = None;
self.first = None;
} else if item_lock.prev.is_none() {
self.last = item_lock.next.clone();
let mut next = item_lock.next.as_ref().unwrap().write().await;
mem::swap(&mut next.prev, &mut item_lock.prev);
} else if item_lock.next.is_none() {
self.first = item_lock.prev.clone();
let mut prev = item_lock.prev.as_ref().unwrap().write().await;
mem::swap(&mut prev.next, &mut item_lock.next);
} else {
let (mut prev, mut next) = join!(
item_lock.prev.as_ref().unwrap().write(),
item_lock.next.as_ref().unwrap().write()
);
mem::swap(&mut next.prev, &mut item_lock.prev);
mem::swap(&mut prev.next, &mut item_lock.next);
}
self.occupied -= item_lock.value.weight() as i64;
Some(item_lock.value.clone())
} else {
None
}
}
pub async fn traverse<C: FnMut(&V) -> () + Send>(&self, mut f: C) {
let mut next = self.last.clone();
while let Some(item) = next {
let lock = item.read().await;
f(&lock.value);
next = lock.next.clone();
}
}
pub async fn evict(&mut self) where K: fmt::Debug {
let mut next = self.last.clone();
while let Some(item) = next {
let lock = item.read().await;
next = lock.next.clone();
if !self.policy.can_evict(&lock.value) {
continue;
}
let (key, _) = self.cache.remove_entry(&lock.key).expect("cache key");
let mut lock = lock.upgrade().await;
self.policy.evict(key, &lock.value).await;
if let Some(prev) = &lock.prev {
let mut prev = prev.write().await;
mem::swap(&mut lock.next, &mut prev.next);
} else {
self.last = next.clone();
}
if let Some(next) = &next {
let mut next = next.write().await;
mem::swap(&mut lock.prev, &mut next.prev);
} else {
self.first = lock.prev.clone();
}
self.capacity -= lock.value.weight() as i64;
if !self.is_full() {
break;
}
}
}
}
async fn bump<K, V>(
item: &RwLock<Item<K, V>>,
) -> (Option<RwLock<Item<K, V>>>, Option<RwLock<Item<K, V>>>) {
let mut item_lock = item.write().await;
let last = if item_lock.next.is_none() {
return (None, None);
} else if item_lock.prev.is_none() && item_lock.next.is_some() {
let mut next_lock = item_lock.next.as_ref().unwrap().write().await;
mem::swap(&mut next_lock.prev, &mut item_lock.prev);
mem::swap(&mut item_lock.next, &mut next_lock.next);
mem::swap(&mut item_lock.prev, &mut next_lock.next);
item_lock.prev.clone()
} else {
let (mut prev_lock, mut next_lock) = join!(
item_lock.prev.as_ref().unwrap().write(),
item_lock.next.as_ref().unwrap().write()
);
let next = item_lock.next.clone();
mem::swap(&mut prev_lock.next, &mut item_lock.next);
mem::swap(&mut item_lock.next, &mut next_lock.next);
mem::swap(&mut next_lock.prev, &mut item_lock.prev);
item_lock.prev = next;
None
};
let first = if item_lock.next.is_some() {
let mut skip_lock = item_lock.next.as_ref().unwrap().write().await;
skip_lock.prev = Some(item.clone());
None
} else {
Some(item.clone())
};
(last, first)
}
#[cfg(test)]
mod tests {
use std::fmt;
use rand::{thread_rng, Rng};
use super::*;
impl Entry for i32 {
fn weight(&self) -> u64 {
2
}
}
struct Evict;
#[async_trait]
impl Policy<i32, i32> for Evict {
fn can_evict(&self, _value: &i32) -> bool {
true
}
async fn evict(&self, _key: i32, _value: &i32) {
}
}
#[allow(dead_code)]
async fn print_debug<K, V: fmt::Display, P>(cache: &LFUCache<K, V, P>) {
let mut next = cache.last.clone();
while let Some(item) = next {
let lock = item.read().await;
if let Some(item) = lock.prev.as_ref() {
print!("{}-", item.read().await.value);
}
print!("{}", lock.value);
next = lock.next.clone();
if let Some(item) = &next {
print!("-{}", item.read().await.value);
}
print!(" ");
}
println!();
}
async fn validate<K: Clone + Eq + Hash, V: Entry + Copy + Eq + fmt::Debug, P: Policy<K, V>>(
cache: &LFUCache<K, V, P>,
) {
if cache.is_empty() {
assert!(cache.first.is_none());
assert!(cache.last.is_none());
} else {
assert!(cache.first.as_ref().unwrap().read().await.next.is_none());
assert!(cache.last.as_ref().unwrap().read().await.prev.is_none());
}
let mut last = None;
let mut next = cache.last.clone();
while let Some(item) = next {
let lock = item.read().await;
if let Some(last) = last {
assert_eq!(lock.prev.as_ref().unwrap().read().await.value, last);
}
last = Some(lock.value);
next = lock.next.clone();
}
}
#[tokio::test]
async fn test_order() {
let mut cache = LFUCache::new(100, Evict);
let expected: Vec<i32> = (0..10).collect();
for i in expected.iter().rev() {
cache.insert(*i, *i).await;
}
let mut actual = Vec::with_capacity(expected.len());
cache.traverse(|i| actual.push(*i)).await;
assert_eq!(actual, expected)
}
#[tokio::test]
async fn test_access() {
let mut cache = LFUCache::new(100, Evict);
let mut rng = thread_rng();
for _ in 0..100_000 {
let i: i32 = rng.gen_range(0..10);
cache.insert(i, i).await;
validate(&mut cache).await;
let i: i32 = rng.gen_range(0..10);
cache.remove(&i).await;
validate(&mut cache).await;
if cache.is_full() {
cache.evict().await;
}
assert!(!cache.is_full());
let mut size = 0;
cache.traverse(|_| size += 1).await;
assert_eq!(cache.len(), size);
}
}
}