use core::hash::Hash;
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::time::{Duration, Instant};
use crate::cache::Cache;
use crate::error::CacheError;
use crate::sharding::{self, Sharded};
use crate::util::MutexExt;
const FAR_FUTURE: Duration = Duration::from_secs(60 * 60 * 24 * 365 * 100);
pub struct TtlCache<K, V> {
capacity: NonZeroUsize,
default_ttl: Duration,
sharded: Sharded<Inner<K, V>>,
}
struct Entry<V> {
value: V,
expires_at: Instant,
}
struct Inner<K, V> {
capacity: NonZeroUsize,
map: HashMap<K, Entry<V>>,
}
impl<K, V> Inner<K, V>
where
K: Eq + Hash + Clone,
{
fn with_capacity(capacity: NonZeroUsize) -> Self {
let cap = capacity.get();
Self {
capacity,
map: HashMap::with_capacity(cap),
}
}
}
impl<K, V> TtlCache<K, V>
where
K: Eq + Hash + Clone,
V: Clone,
{
pub fn new(capacity: usize, ttl: Duration) -> Result<Self, CacheError> {
let cap = NonZeroUsize::new(capacity).ok_or(CacheError::InvalidCapacity)?;
Ok(Self::with_capacity(cap, ttl))
}
pub fn with_capacity(capacity: NonZeroUsize, ttl: Duration) -> Self {
let num_shards = sharding::shard_count(capacity);
let per_shard = sharding::per_shard_capacity(capacity, num_shards);
let sharded = Sharded::from_factory(num_shards, |_| Inner::with_capacity(per_shard));
Self {
capacity,
default_ttl: ttl,
sharded,
}
}
pub fn insert_with_ttl(&self, key: K, value: V, ttl: Duration) -> Option<V> {
let deadline = compute_deadline(ttl);
let mut inner = self.sharded.shard_for(&key).lock_recover();
Self::insert_inner(&mut inner, key, value, deadline)
}
fn insert_inner(inner: &mut Inner<K, V>, key: K, value: V, deadline: Instant) -> Option<V> {
let now = Instant::now();
if let Some(existing) = inner.map.get_mut(&key) {
if existing.expires_at > now {
let old = core::mem::replace(&mut existing.value, value);
existing.expires_at = deadline;
return Some(old);
}
}
let _ = inner.map.remove(&key);
if inner.map.len() >= inner.capacity.get() {
if let Some(victim) = find_victim(&inner.map) {
let _ = inner.map.remove(&victim);
}
}
let _ = inner.map.insert(
key,
Entry {
value,
expires_at: deadline,
},
);
None
}
}
impl<K, V> Cache<K, V> for TtlCache<K, V>
where
K: Eq + Hash + Clone,
V: Clone,
{
fn get(&self, key: &K) -> Option<V> {
let mut inner = self.sharded.shard_for(key).lock_recover();
let now = Instant::now();
let expires_at = inner.map.get(key).map(|e| e.expires_at)?;
if expires_at <= now {
let _ = inner.map.remove(key);
return None;
}
inner.map.get(key).map(|e| e.value.clone())
}
fn insert(&self, key: K, value: V) -> Option<V> {
let deadline = compute_deadline(self.default_ttl);
let mut inner = self.sharded.shard_for(&key).lock_recover();
Self::insert_inner(&mut inner, key, value, deadline)
}
fn remove(&self, key: &K) -> Option<V> {
let mut inner = self.sharded.shard_for(key).lock_recover();
inner.map.remove(key).map(|e| e.value)
}
fn contains_key(&self, key: &K) -> bool {
let mut inner = self.sharded.shard_for(key).lock_recover();
let now = Instant::now();
let Some(expires_at) = inner.map.get(key).map(|e| e.expires_at) else {
return false;
};
if expires_at > now {
return true;
}
let _ = inner.map.remove(key);
false
}
fn len(&self) -> usize {
let mut total = 0;
for mutex in self.sharded.iter() {
let mut inner = mutex.lock_recover();
purge_expired(&mut inner.map);
total += inner.map.len();
}
total
}
fn clear(&self) {
for mutex in self.sharded.iter() {
let mut inner = mutex.lock_recover();
inner.map.clear();
}
}
fn capacity(&self) -> usize {
self.capacity.get()
}
}
fn compute_deadline(ttl: Duration) -> Instant {
let now = Instant::now();
match now.checked_add(ttl) {
Some(t) => t,
None => now.checked_add(FAR_FUTURE).unwrap_or(now),
}
}
fn find_victim<K, V>(map: &HashMap<K, Entry<V>>) -> Option<K>
where
K: Clone,
{
map.iter()
.min_by_key(|(_, e)| e.expires_at)
.map(|(k, _)| k.clone())
}
fn purge_expired<K, V>(map: &mut HashMap<K, Entry<V>>) {
let now = Instant::now();
map.retain(|_, entry| entry.expires_at > now);
}