use crate::{Cached, CachedIter, CachedPeek, CachedRead};
use std::cmp::Eq;
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::hash::Hash;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
#[cfg(feature = "ahash")]
pub type DefaultHashBuilder = ahash::RandomState;
#[cfg(not(feature = "ahash"))]
pub type DefaultHashBuilder = std::collections::hash_map::RandomState;
#[inline]
pub(super) fn new_default_hash_builder() -> DefaultHashBuilder {
#[cfg(feature = "ahash")]
{
ahash::RandomState::new()
}
#[cfg(not(feature = "ahash"))]
{
std::collections::hash_map::RandomState::new()
}
}
const STRIPE_COUNT: usize = 16;
#[repr(align(128))]
struct Slot(AtomicU64);
pub(super) struct StripedCounter {
slots: Box<[Slot]>,
}
impl StripedCounter {
pub(super) fn new() -> Self {
let slots = (0..STRIPE_COUNT)
.map(|_| Slot(AtomicU64::new(0)))
.collect::<Vec<_>>()
.into_boxed_slice();
Self { slots }
}
#[inline]
pub(super) fn increment(&self) {
self.slots[thread_stripe()]
.0
.fetch_add(1, Ordering::Relaxed);
}
pub(super) fn load(&self) -> u64 {
self.slots.iter().map(|s| s.0.load(Ordering::Relaxed)).sum()
}
pub(super) fn reset(&self) {
for slot in self.slots.iter() {
slot.0.store(0, Ordering::Relaxed);
}
}
pub(super) fn snapshot(&self) -> Self {
let total = self.load();
let new = Self::new();
new.slots[0].0.store(total, Ordering::Relaxed);
new
}
}
#[inline]
fn thread_stripe() -> usize {
thread_local! {
static SLOT: usize = {
static NEXT: AtomicUsize = AtomicUsize::new(0);
NEXT.fetch_add(1, Ordering::Relaxed) % STRIPE_COUNT
};
}
SLOT.with(|&s| s)
}
#[cfg(feature = "async_core")]
use {super::CachedAsync, std::future::Future};
mod expiring;
mod expiring_lru;
mod lru;
#[cfg(feature = "time_stores")]
mod lru_ttl;
#[cfg(feature = "redb_store")]
mod redb;
#[cfg(feature = "redis_store")]
mod redis;
pub mod sharded;
#[cfg(feature = "time_stores")]
mod ttl;
#[cfg(feature = "time_stores")]
mod ttl_sorted;
mod unbound;
#[cfg(any(
feature = "time_stores",
feature = "redb_store",
feature = "redis_store"
))]
use crate::time::Duration;
#[cfg(feature = "time_stores")]
use crate::time::Instant;
pub(super) type OnEvict<K, V> = std::sync::Arc<dyn Fn(&K, &V) + Send + Sync>;
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BuildError {
MissingRequired(&'static str),
InvalidValue {
field: &'static str,
reason: &'static str,
},
}
impl std::fmt::Display for BuildError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BuildError::MissingRequired(field) => write!(f, "required field `{field}` was not set"),
BuildError::InvalidValue { field, reason } => {
write!(f, "invalid value for field `{field}`: {reason}")
}
}
}
}
impl std::error::Error for BuildError {}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SetMaxSizeError {
ZeroSize,
}
impl std::fmt::Display for SetMaxSizeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SetMaxSizeError::ZeroSize => write!(f, "max_size must be greater than zero"),
}
}
}
impl std::error::Error for SetMaxSizeError {}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SetTtlError {
ZeroTtl,
}
impl std::fmt::Display for SetTtlError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SetTtlError::ZeroTtl => write!(f, "ttl must be greater than zero"),
}
}
}
impl std::error::Error for SetTtlError {}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CacheSetError {
TimeBounds,
}
impl std::fmt::Display for CacheSetError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CacheSetError::TimeBounds => f.write_str("ttl is outside Instant bounds"),
}
}
}
impl std::error::Error for CacheSetError {}
#[cfg(any(
feature = "time_stores",
feature = "redb_store",
feature = "redis_store"
))]
pub(crate) fn validate_ttl(ttl: Duration) -> Result<(), BuildError> {
if ttl.is_zero() {
Err(BuildError::InvalidValue {
field: "ttl",
reason: "must be greater than zero",
})
} else {
Ok(())
}
}
#[cfg(feature = "time_stores")]
#[derive(Debug)]
pub(crate) struct TimedEntry<V> {
pub(crate) expires_at: Option<Instant>,
pub(crate) value: V,
}
#[cfg(feature = "time_stores")]
impl<V: Clone> Clone for TimedEntry<V> {
fn clone(&self) -> Self {
Self {
expires_at: self.expires_at,
value: self.value.clone(),
}
}
}
#[cfg(feature = "redb_store")]
#[cfg_attr(docsrs, doc(cfg(feature = "redb_store")))]
pub use crate::stores::redb::{RedbCache, RedbCacheBuildError, RedbCacheBuilder, RedbCacheError};
#[cfg(feature = "redis_store")]
#[cfg_attr(docsrs, doc(cfg(feature = "redis_store")))]
pub use crate::stores::redis::{
ConnectionString, RedisCache, RedisCacheBuildError, RedisCacheBuilder, RedisCacheError,
};
pub use expiring::{ExpiringCache, ExpiringCacheBuilder};
pub use expiring_lru::{Expires, ExpiringLruCache, ExpiringLruCacheBuilder};
pub use lru::{LruCache, LruCacheBuilder};
#[cfg(feature = "time_stores")]
#[cfg_attr(docsrs, doc(cfg(feature = "time_stores")))]
pub use lru_ttl::{HasEvict, LruTtlCache, LruTtlCacheBuilder, NoEvict};
#[cfg(feature = "time_stores")]
#[cfg_attr(docsrs, doc(cfg(feature = "time_stores")))]
pub use ttl::{TtlCache, TtlCacheBuilder};
#[cfg(feature = "time_stores")]
#[cfg_attr(docsrs, doc(cfg(feature = "time_stores")))]
pub use ttl_sorted::{TtlSortedCache, TtlSortedCacheBuilder};
pub use unbound::{UnboundCache, UnboundCacheBuilder};
pub use sharded::{
DefaultShardHasher, ShardHasher, ShardedExpiringCache, ShardedExpiringCacheBase,
ShardedExpiringCacheBuilder, ShardedExpiringLruCache, ShardedExpiringLruCacheBase,
ShardedExpiringLruCacheBuilder, ShardedLruCache, ShardedLruCacheBase, ShardedLruCacheBuilder,
ShardedUnboundCache, ShardedUnboundCacheBase, ShardedUnboundCacheBuilder,
};
#[cfg(feature = "time_stores")]
#[cfg_attr(docsrs, doc(cfg(feature = "time_stores")))]
pub use sharded::{
ShardedLruTtlCache, ShardedLruTtlCacheBase, ShardedLruTtlCacheBuilder, ShardedTtlCache,
ShardedTtlCacheBase, ShardedTtlCacheBuilder,
};
#[cfg(all(
feature = "async_core",
feature = "redis_store",
any(
feature = "redis_smol",
feature = "redis_smol_native_tls",
feature = "redis_smol_rustls",
feature = "redis_tokio",
feature = "redis_tokio_native_tls",
feature = "redis_tokio_rustls",
feature = "redis_async_cache",
feature = "redis_connection_manager"
)
))]
#[cfg_attr(
docsrs,
doc(cfg(all(
feature = "async_core",
feature = "redis_store",
any(
feature = "redis_smol",
feature = "redis_smol_native_tls",
feature = "redis_smol_rustls",
feature = "redis_tokio",
feature = "redis_tokio_native_tls",
feature = "redis_tokio_rustls",
feature = "redis_async_cache",
feature = "redis_connection_manager"
)
)))
)]
pub use crate::stores::redis::{AsyncRedisCache, AsyncRedisCacheBuilder};
impl<K, V, S> Cached<K, V> for HashMap<K, V, S>
where
K: Hash + Eq,
S: std::hash::BuildHasher + Default,
{
type Error = std::convert::Infallible;
fn cache_get<Q>(&mut self, k: &Q) -> Option<&V>
where
K: std::borrow::Borrow<Q>,
Q: std::hash::Hash + Eq + ?Sized,
{
HashMap::get(self, k)
}
fn cache_get_mut<Q>(&mut self, k: &Q) -> Option<&mut V>
where
K: std::borrow::Borrow<Q>,
Q: std::hash::Hash + Eq + ?Sized,
{
HashMap::get_mut(self, k)
}
fn cache_set(&mut self, k: K, v: V) -> Option<V> {
HashMap::insert(self, k, v)
}
fn cache_get_or_set_with_mut<F: FnOnce() -> V>(&mut self, key: K, f: F) -> &mut V {
self.entry(key).or_insert_with(f)
}
fn cache_try_get_or_set_with_mut<F: FnOnce() -> Result<V, E>, E>(
&mut self,
key: K,
f: F,
) -> Result<&mut V, E> {
let v = match self.entry(key) {
Entry::Occupied(occupied) => occupied.into_mut(),
Entry::Vacant(vacant) => vacant.insert(f()?),
};
Ok(v)
}
fn cache_remove<Q>(&mut self, k: &Q) -> Option<V>
where
K: std::borrow::Borrow<Q>,
Q: std::hash::Hash + Eq + ?Sized,
{
HashMap::remove(self, k)
}
fn cache_remove_entry<Q>(&mut self, k: &Q) -> Option<(K, V)>
where
K: std::borrow::Borrow<Q>,
Q: std::hash::Hash + Eq + ?Sized,
{
HashMap::remove_entry(self, k)
}
fn cache_clear(&mut self) {
HashMap::clear(self);
}
fn cache_reset(&mut self) {
*self = HashMap::default();
self.cache_reset_metrics();
}
fn cache_size(&self) -> usize {
HashMap::len(self)
}
}
impl<K, V, S> CachedIter<K, V> for HashMap<K, V, S>
where
K: Hash + Eq,
S: std::hash::BuildHasher,
{
fn iter<'a>(&'a self) -> impl Iterator<Item = (&'a K, &'a V)> + 'a
where
K: 'a,
V: 'a,
{
self.iter()
}
}
impl<K, V, S> CachedPeek<K, V> for HashMap<K, V, S>
where
K: Hash + Eq,
S: std::hash::BuildHasher,
{
fn cache_peek<Q>(&self, k: &Q) -> Option<&V>
where
K: std::borrow::Borrow<Q>,
Q: std::hash::Hash + Eq + ?Sized,
{
HashMap::get(self, k)
}
}
impl<K, V, S> CachedRead<K, V> for HashMap<K, V, S>
where
K: Hash + Eq,
S: std::hash::BuildHasher,
{
}
#[cfg(feature = "async_core")]
impl<K, V, S> CachedAsync<K, V> for HashMap<K, V, S>
where
K: Hash + Eq + Clone + Send,
S: std::hash::BuildHasher + Send,
{
fn async_cache_get_or_set_with_mut<'a, F, Fut>(
&'a mut self,
k: K,
f: F,
) -> impl Future<Output = &'a mut V> + Send + 'a
where
K: 'a,
V: Send + 'a,
F: FnOnce() -> Fut + Send + 'a,
Fut: Future<Output = V> + Send + 'a,
{
async move {
match self.entry(k) {
Entry::Occupied(o) => o.into_mut(),
Entry::Vacant(v) => v.insert(f().await),
}
}
}
fn async_cache_try_get_or_set_with_mut<'a, F, Fut, E>(
&'a mut self,
k: K,
f: F,
) -> impl Future<Output = Result<&'a mut V, E>> + Send + 'a
where
K: 'a,
V: Send + 'a,
E: 'a,
F: FnOnce() -> Fut + Send + 'a,
Fut: Future<Output = Result<V, E>> + Send + 'a,
{
async move {
let v = match self.entry(k) {
Entry::Occupied(o) => o.into_mut(),
Entry::Vacant(v) => v.insert(f().await?),
};
Ok(v)
}
}
}
pub trait CacheEvict {
#[must_use]
fn evict(&mut self) -> usize;
}
pub trait ConcurrentCacheEvict {
#[must_use]
fn evict(&self) -> usize;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hashmap() {
let mut c = std::collections::HashMap::new();
assert!(c.cache_get(&1).is_none());
assert_eq!(c.cache_misses(), None);
assert_eq!(c.cache_set(1, 100), None);
assert_eq!(c.cache_get(&1), Some(&100));
assert_eq!(c.cache_hits(), None);
assert_eq!(c.cache_misses(), None);
}
#[test]
fn build_error_display() {
let err1 = BuildError::MissingRequired("ttl");
assert_eq!(err1.to_string(), "required field `ttl` was not set");
let err2 = BuildError::InvalidValue {
field: "max_size",
reason: "must be greater than zero",
};
assert_eq!(
err2.to_string(),
"invalid value for field `max_size`: must be greater than zero"
);
}
#[test]
fn cache_set_error_is_clone_eq() {
assert_eq!(CacheSetError::TimeBounds, CacheSetError::TimeBounds.clone());
}
fn _assert_build_error_bounds<T: Clone + PartialEq + Eq>() {}
fn _check_build_error() {
_assert_build_error_bounds::<BuildError>();
}
#[test]
fn build_error_clone_partial_eq_eq() {
let a = BuildError::MissingRequired("ttl");
let b = a.clone();
assert_eq!(a, b);
let c = BuildError::InvalidValue {
field: "max_size",
reason: "must be greater than zero",
};
let d = c.clone();
assert_eq!(c, d);
assert_ne!(
BuildError::MissingRequired("ttl"),
BuildError::InvalidValue {
field: "ttl",
reason: "must be greater than zero",
},
);
assert_ne!(
BuildError::MissingRequired("ttl"),
BuildError::MissingRequired("max_size"),
);
}
#[test]
fn build_error_invalid_value_field_discriminates() {
assert_ne!(
BuildError::InvalidValue {
field: "max_size",
reason: "must be greater than zero",
},
BuildError::InvalidValue {
field: "ttl",
reason: "must be greater than zero",
},
);
assert_ne!(
BuildError::InvalidValue {
field: "max_size",
reason: "must be greater than zero",
},
BuildError::InvalidValue {
field: "max_size",
reason: "allocation failed",
},
);
let a = BuildError::InvalidValue {
field: "max_size",
reason: "must be greater than zero",
};
assert_eq!(a, a.clone());
}
#[test]
fn lru_build_error_is_comparable_and_cloneable() {
let missing = LruCache::<i32, i32>::builder().build().unwrap_err();
assert_eq!(missing, BuildError::MissingRequired("max_size"));
assert_eq!(missing, missing.clone());
let invalid = LruCache::<i32, i32>::builder()
.max_size(0)
.build()
.unwrap_err();
assert_eq!(
invalid,
BuildError::InvalidValue {
field: "max_size",
reason: "must be greater than zero",
}
);
assert_ne!(missing, invalid);
}
#[cfg(feature = "redis_store")]
#[test]
fn redis_build_error_wraps_build_error_without_clone_eq() {
let err = RedisCacheBuilder::<String, u32>::new().build().unwrap_err();
assert!(
matches!(
err,
RedisCacheBuildError::Build(BuildError::MissingRequired("prefix"))
),
"expected Build(MissingRequired(\"prefix\")), got {err:?}"
);
assert!(!format!("{err:?}").is_empty());
assert_eq!(
err.to_string(),
BuildError::MissingRequired("prefix").to_string()
);
}
#[cfg(feature = "redb_store")]
#[test]
fn redb_build_error_wraps_build_error_without_clone_eq() {
let err = RedbCacheBuilder::<String, u32>::new().build().unwrap_err();
assert!(
matches!(
err,
RedbCacheBuildError::Build(BuildError::MissingRequired("name"))
),
"expected Build(MissingRequired(\"name\")), got {err:?}"
);
assert!(!format!("{err:?}").is_empty());
assert_eq!(
err.to_string(),
BuildError::MissingRequired("name").to_string()
);
}
}