use std::hash::Hash;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
#[cfg(feature = "ahash")]
use ahash::RandomState;
#[cfg(not(feature = "ahash"))]
use std::collections::hash_map::RandomState;
use std::collections::HashMap;
#[cfg(feature = "async_core")]
use crate::ConcurrentCachedAsync;
use crate::{CacheMetrics, ConcurrentCached};
use super::{
CachePadded, DefaultShardHasher, Shard, ShardHasher, checked_shard_count, shard_index,
};
use crate::stores::BuildError;
type OnEvict<K, V> = Arc<dyn Fn(&K, &V) + Send + Sync>;
#[allow(clippy::type_complexity)]
struct UnboundInner<K, V, H> {
shards: Box<[CachePadded<Shard<HashMap<K, V, RandomState>>>]>,
shard_mask: usize,
hasher: H,
on_evict: Option<OnEvict<K, V>>,
}
pub type ShardedCache<K, V> = ShardedCacheBase<K, V, DefaultShardHasher>;
pub struct ShardedCacheBase<K, V, H = DefaultShardHasher> {
inner: Arc<UnboundInner<K, V, H>>,
}
impl<K, V, H> Clone for ShardedCacheBase<K, V, H> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<K, V, H> std::fmt::Debug for ShardedCacheBase<K, V, H> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShardedCache")
.field("shards", &self.inner.shards.len())
.finish_non_exhaustive()
}
}
impl<K, V, H> ShardedCacheBase<K, V, H>
where
K: Hash + Eq,
H: ShardHasher<K>,
{
pub fn builder() -> ShardedCacheBuilder<K, V, DefaultShardHasher> {
ShardedCacheBuilder::default()
}
#[inline]
fn shard_of(&self, k: &K) -> &CachePadded<Shard<HashMap<K, V, RandomState>>> {
let h = self.inner.hasher.shard_hash(k);
&self.inner.shards[shard_index(h, self.inner.shard_mask)]
}
}
impl<K, V> Default for ShardedCache<K, V>
where
K: Hash + Eq,
{
fn default() -> Self {
ShardedCacheBuilder::default()
.build()
.unwrap_or_else(|e| panic!("ShardedCache build failed: {e}"))
}
}
impl<K: Clone + Hash + Eq, V: Clone, H: ShardHasher<K> + Clone> ShardedCacheBase<K, V, H> {
#[must_use]
pub fn deep_clone(&self) -> Self {
let n = self.inner.shards.len();
let shards = (0..n)
.map(|i| {
let guard = self.inner.shards[i].lock.read();
let store_copy = guard.clone();
drop(guard);
let hits = self.inner.shards[i].hits.load(Ordering::Relaxed);
let misses = self.inner.shards[i].misses.load(Ordering::Relaxed);
let shard = Shard {
lock: parking_lot::RwLock::new(store_copy),
hits: AtomicU64::new(hits),
misses: AtomicU64::new(misses),
};
CachePadded(shard)
})
.collect::<Vec<_>>()
.into_boxed_slice();
Self {
inner: Arc::new(UnboundInner {
shards,
shard_mask: self.inner.shard_mask,
hasher: self.inner.hasher.clone(),
on_evict: self.inner.on_evict.clone(),
}),
}
}
}
impl<K, V, H: ShardHasher<K>> ShardedCacheBase<K, V, H>
where
K: Hash + Eq,
{
#[must_use]
pub fn metrics(&self) -> CacheMetrics {
let mut hits = 0u64;
let mut misses = 0u64;
let mut size = 0usize;
for shard in self.inner.shards.iter() {
hits += shard.hits.load(Ordering::Relaxed);
misses += shard.misses.load(Ordering::Relaxed);
size += shard.lock.read().len();
}
CacheMetrics {
hits: Some(hits),
misses: Some(misses),
evictions: None,
size,
capacity: None,
}
}
#[must_use]
pub fn shards(&self) -> usize {
self.inner.shards.len()
}
#[must_use]
pub fn shard_sizes(&self) -> Vec<usize> {
self.inner
.shards
.iter()
.map(|s| s.lock.read().len())
.collect()
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.shards.iter().map(|s| s.lock.read().len()).sum()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.shards.iter().all(|s| s.lock.read().is_empty())
}
pub fn clear(&self) {
for shard in self.inner.shards.iter() {
shard.lock.write().clear();
}
}
pub fn cache_clear_with_on_evict(&self) {
if self.inner.on_evict.is_none() {
return self.clear();
}
for shard in self.inner.shards.iter() {
let entries: Vec<(K, V)> = shard.lock.write().drain().collect();
if let Some(on_evict) = &self.inner.on_evict {
for (k, v) in &entries {
on_evict(k, v);
}
}
}
}
}
impl<K, V, H> ConcurrentCached<K, V> for ShardedCacheBase<K, V, H>
where
K: Hash + Eq,
V: Clone,
H: ShardHasher<K>,
{
type Error = std::convert::Infallible;
fn cache_get(&self, k: &K) -> Result<Option<V>, Self::Error> {
let shard = self.shard_of(k);
let guard = shard.lock.read();
match guard.get(k) {
Some(v) => {
shard.hits.fetch_add(1, Ordering::Relaxed);
Ok(Some(v.clone()))
}
None => {
shard.misses.fetch_add(1, Ordering::Relaxed);
Ok(None)
}
}
}
fn cache_set(&self, k: K, v: V) -> Result<Option<V>, Self::Error> {
let shard = self.shard_of(&k);
Ok(shard.lock.write().insert(k, v))
}
fn cache_remove(&self, k: &K) -> Result<Option<V>, Self::Error> {
ConcurrentCached::cache_remove_entry(self, k).map(|r| r.map(|(_, v)| v))
}
fn cache_remove_entry(&self, k: &K) -> Result<Option<(K, V)>, Self::Error> {
let shard = self.shard_of(k);
let removed = shard.lock.write().remove_entry(k);
if let Some((ref stored_k, ref v)) = removed {
if let Some(on_evict) = &self.inner.on_evict {
on_evict(stored_k, v);
}
}
Ok(removed)
}
fn cache_size(&self) -> Result<Option<usize>, Self::Error> {
Ok(Some(self.len()))
}
fn set_refresh_on_hit(&self, _refresh: bool) -> bool {
false
}
}
#[cfg(feature = "async_core")]
impl<K, V, H> ConcurrentCachedAsync<K, V> for ShardedCacheBase<K, V, H>
where
K: Hash + Eq + Send + Sync,
V: Clone + Send + Sync,
H: ShardHasher<K>,
{
type Error = std::convert::Infallible;
async fn cache_get(&self, k: &K) -> Result<Option<V>, Self::Error> {
ConcurrentCached::cache_get(self, k)
}
async fn cache_set(&self, k: K, v: V) -> Result<Option<V>, Self::Error> {
ConcurrentCached::cache_set(self, k, v)
}
async fn cache_remove(&self, k: &K) -> Result<Option<V>, Self::Error> {
ConcurrentCached::cache_remove(self, k)
}
async fn cache_remove_entry(&self, k: &K) -> Result<Option<(K, V)>, Self::Error> {
ConcurrentCached::cache_remove_entry(self, k)
}
fn cache_size(&self) -> Result<Option<usize>, Self::Error> {
Ok(Some(self.len()))
}
fn set_refresh_on_hit(&self, b: bool) -> bool {
<Self as ConcurrentCached<K, V>>::set_refresh_on_hit(self, b)
}
}
pub struct ShardedCacheBuilder<K, V, H = DefaultShardHasher> {
shards: Option<usize>,
hasher: Option<H>,
on_evict: Option<OnEvict<K, V>>,
_k: std::marker::PhantomData<K>,
_v: std::marker::PhantomData<V>,
}
impl<K, V> Default for ShardedCacheBuilder<K, V, DefaultShardHasher> {
fn default() -> Self {
Self {
shards: None,
hasher: Some(DefaultShardHasher::default()),
on_evict: None,
_k: std::marker::PhantomData,
_v: std::marker::PhantomData,
}
}
}
impl<K, V, H> ShardedCacheBuilder<K, V, H> {
#[must_use]
pub fn shards(mut self, shards: usize) -> Self {
self.shards = Some(shards);
self
}
#[doc(alias = "with_hasher")]
#[must_use]
pub fn hasher<H2: ShardHasher<K>>(self, hasher: H2) -> ShardedCacheBuilder<K, V, H2> {
ShardedCacheBuilder {
shards: self.shards,
hasher: Some(hasher),
on_evict: self.on_evict,
_k: std::marker::PhantomData,
_v: std::marker::PhantomData,
}
}
#[must_use]
pub fn on_evict(mut self, on_evict: impl Fn(&K, &V) + Send + Sync + 'static) -> Self {
self.on_evict = Some(Arc::new(on_evict));
self
}
pub fn build(self) -> Result<ShardedCacheBase<K, V, H>, BuildError>
where
K: Hash + Eq,
H: ShardHasher<K>,
{
let n = checked_shard_count(self.shards)?;
let mask = n - 1;
let shards = (0..n)
.map(|_| CachePadded(Shard::new(HashMap::with_hasher(RandomState::new()))))
.collect::<Vec<_>>()
.into_boxed_slice();
Ok(ShardedCacheBase {
inner: Arc::new(UnboundInner {
shards,
shard_mask: mask,
hasher: self
.hasher
.expect("hasher is always initialized via Default or .hasher()"),
on_evict: self.on_evict,
}),
})
}
#[must_use]
pub fn copy_from<H2: ShardHasher<K>>(
self,
existing: &ShardedCacheBase<K, V, H2>,
) -> ShardedCacheBase<K, V, H>
where
K: Clone + Hash + Eq,
V: Clone,
H: ShardHasher<K>,
{
let new_cache = self
.build()
.unwrap_or_else(|e| panic!("ShardedCache build failed: {e}"));
for shard in existing.inner.shards.iter() {
let entries: Vec<(K, V)> = {
let guard = shard.lock.read();
guard.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
};
for (k, v) in entries {
let _ = ConcurrentCached::cache_set(&new_cache, k, v);
}
}
new_cache
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ConcurrentCached as SyncConcurrentCached;
#[test]
fn basic_get_set_remove() {
let c = ShardedCache::<u32, u32>::builder().build().unwrap();
assert_eq!(
SyncConcurrentCached::cache_get(&c, &1).expect("cache_get must succeed"),
None
);
assert_eq!(
SyncConcurrentCached::cache_set(&c, 1, 100).expect("insert must succeed"),
None
);
assert_eq!(
SyncConcurrentCached::cache_get(&c, &1).expect("key was just inserted"),
Some(100)
);
assert_eq!(
SyncConcurrentCached::cache_set(&c, 1, 200).expect("insert must succeed"),
Some(100)
);
assert_eq!(
SyncConcurrentCached::cache_get(&c, &1).expect("key was just inserted"),
Some(200)
);
assert_eq!(
SyncConcurrentCached::cache_remove(&c, &1).expect("key must be present"),
Some(200)
);
assert_eq!(
SyncConcurrentCached::cache_get(&c, &1).expect("cache_get must succeed"),
None
);
}
#[test]
fn clone_shares_state() {
let c1 = ShardedCache::<u32, u32>::builder().build().unwrap();
let c2 = c1.clone();
SyncConcurrentCached::cache_set(&c1, 1, 10).expect("insert must succeed");
assert_eq!(
SyncConcurrentCached::cache_get(&c2, &1).expect("key was just inserted"),
Some(10)
);
}
#[test]
fn metrics_sum() {
let c = ShardedCache::<u32, u32>::builder().build().unwrap();
SyncConcurrentCached::cache_set(&c, 1, 1).expect("insert must succeed");
SyncConcurrentCached::cache_get(&c, &1).expect("key was just inserted");
SyncConcurrentCached::cache_get(&c, &2).expect("cache_get must succeed");
let m = c.metrics();
assert_eq!(m.hits, Some(1));
assert_eq!(m.misses, Some(1));
}
#[test]
fn len_and_clear() {
let c = ShardedCache::<u32, u32>::builder().build().unwrap();
for i in 0..10u32 {
SyncConcurrentCached::cache_set(&c, i, i).expect("insert must succeed");
}
assert_eq!(c.len(), 10);
assert!(!c.is_empty());
c.clear();
assert_eq!(c.len(), 0);
assert!(c.is_empty());
}
#[test]
fn shard_sizes() {
let c = ShardedCache::<u32, u32>::builder()
.shards(8)
.build()
.unwrap();
for i in 0..100u32 {
SyncConcurrentCached::cache_set(&c, i, i).expect("insert must succeed");
}
let sizes = c.shard_sizes();
assert_eq!(sizes.len(), 8);
assert_eq!(sizes.iter().sum::<usize>(), 100);
}
#[test]
fn on_evict_fires_on_remove() {
use std::sync::atomic::{AtomicUsize, Ordering};
let count = Arc::new(AtomicUsize::new(0));
let count2 = count.clone();
let c = ShardedCacheBase::<u32, u32>::builder()
.on_evict(move |_, _| {
count2.fetch_add(1, Ordering::Relaxed);
})
.build()
.unwrap();
SyncConcurrentCached::cache_set(&c, 1, 1).expect("insert must succeed");
SyncConcurrentCached::cache_remove(&c, &1).expect("key must be present");
assert_eq!(count.load(Ordering::Relaxed), 1);
}
#[test]
fn custom_hasher() {
#[derive(Default)]
struct ConstHasher;
impl ShardHasher<u32> for ConstHasher {
fn shard_hash(&self, _key: &u32) -> u64 {
0
}
}
let c = ShardedCacheBase::<u32, u32>::builder()
.shards(8)
.hasher(ConstHasher)
.build()
.unwrap();
for i in 0..10u32 {
SyncConcurrentCached::cache_set(&c, i, i).expect("insert must succeed");
}
let sizes = c.shard_sizes();
assert_eq!(sizes[0], 10);
assert_eq!(sizes[1..].iter().sum::<usize>(), 0);
}
#[test]
fn copy_from_preserves_entries() {
let old = ShardedCache::<u32, u32>::builder().build().unwrap();
for i in 0..50u32 {
SyncConcurrentCached::cache_set(&old, i, i * 10).expect("insert must succeed");
}
let new_cache = ShardedCacheBase::<u32, u32>::builder()
.shards(4)
.copy_from(&old);
for i in 0..50u32 {
assert_eq!(
SyncConcurrentCached::cache_get(&new_cache, &i).expect("key was just inserted"),
Some(i * 10)
);
}
}
#[test]
fn deep_clone_is_independent() {
let c1 = ShardedCache::<u32, u32>::builder().build().unwrap();
SyncConcurrentCached::cache_set(&c1, 1, 1).expect("insert must succeed");
let c2 = c1.deep_clone();
SyncConcurrentCached::cache_set(&c1, 2, 2).expect("insert must succeed");
assert_eq!(
SyncConcurrentCached::cache_get(&c2, &2).expect("cache_get must succeed"),
None
);
assert_eq!(
SyncConcurrentCached::cache_get(&c1, &1).expect("key was just inserted"),
Some(1)
);
assert_eq!(
SyncConcurrentCached::cache_get(&c2, &1).expect("key was copied to deep clone"),
Some(1)
);
}
#[test]
fn send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ShardedCache<u32, u32>>();
}
#[test]
fn build_error_on_overflow() {
let c = ShardedCacheBase::<u32, u32>::builder()
.shards(usize::MAX)
.build();
assert!(c.is_err());
match c.expect_err("usize::MAX shards should fail") {
BuildError::InvalidValue { field, reason } => {
assert_eq!(field, "shards");
assert!(reason.contains("overflows"));
}
_ => panic!("expected BuildError::InvalidValue"),
}
}
#[test]
fn build_error_on_zero_shards() {
let c = ShardedCacheBase::<u32, u32>::builder().shards(0).build();
assert!(c.is_err(), "zero shards should return Err");
match c.expect_err("zero shards should fail") {
BuildError::InvalidValue { field, .. } => {
assert_eq!(field, "shards");
}
_ => panic!("expected BuildError::InvalidValue"),
}
}
#[test]
fn cache_clear_with_on_evict_fires_for_all_entries() {
use std::sync::atomic::{AtomicUsize, Ordering};
let count = Arc::new(AtomicUsize::new(0));
let count2 = count.clone();
let c = ShardedCacheBase::<u32, u32>::builder()
.on_evict(move |_, _| {
count2.fetch_add(1, Ordering::Relaxed);
})
.build()
.unwrap();
for i in 0..20u32 {
SyncConcurrentCached::cache_set(&c, i, i).expect("insert must succeed");
}
c.cache_clear_with_on_evict();
assert_eq!(
c.len(),
0,
"cache must be empty after cache_clear_with_on_evict"
);
assert_eq!(
count.load(Ordering::Relaxed),
20,
"on_evict must fire for every entry"
);
}
#[test]
fn clear_does_not_fire_on_evict() {
use std::sync::atomic::{AtomicUsize, Ordering};
let count = Arc::new(AtomicUsize::new(0));
let count2 = count.clone();
let c = ShardedCacheBase::<u32, u32>::builder()
.on_evict(move |_, _| {
count2.fetch_add(1, Ordering::Relaxed);
})
.build()
.unwrap();
for i in 0..10u32 {
SyncConcurrentCached::cache_set(&c, i, i).expect("insert must succeed");
}
c.clear();
assert_eq!(
count.load(Ordering::Relaxed),
0,
"clear must not fire on_evict"
);
}
#[test]
fn cache_remove_entry_basic() {
let c = ShardedCacheBase::<u32, u32>::builder()
.shards(1)
.build()
.unwrap();
SyncConcurrentCached::cache_set(&c, 1u32, 100u32).expect("insert must succeed");
assert_eq!(
SyncConcurrentCached::cache_remove_entry(&c, &999u32)
.expect("cache_remove_entry must succeed"),
None
);
assert_eq!(
SyncConcurrentCached::cache_remove_entry(&c, &1u32).expect("key must be present"),
Some((1u32, 100u32))
);
assert_eq!(
SyncConcurrentCached::cache_get(&c, &1u32).expect("cache_get must succeed"),
None
);
}
#[test]
fn cache_remove_entry_fires_on_evict() {
use std::sync::atomic::{AtomicUsize, Ordering};
let count = Arc::new(AtomicUsize::new(0));
let count2 = count.clone();
let c = ShardedCacheBase::<u32, u32>::builder()
.shards(1)
.on_evict(move |_, _| {
count2.fetch_add(1, Ordering::Relaxed);
})
.build()
.unwrap();
SyncConcurrentCached::cache_set(&c, 1u32, 10u32).expect("insert must succeed");
SyncConcurrentCached::cache_remove_entry(&c, &1u32).expect("key must be present");
assert_eq!(count.load(Ordering::Relaxed), 1);
SyncConcurrentCached::cache_remove_entry(&c, &999u32)
.expect("cache_remove_entry must succeed");
assert_eq!(count.load(Ordering::Relaxed), 1, "no fire for absent key");
}
#[test]
fn cache_delete_returns_true_for_present_entry() {
let c = ShardedCacheBase::<u32, u32>::builder()
.shards(1)
.build()
.unwrap();
SyncConcurrentCached::cache_set(&c, 1u32, 10u32).expect("insert must succeed");
assert!(SyncConcurrentCached::cache_delete(&c, &1u32).expect("cache_delete must succeed"));
assert!(!SyncConcurrentCached::cache_delete(&c, &1u32).expect("cache_delete must succeed"));
}
}