use crate::stats::{CacheCounters, CacheStats};
use axess_clock::Clock;
use lru::LruCache;
use parking_lot::Mutex;
use std::collections::HashMap;
use std::hash::Hash;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Duration;
use tokio::sync::OnceCell;
struct Entry<V> {
value: V,
expires_at_micros: i64,
}
pub struct ClockTtlCache<K, V>
where
K: Hash + Eq + Send + Sync,
V: Clone + Send + Sync,
{
inner: Mutex<LruCache<K, Entry<V>>>,
inflight: Mutex<HashMap<K, Arc<OnceCell<V>>>>,
clock: Arc<dyn Clock>,
ttl: Duration,
pub(crate) stats: CacheCounters,
}
impl<K, V> ClockTtlCache<K, V>
where
K: Hash + Eq + Clone + Send + Sync,
V: Clone + Send + Sync,
{
pub fn new(capacity: NonZeroUsize, ttl: Duration, clock: Arc<dyn Clock>) -> Self {
Self {
inner: Mutex::new(LruCache::new(capacity)),
inflight: Mutex::new(HashMap::new()),
clock,
ttl,
stats: CacheCounters::default(),
}
}
pub fn stats(&self) -> CacheStats {
self.stats.snapshot()
}
pub fn reset_stats(&self) {
self.stats.reset();
}
pub fn get(&self, key: &K) -> Option<V> {
let now_micros = self.clock.now().timestamp_micros();
let mut guard = self.inner.lock();
if let Some(entry) = guard.get(key) {
if entry.expires_at_micros >= now_micros {
let v = entry.value.clone();
drop(guard);
self.stats.hits.fetch_add(1, Ordering::Relaxed);
return Some(v);
}
}
guard.pop(key);
drop(guard);
self.stats.misses.fetch_add(1, Ordering::Relaxed);
None
}
pub fn insert(&self, key: K, value: V) {
let now_micros = self.clock.now().timestamp_micros();
let expires_at_micros = now_micros.saturating_add(self.ttl.as_micros() as i64);
let mut guard = self.inner.lock();
let len_before = guard.len();
let cap = guard.cap().get();
guard.put(
key,
Entry {
value,
expires_at_micros,
},
);
let inserted_displaced = len_before >= cap;
drop(guard);
self.stats.inserts.fetch_add(1, Ordering::Relaxed);
if inserted_displaced {
self.stats
.capacity_evictions
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn invalidate(&self, key: &K) -> bool {
let inflight_removed = self.inflight.lock().remove(key).is_some();
let lru_removed = self.inner.lock().pop(key).is_some();
let removed = inflight_removed || lru_removed;
if removed {
self.stats.invalidations.fetch_add(1, Ordering::Relaxed);
}
removed
}
#[tracing::instrument(
name = "axess_cache.invalidate_by",
skip(self, predicate),
fields(removed = tracing::field::Empty),
)]
pub fn invalidate_by(&self, predicate: impl Fn(&K) -> bool) -> usize {
{
let mut inflight = self.inflight.lock();
let to_remove: Vec<K> = inflight.keys().filter(|k| predicate(k)).cloned().collect();
for k in &to_remove {
inflight.remove(k);
}
}
let mut guard = self.inner.lock();
let to_remove: Vec<K> = guard
.iter()
.filter_map(|(k, _)| if predicate(k) { Some(k.clone()) } else { None })
.collect();
let count = to_remove.len();
for k in &to_remove {
guard.pop(k);
}
drop(guard);
if count > 0 {
self.stats
.invalidations
.fetch_add(count as u64, Ordering::Relaxed);
}
tracing::Span::current().record("removed", count);
count
}
#[tracing::instrument(
name = "axess_cache.invalidate_all",
skip(self),
fields(removed = tracing::field::Empty),
)]
pub fn invalidate_all(&self) {
self.inflight.lock().clear();
let mut guard = self.inner.lock();
let count = guard.len();
guard.clear();
drop(guard);
if count > 0 {
self.stats
.invalidations
.fetch_add(count as u64, Ordering::Relaxed);
}
tracing::Span::current().record("removed", count);
}
#[tracing::instrument(
name = "axess_cache.cleanup_expired",
skip(self),
fields(removed = tracing::field::Empty),
)]
pub fn cleanup_expired(&self) -> usize {
let now_micros = self.clock.now().timestamp_micros();
let mut guard = self.inner.lock();
let to_remove: Vec<K> = guard
.iter()
.filter_map(|(k, e)| {
if e.expires_at_micros < now_micros {
Some(k.clone())
} else {
None
}
})
.collect();
let count = to_remove.len();
for k in &to_remove {
guard.pop(k);
}
drop(guard);
if count > 0 {
self.stats
.invalidations
.fetch_add(count as u64, Ordering::Relaxed);
}
tracing::Span::current().record("removed", count);
count
}
#[tracing::instrument(
name = "axess_cache.get_or_try_insert_with",
skip(self, key, fetcher),
fields(joined = tracing::field::Empty, error = tracing::field::Empty),
)]
pub async fn get_or_try_insert_with<F, Fut, E>(&self, key: K, fetcher: F) -> Result<V, E>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<V, E>>,
{
if let Some(v) = self.get(&key) {
return Ok(v);
}
let (cell, joined) = {
let mut inflight = self.inflight.lock();
match inflight.get(&key) {
Some(existing) => (existing.clone(), true),
None => {
let cell = Arc::new(OnceCell::<V>::new());
inflight.insert(key.clone(), cell.clone());
(cell, false)
}
}
};
tracing::Span::current().record("joined", joined);
if joined {
self.stats
.single_flight_joins
.fetch_add(1, Ordering::Relaxed);
}
let inflight_guard = InflightGuard {
cache: self,
key: key.clone(),
cell: cell.clone(),
};
let result: Result<&V, E> = cell.get_or_try_init(fetcher).await;
let outcome = match result {
Ok(v_ref) => {
let v = v_ref.clone();
{
let mut inflight = self.inflight.lock();
if let Some(existing) = inflight.get(&key)
&& Arc::ptr_eq(existing, &cell)
{
self.insert(key.clone(), v.clone());
inflight.remove(&key);
}
};
Ok(v)
}
Err(e) => {
self.stats
.single_flight_errors
.fetch_add(1, Ordering::Relaxed);
tracing::Span::current().record("error", true);
Err(e)
}
};
drop(inflight_guard);
outcome
}
pub fn len(&self) -> usize {
self.inner.lock().len()
}
pub fn is_empty(&self) -> bool {
self.inner.lock().len() == 0
}
pub fn capacity(&self) -> NonZeroUsize {
self.inner.lock().cap()
}
pub fn pending_loads_count(&self) -> usize {
self.inflight.lock().len()
}
}
struct InflightGuard<'a, K, V>
where
K: Hash + Eq + Clone + Send + Sync,
V: Clone + Send + Sync,
{
cache: &'a ClockTtlCache<K, V>,
key: K,
cell: Arc<OnceCell<V>>,
}
impl<K, V> Drop for InflightGuard<'_, K, V>
where
K: Hash + Eq + Clone + Send + Sync,
V: Clone + Send + Sync,
{
fn drop(&mut self) {
let mut inflight = self.cache.inflight.lock();
if let Some(existing) = inflight.get(&self.key)
&& Arc::ptr_eq(existing, &self.cell)
{
inflight.remove(&self.key);
}
}
}