use ahash::RandomState;
use crossbeam_queue::SegQueue;
use std::marker::PhantomData;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::{
AtomicBool, AtomicU8,
Ordering::{Acquire, Relaxed, SeqCst},
};
mod buckets;
mod estimation;
use buckets::Buckets;
use estimation::TinyLfu;
use std::hash::Hash;
const SMALL: bool = false;
const MAIN: bool = true;
#[derive(Debug, Default)]
struct Location(AtomicBool);
impl Location {
fn new_small() -> Self {
Self(AtomicBool::new(SMALL))
}
fn value(&self) -> bool {
self.0.load(Relaxed)
}
fn is_main(&self) -> bool {
self.value()
}
fn move_to_main(&self) {
self.0.store(true, Relaxed);
}
}
const USES_CAP: u8 = 3;
#[derive(Debug, Default)]
struct Uses(AtomicU8);
impl Uses {
pub fn inc_uses(&self) -> u8 {
loop {
let uses = self.uses();
if uses >= USES_CAP {
return uses;
}
if let Err(new) = self.0.compare_exchange(uses, uses + 1, Acquire, Relaxed) {
if new >= USES_CAP {
return new;
} } else {
return uses + 1;
}
}
}
pub fn decr_uses(&self) -> u8 {
loop {
let uses = self.uses();
if uses == 0 {
return 0;
}
if let Err(new) = self.0.compare_exchange(uses, uses - 1, Acquire, Relaxed) {
if new == 0 {
return 0;
} } else {
return uses;
}
}
}
pub fn uses(&self) -> u8 {
self.0.load(Relaxed)
}
}
type Key = u64;
type Weight = u16;
#[derive(Clone)]
pub struct KV<T> {
pub key: Key,
pub data: T,
pub weight: Weight,
}
pub struct Bucket<T> {
uses: Uses,
queue: Location,
weight: Weight,
data: T,
}
const SMALL_QUEUE_PERCENTAGE: f32 = 0.1;
struct FiFoQueues<T> {
total_weight_limit: usize,
small: SegQueue<Key>,
small_weight: AtomicUsize,
main: SegQueue<Key>,
main_weight: AtomicUsize,
estimator: TinyLfu,
_t: PhantomData<T>,
}
impl<T: Clone + Send + Sync + 'static> FiFoQueues<T> {
fn admit(
&self,
key: Key,
data: T,
weight: u16,
ignore_lfu: bool,
buckets: &Buckets<T>,
) -> Vec<KV<T>> {
let new_freq = self.estimator.incr(key);
assert!(weight > 0);
let new_bucket = {
let Some((uses, queue, weight)) = buckets.get_map(&key, |bucket| {
let old_weight = bucket.weight;
let uses = bucket.uses.inc_uses();
fn update_atomic(weight: &AtomicUsize, old: u16, new: u16) {
if old == new {
return;
}
if old > new {
weight.fetch_sub((old - new) as usize, SeqCst);
} else {
weight.fetch_add((new - old) as usize, SeqCst);
}
}
let queue = bucket.queue.is_main();
if queue == MAIN {
update_atomic(&self.main_weight, old_weight, weight);
} else {
update_atomic(&self.small_weight, old_weight, weight);
}
(uses, queue, weight)
}) else {
let mut evicted = self.evict_to_limit(weight, buckets);
let (key, data, weight) = if !ignore_lfu && evicted.len() == 1 {
let evicted_first = &evicted[0];
let evicted_freq = self.estimator.get(evicted_first.key);
if evicted_freq > new_freq {
let first = evicted.pop().expect("just check non-empty");
evicted.push(KV { key, data, weight });
(first.key, first.data, first.weight)
} else {
(key, data, weight)
}
} else {
(key, data, weight)
};
let bucket = Bucket {
queue: Location::new_small(),
weight,
uses: Default::default(), data,
};
let old = buckets.insert(key, bucket);
if old.is_none() {
self.small.push(key);
self.small_weight.fetch_add(weight as usize, SeqCst);
} return evicted;
};
Bucket {
queue: Location(queue.into()),
weight,
uses: Uses(uses.into()),
data,
}
};
buckets.insert(key, new_bucket);
self.evict_to_limit(0, buckets)
}
fn evict_to_limit(&self, extra_weight: Weight, buckets: &Buckets<T>) -> Vec<KV<T>> {
let mut evicted = if self.total_weight_limit
< self.small_weight.load(SeqCst) + self.main_weight.load(SeqCst) + extra_weight as usize
{
Vec::with_capacity(1)
} else {
vec![]
};
while self.total_weight_limit
< self.small_weight.load(SeqCst) + self.main_weight.load(SeqCst) + extra_weight as usize
{
if let Some(evicted_item) = self.evict_one(buckets) {
evicted.push(evicted_item);
} else {
break;
}
}
evicted
}
fn evict_one(&self, buckets: &Buckets<T>) -> Option<KV<T>> {
let evict_small = self.small_weight_limit() <= self.small_weight.load(SeqCst);
if evict_small {
let evicted = self.evict_one_from_small(buckets);
if evicted.is_some() {
return evicted;
}
}
self.evict_one_from_main(buckets)
}
fn small_weight_limit(&self) -> usize {
(self.total_weight_limit as f32 * SMALL_QUEUE_PERCENTAGE).floor() as usize + 1
}
fn evict_one_from_small(&self, buckets: &Buckets<T>) -> Option<KV<T>> {
loop {
let Some(to_evict) = self.small.pop() else {
return None;
};
let v = buckets
.get_map(&to_evict, |bucket| {
let weight = bucket.weight;
self.small_weight.fetch_sub(weight as usize, SeqCst);
if bucket.uses.uses() > 1 {
bucket.queue.move_to_main();
self.main.push(to_evict);
self.main_weight.fetch_add(weight as usize, SeqCst);
None
} else {
let data = bucket.data.clone();
let weight = bucket.weight;
buckets.remove(&to_evict);
Some(KV {
key: to_evict,
data,
weight,
})
}
})
.flatten();
if v.is_some() {
return v;
}
}
}
fn evict_one_from_main(&self, buckets: &Buckets<T>) -> Option<KV<T>> {
loop {
let to_evict = self.main.pop()?;
if let Some(v) = buckets
.get_map(&to_evict, |bucket| {
if bucket.uses.decr_uses() > 0 {
self.main.push(to_evict);
None
} else {
let weight = bucket.weight;
self.main_weight.fetch_sub(weight as usize, SeqCst);
let data = bucket.data.clone();
buckets.remove(&to_evict);
Some(KV {
key: to_evict,
data,
weight,
})
}
})
.flatten()
{
return Some(v);
}
}
}
}
pub struct TinyUfo<K, T> {
queues: FiFoQueues<T>,
buckets: Buckets<T>,
random_status: RandomState,
_k: PhantomData<K>,
}
impl<K: Hash, T: Clone + Send + Sync + 'static> TinyUfo<K, T> {
pub fn new(total_weight_limit: usize, estimated_size: usize) -> Self {
let queues = FiFoQueues {
small: SegQueue::new(),
small_weight: 0.into(),
main: SegQueue::new(),
main_weight: 0.into(),
total_weight_limit,
estimator: TinyLfu::new(estimated_size),
_t: PhantomData,
};
TinyUfo {
queues,
buckets: Buckets::new_fast(estimated_size),
random_status: RandomState::new(),
_k: PhantomData,
}
}
pub fn new_compact(total_weight_limit: usize, estimated_size: usize) -> Self {
let queues = FiFoQueues {
small: SegQueue::new(),
small_weight: 0.into(),
main: SegQueue::new(),
main_weight: 0.into(),
total_weight_limit,
estimator: TinyLfu::new_compact(estimated_size),
_t: PhantomData,
};
TinyUfo {
queues,
buckets: Buckets::new_compact(estimated_size, 32),
random_status: RandomState::new(),
_k: PhantomData,
}
}
pub fn get(&self, key: &K) -> Option<T> {
let key = self.random_status.hash_one(key);
self.buckets.get_map(&key, |p| {
p.uses.inc_uses();
p.data.clone()
})
}
pub fn put(&self, key: K, data: T, weight: Weight) -> Vec<KV<T>> {
let key = self.random_status.hash_one(&key);
self.queues.admit(key, data, weight, false, &self.buckets)
}
pub fn remove(&self, key: &K) -> Option<T> {
let key = self.random_status.hash_one(key);
let result = self.buckets.get_map(&key, |bucket| {
let data = bucket.data.clone();
let weight = bucket.weight;
if bucket.queue.is_main() {
self.queues.main_weight.fetch_sub(weight as usize, SeqCst);
} else {
self.queues.small_weight.fetch_sub(weight as usize, SeqCst);
}
data
});
if result.is_some() {
self.buckets.remove(&key);
}
result
}
pub fn force_put(&self, key: K, data: T, weight: Weight) -> Vec<KV<T>> {
let key = self.random_status.hash_one(&key);
self.queues.admit(key, data, weight, true, &self.buckets)
}
#[cfg(test)]
fn peek_queue(&self, key: K) -> Option<bool> {
let key = self.random_status.hash_one(&key);
self.buckets.get_queue(&key)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_uses() {
let uses: Uses = Default::default();
assert_eq!(uses.uses(), 0);
uses.inc_uses();
assert_eq!(uses.uses(), 1);
for _ in 0..USES_CAP {
uses.inc_uses();
}
assert_eq!(uses.uses(), USES_CAP);
for _ in 0..USES_CAP + 2 {
uses.decr_uses();
}
assert_eq!(uses.uses(), 0);
}
#[test]
fn test_evict_from_small() {
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.queues.estimator = TinyLfu::new_seeded(5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);
cache.put(3, 3, 2);
assert_eq!(cache.peek_queue(1), Some(SMALL));
assert_eq!(cache.peek_queue(2), Some(SMALL));
assert_eq!(cache.peek_queue(3), Some(SMALL));
let evicted = cache.put(4, 4, 3);
assert_eq!(evicted.len(), 2);
assert_eq!(evicted[0].data, 1);
assert_eq!(evicted[1].data, 2);
assert_eq!(cache.peek_queue(1), None);
assert_eq!(cache.peek_queue(2), None);
assert_eq!(cache.peek_queue(3), Some(SMALL));
}
#[test]
fn test_evict_from_small_to_main() {
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.queues.estimator = TinyLfu::new_seeded(5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);
cache.put(3, 3, 2);
cache.get(&1);
cache.get(&1);
assert_eq!(cache.peek_queue(1), Some(SMALL));
assert_eq!(cache.peek_queue(2), Some(SMALL));
assert_eq!(cache.peek_queue(3), Some(SMALL));
let evicted = cache.put(4, 4, 2);
assert_eq!(evicted.len(), 1);
assert_eq!(evicted[0].weight, 2);
assert_eq!(cache.peek_queue(1), Some(MAIN));
let mut remaining = vec![2, 3, 4];
remaining.remove(
remaining
.iter()
.position(|x| *x == evicted[0].data)
.unwrap(),
);
assert_eq!(cache.peek_queue(evicted[0].key), None);
for k in remaining {
assert_eq!(cache.peek_queue(k), Some(SMALL));
}
}
#[test]
fn test_evict_reentry() {
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.queues.estimator = TinyLfu::new_seeded(5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);
cache.put(3, 3, 2);
assert_eq!(cache.peek_queue(1), Some(SMALL));
assert_eq!(cache.peek_queue(2), Some(SMALL));
assert_eq!(cache.peek_queue(3), Some(SMALL));
let evicted = cache.put(4, 4, 1);
assert_eq!(evicted.len(), 1);
assert_eq!(evicted[0].data, 1);
assert_eq!(cache.peek_queue(1), None);
assert_eq!(cache.peek_queue(2), Some(SMALL));
assert_eq!(cache.peek_queue(3), Some(SMALL));
assert_eq!(cache.peek_queue(4), Some(SMALL));
let evicted = cache.put(1, 1, 1);
assert_eq!(evicted.len(), 1);
assert_eq!(evicted[0].data, 2);
assert_eq!(cache.peek_queue(1), Some(SMALL));
assert_eq!(cache.peek_queue(2), None);
assert_eq!(cache.peek_queue(3), Some(SMALL));
assert_eq!(cache.peek_queue(4), Some(SMALL));
}
#[test]
fn test_evict_entry_denied() {
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.queues.estimator = TinyLfu::new_seeded(5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);
cache.put(3, 3, 2);
assert_eq!(cache.peek_queue(1), Some(SMALL));
assert_eq!(cache.peek_queue(2), Some(SMALL));
assert_eq!(cache.peek_queue(3), Some(SMALL));
cache.put(1, 1, 1);
cache.put(2, 2, 2);
cache.put(3, 3, 2);
let evicted = cache.put(4, 4, 1);
assert_eq!(evicted.len(), 1);
assert_eq!(evicted[0].data, 4);
assert_eq!(cache.peek_queue(1), Some(SMALL));
assert_eq!(cache.peek_queue(2), Some(SMALL));
assert_eq!(cache.peek_queue(3), Some(SMALL));
assert_eq!(cache.peek_queue(4), None);
}
#[test]
fn test_force_put() {
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.queues.estimator = TinyLfu::new_seeded(5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);
cache.put(3, 3, 2);
assert_eq!(cache.peek_queue(1), Some(SMALL));
assert_eq!(cache.peek_queue(2), Some(SMALL));
assert_eq!(cache.peek_queue(3), Some(SMALL));
cache.put(1, 1, 1);
cache.put(2, 2, 2);
cache.put(3, 3, 2);
let evicted = cache.force_put(4, 4, 1);
assert_eq!(evicted.len(), 1);
assert_eq!(evicted[0].data, 1);
assert_eq!(cache.peek_queue(1), None);
assert_eq!(cache.peek_queue(2), Some(SMALL));
assert_eq!(cache.peek_queue(3), Some(SMALL));
assert_eq!(cache.peek_queue(4), Some(SMALL));
}
#[test]
fn test_evict_from_main() {
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.queues.estimator = TinyLfu::new_seeded(5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);
cache.put(3, 3, 2);
cache.get(&1);
cache.get(&1);
cache.get(&2);
cache.get(&2);
cache.get(&3);
cache.get(&3);
let evicted = cache.put(4, 4, 1);
assert_eq!(evicted.len(), 1);
assert_eq!(evicted[0].data, 1);
assert_eq!(cache.peek_queue(1), None);
assert_eq!(cache.peek_queue(2), Some(MAIN));
assert_eq!(cache.peek_queue(3), Some(MAIN));
assert_eq!(cache.peek_queue(4), Some(SMALL));
let evicted = cache.put(1, 1, 1);
assert_eq!(evicted.len(), 1);
assert_eq!(evicted[0].data, 4);
assert_eq!(cache.peek_queue(1), Some(SMALL));
assert_eq!(cache.peek_queue(2), Some(MAIN));
assert_eq!(cache.peek_queue(3), Some(MAIN));
assert_eq!(cache.peek_queue(4), None);
}
#[test]
fn test_evict_from_small_compact() {
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.queues.estimator = TinyLfu::new_compact_seeded(5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);
cache.put(3, 3, 2);
assert_eq!(cache.peek_queue(1), Some(SMALL));
assert_eq!(cache.peek_queue(2), Some(SMALL));
assert_eq!(cache.peek_queue(3), Some(SMALL));
let evicted = cache.put(4, 4, 3);
assert_eq!(evicted.len(), 2);
assert_eq!(evicted[0].data, 1);
assert_eq!(evicted[1].data, 2);
assert_eq!(cache.peek_queue(1), None);
assert_eq!(cache.peek_queue(2), None);
assert_eq!(cache.peek_queue(3), Some(SMALL));
}
#[test]
fn test_evict_from_small_to_main_compact() {
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.queues.estimator = TinyLfu::new_compact_seeded(5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);
cache.put(3, 3, 2);
cache.get(&1);
cache.get(&1);
assert_eq!(cache.peek_queue(1), Some(SMALL));
assert_eq!(cache.peek_queue(2), Some(SMALL));
assert_eq!(cache.peek_queue(3), Some(SMALL));
let evicted = cache.put(4, 4, 2);
assert_eq!(evicted.len(), 1);
assert_eq!(evicted[0].weight, 2);
assert_eq!(cache.peek_queue(1), Some(MAIN));
let mut remaining = vec![2, 3, 4];
remaining.remove(
remaining
.iter()
.position(|x| *x == evicted[0].data)
.unwrap(),
);
assert_eq!(cache.peek_queue(evicted[0].key), None);
for k in remaining {
assert_eq!(cache.peek_queue(k), Some(SMALL));
}
}
#[test]
fn test_remove() {
let mut cache = TinyUfo::new(5, 5);
cache.random_status = RandomState::with_seeds(2, 3, 4, 5);
cache.put(1, 1, 1);
cache.put(2, 2, 2);
cache.put(3, 3, 2);
assert_eq!(cache.remove(&1), Some(1));
assert_eq!(cache.remove(&3), Some(3));
assert_eq!(cache.get(&1), None);
assert_eq!(cache.get(&3), None);
cache.put(5, 5, 2);
cache.put(6, 6, 2);
cache.put(7, 7, 2);
assert_eq!(cache.get(&1), None);
assert_eq!(cache.get(&3), None);
assert!(cache.get(&5).is_some() || cache.get(&6).is_some() || cache.get(&7).is_some());
let total_weight =
cache.queues.small_weight.load(SeqCst) + cache.queues.main_weight.load(SeqCst);
assert!(total_weight <= 5); }
}