use std::sync::Arc;
use std::time::Duration;
use aa_core::storage::Result;
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use tokio::sync::Notify;
use crate::cached_value::CachedValue;
use crate::source::CacheSource;
pub struct L1Cache<S: CacheSource> {
inner: S,
entries: Arc<DashMap<S::Key, CachedValue<S::Value>>>,
inflight: Arc<DashMap<S::Key, Arc<Notify>>>,
ttl: Duration,
}
impl<S: CacheSource> L1Cache<S> {
pub fn new(inner: S, ttl: Duration) -> Self {
Self {
inner,
entries: Arc::new(DashMap::new()),
inflight: Arc::new(DashMap::new()),
ttl,
}
}
pub fn inner(&self) -> &S {
&self.inner
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn clear(&self) {
self.entries.clear();
}
pub fn invalidate(&self, key: &S::Key) -> bool {
self.entries.remove(key).is_some()
}
fn fresh(&self, key: &S::Key) -> Option<S::Value> {
let entry = self.entries.get(key)?;
if entry.is_expired(self.ttl) {
None
} else {
Some(entry.value.clone())
}
}
pub async fn get(&self, key: S::Key) -> Result<S::Value> {
loop {
if let Some(value) = self.fresh(&key) {
return Ok(value);
}
let follower = match self.inflight.entry(key.clone()) {
Entry::Vacant(slot) => {
slot.insert(Arc::new(Notify::new()));
None
}
Entry::Occupied(slot) => Some(slot.get().clone()),
};
match follower {
None => {
let result = self.inner.load(&key).await;
if let Ok(ref value) = result {
self.entries.insert(key.clone(), CachedValue::new(value.clone()));
}
if let Some((_, notify)) = self.inflight.remove(&key) {
notify.notify_waiters();
}
return result;
}
Some(notify) => {
let waiter = notify.notified();
tokio::pin!(waiter);
waiter.as_mut().enable();
if let Some(value) = self.fresh(&key) {
return Ok(value);
}
waiter.await;
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use aa_core::storage::AgentId;
use crate::testing::{sample_policy, MemoryPolicyStore};
use crate::L1Cache;
fn agent(seed: u8) -> AgentId {
AgentId::from_bytes([seed; 16])
}
#[tokio::test]
async fn miss_populates_then_serves_from_cache() {
let id = agent(1);
let store = MemoryPolicyStore::with_policy(id, sample_policy(1));
let cache = L1Cache::new(store, Duration::from_secs(60));
let first = cache.get(id).await.expect("policy present");
assert_eq!(first.version, 1);
assert_eq!(cache.inner().call_count(), 1);
assert_eq!(cache.len(), 1);
let second = cache.get(id).await.expect("policy present");
assert_eq!(second.version, 1);
assert_eq!(cache.inner().call_count(), 1);
}
#[tokio::test]
async fn expired_entry_is_treated_as_a_miss() {
let id = agent(2);
let store = MemoryPolicyStore::with_policy(id, sample_policy(1));
let cache = L1Cache::new(store, Duration::from_millis(20));
cache.get(id).await.expect("policy present");
assert_eq!(cache.inner().call_count(), 1);
tokio::time::sleep(Duration::from_millis(40)).await;
cache.get(id).await.expect("policy present");
assert_eq!(cache.inner().call_count(), 2);
}
#[tokio::test]
async fn invalidate_evicts_the_cached_entry() {
let id = agent(3);
let store = MemoryPolicyStore::with_policy(id, sample_policy(1));
let cache = L1Cache::new(store, Duration::from_secs(60));
cache.get(id).await.expect("policy present");
assert_eq!(cache.len(), 1);
assert!(cache.invalidate(&id));
assert_eq!(cache.len(), 0);
assert!(!cache.invalidate(&id));
cache.get(id).await.expect("policy present");
assert_eq!(cache.inner().call_count(), 2);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_misses_collapse_to_one_load() {
use std::sync::Arc;
let id = agent(4);
let store = MemoryPolicyStore::with_policy(id, sample_policy(7)).with_delay(Duration::from_millis(50));
let cache = Arc::new(L1Cache::new(store, Duration::from_secs(60)));
let mut handles = Vec::with_capacity(100);
for _ in 0..100 {
let cache = Arc::clone(&cache);
handles.push(tokio::spawn(async move { cache.get(id).await }));
}
for handle in handles {
let policy = handle.await.expect("task joined").expect("policy present");
assert_eq!(policy.version, 7);
}
assert_eq!(cache.inner().call_count(), 1);
}
}