use super::CachePolicy;
use super::slru::SlruPolicy;
use crate::policy::AdmissionDecision;
use crate::policy::lru_list::LruList;
use parking_lot::Mutex;
use std::hash::Hash;
#[derive(Debug)]
pub struct TinyLfuPolicy<K: Eq + Hash + Clone> {
sketch: cms::CountMinSketch,
window: Mutex<LruList<K>>,
main: SlruPolicy<K>,
window_target_cost: u64,
}
impl<K: Eq + Hash + Clone> TinyLfuPolicy<K> {
pub fn new(total_cache_cost_capacity: u64) -> Self {
let cms_reset_threshold = (total_cache_cost_capacity * 10).max(100);
let window_target_cost = if total_cache_cost_capacity == 0 {
0
} else {
((total_cache_cost_capacity as f64 * 0.01).round() as u64).max(1)
};
let main_cache_cost = total_cache_cost_capacity.saturating_sub(window_target_cost);
Self {
sketch: cms::CountMinSketch::new(cms_reset_threshold as usize),
window: Mutex::new(LruList::new()),
main: SlruPolicy::new(main_cache_cost),
window_target_cost,
}
}
fn access_internal(&self, key: &K, cost: u64) {
self.sketch.increment(key);
let mut window_guard = self.window.lock();
if window_guard.contains(key) {
window_guard.push_front(key.clone(), cost);
return;
}
drop(window_guard);
self.main.access_internal(key, cost);
}
}
impl<K, V> CachePolicy<K, V> for TinyLfuPolicy<K>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
V: Send + Sync + 'static,
{
fn on_access(&self, key: &K, cost: u64) {
self.access_internal(key, cost);
}
fn on_admit(&self, key: &K, cost: u64) -> AdmissionDecision<K> {
self.sketch.increment(key);
let is_in_main = {
let prob_guard = self.main.probationary.lock();
if prob_guard.contains(key) {
true
} else {
drop(prob_guard);
self.main.protected.lock().contains(key)
}
};
if is_in_main {
self.main.access_internal(key, cost);
return AdmissionDecision::Admit;
}
let mut window_guard = self.window.lock();
window_guard.push_front(key.clone(), cost);
let mut rejected_candidates = Vec::new();
while window_guard.current_total_cost() > self.window_target_cost {
let (candidate_key, candidate_cost) = match window_guard.pop_back() {
Some(c) => c,
None => break, };
drop(window_guard);
let main_victim_key_opt = self.main.peek_lru();
let admit_candidate = main_victim_key_opt
.as_ref()
.map_or(true, |main_victim_key| {
self.sketch.estimate(&candidate_key) >= self.sketch.estimate(main_victim_key)
});
if admit_candidate {
self.main.admit_internal(candidate_key, candidate_cost);
} else {
rejected_candidates.push(candidate_key);
}
window_guard = self.window.lock();
}
if rejected_candidates.is_empty() {
AdmissionDecision::Admit
} else {
AdmissionDecision::AdmitAndEvict(rejected_candidates)
}
}
fn on_remove(&self, key: &K) {
if self.window.lock().remove(key).is_some() {
return;
}
<SlruPolicy<K> as CachePolicy<K, V>>::on_remove(&self.main, key);
}
fn evict(&self, cost_to_free: u64) -> (Vec<K>, u64) {
if cost_to_free == 0 {
return (Vec::new(), 0);
}
let (main_victims, main_cost_freed) =
<SlruPolicy<K> as CachePolicy<K, V>>::evict(&self.main, cost_to_free);
(main_victims, main_cost_freed)
}
fn clear(&self) {
self.window.lock().clear();
<SlruPolicy<K> as CachePolicy<K, V>>::clear(&self.main);
self.sketch.clear();
}
}
mod cms {
use std::hash::{BuildHasher, Hash, Hasher};
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
pub(super) struct CountMinSketch {
counters: Vec<Vec<AtomicUsize>>,
hashers: Vec<ahash::RandomState>,
increments: AtomicUsize,
capacity: usize,
}
impl CountMinSketch {
pub fn new(reset_threshold: usize) -> Self {
const DEPTH: usize = 4;
let width = (reset_threshold * 2 / DEPTH).max(256).next_power_of_two();
let mut counters = Vec::with_capacity(DEPTH);
for _ in 0..DEPTH {
let mut row = Vec::with_capacity(width);
for _ in 0..width {
row.push(AtomicUsize::new(0));
}
counters.push(row);
}
let mut hashers = Vec::with_capacity(DEPTH);
for _ in 0..DEPTH {
hashers.push(ahash::RandomState::new());
}
Self {
counters,
hashers,
increments: AtomicUsize::new(0),
capacity: reset_threshold,
}
}
pub fn increment<K: Hash>(&self, key: &K) {
for i in 0..self.counters.len() {
let mut hasher = self.hashers[i].build_hasher();
key.hash(&mut hasher);
let index = hasher.finish() as usize % self.counters[i].len();
self.counters[i][index].fetch_add(1, Ordering::Relaxed);
}
let prev = self.increments.fetch_add(1, Ordering::Relaxed) + 1;
if prev >= self.capacity {
self.reset();
}
}
pub fn estimate<K: Hash>(&self, key: &K) -> usize {
let mut min_count = usize::MAX;
for i in 0..self.counters.len() {
let mut hasher = self.hashers[i].build_hasher();
key.hash(&mut hasher);
let index = hasher.finish() as usize % self.counters[i].len();
min_count = min_count.min(self.counters[i][index].load(Ordering::Relaxed));
}
min_count
}
fn reset(&self) {
self.increments.store(0, Ordering::Relaxed);
for row in &self.counters {
for counter in row {
let current_val = counter.load(Ordering::Relaxed);
counter.store(current_val / 2, Ordering::Relaxed);
}
}
}
pub fn clear(&self) {
self.increments.store(0, Ordering::Relaxed);
for row in &self.counters {
for counter in row {
counter.store(0, Ordering::Relaxed);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_item_goes_to_window() {
let policy: TinyLfuPolicy<i32> = TinyLfuPolicy::new(101);
let decision = <TinyLfuPolicy<i32> as CachePolicy<i32, ()>>::on_admit(&policy, &1, 1);
assert!(matches!(decision, AdmissionDecision::Admit));
assert!(
policy.window.lock().contains(&1),
"Item should be in window"
);
assert!(!policy.main.probationary.lock().contains(&1));
}
#[test]
fn test_window_overflow_causes_rejection() {
let policy: TinyLfuPolicy<i32> = TinyLfuPolicy::new(101);
policy.main.admit_internal(100, 1);
for _ in 0..5 {
policy.sketch.increment(&100);
}
<TinyLfuPolicy<i32> as CachePolicy<i32, ()>>::on_admit(&policy, &1, 1);
assert!(policy.window.lock().contains(&1));
let decision = <TinyLfuPolicy<i32> as CachePolicy<i32, ()>>::on_admit(&policy, &2, 1);
if let AdmissionDecision::AdmitAndEvict(victims) = decision {
assert_eq!(victims, vec![1], "Rejected candidate should be the victim");
} else {
panic!("Expected AdmitAndEvict, got {:?}", decision);
}
assert!(policy.window.lock().contains(&2));
assert!(!policy.window.lock().contains(&1));
assert!(!policy.main.probationary.lock().contains(&1));
assert!(!policy.main.protected.lock().contains(&1));
}
#[test]
fn test_window_overflow_causes_admission() {
let policy: TinyLfuPolicy<i32> = TinyLfuPolicy::new(101);
policy.main.admit_internal(100, 1);
policy.sketch.increment(&100);
<TinyLfuPolicy<i32> as CachePolicy<i32, ()>>::on_admit(&policy, &1, 1);
for _ in 0..5 {
policy.sketch.increment(&1);
}
<TinyLfuPolicy<i32> as CachePolicy<i32, ()>>::on_admit(&policy, &2, 1);
assert!(policy.main.probationary.lock().contains(&1));
}
#[test]
fn test_admission_logic_rejects_infrequent_item() {
let policy: TinyLfuPolicy<i32> = TinyLfuPolicy::new(101);
policy.main.admit_internal(100, 1);
for _ in 0..10 {
policy.sketch.increment(&100);
}
let _decision1 = <TinyLfuPolicy<i32> as CachePolicy<i32, ()>>::on_admit(&policy, &1, 1);
assert!(policy.window.lock().contains(&1));
let decision = <TinyLfuPolicy<i32> as CachePolicy<i32, ()>>::on_admit(&policy, &2, 1);
match decision {
AdmissionDecision::AdmitAndEvict(victims) => {
assert_eq!(
victims,
vec![1],
"The cold candidate (1) should have been rejected and returned as a victim."
);
}
other_decision => panic!(
"Expected AdmitAndEvict with victim [1], got {:?}",
other_decision
),
}
assert!(policy.window.lock().contains(&2));
assert!(!policy.window.lock().contains(&1));
assert_eq!(policy.window.lock().current_total_cost(), 1);
}
#[test]
fn test_replacement_of_existing_item() {
let policy: TinyLfuPolicy<i32> = TinyLfuPolicy::new(101);
<TinyLfuPolicy<i32> as CachePolicy<i32, ()>>::on_admit(&policy, &1, 1);
assert!(policy.window.lock().contains(&1));
assert_eq!(policy.sketch.estimate(&1), 1);
<TinyLfuPolicy<i32> as CachePolicy<i32, ()>>::on_admit(&policy, &2, 1); assert!(policy.main.probationary.lock().contains(&1));
assert!(!policy.window.lock().contains(&1));
let decision = <TinyLfuPolicy<i32> as CachePolicy<i32, ()>>::on_admit(&policy, &1, 5);
assert!(matches!(decision, AdmissionDecision::Admit)); assert_eq!(policy.sketch.estimate(&1), 2, "Frequency should increase");
assert!(!policy.main.probationary.lock().contains(&1));
assert!(policy.main.protected.lock().contains(&1));
let protected_list = policy.main.protected.lock();
let cost = protected_list
.lookup .get(&1)
.map(|&idx| protected_list.nodes[idx].cost) .unwrap();
assert_eq!(cost, 5, "Cost should be updated");
}
}