use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::RwLock;
use super::POOL_ACCESS_COUNTER;
struct CacheEntry<V> {
value: V,
last_access: AtomicU64,
}
impl<V> CacheEntry<V> {
fn new(value: V) -> Self {
Self {
value,
last_access: AtomicU64::new(POOL_ACCESS_COUNTER.fetch_add(1, Ordering::Relaxed)),
}
}
fn touch(&self) {
self.last_access.store(
POOL_ACCESS_COUNTER.fetch_add(1, Ordering::Relaxed),
Ordering::Relaxed,
);
}
}
pub struct LruCache<V: Clone> {
entries: RwLock<HashMap<String, CacheEntry<V>>>,
max_entries: usize,
cache_label: &'static str,
}
impl<V: Clone> LruCache<V> {
pub fn new(max_entries: usize, cache_label: &'static str) -> Self {
Self {
entries: RwLock::new(HashMap::new()),
max_entries,
cache_label,
}
}
pub async fn get_or_create<F, Fut, E>(&self, key: &str, create_fn: F) -> Result<V, E>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<V, E>>,
{
{
let entries = self.entries.read().await;
if let Some(entry) = entries.get(key) {
entry.touch();
return Ok(entry.value.clone());
}
}
let value = create_fn().await?;
let mut entries = self.entries.write().await;
if let Some(existing) = entries.get(key) {
existing.touch();
return Ok(existing.value.clone());
}
if entries.len() >= self.max_entries
&& let Some(lru_key) = entries
.iter()
.min_by_key(|(_, e)| e.last_access.load(Ordering::Relaxed))
.map(|(k, _)| k.clone())
{
tracing::info!(
evicted = %lru_key,
cache = self.cache_label,
"Pool cache at capacity, evicting least-recently-used entry"
);
entries.remove(&lru_key);
}
entries.insert(key.to_string(), CacheEntry::new(value.clone()));
Ok(value)
}
pub async fn evict(&self, key: &str) {
self.entries.write().await.remove(key);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
#[tokio::test]
async fn test_cache_miss_creates() {
let cache = LruCache::new(4, "test");
let val: Result<String, String> = cache
.get_or_create("key1", || async { Ok("value1".to_string()) })
.await;
assert_eq!(val.expect("test"), "value1");
}
#[tokio::test]
async fn test_cache_hit_returns_cached() {
let call_count = Arc::new(AtomicUsize::new(0));
let cache = LruCache::new(4, "test");
let cc = call_count.clone();
let _: Result<String, String> = cache
.get_or_create("key1", || {
let cc = cc.clone();
async move {
cc.fetch_add(1, AtomicOrdering::Relaxed);
Ok("value1".to_string())
}
})
.await;
let cc = call_count.clone();
let val: Result<String, String> = cache
.get_or_create("key1", || {
let cc = cc.clone();
async move {
cc.fetch_add(1, AtomicOrdering::Relaxed);
Ok("value2".to_string())
}
})
.await;
assert_eq!(val.expect("test"), "value1");
assert_eq!(call_count.load(AtomicOrdering::Relaxed), 1);
}
#[tokio::test]
async fn test_lru_eviction_at_capacity() {
let cache = LruCache::new(2, "test");
let _: Result<String, String> = cache
.get_or_create("a", || async { Ok("A".to_string()) })
.await;
let _: Result<String, String> = cache
.get_or_create("b", || async { Ok("B".to_string()) })
.await;
let _: Result<String, String> = cache
.get_or_create("c", || async { Ok("C".to_string()) })
.await;
let call_count = Arc::new(AtomicUsize::new(0));
let cc = call_count.clone();
let val: Result<String, String> = cache
.get_or_create("a", || {
let cc = cc.clone();
async move {
cc.fetch_add(1, AtomicOrdering::Relaxed);
Ok("A2".to_string())
}
})
.await;
assert_eq!(val.expect("test"), "A2");
assert_eq!(call_count.load(AtomicOrdering::Relaxed), 1);
let cc2 = Arc::new(AtomicUsize::new(0));
let cc2_ref = cc2.clone();
let _: Result<String, String> = cache
.get_or_create("c", || {
let cc2_ref = cc2_ref.clone();
async move {
cc2_ref.fetch_add(1, AtomicOrdering::Relaxed);
Ok("C2".to_string())
}
})
.await;
assert_eq!(
cc2.load(AtomicOrdering::Relaxed),
0,
"c should still be cached"
);
}
#[tokio::test]
async fn test_touch_updates_lru_order() {
let cache = LruCache::new(2, "test");
let _: Result<String, String> = cache
.get_or_create("a", || async { Ok("A".to_string()) })
.await;
let _: Result<String, String> = cache
.get_or_create("b", || async { Ok("B".to_string()) })
.await;
let _: Result<String, String> = cache
.get_or_create("a", || async { Ok("should not be called".to_string()) })
.await;
let _: Result<String, String> = cache
.get_or_create("c", || async { Ok("C".to_string()) })
.await;
let cc = Arc::new(AtomicUsize::new(0));
let cc_ref = cc.clone();
let val: Result<String, String> = cache
.get_or_create("a", || {
let cc_ref = cc_ref.clone();
async move {
cc_ref.fetch_add(1, AtomicOrdering::Relaxed);
Ok("A2".to_string())
}
})
.await;
assert_eq!(val.expect("test"), "A");
assert_eq!(
cc.load(AtomicOrdering::Relaxed),
0,
"a should still be cached"
);
}
#[tokio::test]
async fn test_evict_removes_entry() {
let cache = LruCache::new(4, "test");
let _: Result<String, String> = cache
.get_or_create("key1", || async { Ok("value1".to_string()) })
.await;
cache.evict("key1").await;
let call_count = Arc::new(AtomicUsize::new(0));
let cc = call_count.clone();
let val: Result<String, String> = cache
.get_or_create("key1", || {
let cc = cc.clone();
async move {
cc.fetch_add(1, AtomicOrdering::Relaxed);
Ok("value2".to_string())
}
})
.await;
assert_eq!(val.expect("test"), "value2");
assert_eq!(call_count.load(AtomicOrdering::Relaxed), 1);
}
#[tokio::test]
async fn test_race_uses_existing() {
let cache = Arc::new(LruCache::new(4, "test"));
let barrier = Arc::new(tokio::sync::Barrier::new(2));
let cache1 = cache.clone();
let barrier1 = barrier.clone();
let h1 = tokio::spawn(async move {
let val: Result<String, String> = cache1
.get_or_create("key1", || {
let barrier1 = barrier1.clone();
async move {
barrier1.wait().await;
Ok("from_task_1".to_string())
}
})
.await;
val.expect("test")
});
let cache2 = cache.clone();
let barrier2 = barrier.clone();
let h2 = tokio::spawn(async move {
let val: Result<String, String> = cache2
.get_or_create("key1", || {
let barrier2 = barrier2.clone();
async move {
barrier2.wait().await;
Ok("from_task_2".to_string())
}
})
.await;
val.expect("test")
});
let (v1, v2) = tokio::join!(h1, h2);
let v1 = v1.expect("test");
let v2 = v2.expect("test");
assert_eq!(v1, v2, "both tasks should see the same cached value");
}
}