1use std::sync::Arc;
4use std::time::Duration;
5
6use aa_core::storage::Result;
7use dashmap::mapref::entry::Entry;
8use dashmap::DashMap;
9use tokio::sync::Notify;
10
11use crate::cached_value::CachedValue;
12use crate::source::CacheSource;
13
14pub struct L1Cache<S: CacheSource> {
22 inner: S,
23 entries: Arc<DashMap<S::Key, CachedValue<S::Value>>>,
24 inflight: Arc<DashMap<S::Key, Arc<Notify>>>,
25 ttl: Duration,
26}
27
28impl<S: CacheSource> L1Cache<S> {
29 pub fn new(inner: S, ttl: Duration) -> Self {
31 Self {
32 inner,
33 entries: Arc::new(DashMap::new()),
34 inflight: Arc::new(DashMap::new()),
35 ttl,
36 }
37 }
38
39 pub fn inner(&self) -> &S {
41 &self.inner
42 }
43
44 #[must_use]
47 pub fn len(&self) -> usize {
48 self.entries.len()
49 }
50
51 #[must_use]
53 pub fn is_empty(&self) -> bool {
54 self.entries.is_empty()
55 }
56
57 pub fn clear(&self) {
59 self.entries.clear();
60 }
61
62 pub fn invalidate(&self, key: &S::Key) -> bool {
68 self.entries.remove(key).is_some()
69 }
70
71 fn fresh(&self, key: &S::Key) -> Option<S::Value> {
73 let entry = self.entries.get(key)?;
74 if entry.is_expired(self.ttl) {
75 None
76 } else {
77 Some(entry.value.clone())
78 }
79 }
80
81 pub async fn get(&self, key: S::Key) -> Result<S::Value> {
91 loop {
92 if let Some(value) = self.fresh(&key) {
94 return Ok(value);
95 }
96
97 let follower = match self.inflight.entry(key.clone()) {
99 Entry::Vacant(slot) => {
100 slot.insert(Arc::new(Notify::new()));
101 None
102 }
103 Entry::Occupied(slot) => Some(slot.get().clone()),
104 };
105
106 match follower {
107 None => {
109 let result = self.inner.load(&key).await;
110 if let Ok(ref value) = result {
111 self.entries.insert(key.clone(), CachedValue::new(value.clone()));
112 }
113 if let Some((_, notify)) = self.inflight.remove(&key) {
114 notify.notify_waiters();
115 }
116 return result;
117 }
118 Some(notify) => {
120 let waiter = notify.notified();
121 tokio::pin!(waiter);
122 waiter.as_mut().enable();
127 if let Some(value) = self.fresh(&key) {
128 return Ok(value);
129 }
130 waiter.await;
131 }
132 }
133 }
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use std::time::Duration;
140
141 use aa_core::storage::AgentId;
142
143 use crate::testing::{sample_policy, MemoryPolicyStore};
144 use crate::L1Cache;
145
146 fn agent(seed: u8) -> AgentId {
147 AgentId::from_bytes([seed; 16])
148 }
149
150 #[tokio::test]
151 async fn miss_populates_then_serves_from_cache() {
152 let id = agent(1);
153 let store = MemoryPolicyStore::with_policy(id, sample_policy(1));
154 let cache = L1Cache::new(store, Duration::from_secs(60));
155
156 let first = cache.get(id).await.expect("policy present");
158 assert_eq!(first.version, 1);
159 assert_eq!(cache.inner().call_count(), 1);
160 assert_eq!(cache.len(), 1);
161
162 let second = cache.get(id).await.expect("policy present");
164 assert_eq!(second.version, 1);
165 assert_eq!(cache.inner().call_count(), 1);
166 }
167
168 #[tokio::test]
169 async fn expired_entry_is_treated_as_a_miss() {
170 let id = agent(2);
171 let store = MemoryPolicyStore::with_policy(id, sample_policy(1));
172 let cache = L1Cache::new(store, Duration::from_millis(20));
173
174 cache.get(id).await.expect("policy present");
175 assert_eq!(cache.inner().call_count(), 1);
176
177 tokio::time::sleep(Duration::from_millis(40)).await;
179 cache.get(id).await.expect("policy present");
180 assert_eq!(cache.inner().call_count(), 2);
181 }
182
183 #[tokio::test]
184 async fn invalidate_evicts_the_cached_entry() {
185 let id = agent(3);
186 let store = MemoryPolicyStore::with_policy(id, sample_policy(1));
187 let cache = L1Cache::new(store, Duration::from_secs(60));
188
189 cache.get(id).await.expect("policy present");
190 assert_eq!(cache.len(), 1);
191
192 assert!(cache.invalidate(&id));
194 assert_eq!(cache.len(), 0);
195
196 assert!(!cache.invalidate(&id));
198
199 cache.get(id).await.expect("policy present");
201 assert_eq!(cache.inner().call_count(), 2);
202 }
203
204 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
205 async fn concurrent_misses_collapse_to_one_load() {
206 use std::sync::Arc;
207
208 let id = agent(4);
209 let store = MemoryPolicyStore::with_policy(id, sample_policy(7)).with_delay(Duration::from_millis(50));
212 let cache = Arc::new(L1Cache::new(store, Duration::from_secs(60)));
213
214 let mut handles = Vec::with_capacity(100);
216 for _ in 0..100 {
217 let cache = Arc::clone(&cache);
218 handles.push(tokio::spawn(async move { cache.get(id).await }));
219 }
220 for handle in handles {
221 let policy = handle.await.expect("task joined").expect("policy present");
222 assert_eq!(policy.version, 7);
223 }
224
225 assert_eq!(cache.inner().call_count(), 1);
227 }
228}