1use std::{collections::HashMap, future::Future, sync::Arc, time::Duration};
2
3use serde::{Serialize, de::DeserializeOwned};
4use tokio::sync::Mutex;
5
6use crate::cache::{CacheKey, CacheResult, CacheStats, CacheStore, jitter_ttl};
7
8const NOT_FOUND_PLACEHOLDER: &[u8] = b"__rs_zero_not_found__";
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct CacheAsideConfig {
13 pub value_ttl: Duration,
15 pub not_found_ttl: Duration,
17 pub ttl_jitter_ratio: f64,
19}
20
21impl Default for CacheAsideConfig {
22 fn default() -> Self {
23 Self {
24 value_ttl: Duration::from_secs(300),
25 not_found_ttl: Duration::from_secs(60),
26 ttl_jitter_ratio: 0.05,
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
33pub struct CacheAside<S> {
34 store: S,
35 config: CacheAsideConfig,
36 stats: CacheStats,
37 locks: Arc<Mutex<HashMap<String, Arc<Mutex<()>>>>>,
38 #[cfg(feature = "observability")]
39 metrics: Option<crate::observability::MetricsRegistry>,
40}
41
42impl<S> CacheAside<S> {
43 pub fn new(store: S, config: CacheAsideConfig) -> Self {
45 Self {
46 store,
47 config,
48 stats: CacheStats::default(),
49 locks: Arc::new(Mutex::new(HashMap::new())),
50 #[cfg(feature = "observability")]
51 metrics: None,
52 }
53 }
54
55 pub fn stats(&self) -> CacheStats {
57 self.stats.clone()
58 }
59
60 #[cfg(feature = "observability")]
62 pub fn with_metrics(mut self, metrics: crate::observability::MetricsRegistry) -> Self {
63 self.metrics = Some(metrics);
64 self
65 }
66
67 fn record_event(&self, operation: &str, result: &str) {
68 #[cfg(feature = "observability")]
69 crate::observability::cache::record_cache_event(
70 self.metrics.as_ref(),
71 "cache_aside",
72 operation,
73 result,
74 );
75
76 #[cfg(not(feature = "observability"))]
77 {
78 let _ = (operation, result);
79 }
80 }
81}
82
83impl<S> CacheAside<S>
84where
85 S: CacheStore,
86{
87 pub async fn delete(&self, key: &CacheKey) -> CacheResult<()> {
89 match self.store.delete(key).await {
90 Ok(()) => {
91 self.record_event("delete", "success");
92 Ok(())
93 }
94 Err(error) => {
95 self.stats.record_delete_error();
96 self.record_event("delete", "error");
97 Err(error)
98 }
99 }
100 }
101
102 pub async fn get_or_load_json<T, F, Fut>(
104 &self,
105 key: &CacheKey,
106 loader: F,
107 ) -> CacheResult<Option<T>>
108 where
109 T: DeserializeOwned + Serialize + Send + Sync,
110 F: FnOnce() -> Fut + Send,
111 Fut: Future<Output = CacheResult<Option<T>>> + Send,
112 {
113 if let Some(value) = self.read_cached_json(key).await? {
114 return Ok(value);
115 }
116
117 self.stats.record_miss();
118 self.record_event("get", "miss");
119 let rendered = key.render();
120 let lock = self.key_lock(&rendered).await;
121 let guard = lock.lock().await;
122
123 if let Some(value) = self.read_cached_json(key).await? {
124 drop(guard);
125 self.release_key_lock(&rendered, &lock).await;
126 return Ok(value);
127 }
128
129 let loaded = loader().await.inspect_err(|_| {
130 self.stats.record_loader_error();
131 self.record_event("load", "error");
132 })?;
133 match loaded.as_ref() {
134 Some(value) => self.write_json(key, value).await?,
135 None => self.write_not_found(key).await?,
136 }
137
138 drop(guard);
139 self.release_key_lock(&rendered, &lock).await;
140 Ok(loaded)
141 }
142
143 async fn read_cached_json<T>(&self, key: &CacheKey) -> CacheResult<Option<Option<T>>>
144 where
145 T: DeserializeOwned + Send,
146 {
147 let Some(bytes) = self.store.get_raw(key).await? else {
148 return Ok(None);
149 };
150
151 if bytes == NOT_FOUND_PLACEHOLDER {
152 self.stats.record_negative_hit();
153 self.record_event("get", "negative_hit");
154 return Ok(Some(None));
155 }
156
157 match serde_json::from_slice(&bytes) {
158 Ok(value) => {
159 self.stats.record_hit();
160 self.record_event("get", "hit");
161 Ok(Some(Some(value)))
162 }
163 Err(_) => {
164 self.record_event("get", "corrupt");
165 if self.store.delete(key).await.is_err() {
166 self.stats.record_delete_error();
167 self.record_event("delete", "corrupt_error");
168 } else {
169 self.record_event("delete", "corrupt");
170 }
171 Ok(None)
172 }
173 }
174 }
175
176 async fn write_json<T>(&self, key: &CacheKey, value: &T) -> CacheResult<()>
177 where
178 T: Serialize + Sync,
179 {
180 let ttl = jitter_ttl(
181 self.config.value_ttl,
182 self.config.ttl_jitter_ratio,
183 key.render(),
184 );
185 let bytes = serde_json::to_vec(value)?;
186 match self.store.set_raw(key, bytes, Some(ttl)).await {
187 Ok(()) => {
188 self.record_event("set", "success");
189 Ok(())
190 }
191 Err(error) => {
192 self.stats.record_set_error();
193 self.record_event("set", "error");
194 Err(error)
195 }
196 }
197 }
198
199 async fn write_not_found(&self, key: &CacheKey) -> CacheResult<()> {
200 let ttl = jitter_ttl(
201 self.config.not_found_ttl,
202 self.config.ttl_jitter_ratio,
203 key.render(),
204 );
205 match self
206 .store
207 .set_raw(key, NOT_FOUND_PLACEHOLDER.to_vec(), Some(ttl))
208 .await
209 {
210 Ok(()) => {
211 self.record_event("set", "negative");
212 Ok(())
213 }
214 Err(error) => {
215 self.stats.record_set_error();
216 self.record_event("set", "error");
217 Err(error)
218 }
219 }
220 }
221
222 async fn key_lock(&self, rendered: &str) -> Arc<Mutex<()>> {
223 let mut locks = self.locks.lock().await;
224 locks
225 .entry(rendered.to_string())
226 .or_insert_with(|| Arc::new(Mutex::new(())))
227 .clone()
228 }
229
230 async fn release_key_lock(&self, rendered: &str, lock: &Arc<Mutex<()>>) {
231 let mut locks = self.locks.lock().await;
232 if locks
233 .get(rendered)
234 .is_some_and(|current| Arc::ptr_eq(current, lock) && Arc::strong_count(lock) == 2)
235 {
236 locks.remove(rendered);
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use std::{
244 sync::{
245 Arc,
246 atomic::{AtomicUsize, Ordering},
247 },
248 time::Duration,
249 };
250
251 use crate::cache::{CacheAside, CacheAsideConfig, CacheKey, CacheStore, MemoryCacheStore};
252
253 #[tokio::test]
254 async fn cache_aside_merges_concurrent_misses() {
255 let client = CacheAside::new(
256 MemoryCacheStore::new(),
257 CacheAsideConfig {
258 value_ttl: Duration::from_secs(60),
259 ..CacheAsideConfig::default()
260 },
261 );
262 let key = CacheKey::new("app", ["user", "42"]);
263 let calls = Arc::new(AtomicUsize::new(0));
264
265 let mut handles = Vec::new();
266 for _ in 0..8 {
267 let client = client.clone();
268 let key = key.clone();
269 let calls = calls.clone();
270 handles.push(tokio::spawn(async move {
271 client
272 .get_or_load_json(&key, || async move {
273 calls.fetch_add(1, Ordering::SeqCst);
274 tokio::time::sleep(Duration::from_millis(20)).await;
275 Ok(Some(serde_json::json!({"id":42})))
276 })
277 .await
278 .expect("load")
279 }));
280 }
281
282 for handle in handles {
283 assert_eq!(handle.await.expect("join").expect("value")["id"], 42);
284 }
285 assert_eq!(calls.load(Ordering::SeqCst), 1);
286 }
287
288 #[tokio::test]
289 async fn cache_aside_uses_negative_cache() {
290 let client = CacheAside::new(MemoryCacheStore::new(), CacheAsideConfig::default());
291 let key = CacheKey::new("app", ["missing"]);
292 let calls = Arc::new(AtomicUsize::new(0));
293
294 for _ in 0..2 {
295 let calls = calls.clone();
296 let value: Option<serde_json::Value> = client
297 .get_or_load_json(&key, || async move {
298 calls.fetch_add(1, Ordering::SeqCst);
299 Ok(None)
300 })
301 .await
302 .expect("load");
303 assert!(value.is_none());
304 }
305
306 assert_eq!(calls.load(Ordering::SeqCst), 1);
307 assert_eq!(client.stats().snapshot().negative_hits, 1);
308 }
309
310 #[tokio::test]
311 async fn cache_aside_deletes_corrupt_value_and_reloads() {
312 let store = MemoryCacheStore::new();
313 let client = CacheAside::new(store.clone(), CacheAsideConfig::default());
314 let key = CacheKey::new("app", ["corrupt"]);
315 let calls = Arc::new(AtomicUsize::new(0));
316
317 store
318 .set_raw(&key, b"{not-json".to_vec(), None)
319 .await
320 .expect("set corrupt");
321
322 let value: Option<serde_json::Value> = client
323 .get_or_load_json(&key, || {
324 let calls = calls.clone();
325 async move {
326 calls.fetch_add(1, Ordering::SeqCst);
327 Ok(Some(serde_json::json!({"fresh": true})))
328 }
329 })
330 .await
331 .expect("reload");
332
333 assert_eq!(value.expect("value")["fresh"], true);
334 assert_eq!(calls.load(Ordering::SeqCst), 1);
335 let cached: serde_json::Value = store.get_json(&key).await.expect("cache").expect("value");
336 assert_eq!(cached["fresh"], true);
337 }
338}