use crate::lru_cache::LruCache;
use std::hash::{Hash, Hasher};
use std::sync::Mutex;
const FNV_OFFSET: u64 = 0xcbf29ce484222325u64;
const FNV_PRIME: u64 = 0x00000100000001b3u64;
struct Fnv1aHasher(u64);
impl Fnv1aHasher {
fn new() -> Self {
Self(FNV_OFFSET)
}
fn finish(&self) -> u64 {
self.0
}
}
impl Hasher for Fnv1aHasher {
fn write(&mut self, bytes: &[u8]) {
for &b in bytes {
self.0 ^= u64::from(b);
self.0 = self.0.wrapping_mul(FNV_PRIME);
}
}
fn finish(&self) -> u64 {
Fnv1aHasher::finish(self)
}
}
fn fnv1a_hash<K: Hash>(key: &K) -> u64 {
let mut hasher = Fnv1aHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
pub struct ShardedLruCache<K, V>
where
K: Eq + Hash + Clone + Send + 'static,
V: Clone + Send + 'static,
{
shards: Vec<Mutex<LruCache<K, V>>>,
num_shards: usize,
capacity: usize,
}
impl<K, V> ShardedLruCache<K, V>
where
K: Eq + Hash + Clone + Send + 'static,
V: Clone + Send + 'static,
{
pub fn new(num_shards: usize, capacity: usize) -> Self {
let num_shards = num_shards.clamp(1, capacity.max(1));
let base = capacity / num_shards;
let remainder = capacity % num_shards;
let shards = (0..num_shards)
.map(|i| {
let shard_cap = if i < remainder { base + 1 } else { base };
Mutex::new(LruCache::new(shard_cap.max(1)))
})
.collect();
Self {
shards,
num_shards,
capacity,
}
}
fn shard_index(&self, key: &K) -> usize {
(fnv1a_hash(key) % self.num_shards as u64) as usize
}
pub fn get(&self, key: &K) -> Option<V> {
let idx = self.shard_index(key);
self.shards[idx]
.lock()
.ok()
.and_then(|mut shard| shard.get(key).cloned())
}
pub fn put(&self, key: K, value: V, size_bytes: usize) {
let idx = self.shard_index(&key);
if let Ok(mut shard) = self.shards[idx].lock() {
shard.insert(key, value, size_bytes);
}
}
pub fn contains(&self, key: &K) -> bool {
let idx = self.shard_index(key);
self.shards[idx]
.lock()
.map(|shard| shard.contains(key))
.unwrap_or(false)
}
pub fn remove(&self, key: &K) -> Option<V> {
let idx = self.shard_index(key);
self.shards[idx]
.lock()
.ok()
.and_then(|mut shard| shard.remove(key))
}
pub fn len(&self) -> usize {
self.shards
.iter()
.map(|s| s.lock().map(|shard| shard.len()).unwrap_or(0))
.sum()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn num_shards(&self) -> usize {
self.num_shards
}
pub fn shard_capacity(&self, idx: usize) -> usize {
self.shards
.get(idx)
.and_then(|s| s.lock().ok())
.map(|shard| shard.capacity())
.unwrap_or(0)
}
pub fn shard_lengths(&self) -> Vec<usize> {
self.shards
.iter()
.map(|s| s.lock().map(|shard| shard.len()).unwrap_or(0))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use std::sync::Arc;
use std::thread;
#[test]
fn test_put_and_get() {
let cache: ShardedLruCache<String, i32> = ShardedLruCache::new(4, 100);
cache.put("key_a".to_string(), 42, 8);
assert_eq!(cache.get(&"key_a".to_string()), Some(42));
}
#[test]
fn test_get_absent() {
let cache: ShardedLruCache<String, i32> = ShardedLruCache::new(4, 100);
assert_eq!(cache.get(&"missing".to_string()), None);
}
#[test]
fn test_contains() {
let cache: ShardedLruCache<u32, u32> = ShardedLruCache::new(8, 200);
cache.put(99, 999, 4);
assert!(cache.contains(&99));
assert!(!cache.contains(&100));
}
#[test]
fn test_len() {
let cache: ShardedLruCache<u32, u32> = ShardedLruCache::new(4, 100);
assert_eq!(cache.len(), 0);
cache.put(1, 10, 4);
cache.put(2, 20, 4);
cache.put(3, 30, 4);
assert_eq!(cache.len(), 3);
}
#[test]
fn test_is_empty() {
let cache: ShardedLruCache<i32, i32> = ShardedLruCache::new(4, 50);
assert!(cache.is_empty());
cache.put(1, 1, 1);
assert!(!cache.is_empty());
}
#[test]
fn test_capacity() {
let cache: ShardedLruCache<i32, i32> = ShardedLruCache::new(4, 128);
assert_eq!(cache.capacity(), 128);
}
#[test]
fn test_num_shards() {
let cache: ShardedLruCache<i32, i32> = ShardedLruCache::new(8, 1000);
assert_eq!(cache.num_shards(), 8);
}
#[test]
fn test_remove() {
let cache: ShardedLruCache<String, String> = ShardedLruCache::new(4, 50);
cache.put("k".to_string(), "v".to_string(), 2);
let removed = cache.remove(&"k".to_string());
assert_eq!(removed, Some("v".to_string()));
assert!(!cache.contains(&"k".to_string()));
assert_eq!(cache.len(), 0);
}
#[test]
fn test_distribution_across_shards() {
let cache: ShardedLruCache<u32, u32> = ShardedLruCache::new(4, 400);
for i in 0u32..200 {
cache.put(i, i, 4);
}
let lengths = cache.shard_lengths();
assert_eq!(lengths.len(), 4);
let non_empty = lengths.iter().filter(|&&l| l > 0).count();
assert!(
non_empty >= 2,
"entries should spread across at least 2 shards"
);
}
#[test]
fn test_lru_eviction_within_shard() {
let cache: ShardedLruCache<u32, u32> = ShardedLruCache::new(1, 3);
cache.put(1, 10, 1);
cache.put(2, 20, 1);
cache.put(3, 30, 1);
cache.get(&1);
cache.put(4, 40, 1);
assert!(!cache.contains(&2), "key 2 should be evicted");
assert!(cache.contains(&1));
assert!(cache.contains(&3));
assert!(cache.contains(&4));
}
#[test]
fn test_fill_to_capacity() {
let cap = 50usize;
let shards = 4;
let cache: ShardedLruCache<usize, usize> = ShardedLruCache::new(shards, cap);
for i in 0..200 {
cache.put(i, i, 1);
}
assert!(
cache.len() <= cap,
"total len {} must not exceed capacity {}",
cache.len(),
cap
);
}
#[test]
fn test_concurrent_reads() {
let cache = Arc::new(ShardedLruCache::<u32, u32>::new(8, 1000));
for i in 0u32..100 {
cache.put(i, i * 2, 4);
}
let threads: Vec<_> = (0..8)
.map(|t| {
let c = Arc::clone(&cache);
thread::spawn(move || {
for i in 0u32..100 {
let v = c.get(&i);
if let Some(val) = v {
assert_eq!(val, i * 2, "thread {t}: key {i} has wrong value");
}
}
})
})
.collect();
for t in threads {
t.join().expect("thread panicked");
}
}
#[test]
fn test_concurrent_writes() {
let cache = Arc::new(ShardedLruCache::<u32, u32>::new(8, 500));
let threads: Vec<_> = (0u32..8)
.map(|t| {
let c = Arc::clone(&cache);
thread::spawn(move || {
for i in 0u32..100 {
c.put(t * 100 + i, t * 1000 + i, 4);
}
})
})
.collect();
for t in threads {
t.join().expect("thread panicked");
}
assert!(
cache.len() <= 500,
"total len {} must not exceed 500",
cache.len()
);
}
#[test]
fn test_concurrent_mixed_rw() {
let cache = Arc::new(ShardedLruCache::<u32, u32>::new(4, 200));
for i in 0u32..50 {
cache.put(i, i, 1);
}
let threads: Vec<_> = (0u32..4)
.map(|t| {
let c = Arc::clone(&cache);
thread::spawn(move || {
for i in 0u32..50 {
c.get(&i);
c.put(1000 + t * 100 + i, t + i, 1);
}
})
})
.collect();
for t in threads {
t.join().expect("thread panicked");
}
}
#[test]
fn test_shard_capacity() {
let total = 10usize;
let shards = 3usize;
let cache: ShardedLruCache<i32, i32> = ShardedLruCache::new(shards, total);
let c0 = cache.shard_capacity(0);
let c1 = cache.shard_capacity(1);
let c2 = cache.shard_capacity(2);
assert_eq!(c0 + c1 + c2, total);
assert!(c0 >= 1);
assert!(c1 >= 1);
assert!(c2 >= 1);
}
#[test]
fn test_all_keys_retrievable_within_capacity() {
let cap = 20usize;
let cache: ShardedLruCache<u32, u32> = ShardedLruCache::new(4, cap);
let keys: Vec<u32> = (0..cap as u32).collect();
for &k in &keys {
cache.put(k, k * 10, 1);
}
let found: HashSet<u32> = keys
.iter()
.filter(|&&k| cache.contains(&k))
.copied()
.collect();
assert_eq!(found.len(), cap, "all {cap} keys should be in cache");
}
}