use std::fmt;
use std::hash::Hash;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use dashmap::DashMap;
use tokio::sync::OnceCell;
pub struct EpochCache<K, V> {
map: DashMap<K, CacheEntry<V>>,
inflight: DashMap<K, Arc<OnceCell<V>>>,
epoch: AtomicU64,
max_entries: usize,
}
struct CacheEntry<V> {
value: V,
epoch: u64,
}
impl<K, V> EpochCache<K, V>
where
K: Eq + Hash + Clone + fmt::Debug,
V: Clone,
{
pub fn new() -> Self {
Self::with_capacity(usize::MAX)
}
pub fn with_capacity(max_entries: usize) -> Self {
Self {
map: DashMap::new(),
inflight: DashMap::new(),
epoch: AtomicU64::new(0),
max_entries,
}
}
pub async fn get_or_insert_with<F, Fut>(
&self,
key: K,
f: F,
) -> Result<V, crate::error::HirnDbError>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<V, crate::error::HirnDbError>>,
{
let current_epoch = self.epoch.load(Ordering::Acquire);
if let Some(entry) = self.map.get(&key).filter(|e| e.epoch >= current_epoch) {
tracing::trace!(key = ?key, "cache hit");
return Ok(entry.value.clone());
}
let cell: Arc<OnceCell<V>> = self
.inflight
.entry(key.clone())
.or_insert_with(|| Arc::new(OnceCell::new()))
.clone();
let result = cell
.get_or_try_init(|| async {
tracing::debug!(key = ?key, "cache miss — computing value");
f().await
})
.await;
self.inflight.remove_if(&key, |_, v| Arc::ptr_eq(v, &cell));
match result {
Ok(val) => {
let owned = val.clone();
self.insert_evicting(key, owned.clone(), current_epoch);
Ok(owned)
}
Err(e) => Err(e),
}
}
fn insert_evicting(&self, key: K, val: V, current_epoch: u64) {
if self.map.len() >= self.max_entries {
if let Some((evict_key, _)) = self
.map
.iter()
.min_by_key(|e| e.epoch)
.map(|e| (e.key().clone(), ()))
{
self.map.remove(&evict_key);
}
}
self.map.insert(
key,
CacheEntry {
value: val,
epoch: current_epoch,
},
);
}
pub fn get(&self, key: &K) -> Option<V> {
let current_epoch = self.epoch.load(Ordering::Acquire);
self.map.get(key).and_then(|entry| {
if entry.epoch >= current_epoch {
tracing::trace!(key = ?key, "cache hit (sync)");
Some(entry.value.clone())
} else {
tracing::debug!(key = ?key, "cache miss (stale)");
None
}
})
}
pub fn put(&self, key: K, val: V) {
let current_epoch = self.epoch.load(Ordering::Acquire);
tracing::trace!(key = ?key, epoch = current_epoch, "cache put");
self.insert_evicting(key, val, current_epoch);
}
pub fn invalidate(&self, key: &K) {
tracing::info!(key = ?key, "cache invalidate");
self.map.remove(key);
}
pub fn invalidate_all(&self) {
let old = self.epoch.fetch_add(1, Ordering::AcqRel);
tracing::info!(
old_epoch = old,
new_epoch = old + 1,
"cache invalidate_all — generation boundary advanced"
);
}
pub fn epoch(&self) -> u64 {
self.epoch.load(Ordering::Acquire)
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
}
impl<K, V> Default for EpochCache<K, V>
where
K: Eq + Hash + Clone + fmt::Debug,
V: Clone,
{
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tokio::sync::Barrier;
#[tokio::test(flavor = "multi_thread")]
async fn insert_and_get() {
let cache: EpochCache<String, i32> = EpochCache::new();
let val = cache
.get_or_insert_with("key".to_string(), || async { Ok(42) })
.await
.unwrap();
assert_eq!(val, 42);
assert_eq!(cache.get(&"key".to_string()), Some(42));
}
#[tokio::test(flavor = "multi_thread")]
async fn invalidate_single() {
let cache: EpochCache<String, i32> = EpochCache::new();
cache
.get_or_insert_with("key".to_string(), || async { Ok(42) })
.await
.unwrap();
cache.invalidate(&"key".to_string());
assert_eq!(cache.get(&"key".to_string()), None);
}
#[tokio::test(flavor = "multi_thread")]
async fn invalidate_all_bumps_epoch() {
let cache: EpochCache<String, i32> = EpochCache::new();
cache
.get_or_insert_with("a".to_string(), || async { Ok(1) })
.await
.unwrap();
cache
.get_or_insert_with("b".to_string(), || async { Ok(2) })
.await
.unwrap();
assert_eq!(cache.epoch(), 0);
cache.invalidate_all();
assert_eq!(cache.epoch(), 1);
assert_eq!(cache.get(&"a".to_string()), None);
assert_eq!(cache.get(&"b".to_string()), None);
}
#[tokio::test(flavor = "multi_thread")]
async fn recomputes_after_invalidate_all() {
let cache: EpochCache<String, i32> = EpochCache::new();
cache
.get_or_insert_with("key".to_string(), || async { Ok(1) })
.await
.unwrap();
cache.invalidate_all();
let val = cache
.get_or_insert_with("key".to_string(), || async { Ok(99) })
.await
.unwrap();
assert_eq!(val, 99);
}
#[tokio::test(flavor = "multi_thread")]
async fn invalidate_then_put_preserves_new_entry() {
let cache: EpochCache<String, i32> = EpochCache::new();
let key = "key".to_string();
cache.put(key.clone(), 1);
cache.invalidate_all();
cache.put(key.clone(), 2);
assert_eq!(cache.get(&key), Some(2));
}
#[tokio::test(flavor = "multi_thread")]
async fn pre_invalidation_entries_not_returned_after_boundary() {
let cache: EpochCache<String, i32> = EpochCache::new();
let key = "key".to_string();
cache.put(key.clone(), 1);
assert_eq!(cache.get(&key), Some(1));
cache.invalidate_all();
assert_eq!(cache.get(&key), None);
}
#[tokio::test(flavor = "multi_thread")]
async fn concurrent_invalidate_read_insert_loops_preserve_post_boundary_values() {
let cache = Arc::new(EpochCache::<String, usize>::new());
let barrier = Arc::new(Barrier::new(2));
let rounds = 128;
let invalidator_cache = Arc::clone(&cache);
let invalidator_barrier = Arc::clone(&barrier);
let invalidator = tokio::spawn(async move {
for _ in 0..rounds {
invalidator_barrier.wait().await;
invalidator_cache.invalidate_all();
}
});
let writer_cache = Arc::clone(&cache);
let writer_barrier = Arc::clone(&barrier);
let writer = tokio::spawn(async move {
let key = "key".to_string();
for round in 0..rounds {
writer_barrier.wait().await;
let target_epoch = round as u64 + 1;
while writer_cache.epoch() < target_epoch {
tokio::task::yield_now().await;
}
assert_eq!(writer_cache.get(&key), None);
writer_cache.put(key.clone(), round);
assert_eq!(writer_cache.get(&key), Some(round));
}
});
invalidator.await.unwrap();
writer.await.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
async fn concurrent_access() {
let cache = Arc::new(EpochCache::<u64, u64>::new());
let mut handles = Vec::new();
for i in 0..10 {
let cache = Arc::clone(&cache);
handles.push(tokio::spawn(async move {
let val = cache
.get_or_insert_with(i, || async move { Ok(i * 10) })
.await
.unwrap();
assert_eq!(val, i * 10);
}));
}
for h in handles {
h.await.unwrap();
}
assert_eq!(cache.len(), 10);
}
#[tokio::test(flavor = "multi_thread")]
#[tracing_test::traced_test]
async fn tracing_emits_hit_miss_invalidate() {
let cache: EpochCache<String, i32> = EpochCache::new();
cache
.get_or_insert_with("k1".to_string(), || async { Ok(10) })
.await
.unwrap();
assert!(logs_contain("cache miss"));
let _ = cache
.get_or_insert_with("k1".to_string(), || async { Ok(999) })
.await
.unwrap();
assert!(logs_contain("cache hit"));
cache.invalidate(&"k1".to_string());
assert!(logs_contain("cache invalidate"));
cache.put("k2".to_string(), 20);
assert!(logs_contain("cache put"));
cache.invalidate_all();
assert!(logs_contain("generation boundary advanced"));
}
}