use crate::runtime::cache::lru::AccessTracker;
use rustc_hash::FxHashMap;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub struct CacheEntry {
pub key: u64,
pub size: u64,
pub tier: usize,
}
#[non_exhaustive]
pub struct CacheTier {
pub name: String,
pub capacity: u64,
pub used: u64,
pub(crate) entries: FxHashMap<u64, CacheEntry>,
}
impl CacheTier {
#[inline]
pub fn new(name: impl Into<String>, capacity: u64) -> Self {
Self {
name: name.into(),
capacity,
used: 0,
entries: FxHashMap::default(),
}
}
}
#[non_exhaustive]
pub struct AccessStats {
pub frequency: u32,
pub recency_rank: usize,
pub size: u64,
}
pub trait TierPolicy: Send + Sync {
fn should_promote(&self, key: u64, stats: &AccessStats) -> bool;
fn eviction_candidate(
&self,
tier: usize,
entries: &FxHashMap<u64, CacheEntry>,
tracker: &AccessTracker,
) -> Option<u64>;
}
#[non_exhaustive]
pub struct LruPolicy {
pub promote_threshold: u32,
}
impl LruPolicy {
pub const DEFAULT_THRESHOLD: u32 = 3;
#[inline]
pub fn new(promote_threshold: u32) -> Self {
Self { promote_threshold }
}
}
impl Default for LruPolicy {
fn default() -> Self {
Self::new(Self::DEFAULT_THRESHOLD)
}
}
impl TierPolicy for LruPolicy {
fn should_promote(&self, _key: u64, stats: &AccessStats) -> bool {
stats.frequency >= self.promote_threshold
}
fn eviction_candidate(
&self,
_tier: usize,
entries: &FxHashMap<u64, CacheEntry>,
tracker: &AccessTracker,
) -> Option<u64> {
for (key, _meta) in tracker.iter_coldest() {
if entries.contains_key(key) {
return Some(*key);
}
}
entries.keys().next().copied()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum CacheError {
KeyNotFound,
EntryTooLarge,
}
impl std::fmt::Display for CacheError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::KeyNotFound => write!(
f,
"Key not found in cache. Fix: verify the key was inserted before operating on it."
),
Self::EntryTooLarge => write!(
f,
"Entry size exceeds the capacity of the largest tier. Fix: reduce the buffer size or increase the tier capacity."
),
}
}
}
impl std::error::Error for CacheError {}
#[non_exhaustive]
pub struct TieredCache<P: TierPolicy = LruPolicy> {
pub(crate) tiers: Vec<CacheTier>,
pub(crate) tracker: AccessTracker,
pub(crate) policy: P,
}
impl TieredCache<LruPolicy> {
#[inline]
pub fn new(tiers: Vec<CacheTier>) -> Self {
Self::with_policy(tiers, LruPolicy::default())
}
}
impl<P: TierPolicy> TieredCache<P> {
#[inline]
pub fn with_policy(tiers: Vec<CacheTier>, policy: P) -> Self {
Self {
tiers,
tracker: AccessTracker::new(),
policy,
}
}
#[inline]
pub fn get(&self, key: u64) -> Option<&CacheEntry> {
self.tiers.iter().find_map(|tier| tier.entries.get(&key))
}
#[inline]
pub fn insert(&mut self, key: u64, size: u64) -> Result<(), CacheError> {
if self.get(key).is_some() {
self.evict(key);
}
self.tracker.set_size(key, size);
self.insert_into_tier(key, size, 0)
}
#[inline]
pub fn record_access(&mut self, key: u64) {
if self.get(key).is_some() {
self.tracker.record(key);
}
}
#[inline]
pub fn promote(&mut self, key: u64) -> Result<(), CacheError> {
let entry = self.get(key).copied().ok_or(CacheError::KeyNotFound)?;
let stats = self.tracker.stats(key).ok_or(CacheError::KeyNotFound)?;
if !self.policy.should_promote(key, &stats) {
return Ok(());
}
let target = entry.tier.saturating_add(1);
if target >= self.tiers.len() {
return Ok(());
}
let size = entry.size;
self.remove_entry(key);
self.move_into_tier(key, size, target, entry.tier)
}
#[inline]
pub fn demote(&mut self, key: u64) -> Result<(), CacheError> {
let entry = self.get(key).copied().ok_or(CacheError::KeyNotFound)?;
if entry.tier == 0 {
return Ok(());
}
let target = entry.tier - 1;
let size = entry.size;
self.remove_entry(key);
self.move_into_tier(key, size, target, entry.tier)
}
fn insert_into_tier(
&mut self,
key: u64,
size: u64,
mut start: usize,
) -> Result<(), CacheError> {
while start < self.tiers.len() {
if size > self.tiers[start].capacity {
start += 1;
continue;
}
if self.make_room(start, size) {
self.tiers[start].used = self.tiers[start].used.saturating_add(size);
self.tiers[start].entries.insert(
key,
CacheEntry {
key,
size,
tier: start,
},
);
return Ok(());
}
start += 1;
}
Err(CacheError::EntryTooLarge)
}
fn move_into_tier(
&mut self,
key: u64,
size: u64,
target: usize,
fallback: usize,
) -> Result<(), CacheError> {
if self.make_room(target, size) {
self.tiers[target].used = self.tiers[target].used.saturating_add(size);
self.tiers[target].entries.insert(
key,
CacheEntry {
key,
size,
tier: target,
},
);
Ok(())
} else {
self.insert_into_tier(key, size, fallback)
}
}
fn make_room(&mut self, tier: usize, size: u64) -> bool {
loop {
let used = self.tiers[tier].used;
let cap = self.tiers[tier].capacity;
if used.saturating_add(size) <= cap {
return true;
}
let candidate = {
let entries = &self.tiers[tier].entries;
self.policy.eviction_candidate(tier, entries, &self.tracker)
};
if let Some(key) = candidate {
self.evict_from_tier(key, tier);
} else {
return false;
}
}
}
fn remove_entry(&mut self, key: u64) -> Option<CacheEntry> {
for tier in &mut self.tiers {
if let Some(entry) = tier.entries.remove(&key) {
tier.used = tier.used.saturating_sub(entry.size);
return Some(entry);
}
}
None
}
fn evict(&mut self, key: u64) -> Option<CacheEntry> {
for tier in &mut self.tiers {
if let Some(entry) = tier.entries.remove(&key) {
tier.used = tier.used.saturating_sub(entry.size);
self.tracker.remove(key);
return Some(entry);
}
}
None
}
fn evict_from_tier(&mut self, key: u64, tier: usize) -> Option<CacheEntry> {
let tier = &mut self.tiers[tier];
let entry = tier.entries.remove(&key)?;
tier.used = tier.used.saturating_sub(entry.size);
self.tracker.remove(key);
Some(entry)
}
}