use std::hash::Hash;
use std::marker::PhantomData;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
#[cfg(feature = "async_core")]
use crate::ConcurrentCachedAsync;
use crate::time::{Duration, Instant};
use crate::{CacheMetrics, CacheTtl, ConcurrentCached, ConcurrentCloneCached};
use super::{
CachePadded, DefaultShardHasher, Shard, ShardHasher, checked_shard_count, shard_index,
};
use crate::stores::{BuildError, CacheEvict, HasEvict, LruCache, NoEvict, TimedEntry};
use crate::{Cached, CachedIter, CachedPeek};
type OnEvict<K, V> = Arc<dyn Fn(&K, &V) + Send + Sync>;
#[allow(clippy::type_complexity)]
struct LruTtlInner<K, V, H> {
shards: Box<[CachePadded<Shard<LruCache<K, TimedEntry<V>>>>]>,
shard_mask: usize,
hasher: H,
on_evict: Option<OnEvict<K, V>>,
ttl_nanos: AtomicU64,
refresh: AtomicBool,
non_capacity_evictions: AtomicU64,
total_capacity: usize,
}
pub type ShardedLruTtlCache<K, V> = ShardedLruTtlCacheBase<K, V, DefaultShardHasher>;
pub struct ShardedLruTtlCacheBase<K, V, H = DefaultShardHasher> {
inner: Arc<LruTtlInner<K, V, H>>,
}
impl<K, V, H> Clone for ShardedLruTtlCacheBase<K, V, H> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<K, V, H> std::fmt::Debug for ShardedLruTtlCacheBase<K, V, H> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let nanos = self.inner.ttl_nanos.load(Ordering::Relaxed);
let ttl = if nanos == 0 {
None
} else {
Some(Duration::from_nanos(nanos))
};
f.debug_struct("ShardedLruTtlCache")
.field("shards", &self.inner.shards.len())
.field("capacity", &self.inner.total_capacity)
.field("ttl", &ttl)
.finish_non_exhaustive()
}
}
impl<K, V, H> ShardedLruTtlCacheBase<K, V, H>
where
K: Hash + Eq + Clone,
H: ShardHasher<K>,
{
pub fn builder() -> ShardedLruTtlCacheBuilder<K, V, DefaultShardHasher> {
ShardedLruTtlCacheBuilder::default()
}
#[inline]
fn shard_of(&self, k: &K) -> &CachePadded<Shard<LruCache<K, TimedEntry<V>>>> {
let h = self.inner.hasher.shard_hash(k);
&self.inner.shards[shard_index(h, self.inner.shard_mask)]
}
#[inline]
fn ttl_duration(&self) -> Option<Duration> {
let nanos = self.inner.ttl_nanos.load(Ordering::Relaxed);
if nanos == 0 {
None
} else {
Some(Duration::from_nanos(nanos))
}
}
}
impl<K: Clone + Hash + Eq, V: Clone, H: ShardHasher<K> + Clone> ShardedLruTtlCacheBase<K, V, H> {
#[must_use]
pub fn deep_clone(&self) -> Self {
let n = self.inner.shards.len();
let shards = (0..n)
.map(|i| {
let guard = self.inner.shards[i].lock.read();
let store_copy = guard.clone();
let hits = self.inner.shards[i].hits.load(Ordering::Relaxed);
let misses = self.inner.shards[i].misses.load(Ordering::Relaxed);
drop(guard);
let shard = Shard {
lock: parking_lot::RwLock::new(store_copy),
hits: AtomicU64::new(hits),
misses: AtomicU64::new(misses),
};
CachePadded(shard)
})
.collect::<Vec<_>>()
.into_boxed_slice();
Self {
inner: Arc::new(LruTtlInner {
shards,
shard_mask: self.inner.shard_mask,
hasher: self.inner.hasher.clone(),
on_evict: self.inner.on_evict.clone(),
ttl_nanos: AtomicU64::new(self.inner.ttl_nanos.load(Ordering::Relaxed)),
refresh: AtomicBool::new(self.inner.refresh.load(Ordering::Relaxed)),
non_capacity_evictions: AtomicU64::new(
self.inner.non_capacity_evictions.load(Ordering::Relaxed),
),
total_capacity: self.inner.total_capacity,
}),
}
}
}
impl<K, V, H: ShardHasher<K>> ShardedLruTtlCacheBase<K, V, H>
where
K: Hash + Eq + Clone,
{
#[must_use]
pub fn metrics(&self) -> CacheMetrics {
let mut hits = 0u64;
let mut misses = 0u64;
let mut lru_evictions = 0u64;
let mut size = 0usize;
for shard in self.inner.shards.iter() {
hits += shard.hits.load(Ordering::Relaxed);
misses += shard.misses.load(Ordering::Relaxed);
let guard = shard.lock.read();
if let Some(e) = guard.cache_evictions() {
lru_evictions += e;
}
size += guard.cache_size();
}
CacheMetrics {
hits: Some(hits),
misses: Some(misses),
evictions: Some(
lru_evictions + self.inner.non_capacity_evictions.load(Ordering::Relaxed),
),
size,
capacity: Some(self.inner.total_capacity),
}
}
#[must_use]
pub fn shards(&self) -> usize {
self.inner.shards.len()
}
#[must_use]
pub fn shard_sizes(&self) -> Vec<usize> {
self.inner
.shards
.iter()
.map(|s| s.lock.read().cache_size())
.collect()
}
#[must_use]
pub fn len(&self) -> usize {
self.inner
.shards
.iter()
.map(|s| s.lock.read().cache_size())
.sum()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner
.shards
.iter()
.all(|s| s.lock.read().cache_size() == 0)
}
pub fn clear(&self) {
for shard in self.inner.shards.iter() {
shard.lock.write().cache_clear();
}
}
pub fn cache_clear_with_on_evict(&self) {
if self.inner.on_evict.is_none() {
return self.clear();
}
for shard in self.inner.shards.iter() {
let removed: Vec<(K, TimedEntry<V>)> = {
let mut guard = shard.lock.write();
let keys: Vec<K> = guard.iter().map(|(k, _)| k.clone()).collect();
let mut removed = Vec::with_capacity(keys.len());
for k in keys {
if let Some(pair) = guard.pop_raw(&k) {
removed.push(pair);
}
}
removed
};
if !removed.is_empty() {
self.inner
.non_capacity_evictions
.fetch_add(removed.len() as u64, Ordering::Relaxed);
if let Some(on_evict) = &self.inner.on_evict {
for (k, entry) in &removed {
on_evict(k, &entry.value);
}
}
}
}
}
#[must_use]
pub fn capacity(&self) -> usize {
self.inner.total_capacity
}
#[must_use]
pub fn evict(&self) -> usize {
let ttl = match self.ttl_duration() {
None => return 0,
Some(t) => t,
};
let mut total = 0;
let now = Instant::now();
for shard in self.inner.shards.iter() {
let removed = {
let mut guard = shard.lock.write();
let expired: Vec<K> = guard
.iter()
.filter(|(_, e)| now.saturating_duration_since(e.instant) >= ttl)
.map(|(k, _)| k.clone())
.collect();
let mut removed = Vec::new();
for k in expired {
if let Some((key, entry)) = guard.pop_raw(&k) {
removed.push((key, entry));
}
}
removed
};
total += removed.len();
if !removed.is_empty() {
self.inner
.non_capacity_evictions
.fetch_add(removed.len() as u64, Ordering::Relaxed);
if let Some(cb) = &self.inner.on_evict {
for (k, entry) in &removed {
cb(k, &entry.value);
}
}
}
}
total
}
#[must_use]
pub fn ttl(&self) -> Option<Duration> {
self.ttl_duration()
}
pub fn set_ttl(&self, ttl: Duration) -> Option<Duration> {
assert!(
!ttl.is_zero(),
"TTL must be non-zero; use unset_ttl() to disable expiry"
);
let prev = self.inner.ttl_nanos.swap(
ttl.as_nanos().min(u64::MAX as u128) as u64,
Ordering::Relaxed,
);
if prev == 0 {
None
} else {
Some(Duration::from_nanos(prev))
}
}
pub fn unset_ttl(&self) -> Option<Duration> {
let prev = self.inner.ttl_nanos.swap(0, Ordering::Relaxed);
if prev == 0 {
None
} else {
Some(Duration::from_nanos(prev))
}
}
pub fn set_refresh_on_hit(&self, refresh: bool) -> bool {
self.inner.refresh.swap(refresh, Ordering::Relaxed)
}
#[must_use]
pub fn refresh_on_hit(&self) -> bool {
self.inner.refresh.load(Ordering::Relaxed)
}
}
impl<K, V, H> CacheTtl for ShardedLruTtlCacheBase<K, V, H>
where
K: Hash + Eq + Clone,
H: ShardHasher<K>,
{
fn ttl(&self) -> Option<Duration> {
self.ttl_duration()
}
fn set_ttl(&mut self, ttl: Duration) -> Option<Duration> {
ShardedLruTtlCacheBase::set_ttl(self, ttl)
}
fn unset_ttl(&mut self) -> Option<Duration> {
ShardedLruTtlCacheBase::unset_ttl(self)
}
}
impl<K, V, H> CacheEvict for ShardedLruTtlCacheBase<K, V, H>
where
K: Hash + Eq + Clone,
H: ShardHasher<K>,
{
fn evict(&mut self) -> usize {
ShardedLruTtlCacheBase::evict(self)
}
}
impl<K, V, H> ConcurrentCached<K, V> for ShardedLruTtlCacheBase<K, V, H>
where
K: Hash + Eq + Clone,
V: Clone,
H: ShardHasher<K>,
{
type Error = std::convert::Infallible;
fn cache_get(&self, k: &K) -> Result<Option<V>, Self::Error> {
let shard = self.shard_of(k);
let ttl = self.ttl_duration();
let refresh = self.inner.refresh.load(Ordering::Relaxed);
let mut guard = shard.lock.write();
let expired = match guard.cache_peek(k) {
None => {
shard.misses.fetch_add(1, Ordering::Relaxed);
return Ok(None);
}
Some(entry) => match &ttl {
None => false,
Some(t) => entry.instant.elapsed() >= *t,
},
};
if expired {
let removed = guard.pop_raw(k);
drop(guard);
if let Some((ref ek, ref entry)) = removed {
if let Some(cb) = &self.inner.on_evict {
cb(ek, &entry.value);
}
self.inner
.non_capacity_evictions
.fetch_add(1, Ordering::Relaxed);
}
shard.misses.fetch_add(1, Ordering::Relaxed);
return Ok(None);
}
let value = if refresh {
guard.cache_get_mut(k).map(|e| {
e.instant = Instant::now();
e.value.clone()
})
} else {
guard.cache_get(k).map(|e| e.value.clone())
};
shard.hits.fetch_add(1, Ordering::Relaxed);
Ok(value)
}
fn cache_set(&self, k: K, v: V) -> Result<Option<V>, Self::Error> {
let shard = self.shard_of(&k);
let new_entry = TimedEntry {
instant: Instant::now(),
value: v,
};
let old = shard.lock.write().cache_set(k, new_entry);
Ok(old.map(|e| e.value))
}
fn cache_remove(&self, k: &K) -> Result<Option<V>, Self::Error> {
let shard = self.shard_of(k);
let removed = shard.lock.write().pop_raw(k);
if let Some((key, entry)) = removed {
self.inner
.non_capacity_evictions
.fetch_add(1, Ordering::Relaxed);
if let Some(on_evict) = &self.inner.on_evict {
on_evict(&key, &entry.value);
}
let expired = match self.ttl_duration() {
None => false,
Some(ttl) => entry.instant.elapsed() >= ttl,
};
if expired {
Ok(None)
} else {
Ok(Some(entry.value))
}
} else {
Ok(None)
}
}
fn cache_remove_entry(&self, k: &K) -> Result<Option<(K, V)>, Self::Error> {
let shard = self.shard_of(k);
let removed = shard.lock.write().pop_raw(k);
if let Some((ref stored_k, ref entry)) = removed {
self.inner
.non_capacity_evictions
.fetch_add(1, Ordering::Relaxed);
if let Some(on_evict) = &self.inner.on_evict {
on_evict(stored_k, &entry.value);
}
}
Ok(removed.map(|(k, entry)| (k, entry.value)))
}
fn cache_size(&self) -> Result<Option<usize>, Self::Error> {
Ok(Some(self.len()))
}
fn set_refresh_on_hit(&self, refresh: bool) -> bool {
self.inner.refresh.swap(refresh, Ordering::Relaxed)
}
fn ttl(&self) -> Option<Duration> {
self.ttl_duration()
}
fn set_ttl(&self, ttl: Duration) -> Option<Duration> {
ShardedLruTtlCacheBase::set_ttl(self, ttl)
}
fn unset_ttl(&self) -> Option<Duration> {
ShardedLruTtlCacheBase::unset_ttl(self)
}
}
#[cfg(feature = "async_core")]
impl<K, V, H> ConcurrentCachedAsync<K, V> for ShardedLruTtlCacheBase<K, V, H>
where
K: Hash + Eq + Clone + Send + Sync,
V: Clone + Send + Sync,
H: ShardHasher<K>,
{
type Error = std::convert::Infallible;
async fn cache_get(&self, k: &K) -> Result<Option<V>, Self::Error> {
ConcurrentCached::cache_get(self, k)
}
async fn cache_set(&self, k: K, v: V) -> Result<Option<V>, Self::Error> {
ConcurrentCached::cache_set(self, k, v)
}
async fn cache_remove(&self, k: &K) -> Result<Option<V>, Self::Error> {
ConcurrentCached::cache_remove(self, k)
}
async fn cache_remove_entry(&self, k: &K) -> Result<Option<(K, V)>, Self::Error> {
ConcurrentCached::cache_remove_entry(self, k)
}
fn cache_size(&self) -> Result<Option<usize>, Self::Error> {
Ok(Some(self.len()))
}
fn set_refresh_on_hit(&self, b: bool) -> bool {
<Self as ConcurrentCached<K, V>>::set_refresh_on_hit(self, b)
}
fn ttl(&self) -> Option<Duration> {
self.ttl_duration()
}
fn set_ttl(&self, ttl: Duration) -> Option<Duration> {
ShardedLruTtlCacheBase::set_ttl(self, ttl)
}
fn unset_ttl(&self) -> Option<Duration> {
ShardedLruTtlCacheBase::unset_ttl(self)
}
}
pub struct ShardedLruTtlCacheBuilder<K, V, H = DefaultShardHasher, E = NoEvict> {
shards: Option<usize>,
max_size: Option<usize>,
per_shard_max_size: Option<usize>,
ttl: Option<Duration>,
refresh: bool,
hasher: Option<H>,
on_evict: Option<OnEvict<K, V>>,
_evict: PhantomData<E>,
}
impl<K, V> Default for ShardedLruTtlCacheBuilder<K, V, DefaultShardHasher> {
fn default() -> Self {
Self {
shards: None,
max_size: None,
per_shard_max_size: None,
ttl: None,
refresh: false,
hasher: Some(DefaultShardHasher::default()),
on_evict: None,
_evict: PhantomData,
}
}
}
impl<K, V, H, E> ShardedLruTtlCacheBuilder<K, V, H, E> {
#[doc(alias = "size")]
#[doc(alias = "capacity")]
#[must_use]
pub fn max_size(mut self, max_size: usize) -> Self {
self.max_size = Some(max_size);
self
}
#[must_use]
pub fn per_shard_max_size(mut self, per_shard_max_size: usize) -> Self {
self.per_shard_max_size = Some(per_shard_max_size);
self
}
#[must_use]
pub fn ttl(mut self, ttl: Duration) -> Self {
self.ttl = Some(ttl);
self
}
#[must_use]
pub fn shards(mut self, shards: usize) -> Self {
self.shards = Some(shards);
self
}
#[must_use]
pub fn refresh_on_hit(mut self, refresh: bool) -> Self {
self.refresh = refresh;
self
}
#[must_use]
pub fn refresh(self, refresh: bool) -> Self {
self.refresh_on_hit(refresh)
}
#[must_use]
pub fn hasher<H2: ShardHasher<K>>(self, hasher: H2) -> ShardedLruTtlCacheBuilder<K, V, H2, E> {
ShardedLruTtlCacheBuilder {
shards: self.shards,
max_size: self.max_size,
per_shard_max_size: self.per_shard_max_size,
ttl: self.ttl,
refresh: self.refresh,
hasher: Some(hasher),
on_evict: self.on_evict,
_evict: PhantomData,
}
}
fn resolve_per_shard_cap(&self, n_shards: usize) -> Result<usize, BuildError> {
match (self.max_size, self.per_shard_max_size) {
(Some(_), Some(_)) => Err(BuildError::InvalidValue {
field: "max_size / per_shard_max_size",
reason: "`max_size` and `per_shard_max_size` are mutually exclusive",
}),
(None, None) => Err(BuildError::MissingRequired("max_size")),
(Some(total), None) => {
if total == 0 {
return Err(BuildError::InvalidValue {
field: "max_size",
reason: "must be greater than zero",
});
}
let mut cap = total.div_ceil(n_shards);
if n_shards > 1 {
cap = std::cmp::max(cap, 16);
}
Ok(cap)
}
(None, Some(per)) => {
if per == 0 {
return Err(BuildError::InvalidValue {
field: "per_shard_max_size",
reason: "must be greater than zero",
});
}
Ok(per)
}
}
}
fn total_capacity(&self, n_shards: usize, per_shard_cap: usize) -> Result<usize, BuildError> {
let field = if self.per_shard_max_size.is_some() {
"per_shard_max_size"
} else {
"max_size"
};
n_shards
.checked_mul(per_shard_cap)
.ok_or(BuildError::InvalidValue {
field,
reason: "effective sharded capacity overflows usize",
})
}
fn validated_parts(&self) -> Result<(Duration, usize, usize, usize), BuildError> {
let ttl = self.ttl.ok_or(BuildError::MissingRequired("ttl"))?;
crate::stores::validate_ttl(ttl)?;
let n = checked_shard_count(self.shards)?;
let mask = n - 1;
let per_shard_cap = self.resolve_per_shard_cap(n)?;
let total_cap = self.total_capacity(n, per_shard_cap)?;
Ok((ttl, mask, per_shard_cap, total_cap))
}
}
impl<K, V, H> ShardedLruTtlCacheBuilder<K, V, H, NoEvict> {
#[must_use]
pub fn on_evict(
self,
on_evict: impl Fn(&K, &V) + Send + Sync + 'static,
) -> ShardedLruTtlCacheBuilder<K, V, H, HasEvict> {
ShardedLruTtlCacheBuilder {
shards: self.shards,
max_size: self.max_size,
per_shard_max_size: self.per_shard_max_size,
ttl: self.ttl,
refresh: self.refresh,
hasher: self.hasher,
on_evict: Some(Arc::new(on_evict)),
_evict: PhantomData,
}
}
pub fn build(self) -> Result<ShardedLruTtlCacheBase<K, V, H>, BuildError>
where
K: Hash + Eq + Clone,
H: ShardHasher<K>,
{
let (ttl, mask, per_shard_cap, total_cap) = self.validated_parts()?;
let n = mask + 1;
let shards = (0..n)
.map(|_| {
let mut lru: LruCache<K, TimedEntry<V>> =
LruCache::builder().max_size(per_shard_cap).build()?;
lru.disable_hit_miss_tracking();
Ok(CachePadded(Shard::new(lru)))
})
.collect::<Result<Vec<_>, BuildError>>()?
.into_boxed_slice();
Ok(ShardedLruTtlCacheBase {
inner: Arc::new(LruTtlInner {
shards,
shard_mask: mask,
hasher: self
.hasher
.expect("hasher is always initialized via Default or .hasher()"),
on_evict: None,
ttl_nanos: AtomicU64::new(ttl.as_nanos().min(u64::MAX as u128) as u64),
refresh: AtomicBool::new(self.refresh),
non_capacity_evictions: AtomicU64::new(0),
total_capacity: total_cap,
}),
})
}
#[must_use]
pub fn copy_from<H2: ShardHasher<K>>(
self,
existing: &ShardedLruTtlCacheBase<K, V, H2>,
) -> ShardedLruTtlCacheBase<K, V, H>
where
K: Clone + Hash + Eq,
V: Clone,
H: ShardHasher<K>,
{
copy_from_lru_ttl(
self.build()
.unwrap_or_else(|e| panic!("ShardedLruTtlCache build failed: {e}")),
existing,
)
}
}
impl<K, V, H> ShardedLruTtlCacheBuilder<K, V, H, HasEvict> {
pub fn build(self) -> Result<ShardedLruTtlCacheBase<K, V, H>, BuildError>
where
K: Hash + Eq + Clone + 'static,
V: 'static,
H: ShardHasher<K>,
{
let (ttl, mask, per_shard_cap, total_cap) = self.validated_parts()?;
let n = mask + 1;
#[allow(clippy::type_complexity)]
let lru_on_evict: Option<Arc<dyn Fn(&K, &TimedEntry<V>) + Send + Sync>> =
self.on_evict.as_ref().map(|cb| {
let cb = Arc::clone(cb);
let f: Arc<dyn Fn(&K, &TimedEntry<V>) + Send + Sync> =
Arc::new(move |k: &K, entry: &TimedEntry<V>| cb(k, &entry.value));
f
});
let shards = (0..n)
.map(|_| {
let mut lru: LruCache<K, TimedEntry<V>> =
LruCache::builder().max_size(per_shard_cap).build()?;
lru.on_evict = lru_on_evict.clone();
lru.disable_hit_miss_tracking();
Ok(CachePadded(Shard::new(lru)))
})
.collect::<Result<Vec<_>, BuildError>>()?
.into_boxed_slice();
Ok(ShardedLruTtlCacheBase {
inner: Arc::new(LruTtlInner {
shards,
shard_mask: mask,
hasher: self
.hasher
.expect("hasher is always initialized via Default or .hasher()"),
on_evict: self.on_evict,
ttl_nanos: AtomicU64::new(ttl.as_nanos().min(u64::MAX as u128) as u64),
refresh: AtomicBool::new(self.refresh),
non_capacity_evictions: AtomicU64::new(0),
total_capacity: total_cap,
}),
})
}
#[must_use]
pub fn copy_from<H2: ShardHasher<K>>(
self,
existing: &ShardedLruTtlCacheBase<K, V, H2>,
) -> ShardedLruTtlCacheBase<K, V, H>
where
K: Clone + Hash + Eq + 'static,
V: Clone + 'static,
H: ShardHasher<K>,
{
copy_from_lru_ttl(
self.build()
.unwrap_or_else(|e| panic!("ShardedLruTtlCache build failed: {e}")),
existing,
)
}
}
fn copy_from_lru_ttl<K, V, H, H2>(
new_cache: ShardedLruTtlCacheBase<K, V, H>,
existing: &ShardedLruTtlCacheBase<K, V, H2>,
) -> ShardedLruTtlCacheBase<K, V, H>
where
K: Clone + Hash + Eq,
V: Clone,
H: ShardHasher<K>,
H2: ShardHasher<K>,
{
let existing_ttl = existing.ttl_duration();
for shard in existing.inner.shards.iter() {
let entries: Vec<(K, TimedEntry<V>)> = {
let guard = shard.lock.read();
guard.iter_order()
};
for (k, entry) in entries.into_iter().rev() {
if let Some(ttl) = existing_ttl {
if entry.instant.elapsed() >= ttl {
continue;
}
}
let new_shard = new_cache.shard_of(&k);
new_shard.lock.write().cache_set(k, entry);
}
}
new_cache
}
impl<K, V, H> ConcurrentCloneCached<K, V> for ShardedLruTtlCacheBase<K, V, H>
where
K: Hash + Eq + Clone,
V: Clone,
H: ShardHasher<K>,
{
fn cache_get_with_expiry_status(&self, k: &K) -> (Option<V>, bool) {
let shard = self.shard_of(k);
let ttl = self.ttl_duration();
let refresh = self.inner.refresh.load(Ordering::Relaxed);
let mut guard = shard.lock.write();
let live = if refresh {
guard
.get_mut_if(k, |e| ttl.is_none_or(|t| e.instant.elapsed() < t))
.map(|e| {
e.instant = Instant::now();
e.value.clone()
})
} else {
guard
.get_if(k, |e| ttl.is_none_or(|t| e.instant.elapsed() < t))
.map(|e| e.value.clone())
};
if let Some(value) = live {
drop(guard);
shard.hits.fetch_add(1, Ordering::Relaxed);
return (Some(value), false);
}
let stale = guard.cache_peek(k).map(|e| e.value.clone());
drop(guard);
shard.misses.fetch_add(1, Ordering::Relaxed);
match stale {
Some(v) => (Some(v), true),
None => (None, false),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ConcurrentCached as SyncConcurrentCached;
use crate::ConcurrentCloneCached;
#[test]
fn basic_get_set_remove() {
let c = ShardedLruTtlCache::<u32, u32>::builder()
.max_size(64)
.ttl(Duration::from_secs(60))
.build()
.unwrap();
assert_eq!(
SyncConcurrentCached::cache_get(&c, &1).expect("cache_get must succeed"),
None
);
assert_eq!(
SyncConcurrentCached::cache_set(&c, 1, 100).expect("insert must succeed"),
None
);
assert_eq!(
SyncConcurrentCached::cache_get(&c, &1).expect("key was just inserted"),
Some(100)
);
assert_eq!(
SyncConcurrentCached::cache_remove(&c, &1).expect("key must be present"),
Some(100)
);
assert_eq!(
SyncConcurrentCached::cache_get(&c, &1).expect("cache_get must succeed"),
None
);
}
#[test]
fn cache_remove_fires_on_evict_and_increments_metrics() {
use std::sync::atomic::{AtomicUsize, Ordering};
let count = Arc::new(AtomicUsize::new(0));
let count2 = count.clone();
let c = ShardedLruTtlCacheBase::<u32, u32>::builder()
.max_size(64)
.ttl(Duration::from_secs(60))
.shards(1)
.on_evict(move |_, _| {
count2.fetch_add(1, Ordering::Relaxed);
})
.build()
.unwrap();
SyncConcurrentCached::cache_set(&c, 1, 10).expect("insert must succeed");
let before = c
.metrics()
.evictions
.expect("eviction-tracking stores report an evictions count");
assert_eq!(
SyncConcurrentCached::cache_remove(&c, &1).expect("key must be present"),
Some(10)
);
assert_eq!(
SyncConcurrentCached::cache_remove(&c, &999).expect("cache_remove must succeed"),
None
);
let after = c
.metrics()
.evictions
.expect("eviction-tracking stores report an evictions count");
assert_eq!(count.load(Ordering::Relaxed), 1);
assert_eq!(after - before, 1);
}
#[test]
fn clone_shares_state() {
let c1 = ShardedLruTtlCache::<u32, u32>::builder()
.max_size(64)
.ttl(Duration::from_secs(60))
.build()
.unwrap();
let c2 = c1.clone();
SyncConcurrentCached::cache_set(&c1, 1, 10).expect("insert must succeed");
assert_eq!(
SyncConcurrentCached::cache_get(&c2, &1).expect("key was just inserted"),
Some(10)
);
}
#[test]
fn ttl_expiry() {
let c = ShardedLruTtlCache::<u32, u32>::builder()
.max_size(64)
.ttl(Duration::from_millis(50))
.build()
.unwrap();
SyncConcurrentCached::cache_set(&c, 1, 100).expect("insert must succeed");
assert_eq!(
SyncConcurrentCached::cache_get(&c, &1).expect("key was just inserted"),
Some(100)
);
std::thread::sleep(std::time::Duration::from_millis(100));
assert_eq!(
SyncConcurrentCached::cache_get(&c, &1).expect("cache_get must succeed"),
None
);
}
#[test]
fn lru_eviction_fires() {
use std::sync::atomic::{AtomicUsize, Ordering as AO};
let count = std::sync::Arc::new(AtomicUsize::new(0));
let count2 = count.clone();
let c = ShardedLruTtlCacheBase::<u32, u32>::builder()
.max_size(8)
.shards(1)
.ttl(Duration::from_secs(60))
.on_evict(move |_, _| {
count2.fetch_add(1, AO::Relaxed);
})
.build()
.unwrap();
for i in 0..16u32 {
SyncConcurrentCached::cache_set(&c, i, i).expect("insert must succeed");
}
assert!(
count.load(AO::Relaxed) > 0,
"LRU eviction should have fired"
);
}
#[test]
fn per_shard_max_size_and_size_exclusive() {
let err = ShardedLruTtlCacheBase::<u32, u32>::builder()
.max_size(100)
.per_shard_max_size(10)
.ttl(Duration::from_secs(60))
.build();
assert!(err.is_err());
}
#[test]
fn build_rejects_overflowing_shards_and_capacity() {
let err = ShardedLruTtlCacheBase::<u32, u32>::builder()
.max_size(1)
.ttl(Duration::from_secs(60))
.shards(usize::MAX)
.build();
assert!(matches!(
err,
Err(BuildError::InvalidValue {
field: "shards",
..
})
));
let err = ShardedLruTtlCacheBase::<u32, u32>::builder()
.per_shard_max_size(usize::MAX)
.ttl(Duration::from_secs(60))
.shards(2)
.build();
assert!(matches!(
err,
Err(BuildError::InvalidValue {
field: "per_shard_max_size",
..
})
));
}
#[test]
fn builder_without_on_evict_does_not_require_static_keys_or_values() {
let key = String::from("key");
let value = String::from("value");
let cache: ShardedLruTtlCacheBase<&str, &str> = ShardedLruTtlCache::builder()
.max_size(8)
.ttl(Duration::from_secs(60))
.build()
.expect("valid builder config");
SyncConcurrentCached::cache_set(&cache, key.as_str(), value.as_str())
.expect("insert must succeed");
assert_eq!(
SyncConcurrentCached::cache_get(&cache, &key.as_str()).expect("key was just inserted"),
Some(value.as_str())
);
}
#[test]
fn set_ttl_inherent() {
let c = ShardedLruTtlCache::<u32, u32>::builder()
.max_size(64)
.ttl(Duration::from_secs(60))
.build()
.unwrap();
let prev = c.set_ttl(Duration::from_secs(30));
assert_eq!(prev, Some(Duration::from_secs(60)));
assert_eq!(c.ttl(), Some(Duration::from_secs(30)));
}
#[test]
fn copy_from_skips_expired() {
let old = ShardedLruTtlCache::<u32, u32>::builder()
.max_size(64)
.ttl(Duration::from_millis(50))
.build()
.unwrap();
for i in 0..10u32 {
SyncConcurrentCached::cache_set(&old, i, i).expect("insert must succeed");
}
std::thread::sleep(std::time::Duration::from_millis(100));
let new_cache = ShardedLruTtlCacheBase::<u32, u32>::builder()
.max_size(64)
.ttl(Duration::from_secs(60))
.copy_from(&old);
assert_eq!(new_cache.len(), 0);
}
#[test]
fn copy_from_preserves_live_entries() {
let old = ShardedLruTtlCacheBase::<u32, u32>::builder()
.max_size(1024)
.shards(1)
.ttl(Duration::from_secs(60))
.build()
.unwrap();
for i in 0..20u32 {
SyncConcurrentCached::cache_set(&old, i, i * 10).expect("insert must succeed");
}
let new_cache = ShardedLruTtlCacheBase::<u32, u32>::builder()
.max_size(1024)
.shards(4)
.ttl(Duration::from_secs(60))
.copy_from(&old);
for i in 0..20u32 {
assert_eq!(
SyncConcurrentCached::cache_get(&new_cache, &i).expect("key was just inserted"),
Some(i * 10)
);
}
}
#[test]
fn copy_from_respects_capacity() {
let old = ShardedLruTtlCacheBase::<u32, u32>::builder()
.max_size(64)
.shards(1)
.ttl(Duration::from_secs(60))
.build()
.unwrap();
for i in 0..32u32 {
SyncConcurrentCached::cache_set(&old, i, i).expect("insert must succeed");
}
let new_cache = ShardedLruTtlCacheBase::<u32, u32>::builder()
.max_size(16)
.shards(1)
.ttl(Duration::from_secs(60))
.copy_from(&old);
assert!(new_cache.len() <= 16);
}
#[test]
fn build_reports_invalid_config() {
let err = ShardedLruTtlCache::<u32, u32>::builder()
.max_size(0)
.ttl(Duration::from_secs(60))
.build();
assert!(matches!(
err,
Err(BuildError::InvalidValue {
field: "max_size",
..
})
));
let err = ShardedLruTtlCache::<u32, u32>::builder()
.max_size(1)
.ttl(Duration::from_secs(60))
.shards(0)
.build();
assert!(matches!(
err,
Err(BuildError::InvalidValue {
field: "shards",
..
})
));
let err = ShardedLruTtlCache::<u32, u32>::builder()
.max_size(1)
.ttl(Duration::from_nanos(0))
.build();
assert!(matches!(err, Err(BuildError::InvalidTtl { .. })));
}
#[test]
fn send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ShardedLruTtlCache<u32, u32>>();
}
#[test]
fn build_rejects_zero_ttl() {
let err = ShardedLruTtlCacheBase::<u32, u32>::builder()
.max_size(8)
.ttl(Duration::from_nanos(0))
.build();
assert!(
matches!(err, Err(crate::stores::BuildError::InvalidTtl { .. })),
"expected InvalidTtl, got {err:?}",
);
}
#[test]
fn cache_clear_with_on_evict_fires_for_all_entries() {
use std::sync::atomic::{AtomicU64, Ordering};
let count = Arc::new(AtomicU64::new(0));
let count2 = count.clone();
let c = ShardedLruTtlCacheBase::<u32, u32>::builder()
.shards(1)
.max_size(64)
.ttl(Duration::from_secs(3600))
.on_evict(move |_, _| {
count2.fetch_add(1, Ordering::Relaxed);
})
.build()
.unwrap();
for i in 0..20u32 {
SyncConcurrentCached::cache_set(&c, i, i).expect("insert must succeed");
}
let before = c
.metrics()
.evictions
.expect("eviction-tracking stores report an evictions count");
c.cache_clear_with_on_evict();
assert_eq!(
c.len(),
0,
"cache must be empty after cache_clear_with_on_evict"
);
assert_eq!(
count.load(Ordering::Relaxed),
20,
"on_evict must fire for every entry"
);
assert_eq!(
c.metrics()
.evictions
.expect("eviction-tracking stores report an evictions count")
- before,
20,
"evictions counter must increment for each entry"
);
}
#[test]
fn clear_does_not_fire_on_evict() {
use std::sync::atomic::{AtomicU64, Ordering};
let count = Arc::new(AtomicU64::new(0));
let count2 = count.clone();
let c = ShardedLruTtlCacheBase::<u32, u32>::builder()
.max_size(64)
.ttl(Duration::from_secs(3600))
.on_evict(move |_, _| {
count2.fetch_add(1, Ordering::Relaxed);
})
.build()
.unwrap();
for i in 0..10u32 {
SyncConcurrentCached::cache_set(&c, i, i).expect("insert must succeed");
}
c.clear();
assert_eq!(
count.load(Ordering::Relaxed),
0,
"clear must not fire on_evict"
);
}
#[test]
fn cache_remove_entry_returns_some_for_live_entry() {
let c = ShardedLruTtlCache::<u32, u32>::builder()
.max_size(64)
.ttl(Duration::from_secs(60))
.build()
.unwrap();
SyncConcurrentCached::cache_set(&c, 1u32, 100u32).expect("insert must succeed");
assert_eq!(
SyncConcurrentCached::cache_remove_entry(&c, &999u32)
.expect("cache_remove_entry must succeed"),
None
);
assert_eq!(
SyncConcurrentCached::cache_remove_entry(&c, &1u32).expect("key must be present"),
Some((1u32, 100u32))
);
assert_eq!(
SyncConcurrentCached::cache_get(&c, &1u32).expect("cache_get must succeed"),
None
);
}
#[test]
fn cache_remove_entry_returns_some_for_expired_entry() {
let c = ShardedLruTtlCache::<u32, u32>::builder()
.max_size(64)
.ttl(Duration::from_millis(50))
.build()
.unwrap();
SyncConcurrentCached::cache_set(&c, 1u32, 100u32).expect("insert must succeed");
SyncConcurrentCached::cache_set(&c, 2u32, 200u32).expect("insert must succeed");
std::thread::sleep(std::time::Duration::from_millis(100));
assert_eq!(
SyncConcurrentCached::cache_remove(&c, &1u32).expect("cache_remove must succeed"),
None
);
let removed =
SyncConcurrentCached::cache_remove_entry(&c, &2u32).expect("key must be present");
assert!(
removed.is_some(),
"cache_remove_entry must return Some for expired entry"
);
assert_eq!(removed.expect("must be Some"), (2u32, 200u32));
}
#[test]
fn cache_delete_returns_true_for_expired_entry() {
let c = ShardedLruTtlCache::<u32, u32>::builder()
.max_size(64)
.ttl(Duration::from_millis(50))
.build()
.unwrap();
SyncConcurrentCached::cache_set(&c, 1u32, 100u32).expect("insert must succeed");
std::thread::sleep(std::time::Duration::from_millis(100));
assert!(
SyncConcurrentCached::cache_delete(&c, &1u32).expect("cache_delete must succeed"),
"cache_delete must be true for expired entry"
);
assert!(!SyncConcurrentCached::cache_delete(&c, &1u32).expect("cache_delete must succeed"));
}
#[test]
fn cache_remove_entry_fires_on_evict_for_expired() {
use std::sync::atomic::{AtomicU64, Ordering};
let count = Arc::new(AtomicU64::new(0));
let count2 = count.clone();
let c = ShardedLruTtlCacheBase::<u32, u32>::builder()
.max_size(64)
.ttl(Duration::from_millis(50))
.shards(1)
.on_evict(move |_, _| {
count2.fetch_add(1, Ordering::Relaxed);
})
.build()
.unwrap();
SyncConcurrentCached::cache_set(&c, 1u32, 10u32).expect("insert must succeed");
std::thread::sleep(std::time::Duration::from_millis(100));
SyncConcurrentCached::cache_remove_entry(&c, &1u32).expect("key must be present");
assert_eq!(
count.load(Ordering::Relaxed),
1,
"on_evict fires for expired entries"
);
SyncConcurrentCached::cache_remove_entry(&c, &999u32)
.expect("cache_remove_entry must succeed");
assert_eq!(count.load(Ordering::Relaxed), 1, "no fire for absent key");
}
#[test]
fn cache_remove_entry_increments_eviction_counter() {
let c = ShardedLruTtlCacheBase::<u32, u32>::builder()
.max_size(64)
.ttl(Duration::from_millis(10))
.shards(1)
.build()
.unwrap();
SyncConcurrentCached::cache_set(&c, 1u32, 10u32).expect("insert must succeed");
std::thread::sleep(std::time::Duration::from_millis(100));
let before = c.metrics().evictions.expect("evictions are always tracked");
SyncConcurrentCached::cache_remove_entry(&c, &1u32).expect("key must be present"); SyncConcurrentCached::cache_remove_entry(&c, &999u32)
.expect("cache_remove_entry must succeed"); assert_eq!(
c.metrics().evictions.expect("evictions are always tracked") - before,
1,
"cache_remove_entry must increment evictions for present key only"
);
}
#[test]
fn concurrent_clone_cached_absent_is_none_false() {
let c = ShardedLruTtlCache::<u32, u32>::builder()
.max_size(64)
.ttl(Duration::from_secs(60))
.build()
.unwrap();
assert_eq!(
ConcurrentCloneCached::cache_get_with_expiry_status(&c, &1u32),
(None, false),
"absent key must return (None, false)"
);
assert_eq!(
c.metrics().misses,
Some(1),
"absent lookup must increment misses"
);
}
#[test]
fn concurrent_clone_cached_live_entry_is_some_false() {
let c = ShardedLruTtlCache::<u32, u32>::builder()
.max_size(64)
.ttl(Duration::from_secs(60))
.build()
.unwrap();
SyncConcurrentCached::cache_set(&c, 1u32, 42u32).expect("insert must succeed");
assert_eq!(
ConcurrentCloneCached::cache_get_with_expiry_status(&c, &1u32),
(Some(42), false),
"live entry must return (Some(v), false)"
);
assert_eq!(c.metrics().hits, Some(1), "live lookup must increment hits");
assert_eq!(
c.metrics().evictions,
Some(0),
"live lookup must not increment evictions"
);
}
#[test]
fn concurrent_clone_cached_expired_returns_stale_no_eviction() {
let c = ShardedLruTtlCacheBase::<u32, u32>::builder()
.max_size(64)
.ttl(Duration::from_millis(50))
.shards(1)
.build()
.unwrap();
SyncConcurrentCached::cache_set(&c, 1u32, 99u32).expect("insert must succeed");
std::thread::sleep(std::time::Duration::from_millis(100));
let (val, expired) = ConcurrentCloneCached::cache_get_with_expiry_status(&c, &1u32);
assert_eq!(val, Some(99), "expired entry must return the stale value");
assert!(expired, "expired entry must set the expired flag");
assert_eq!(
c.metrics().misses,
Some(1),
"expired lookup must increment misses"
);
assert_eq!(
c.metrics().evictions,
Some(0),
"expired lookup must NOT increment evictions"
);
let (val2, expired2) = ConcurrentCloneCached::cache_get_with_expiry_status(&c, &1u32);
assert_eq!(
val2,
Some(99),
"entry must still be present after expiry-status lookup"
);
assert!(
expired2,
"entry must still be expired on second expiry-status call"
);
}
#[test]
fn concurrent_clone_cached_live_lookup_promotes_lru() {
let c = ShardedLruTtlCacheBase::<u32, u32>::builder()
.max_size(2)
.ttl(Duration::from_secs(60))
.shards(1)
.build()
.unwrap();
SyncConcurrentCached::cache_set(&c, 1u32, 10u32).expect("insert must succeed");
SyncConcurrentCached::cache_set(&c, 2u32, 20u32).expect("insert must succeed");
assert_eq!(
ConcurrentCloneCached::cache_get_with_expiry_status(&c, &1u32),
(Some(10), false),
"live lookup must return the value"
);
SyncConcurrentCached::cache_set(&c, 3u32, 30u32).expect("insert must succeed");
assert_eq!(
SyncConcurrentCached::cache_get(&c, &1u32).expect("cache_get must succeed"),
Some(10),
"key 1 must survive eviction because the live expiry-status lookup promoted it"
);
assert_eq!(
SyncConcurrentCached::cache_get(&c, &2u32).expect("cache_get must succeed"),
None,
"key 2 must be evicted as the least-recently-used entry"
);
}
}