use super::basic_lru_cache::private::Cache;
use super::BasicLruCache;
use parking_lot::Mutex;
use std::collections::hash_map::DefaultHasher;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use std::ops::Deref;
const MIN_SHARD_CAPACITY: usize = 4;
const MAX_SHARDS: usize = 16;
pub struct ShardedLruCache<K, V>
where
K: Clone + Debug + Hash + Eq + Send + Sync + 'static,
V: Clone + Debug + Send + Sync + 'static,
{
shards: Vec<Mutex<BasicLruCache<K, V>>>,
total_capacity: usize,
num_shards: usize, }
impl<K, V> ShardedLruCache<K, V>
where
K: Clone + Debug + Hash + Eq + Send + Sync + 'static,
V: Clone + Debug + Send + Sync + 'static,
{
pub fn new(capacity: usize) -> Self {
assert!(capacity > 0, "Capacity must be positive");
let theoretical_shards = capacity / MIN_SHARD_CAPACITY;
let num_shards = if theoretical_shards >= MAX_SHARDS {
MAX_SHARDS
} else {
let mut n = 1;
while n * 2 <= theoretical_shards && n < MAX_SHARDS {
n *= 2;
}
n
};
let base_shard_capacity = capacity / num_shards;
let remaining_capacity = capacity % num_shards;
let mut shards = Vec::with_capacity(num_shards);
for i in 0..num_shards {
let shard_capacity = if i < remaining_capacity {
base_shard_capacity + 1
} else {
base_shard_capacity
};
shards.push(Mutex::new(BasicLruCache::new(shard_capacity)));
}
Self {
shards,
total_capacity: capacity,
num_shards,
}
}
pub fn capacity(&self) -> usize {
self.total_capacity
}
pub fn num_shards(&self) -> usize {
self.num_shards
}
fn get_shard_index(&self, key: &K) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash = hasher.finish();
(hash as usize) & (self.num_shards - 1)
}
pub fn get(&self, key: &K) -> Option<V> {
let shard_idx = self.get_shard_index(key);
let shard = self.shards[shard_idx].lock();
shard.deref().get(key)
}
pub fn put(&self, key: K, value: V) -> Option<V> {
let shard_idx = self.get_shard_index(&key);
let shard = self.shards[shard_idx].lock();
shard.deref().put(key, value)
}
pub fn remove(&self, key: &K) -> Option<V> {
let shard_idx = self.get_shard_index(key);
let shard = self.shards[shard_idx].lock();
shard.deref().remove(key)
}
pub fn len(&self) -> usize {
self.shards
.iter()
.map(|shard| shard.lock().deref().len())
.sum()
}
pub fn is_empty(&self) -> bool {
self.shards
.iter()
.all(|shard| shard.lock().deref().is_empty())
}
pub fn clear(&self) {
for shard in &self.shards {
let shard = shard.lock();
shard.deref().clear();
}
}
}
impl<K, V> Cache<K, V> for ShardedLruCache<K, V>
where
K: Clone + Debug + Hash + Eq + Send + Sync + 'static,
V: Clone + Debug + Send + Sync + 'static,
{
fn get(&self, key: &K) -> Option<V> {
self.get(key)
}
fn put(&self, key: K, value: V) -> Option<V> {
self.put(key, value)
}
fn remove(&self, key: &K) -> Option<V> {
self.remove(key)
}
fn len(&self) -> usize {
self.len()
}
fn is_empty(&self) -> bool {
self.is_empty()
}
fn clear(&self) {
self.clear()
}
}
#[cfg(test)]
mod tests {
use super::*;
use parking_lot::Mutex as PLMutex;
use std::collections::HashSet;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[test]
fn test_basic_operations() {
let cache = ShardedLruCache::new(2);
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
assert_eq!(cache.put("key1".to_string(), "one".to_string()), None);
assert_eq!(cache.put("key2".to_string(), "two".to_string()), None);
assert!(!cache.is_empty());
assert_eq!(cache.len(), 2);
assert_eq!(cache.get(&"key1".to_string()), Some("one".to_string()));
assert_eq!(cache.get(&"key2".to_string()), Some("two".to_string()));
cache.put("key3".to_string(), "three".to_string());
println!("cache.len(): {}", cache.len());
println!("cache.capacity(): {}", cache.capacity());
assert!(cache.len() <= cache.capacity());
}
#[test]
fn test_concurrent_access() {
let cache = Arc::new(ShardedLruCache::new(1000));
let mut handles = vec![];
for i in 0..10 {
let cache = Arc::clone(&cache);
let handle = thread::spawn(move || {
for j in 0..100 {
let key = format!("key_{}_{}", i, j);
cache.put(key.clone(), format!("value_{}", j));
thread::sleep(Duration::from_micros(1));
if let Some(value) = cache.get(&key) {
assert_eq!(value, format!("value_{}", j));
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert!(cache.len() <= cache.capacity());
}
#[test]
fn test_concurrent_mixed_operations() {
let cache = Arc::new(ShardedLruCache::new(2000));
let mut handles = vec![];
let operation_count = 1000;
for i in 0..4 {
let cache = Arc::clone(&cache);
let handle = thread::spawn(move || {
for j in 0..operation_count {
let key = format!("key_{}", j % 100);
cache.put(key.clone(), format!("writer_{}_value_{}", i, j));
}
});
handles.push(handle);
}
for _i in 0..4 {
let cache = Arc::clone(&cache);
let handle = thread::spawn(move || {
for j in 0..operation_count {
let key = format!("key_{}", j % 100);
if let Some(value) = cache.get(&key) {
assert!(
value.starts_with("writer_") || value.starts_with("mixed_"),
"Invalid value format: {}",
value
);
}
}
});
handles.push(handle);
}
for i in 0..4 {
let cache = Arc::clone(&cache);
let handle = thread::spawn(move || {
for j in 0..operation_count {
let key = format!("key_{}", j % 100);
if j % 2 == 0 {
cache.put(key.clone(), format!("mixed_{}_value_{}", i, j));
} else {
let _ = cache.get(&key);
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert!(cache.len() <= cache.capacity());
}
#[test]
fn test_concurrent_capacity_correctness() {
let capacity = 100;
let cache = Arc::new(ShardedLruCache::new(capacity));
let threads_count = 8;
let operations_per_thread = 1000;
let total_ops_counter = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for i in 0..threads_count {
let cache = Arc::clone(&cache);
let ops_counter = Arc::clone(&total_ops_counter);
let handle = thread::spawn(move || {
for j in 0..operations_per_thread {
let key = format!("key_{}_{}", i, j);
cache.put(key, j);
ops_counter.fetch_add(1, Ordering::SeqCst);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert!(
cache.len() <= capacity,
"Cache size {} exceeded capacity {}",
cache.len(),
capacity
);
assert_eq!(
total_ops_counter.load(Ordering::SeqCst),
threads_count * operations_per_thread
);
}
#[test]
fn test_concurrent_remove_correctness() {
let cache = Arc::new(ShardedLruCache::new(1000));
let removed_values = Arc::new(PLMutex::new(HashSet::new()));
let mut handles = vec![];
for i in 0..100 {
cache.put(format!("key_{}", i), i);
}
for i in 0..4 {
let cache = Arc::clone(&cache);
let removed = Arc::clone(&removed_values);
let handle = thread::spawn(move || {
for j in 0..25 {
let key = format!("key_{}", i * 25 + j);
if let Some(value) = cache.remove(&key) {
removed.lock().insert(value);
}
}
});
handles.push(handle);
}
for _ in 0..4 {
let cache = Arc::clone(&cache);
let handle = thread::spawn(move || {
for i in 0..100 {
let key = format!("key_{}", i);
if let Some(value) = cache.get(&key) {
assert!(value >= 0 && value < 100);
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let removed_count = removed_values.lock().len();
assert!(removed_count > 0, "Should have values removed");
assert!(
removed_count <= 100,
"Removed values should not exceed initial count"
);
}
#[test]
fn test_concurrent_clear_correctness() {
let cache = Arc::new(ShardedLruCache::new(1000));
let mut handles = vec![];
for i in 0..500 {
cache.put(format!("init_key_{}", i), i);
}
let cache_clone = Arc::clone(&cache);
handles.push(thread::spawn(move || {
thread::sleep(Duration::from_millis(50));
cache_clone.clear();
}));
for i in 0..4 {
let cache = Arc::clone(&cache);
let handle = thread::spawn(move || {
for j in 0..100 {
let key = format!("key_{}_{}", i, j);
cache.put(key.clone(), j);
thread::sleep(Duration::from_micros(10));
let _ = cache.get(&key);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert!(cache.len() <= cache.capacity());
}
#[test]
fn test_concurrent_shard_distribution() {
let cache = Arc::new(ShardedLruCache::new(1000));
let shard_counts = (0..cache.num_shards())
.map(|_| Arc::new(AtomicUsize::new(0)))
.collect::<Vec<_>>();
let mut handles = vec![];
for i in 0..8 {
let cache = Arc::clone(&cache);
let shard_counts = shard_counts.clone();
let handle = thread::spawn(move || {
for j in 0..100 {
let key = format!("key_{}_{}", i, j);
let shard_idx = cache.get_shard_index(&key);
shard_counts[shard_idx].fetch_add(1, Ordering::SeqCst);
cache.put(key.clone(), j);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let total_ops: usize = shard_counts
.iter()
.map(|counter| counter.load(Ordering::SeqCst))
.sum();
assert_eq!(total_ops, 800);
let unused_shards = shard_counts
.iter()
.filter(|counter| counter.load(Ordering::SeqCst) == 0)
.count();
assert_eq!(unused_shards, 0, "All shards should be used");
}
}