use std::collections::HashMap;
use std::hash::Hash;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use tokio::time::Instant;
pub trait Clock: Send + Sync + 'static {
fn now(&self) -> Instant;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TokioClock;
impl Clock for TokioClock {
fn now(&self) -> Instant {
Instant::now()
}
}
#[derive(Debug, Clone)]
struct Entry<V> {
value: V,
expires_at: Instant,
recency: u64,
}
pub struct TtlCache<K, V> {
inner: Mutex<CacheInner<K, V>>,
capacity: NonZeroUsize,
clock: Arc<dyn Clock>,
}
struct CacheInner<K, V> {
entries: HashMap<K, Entry<V>>,
counter: u64,
}
impl<K, V> TtlCache<K, V>
where
K: Eq + Hash + Clone,
V: Clone,
{
pub fn new(capacity: NonZeroUsize) -> Self {
Self::with_clock(capacity, Arc::new(TokioClock))
}
pub fn with_clock(capacity: NonZeroUsize, clock: Arc<dyn Clock>) -> Self {
Self {
inner: Mutex::new(CacheInner {
entries: HashMap::with_capacity(capacity.get()),
counter: 0,
}),
capacity,
clock,
}
}
pub fn capacity(&self) -> usize {
self.capacity.get()
}
pub fn len(&self) -> usize {
let Ok(inner) = self.inner.lock() else {
return 0;
};
inner.entries.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn get(&self, key: &K) -> Option<V> {
let now = self.clock.now();
let Ok(mut inner) = self.inner.lock() else {
return None;
};
let expired = inner
.entries
.get(key)
.map(|entry| entry.expires_at <= now)
.unwrap_or(false);
if expired {
inner.entries.remove(key);
return None;
}
inner.counter = inner.counter.saturating_add(1);
let counter = inner.counter;
let entry = inner.entries.get_mut(key)?;
entry.recency = counter;
Some(entry.value.clone())
}
pub fn insert(&self, key: K, value: V, ttl: Duration) {
let now = self.clock.now();
let expires_at = now.checked_add(ttl).unwrap_or(now);
let Ok(mut inner) = self.inner.lock() else {
return;
};
inner.counter = inner.counter.saturating_add(1);
let recency = inner.counter;
if let Some(entry) = inner.entries.get_mut(&key) {
entry.value = value;
entry.expires_at = expires_at;
entry.recency = recency;
return;
}
if inner.entries.len() >= self.capacity.get() {
evict_expired(&mut inner.entries, now);
}
if inner.entries.len() >= self.capacity.get() {
evict_lru(&mut inner.entries);
}
inner.entries.insert(
key,
Entry {
value,
expires_at,
recency,
},
);
}
pub fn prune(&self) -> usize {
let now = self.clock.now();
let Ok(mut inner) = self.inner.lock() else {
return 0;
};
evict_expired(&mut inner.entries, now)
}
pub fn clear(&self) {
if let Ok(mut inner) = self.inner.lock() {
inner.entries.clear();
}
}
}
fn evict_expired<K, V>(entries: &mut HashMap<K, Entry<V>>, now: Instant) -> usize
where
K: Eq + Hash + Clone,
{
let expired: Vec<K> = entries
.iter()
.filter_map(|(k, entry)| {
if entry.expires_at <= now {
Some(k.clone())
} else {
None
}
})
.collect();
let removed = expired.len();
for k in expired {
entries.remove(&k);
}
removed
}
fn evict_lru<K, V>(entries: &mut HashMap<K, Entry<V>>)
where
K: Eq + Hash + Clone,
{
let victim = entries
.iter()
.min_by_key(|(_, entry)| entry.recency)
.map(|(k, _)| k.clone());
if let Some(key) = victim {
entries.remove(&key);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn nz(n: usize) -> NonZeroUsize {
NonZeroUsize::new(n).expect("non-zero capacity")
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn insert_and_get_returns_value() {
let cache: TtlCache<&'static str, u32> = TtlCache::new(nz(4));
cache.insert("k", 42, Duration::from_secs(30));
assert_eq!(cache.get(&"k"), Some(42));
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn expired_entry_returns_none() {
let cache: TtlCache<&'static str, u32> = TtlCache::new(nz(4));
cache.insert("k", 42, Duration::from_secs(1));
tokio::time::advance(Duration::from_secs(2)).await;
assert_eq!(cache.get(&"k"), None);
assert!(cache.is_empty());
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn lru_eviction_when_capacity_exceeded() {
let cache: TtlCache<&'static str, u32> = TtlCache::new(nz(2));
cache.insert("a", 1, Duration::from_secs(60));
cache.insert("b", 2, Duration::from_secs(60));
let _ = cache.get(&"a");
cache.insert("c", 3, Duration::from_secs(60));
assert_eq!(cache.get(&"a"), Some(1));
assert_eq!(cache.get(&"b"), None);
assert_eq!(cache.get(&"c"), Some(3));
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn prune_removes_only_expired() {
let cache: TtlCache<&'static str, u32> = TtlCache::new(nz(4));
cache.insert("short", 1, Duration::from_secs(1));
cache.insert("long", 2, Duration::from_secs(60));
tokio::time::advance(Duration::from_secs(2)).await;
let removed = cache.prune();
assert_eq!(removed, 1);
assert_eq!(cache.get(&"short"), None);
assert_eq!(cache.get(&"long"), Some(2));
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn overwrite_updates_value_and_ttl() {
let cache: TtlCache<&'static str, u32> = TtlCache::new(nz(2));
cache.insert("k", 1, Duration::from_secs(1));
cache.insert("k", 2, Duration::from_secs(30));
tokio::time::advance(Duration::from_secs(2)).await;
assert_eq!(cache.get(&"k"), Some(2));
}
}