use alloc::collections::VecDeque;
use alloc::vec::Vec;
use core::hash::{BuildHasher, Hash};
use core::sync::atomic::{AtomicU8, Ordering};
#[cfg(feature = "std")]
type RwLock<T> = parking_lot::RwLock<T>;
#[cfg(not(feature = "std"))]
type RwLock<T> = spin::RwLock<T>;
const MAX_FREQ: u8 = 3;
pub trait Weighter<K, V> {
fn weight(&self, key: &K, value: &V) -> u64;
}
#[derive(Clone, Copy, Default)]
pub struct UnitWeighter;
impl<K, V> Weighter<K, V> for UnitWeighter {
#[inline]
fn weight(&self, _: &K, _: &V) -> u64 {
1
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum Location {
Small,
Main,
}
struct Slot<V> {
value: V,
weight: u64,
freq: AtomicU8,
loc: Location,
}
#[repr(align(64))]
struct Padded<T>(T);
struct ShardCore<K, V, S> {
map: hashbrown::HashTable<(K, Slot<V>)>,
hasher: S,
small: VecDeque<K>,
main: VecDeque<K>,
ghost: VecDeque<K>,
ghost_set: hashbrown::HashSet<K, S>,
small_bytes: u64,
main_bytes: u64,
capacity: u64,
small_target: u64,
ghost_capacity: usize,
}
impl<K, V, S> ShardCore<K, V, S>
where
K: Eq + Hash + Clone,
V: Clone,
S: BuildHasher + Clone,
{
fn new(capacity: u64, ghost_capacity: usize, hasher: S) -> Self {
Self {
map: hashbrown::HashTable::new(),
hasher: hasher.clone(),
small: VecDeque::new(),
main: VecDeque::new(),
ghost: VecDeque::new(),
ghost_set: hashbrown::HashSet::with_hasher(hasher),
small_bytes: 0,
main_bytes: 0,
capacity,
small_target: (capacity / 10).max(1),
ghost_capacity,
}
}
fn get(&self, hash: u64, key: &K) -> Option<V> {
let (_, slot) = self.map.find(hash, |(k, _)| k == key)?;
let f = slot.freq.load(Ordering::Relaxed);
if f < MAX_FREQ {
slot.freq.store(f + 1, Ordering::Relaxed);
}
Some(slot.value.clone())
}
#[cfg(any(feature = "zstd", test))]
fn peek(&self, hash: u64, key: &K) -> Option<V> {
self.map
.find(hash, |(k, _)| k == key)
.map(|(_, slot)| slot.value.clone())
}
fn len(&self) -> usize {
self.map.len()
}
fn insert(&mut self, hash: u64, key: K, value: V, weight: u64) {
if let Some((_, slot)) = self.map.find_mut(hash, |(k, _)| *k == key) {
let old = slot.weight;
slot.value = value;
slot.weight = weight;
match slot.loc {
Location::Small => self.small_bytes = adjust(self.small_bytes, old, weight),
Location::Main => self.main_bytes = adjust(self.main_bytes, old, weight),
}
} else {
let loc = if self.ghost_set.remove(&key) {
Location::Main
} else {
Location::Small
};
let hasher = &self.hasher;
self.map.insert_unique(
hash,
(
key.clone(),
Slot {
value,
weight,
freq: AtomicU8::new(0),
loc,
},
),
|(k, _)| hasher.hash_one(k),
);
match loc {
Location::Small => {
self.small_bytes += weight;
self.small.push_back(key);
}
Location::Main => {
self.main_bytes += weight;
self.main.push_back(key);
}
}
}
self.evict_to_capacity();
}
fn remove(&mut self, hash: u64, key: &K) {
let Ok(entry) = self.map.find_entry(hash, |(k, _)| k == key) else {
return;
};
let (_, slot) = entry.remove().0;
match slot.loc {
Location::Small => self.small_bytes -= slot.weight,
Location::Main => self.main_bytes -= slot.weight,
}
}
#[inline]
fn resident_bytes(&self) -> u64 {
self.small_bytes + self.main_bytes
}
fn evict_to_capacity(&mut self) {
while self.resident_bytes() > self.capacity {
if !self.evict_one() {
break;
}
}
}
fn evict_one(&mut self) -> bool {
let prefer_small = self.main_bytes == 0 || self.small_bytes >= self.small_target;
if prefer_small {
self.evict_from_small() || self.evict_from_main()
} else {
self.evict_from_main() || self.evict_from_small()
}
}
fn evict_from_small(&mut self) -> bool {
while let Some(key) = self.small.pop_front() {
let hash = self.hasher.hash_one(&key);
let Some((_, slot)) = self.map.find_mut(hash, |(k, _)| *k == key) else {
continue; };
let w = slot.weight;
if slot.freq.load(Ordering::Relaxed) > 0 {
slot.freq.store(0, Ordering::Relaxed);
slot.loc = Location::Main;
self.small_bytes -= w;
self.main_bytes += w;
self.main.push_back(key);
} else {
if let Ok(entry) = self.map.find_entry(hash, |(k, _)| *k == key) {
entry.remove();
}
self.small_bytes -= w;
self.push_ghost(key);
}
return true;
}
false
}
fn evict_from_main(&mut self) -> bool {
while let Some(key) = self.main.pop_front() {
let hash = self.hasher.hash_one(&key);
let Some((_, slot)) = self.map.find_mut(hash, |(k, _)| *k == key) else {
continue; };
let f = slot.freq.load(Ordering::Relaxed);
if f > 0 {
slot.freq.store(f - 1, Ordering::Relaxed);
self.main.push_back(key);
} else {
let w = slot.weight;
if let Ok(entry) = self.map.find_entry(hash, |(k, _)| *k == key) {
entry.remove();
}
self.main_bytes -= w;
}
return true;
}
false
}
fn push_ghost(&mut self, key: K) {
if self.ghost_capacity == 0 {
return;
}
if self.ghost_set.insert(key.clone()) {
self.ghost.push_back(key);
}
while self.ghost.len() > self.ghost_capacity {
if let Some(old) = self.ghost.pop_front() {
self.ghost_set.remove(&old);
}
}
}
}
#[inline]
fn adjust(total: u64, old: u64, new: u64) -> u64 {
total - old + new
}
pub struct ShardedCache<K, V, W, S> {
shards: Vec<Padded<RwLock<ShardCore<K, V, S>>>>,
shard_mask: u64,
weighter: W,
hasher: S,
capacity: u64,
}
impl<K, V, W, S> ShardedCache<K, V, W, S>
where
K: Eq + Hash + Clone,
V: Clone,
W: Weighter<K, V>,
S: BuildHasher + Clone,
{
pub fn with_weighter(
capacity: u64,
shard_count_hint: usize,
est_items: usize,
weighter: W,
hasher: S,
) -> Self {
let shard_count = shard_count_hint.next_power_of_two().clamp(1, 256);
let per_shard_cap = capacity.div_ceil(shard_count as u64);
let ghost_capacity = (est_items / shard_count).max(16);
let mut shards = Vec::with_capacity(shard_count);
for _ in 0..shard_count {
shards.push(Padded(RwLock::new(ShardCore::new(
per_shard_cap,
ghost_capacity,
hasher.clone(),
))));
}
Self {
shards,
shard_mask: shard_count as u64 - 1,
weighter,
hasher,
capacity,
}
}
#[inline]
fn locate(&self, key: &K) -> (u64, &RwLock<ShardCore<K, V, S>>) {
let h = self.hasher.hash_one(key);
let masked = ((h >> 32) ^ h) & self.shard_mask;
#[expect(
clippy::cast_possible_truncation,
reason = "masked index <= 255 fits usize on all targets"
)]
let idx = masked as usize;
#[expect(
clippy::indexing_slicing,
reason = "idx = hash & (len - 1) with len a power of two, so idx < len"
)]
(h, &self.shards[idx].0)
}
pub fn get(&self, key: &K) -> Option<V> {
let (h, shard) = self.locate(key);
shard.read().get(h, key)
}
#[cfg(any(feature = "zstd", test))]
pub fn peek(&self, key: &K) -> Option<V> {
let (h, shard) = self.locate(key);
shard.read().peek(h, key)
}
pub fn insert(&self, key: K, value: V) {
let weight = self.weighter.weight(&key, &value);
let (h, shard) = self.locate(&key);
shard.write().insert(h, key, value, weight);
}
pub fn remove(&self, key: &K) {
let (h, shard) = self.locate(key);
shard.write().remove(h, key);
}
pub fn weight(&self) -> u64 {
self.shards
.iter()
.map(|s| s.0.read().resident_bytes())
.sum()
}
pub fn capacity(&self) -> u64 {
self.capacity
}
pub fn len(&self) -> usize {
self.shards.iter().map(|s| s.0.read().len()).sum()
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::indexing_slicing,
clippy::expect_used,
reason = "test code"
)]
mod tests {
use super::*;
use rustc_hash::FxBuildHasher;
#[derive(Clone, Copy)]
struct LenWeighter;
impl Weighter<u64, alloc::vec::Vec<u8>> for LenWeighter {
fn weight(&self, _: &u64, v: &alloc::vec::Vec<u8>) -> u64 {
v.len() as u64
}
}
fn byte_cache(
capacity: u64,
) -> ShardedCache<u64, alloc::vec::Vec<u8>, LenWeighter, FxBuildHasher> {
ShardedCache::with_weighter(capacity, 8, 1024, LenWeighter, FxBuildHasher)
}
#[test]
fn insert_get_roundtrip() {
let c = byte_cache(10_000);
c.insert(1, vec![0u8; 100]);
assert_eq!(c.get(&1), Some(vec![0u8; 100]));
assert_eq!(c.get(&2), None);
}
#[test]
fn peek_does_not_promote_but_get_does() {
let c = byte_cache(10_000);
c.insert(7, vec![1u8; 50]);
assert_eq!(c.peek(&7), Some(vec![1u8; 50]));
assert_eq!(c.peek(&999), None);
}
#[test]
fn weight_tracks_resident_bytes() {
let c = byte_cache(10_000);
assert_eq!(c.weight(), 0);
c.insert(1, vec![0u8; 100]);
c.insert(2, vec![0u8; 200]);
assert_eq!(c.weight(), 300);
c.remove(&1);
assert_eq!(c.weight(), 200);
c.remove(&999); assert_eq!(c.weight(), 200);
}
#[test]
fn replace_adjusts_weight_in_place() {
let c = byte_cache(10_000);
c.insert(1, vec![0u8; 100]);
c.insert(1, vec![0u8; 250]); assert_eq!(c.weight(), 250);
assert_eq!(c.get(&1), Some(vec![0u8; 250]));
c.insert(1, vec![0u8; 30]); assert_eq!(c.weight(), 30);
}
#[test]
fn eviction_keeps_resident_under_capacity() {
let c = byte_cache(1_000);
for i in 0..1_000u64 {
c.insert(i, vec![0u8; 100]);
}
assert!(
c.weight() <= 1_000 + 100,
"resident weight {} exceeded capacity",
c.weight(),
);
}
#[test]
fn frequently_read_entries_survive_eviction_pressure() {
let c = byte_cache(2_000);
c.insert(0, vec![0u8; 100]);
for _ in 0..8 {
assert_eq!(c.get(&0), Some(vec![0u8; 100]));
}
for i in 1..200u64 {
c.insert(i, vec![0u8; 100]);
let _ = c.get(&0); }
assert_eq!(c.get(&0), Some(vec![0u8; 100]), "hot key was evicted");
}
#[test]
fn unit_weighter_is_count_capacity() {
let c: ShardedCache<u64, u64, UnitWeighter, FxBuildHasher> =
ShardedCache::with_weighter(4, 2, 64, UnitWeighter, FxBuildHasher);
for i in 0..100u64 {
c.insert(i, i);
}
assert!(c.weight() <= 4 + 2, "entry count {} exceeded", c.weight());
}
#[test]
fn oversized_entry_does_not_wedge_the_cache() {
let c = byte_cache(1_000);
c.insert(1, vec![0u8; 5_000]); c.insert(2, vec![0u8; 100]);
assert_eq!(c.get(&2), Some(vec![0u8; 100]));
assert!(c.weight() <= 1_100);
}
#[cfg(feature = "std")]
#[test]
fn concurrent_stress_keeps_invariants() {
use std::sync::Arc;
use std::thread;
let cache = Arc::new(byte_cache(50_000));
let threads: Vec<_> = (0..8)
.map(|t| {
let cache = Arc::clone(&cache);
thread::spawn(move || {
for i in 0..5_000u64 {
let key = (t * 5_000 + i) % 2_000; match i % 4 {
0 => cache.insert(key, vec![0u8; 64]),
1 => {
let _ = cache.get(&key);
}
2 => {
let _ = cache.peek(&key);
}
_ => cache.remove(&key),
}
}
})
})
.collect();
for h in threads {
h.join().expect("worker thread panicked");
}
assert!(
cache.weight() <= 50_000 + 64,
"weight {} exceeded capacity after concurrent churn",
cache.weight(),
);
assert_eq!(
cache.weight(),
cache.len() as u64 * 64,
"atomic weight diverged from resident entries",
);
}
}