use crate::SieveCache;
use std::borrow::Borrow;
use std::collections::hash_map::DefaultHasher;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
const DEFAULT_SHARDS: usize = 16;
#[derive(Clone)]
pub struct ShardedSieveCache<K, V>
where
K: Eq + Hash + Clone + Send + Sync,
V: Send + Sync,
{
shards: Vec<Arc<Mutex<SieveCache<K, V>>>>,
num_shards: usize,
}
impl<K, V> Default for ShardedSieveCache<K, V>
where
K: Eq + Hash + Clone + Send + Sync,
V: Send + Sync,
{
fn default() -> Self {
Self::new(100).expect("Failed to create cache with default capacity")
}
}
impl<K, V> fmt::Debug for ShardedSieveCache<K, V>
where
K: Eq + Hash + Clone + Send + Sync + fmt::Debug,
V: Send + Sync + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ShardedSieveCache")
.field("capacity", &self.capacity())
.field("len", &self.len())
.field("num_shards", &self.num_shards)
.finish()
}
}
impl<K, V> IntoIterator for ShardedSieveCache<K, V>
where
K: Eq + Hash + Clone + Send + Sync,
V: Clone + Send + Sync,
{
type Item = (K, V);
type IntoIter = std::vec::IntoIter<(K, V)>;
fn into_iter(self) -> Self::IntoIter {
self.entries().into_iter()
}
}
#[cfg(feature = "sync")]
impl<K, V> From<crate::SyncSieveCache<K, V>> for ShardedSieveCache<K, V>
where
K: Eq + Hash + Clone + Send + Sync,
V: Clone + Send + Sync,
{
fn from(sync_cache: crate::SyncSieveCache<K, V>) -> Self {
let capacity = sync_cache.capacity();
let sharded = Self::new(capacity).expect("Failed to create sharded cache");
for (key, value) in sync_cache.entries() {
sharded.insert(key, value);
}
sharded
}
}
impl<K, V> ShardedSieveCache<K, V>
where
K: Eq + Hash + Clone + Send + Sync,
V: Send + Sync,
{
pub fn new(capacity: usize) -> Result<Self, &'static str> {
Self::with_shards(capacity, DEFAULT_SHARDS)
}
pub fn with_shards(capacity: usize, num_shards: usize) -> Result<Self, &'static str> {
if capacity == 0 {
return Err("capacity must be greater than 0");
}
if num_shards == 0 {
return Err("number of shards must be greater than 0");
}
let base_capacity_per_shard = capacity / num_shards;
let remaining = capacity % num_shards;
let mut shards = Vec::with_capacity(num_shards);
for i in 0..num_shards {
let shard_capacity = if i < remaining {
base_capacity_per_shard + 1
} else {
base_capacity_per_shard
};
let shard_capacity = std::cmp::max(1, shard_capacity);
shards.push(Arc::new(Mutex::new(SieveCache::new(shard_capacity)?)));
}
Ok(Self { shards, num_shards })
}
#[inline]
fn get_shard_index<Q>(&self, key: &Q) -> usize
where
Q: Hash + ?Sized,
{
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash = hasher.finish() as usize;
hash % self.num_shards
}
#[inline]
fn get_shard<Q>(&self, key: &Q) -> &Arc<Mutex<SieveCache<K, V>>>
where
Q: Hash + ?Sized,
{
let index = self.get_shard_index(key);
&self.shards[index]
}
#[inline]
fn locked_shard<Q>(&self, key: &Q) -> MutexGuard<'_, SieveCache<K, V>>
where
Q: Hash + ?Sized,
{
self.get_shard(key)
.lock()
.unwrap_or_else(PoisonError::into_inner)
}
pub fn capacity(&self) -> usize {
self.shards
.iter()
.map(|shard| {
shard
.lock()
.unwrap_or_else(PoisonError::into_inner)
.capacity()
})
.sum()
}
pub fn len(&self) -> usize {
self.shards
.iter()
.map(|shard| shard.lock().unwrap_or_else(PoisonError::into_inner).len())
.sum()
}
pub fn is_empty(&self) -> bool {
self.shards.iter().all(|shard| {
shard
.lock()
.unwrap_or_else(PoisonError::into_inner)
.is_empty()
})
}
pub fn contains_key<Q>(&self, key: &Q) -> bool
where
Q: Hash + Eq + ?Sized,
K: Borrow<Q>,
{
let guard = self.locked_shard(key);
guard.contains_key(key)
}
pub fn get<Q>(&self, key: &Q) -> Option<V>
where
Q: Hash + Eq + ?Sized,
K: Borrow<Q>,
V: Clone,
{
let mut guard = self.locked_shard(key);
guard.get(key).cloned()
}
pub fn get_mut<Q, F>(&self, key: &Q, f: F) -> bool
where
Q: Hash + Eq + ?Sized,
K: Borrow<Q>,
F: FnOnce(&mut V),
V: Clone,
{
let value_opt = {
let mut guard = self.locked_shard(key);
guard.get_mut(key).map(|v| v.clone())
};
if let Some(mut value) = value_opt {
f(&mut value);
let mut guard = self.locked_shard(key);
if let Some(original) = guard.get_mut(key) {
*original = value;
true
} else {
false
}
} else {
false
}
}
pub fn insert(&self, key: K, value: V) -> bool {
let mut guard = self.locked_shard(&key);
guard.insert(key, value)
}
pub fn remove<Q>(&self, key: &Q) -> Option<V>
where
K: Borrow<Q>,
Q: Eq + Hash + ?Sized,
{
let mut guard = self.locked_shard(key);
guard.remove(key)
}
pub fn evict(&self) -> Option<V> {
for shard in &self.shards {
let result = shard.lock().unwrap_or_else(PoisonError::into_inner).evict();
if result.is_some() {
return result;
}
}
None
}
pub fn clear(&self) {
for shard in &self.shards {
let mut guard = shard.lock().unwrap_or_else(PoisonError::into_inner);
guard.clear();
}
}
pub fn keys(&self) -> Vec<K> {
let mut all_keys = Vec::new();
for shard in &self.shards {
let guard = shard.lock().unwrap_or_else(PoisonError::into_inner);
all_keys.extend(guard.keys().cloned());
}
all_keys
}
pub fn values(&self) -> Vec<V>
where
V: Clone,
{
let mut all_values = Vec::new();
for shard in &self.shards {
let guard = shard.lock().unwrap_or_else(PoisonError::into_inner);
all_values.extend(guard.values().cloned());
}
all_values
}
pub fn entries(&self) -> Vec<(K, V)>
where
V: Clone,
{
let mut all_entries = Vec::new();
for shard in &self.shards {
let guard = shard.lock().unwrap_or_else(PoisonError::into_inner);
all_entries.extend(guard.iter().map(|(k, v)| (k.clone(), v.clone())));
}
all_entries
}
pub fn for_each_value<F>(&self, mut f: F)
where
F: FnMut(&mut V),
V: Clone,
{
for shard in &self.shards {
let entries = {
let guard = shard.lock().unwrap_or_else(PoisonError::into_inner);
guard
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<Vec<(K, V)>>()
};
let mut updated_entries = Vec::new();
for (key, mut value) in entries {
f(&mut value);
updated_entries.push((key, value));
}
let mut guard = shard.lock().unwrap_or_else(PoisonError::into_inner);
for (key, value) in updated_entries {
guard.insert(key, value);
}
}
}
pub fn for_each_entry<F>(&self, mut f: F)
where
F: FnMut((&K, &mut V)),
V: Clone,
{
for shard in &self.shards {
let entries = {
let guard = shard.lock().unwrap_or_else(PoisonError::into_inner);
guard
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<Vec<(K, V)>>()
};
let mut updated_entries = Vec::new();
for (key, mut value) in entries {
let key_ref = &key;
f((key_ref, &mut value));
updated_entries.push((key, value));
}
let mut guard = shard.lock().unwrap_or_else(PoisonError::into_inner);
for (key, value) in updated_entries {
guard.insert(key, value);
}
}
}
pub fn with_key_lock<Q, F, T>(&self, key: &Q, f: F) -> T
where
Q: Hash + ?Sized,
F: FnOnce(&mut SieveCache<K, V>) -> T,
{
let mut guard = self.locked_shard(key);
f(&mut guard)
}
pub fn num_shards(&self) -> usize {
self.num_shards
}
pub fn get_shard_by_index(&self, index: usize) -> Option<&Arc<Mutex<SieveCache<K, V>>>> {
self.shards.get(index)
}
pub fn retain<F>(&self, mut f: F)
where
F: FnMut(&K, &V) -> bool,
V: Clone,
{
for shard in &self.shards {
let entries = {
let guard = shard.lock().unwrap_or_else(PoisonError::into_inner);
guard
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<Vec<(K, V)>>()
};
let mut keys_to_remove = Vec::new();
for (key, value) in entries {
if !f(&key, &value) {
keys_to_remove.push(key);
}
}
if !keys_to_remove.is_empty() {
let mut guard = shard.lock().unwrap_or_else(PoisonError::into_inner);
for key in keys_to_remove {
guard.remove(&key);
}
}
}
}
pub fn recommended_capacity(
&self,
min_factor: f64,
max_factor: f64,
low_threshold: f64,
high_threshold: f64,
) -> usize {
let mut total_recommended = 0;
for shard in &self.shards {
let shard_recommended = {
let guard = shard.lock().unwrap_or_else(PoisonError::into_inner);
guard.recommended_capacity(min_factor, max_factor, low_threshold, high_threshold)
};
total_recommended += shard_recommended;
}
if self.is_empty() {
self.capacity()
} else {
std::cmp::max(self.num_shards, total_recommended)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[test]
fn test_sharded_cache_basics() {
let cache = ShardedSieveCache::new(100).unwrap();
assert!(cache.insert("key1".to_string(), "value1".to_string()));
assert_eq!(cache.get(&"key1".to_string()), Some("value1".to_string()));
assert!(cache.contains_key(&"key1".to_string()));
assert!(cache.capacity() >= 100); assert_eq!(cache.len(), 1);
assert_eq!(
cache.remove(&"key1".to_string()),
Some("value1".to_string())
);
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_custom_shard_count() {
let cache = ShardedSieveCache::with_shards(100, 4).unwrap();
assert_eq!(cache.num_shards(), 4);
for i in 0..10 {
let key = format!("key{}", i);
let value = format!("value{}", i);
cache.insert(key, value);
}
assert_eq!(cache.len(), 10);
}
#[test]
fn test_parallel_access() {
let cache = Arc::new(ShardedSieveCache::with_shards(1000, 16).unwrap());
let mut handles = vec![];
for t in 0..8 {
let cache_clone = Arc::clone(&cache);
let handle = thread::spawn(move || {
for i in 0..100 {
let key = format!("thread{}key{}", t, i);
let value = format!("value{}_{}", t, i);
cache_clone.insert(key, value);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(cache.len(), 800);
assert_eq!(
cache.get(&"thread0key50".to_string()),
Some("value0_50".to_string())
);
assert_eq!(
cache.get(&"thread7key99".to_string()),
Some("value7_99".to_string())
);
}
#[test]
fn test_with_key_lock() {
let cache = ShardedSieveCache::new(100).unwrap();
cache.with_key_lock(&"test_key", |shard| {
shard.insert("key1".to_string(), "value1".to_string());
shard.insert("key2".to_string(), "value2".to_string());
shard.insert("key3".to_string(), "value3".to_string());
});
assert_eq!(cache.len(), 3);
}
#[test]
fn test_eviction() {
let cache = ShardedSieveCache::with_shards(10, 2).unwrap();
for i in 0..15 {
let key = format!("key{}", i);
let value = format!("value{}", i);
cache.insert(key, value);
}
assert!(cache.len() <= 10);
let evicted = cache.evict();
assert!(evicted.is_some());
}
#[test]
fn test_contention() {
let cache = Arc::new(ShardedSieveCache::with_shards(1000, 16).unwrap());
let mut handles = vec![];
let keys: Vec<String> = (0..16).map(|i| format!("shard_key_{}", i)).collect();
for i in 0..16 {
let cache_clone = Arc::clone(&cache);
let key = keys[i].clone();
let handle = thread::spawn(move || {
for j in 0..1000 {
cache_clone.insert(key.clone(), format!("value_{}", j));
let _ = cache_clone.get(&key);
if j % 100 == 0 {
thread::sleep(Duration::from_micros(1));
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
for key in keys {
assert!(cache.contains_key(&key));
}
}
#[test]
fn test_get_mut() {
let cache = ShardedSieveCache::new(100).unwrap();
cache.insert("key".to_string(), "value".to_string());
let modified = cache.get_mut(&"key".to_string(), |value| {
*value = "new_value".to_string();
});
assert!(modified);
assert_eq!(cache.get(&"key".to_string()), Some("new_value".to_string()));
let modified = cache.get_mut(&"missing".to_string(), |_| {
panic!("This should not be called");
});
assert!(!modified);
}
#[test]
fn test_get_mut_concurrent() {
let cache = Arc::new(ShardedSieveCache::with_shards(100, 8).unwrap());
for i in 0..10 {
cache.insert(format!("key{}", i), 0);
}
let mut handles = vec![];
for _ in 0..5 {
let cache_clone = Arc::clone(&cache);
let handle = thread::spawn(move || {
for i in 0..10 {
for _ in 0..100 {
cache_clone.get_mut(&format!("key{}", i), |value| {
*value += 1;
});
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
for i in 0..10 {
let value = cache.get(&format!("key{}", i));
assert!(value.is_some());
let num = value.unwrap();
assert!(
num > 0,
"Value for key{} should be positive but was {}",
i,
num
);
}
}
#[test]
fn test_clear() {
let cache = ShardedSieveCache::with_shards(100, 4).unwrap();
for i in 0..20 {
cache.insert(format!("key{}", i), format!("value{}", i));
}
assert_eq!(cache.len(), 20);
assert!(!cache.is_empty());
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
for i in 0..20 {
assert_eq!(cache.get(&format!("key{}", i)), None);
}
}
#[test]
fn test_keys_values_entries() {
let cache = ShardedSieveCache::with_shards(100, 4).unwrap();
for i in 0..10 {
cache.insert(format!("key{}", i), format!("value{}", i));
}
let keys = cache.keys();
assert_eq!(keys.len(), 10);
for i in 0..10 {
assert!(keys.contains(&format!("key{}", i)));
}
let values = cache.values();
assert_eq!(values.len(), 10);
for i in 0..10 {
assert!(values.contains(&format!("value{}", i)));
}
let entries = cache.entries();
assert_eq!(entries.len(), 10);
for i in 0..10 {
assert!(entries.contains(&(format!("key{}", i), format!("value{}", i))));
}
}
#[test]
fn test_for_each_operations() {
let cache = ShardedSieveCache::with_shards(100, 4).unwrap();
for i in 0..10 {
cache.insert(format!("key{}", i), format!("value{}", i));
}
cache.for_each_value(|value| {
*value = format!("{}_updated", value);
});
for i in 0..10 {
assert_eq!(
cache.get(&format!("key{}", i)),
Some(format!("value{}_updated", i))
);
}
cache.for_each_entry(|(key, value)| {
if key.ends_with("5") {
*value = format!("{}_special", value);
}
});
assert_eq!(
cache.get(&"key5".to_string()),
Some("value5_updated_special".to_string())
);
assert_eq!(
cache.get(&"key1".to_string()),
Some("value1_updated".to_string())
);
}
#[test]
fn test_multithreaded_operations() {
let cache = Arc::new(ShardedSieveCache::with_shards(100, 8).unwrap());
for i in 0..20 {
cache.insert(format!("key{}", i), format!("value{}", i));
}
let cache_clone = Arc::clone(&cache);
let handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(10));
for i in 0..20 {
let _ = cache_clone.get(&format!("key{}", i));
thread::sleep(Duration::from_micros(100));
}
});
thread::sleep(Duration::from_millis(5));
cache.clear();
for i in 30..40 {
cache.insert(format!("newkey{}", i), format!("newvalue{}", i));
}
handle.join().unwrap();
assert_eq!(cache.len(), 10);
for i in 30..40 {
assert_eq!(
cache.get(&format!("newkey{}", i)),
Some(format!("newvalue{}", i))
);
}
}
#[test]
fn test_retain() {
let cache = ShardedSieveCache::with_shards(100, 4).unwrap();
cache.insert("even1".to_string(), 2);
cache.insert("even2".to_string(), 4);
cache.insert("odd1".to_string(), 1);
cache.insert("odd2".to_string(), 3);
assert_eq!(cache.len(), 4);
cache.retain(|_, v| v % 2 == 0);
assert_eq!(cache.len(), 2);
assert!(cache.contains_key(&"even1".to_string()));
assert!(cache.contains_key(&"even2".to_string()));
assert!(!cache.contains_key(&"odd1".to_string()));
assert!(!cache.contains_key(&"odd2".to_string()));
cache.retain(|k, _| k.contains('1'));
assert_eq!(cache.len(), 1);
assert!(cache.contains_key(&"even1".to_string()));
assert!(!cache.contains_key(&"even2".to_string()));
}
#[test]
fn test_recommended_capacity() {
let cache = ShardedSieveCache::<String, u32>::with_shards(100, 4).unwrap();
assert_eq!(cache.recommended_capacity(0.5, 2.0, 0.3, 0.7), 100);
let cache = ShardedSieveCache::with_shards(100, 4).unwrap();
for i in 0..90 {
cache.insert(i.to_string(), i);
}
for i in 0..5 {
cache.get(&i.to_string()); }
let recommended = cache.recommended_capacity(0.5, 2.0, 0.1, 0.7); assert!(recommended < 100);
assert!(recommended >= 50); assert!(recommended >= 4);
let cache = ShardedSieveCache::with_shards(100, 4).unwrap();
for i in 0..90 {
cache.insert(i.to_string(), i);
if i % 10 != 0 {
cache.get(&i.to_string());
}
}
let recommended = cache.recommended_capacity(0.5, 2.0, 0.3, 0.7);
assert!(recommended > 100);
assert!(recommended <= 200);
let cache = ShardedSieveCache::with_shards(100, 4).unwrap();
for i in 0..90 {
cache.insert(i.to_string(), i);
if i % 2 == 0 {
cache.get(&i.to_string());
}
}
let recommended = cache.recommended_capacity(0.5, 2.0, 0.3, 0.7);
assert!(
recommended >= 95,
"With normal utilization, capacity should be close to original"
);
assert!(
recommended <= 100,
"With normal utilization, capacity should not exceed original"
);
let cache = ShardedSieveCache::with_shards(2000, 4).unwrap();
for i in 0..100 {
cache.insert(i.to_string(), i);
cache.get(&i.to_string());
}
let recommended = cache.recommended_capacity(0.5, 2.0, 0.3, 0.7);
assert!(
recommended < 2000,
"With low fill ratio, capacity should be decreased despite high hit rate"
);
assert!(
recommended >= 1000, "Capacity should not go below min_factor of current capacity"
);
assert!(
recommended >= 4, "Capacity should not go below number of shards"
);
}
#[test]
fn test_retain_concurrent() {
let cache = ShardedSieveCache::with_shards(100, 8).unwrap();
for i in 0..10 {
cache.insert(format!("even{}", i * 2), i * 2);
cache.insert(format!("odd{}", i * 2 + 1), i * 2 + 1);
}
cache.retain(|_, value| value % 2 == 1);
assert_eq!(cache.len(), 10, "Should have 10 odd-valued entries");
for (_, value) in cache.entries() {
assert_eq!(
value % 2,
1,
"Found an even value {value} which should have been removed"
);
}
for i in 0..10 {
let odd_key = format!("odd{}", i * 2 + 1);
assert!(cache.contains_key(&odd_key), "Missing odd entry: {odd_key}");
assert_eq!(cache.get(&odd_key), Some(i * 2 + 1));
}
}
#[test]
fn test_deadlock_prevention() {
let cache = Arc::new(ShardedSieveCache::with_shards(100, 4).unwrap());
cache.insert("keyA_1".to_string(), 1);
cache.insert("keyB_2".to_string(), 2);
let cache_clone1 = Arc::clone(&cache);
let cache_clone2 = Arc::clone(&cache);
let thread1 = thread::spawn(move || {
cache_clone1.get_mut(&"keyA_1".to_string(), |value| {
let other_value = cache_clone1.get(&"keyB_2".to_string());
assert_eq!(other_value, Some(2));
cache_clone1.insert("keyC_3".to_string(), 3);
*value += 10;
});
});
let thread2 = thread::spawn(move || {
thread::sleep(Duration::from_millis(5));
cache_clone2.get_mut(&"keyB_2".to_string(), |value| {
*value += 20;
let _ = cache_clone2.get(&"keyA_1".to_string());
});
});
thread1.join().unwrap();
thread2.join().unwrap();
assert_eq!(cache.get(&"keyA_1".to_string()), Some(11)); assert_eq!(cache.get(&"keyB_2".to_string()), Some(22)); assert_eq!(cache.get(&"keyC_3".to_string()), Some(3));
}
}