Skip to main content

aa_cache/
l1.rs

1//! [`L1Cache`] — a `DashMap`-backed, TTL'd, cache-aside wrapper over a store.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use aa_core::storage::Result;
7use dashmap::mapref::entry::Entry;
8use dashmap::DashMap;
9use tokio::sync::Notify;
10
11use crate::cached_value::CachedValue;
12use crate::source::CacheSource;
13
14/// In-process L1 cache that fronts a [`CacheSource`] with a [`DashMap`].
15///
16/// `get` serves fresh keys from memory and falls back to the wrapped store on a
17/// miss or once an entry's TTL elapses, repopulating the cache on the way out
18/// (cache-aside). Concurrent misses for the same key collapse to a single
19/// `load` call (stampede protection), so a burst of cold lookups never fans out
20/// into N backend round-trips.
21pub struct L1Cache<S: CacheSource> {
22    inner: S,
23    entries: Arc<DashMap<S::Key, CachedValue<S::Value>>>,
24    inflight: Arc<DashMap<S::Key, Arc<Notify>>>,
25    ttl: Duration,
26}
27
28impl<S: CacheSource> L1Cache<S> {
29    /// Wrap `inner`, expiring cached entries `ttl` after insertion.
30    pub fn new(inner: S, ttl: Duration) -> Self {
31        Self {
32            inner,
33            entries: Arc::new(DashMap::new()),
34            inflight: Arc::new(DashMap::new()),
35            ttl,
36        }
37    }
38
39    /// Borrow the wrapped store.
40    pub fn inner(&self) -> &S {
41        &self.inner
42    }
43
44    /// Number of entries currently held (including any past their TTL but not
45    /// yet evicted). Intended for diagnostics, not control flow.
46    #[must_use]
47    pub fn len(&self) -> usize {
48        self.entries.len()
49    }
50
51    /// Whether the cache holds no entries.
52    #[must_use]
53    pub fn is_empty(&self) -> bool {
54        self.entries.is_empty()
55    }
56
57    /// Drop every cached entry.
58    pub fn clear(&self) {
59        self.entries.clear();
60    }
61
62    /// Drop the cached entry for `key`; returns whether one was present.
63    ///
64    /// This is the hook the Epic C push-invalidation channel calls when the
65    /// Gateway reports that an agent's policy changed: the next `get` reloads
66    /// from the source of truth rather than serving a stale entry.
67    pub fn invalidate(&self, key: &S::Key) -> bool {
68        self.entries.remove(key).is_some()
69    }
70
71    /// Return a fresh (non-expired) cached value for `key`, if present.
72    fn fresh(&self, key: &S::Key) -> Option<S::Value> {
73        let entry = self.entries.get(key)?;
74        if entry.is_expired(self.ttl) {
75            None
76        } else {
77            Some(entry.value.clone())
78        }
79    }
80
81    /// Fetch the value for `key`, serving from cache when fresh.
82    ///
83    /// Cache-aside: a hit clones out of the `DashMap`; a miss (or an expired
84    /// entry) loads from the wrapped store, populates the cache, and returns.
85    ///
86    /// Stampede protection: the first caller to miss a key becomes the *leader*
87    /// and performs the single `load`; concurrent callers become *followers*,
88    /// wait on a shared [`Notify`], then re-read the now-populated cache. The
89    /// inner store therefore sees exactly one call per key per miss window.
90    pub async fn get(&self, key: S::Key) -> Result<S::Value> {
91        loop {
92            // Fast path: a fresh cache hit needs no coordination.
93            if let Some(value) = self.fresh(&key) {
94                return Ok(value);
95            }
96
97            // Miss: claim leadership for this key, or grab the in-flight signal.
98            let follower = match self.inflight.entry(key.clone()) {
99                Entry::Vacant(slot) => {
100                    slot.insert(Arc::new(Notify::new()));
101                    None
102                }
103                Entry::Occupied(slot) => Some(slot.get().clone()),
104            };
105
106            match follower {
107                // Leader: load once, populate, then wake every waiter.
108                None => {
109                    let result = self.inner.load(&key).await;
110                    if let Ok(ref value) = result {
111                        self.entries.insert(key.clone(), CachedValue::new(value.clone()));
112                    }
113                    if let Some((_, notify)) = self.inflight.remove(&key) {
114                        notify.notify_waiters();
115                    }
116                    return result;
117                }
118                // Follower: wait for the leader, then retry the loop.
119                Some(notify) => {
120                    let waiter = notify.notified();
121                    tokio::pin!(waiter);
122                    // Register before re-checking the cache so the leader's
123                    // notification can't be missed (tokio::sync::Notify pattern):
124                    // the leader always populates `entries` before notifying, so
125                    // either the re-check sees the value or the wait is woken.
126                    waiter.as_mut().enable();
127                    if let Some(value) = self.fresh(&key) {
128                        return Ok(value);
129                    }
130                    waiter.await;
131                }
132            }
133        }
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use std::time::Duration;
140
141    use aa_core::storage::AgentId;
142
143    use crate::testing::{sample_policy, MemoryPolicyStore};
144    use crate::L1Cache;
145
146    fn agent(seed: u8) -> AgentId {
147        AgentId::from_bytes([seed; 16])
148    }
149
150    #[tokio::test]
151    async fn miss_populates_then_serves_from_cache() {
152        let id = agent(1);
153        let store = MemoryPolicyStore::with_policy(id, sample_policy(1));
154        let cache = L1Cache::new(store, Duration::from_secs(60));
155
156        // First get is a miss: hits the store and populates the cache.
157        let first = cache.get(id).await.expect("policy present");
158        assert_eq!(first.version, 1);
159        assert_eq!(cache.inner().call_count(), 1);
160        assert_eq!(cache.len(), 1);
161
162        // Second get is a hit: served from memory, the store is not touched again.
163        let second = cache.get(id).await.expect("policy present");
164        assert_eq!(second.version, 1);
165        assert_eq!(cache.inner().call_count(), 1);
166    }
167
168    #[tokio::test]
169    async fn expired_entry_is_treated_as_a_miss() {
170        let id = agent(2);
171        let store = MemoryPolicyStore::with_policy(id, sample_policy(1));
172        let cache = L1Cache::new(store, Duration::from_millis(20));
173
174        cache.get(id).await.expect("policy present");
175        assert_eq!(cache.inner().call_count(), 1);
176
177        // Let the entry age past its TTL; the next get must reload from the store.
178        tokio::time::sleep(Duration::from_millis(40)).await;
179        cache.get(id).await.expect("policy present");
180        assert_eq!(cache.inner().call_count(), 2);
181    }
182
183    #[tokio::test]
184    async fn invalidate_evicts_the_cached_entry() {
185        let id = agent(3);
186        let store = MemoryPolicyStore::with_policy(id, sample_policy(1));
187        let cache = L1Cache::new(store, Duration::from_secs(60));
188
189        cache.get(id).await.expect("policy present");
190        assert_eq!(cache.len(), 1);
191
192        // Invalidate removes the entry and reports it was present.
193        assert!(cache.invalidate(&id));
194        assert_eq!(cache.len(), 0);
195
196        // Invalidating the now-absent key reports nothing was removed.
197        assert!(!cache.invalidate(&id));
198
199        // The next get is a fresh miss that reloads from the store.
200        cache.get(id).await.expect("policy present");
201        assert_eq!(cache.inner().call_count(), 2);
202    }
203
204    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
205    async fn concurrent_misses_collapse_to_one_load() {
206        use std::sync::Arc;
207
208        let id = agent(4);
209        // A 50ms inner delay holds the leader long enough for all followers to
210        // pile up behind it before it finishes loading.
211        let store = MemoryPolicyStore::with_policy(id, sample_policy(7)).with_delay(Duration::from_millis(50));
212        let cache = Arc::new(L1Cache::new(store, Duration::from_secs(60)));
213
214        // Fire 100 concurrent gets for the same cold key.
215        let mut handles = Vec::with_capacity(100);
216        for _ in 0..100 {
217            let cache = Arc::clone(&cache);
218            handles.push(tokio::spawn(async move { cache.get(id).await }));
219        }
220        for handle in handles {
221            let policy = handle.await.expect("task joined").expect("policy present");
222            assert_eq!(policy.version, 7);
223        }
224
225        // Every miss collapsed onto a single inner load.
226        assert_eq!(cache.inner().call_count(), 1);
227    }
228}