Skip to main content

rs_zero/cache/
aside.rs

1use std::{collections::HashMap, future::Future, sync::Arc, time::Duration};
2
3use serde::{Serialize, de::DeserializeOwned};
4use tokio::sync::Mutex;
5
6use crate::cache::{CacheKey, CacheResult, CacheStats, CacheStore, jitter_ttl};
7
8const NOT_FOUND_PLACEHOLDER: &[u8] = b"__rs_zero_not_found__";
9
10/// Cache-aside strategy configuration.
11#[derive(Debug, Clone, PartialEq)]
12pub struct CacheAsideConfig {
13    /// TTL for positive values loaded from the source of truth.
14    pub value_ttl: Duration,
15    /// TTL for negative values returned as not found.
16    pub not_found_ttl: Duration,
17    /// Per-key TTL jitter ratio.
18    pub ttl_jitter_ratio: f64,
19}
20
21impl Default for CacheAsideConfig {
22    fn default() -> Self {
23        Self {
24            value_ttl: Duration::from_secs(300),
25            not_found_ttl: Duration::from_secs(60),
26            ttl_jitter_ratio: 0.05,
27        }
28    }
29}
30
31/// High-level cache-aside client with singleflight and negative caching.
32#[derive(Debug, Clone)]
33pub struct CacheAside<S> {
34    store: S,
35    config: CacheAsideConfig,
36    stats: CacheStats,
37    locks: Arc<Mutex<HashMap<String, Arc<Mutex<()>>>>>,
38    #[cfg(feature = "observability")]
39    metrics: Option<crate::observability::MetricsRegistry>,
40}
41
42impl<S> CacheAside<S> {
43    /// Creates a cache-aside helper around the given store.
44    pub fn new(store: S, config: CacheAsideConfig) -> Self {
45        Self {
46            store,
47            config,
48            stats: CacheStats::default(),
49            locks: Arc::new(Mutex::new(HashMap::new())),
50            #[cfg(feature = "observability")]
51            metrics: None,
52        }
53    }
54
55    /// Returns shared cache statistics.
56    pub fn stats(&self) -> CacheStats {
57        self.stats.clone()
58    }
59
60    /// Attaches a metrics registry to this cache-aside helper.
61    #[cfg(feature = "observability")]
62    pub fn with_metrics(mut self, metrics: crate::observability::MetricsRegistry) -> Self {
63        self.metrics = Some(metrics);
64        self
65    }
66
67    fn record_event(&self, operation: &str, result: &str) {
68        #[cfg(feature = "observability")]
69        crate::observability::cache::record_cache_event(
70            self.metrics.as_ref(),
71            "cache_aside",
72            operation,
73            result,
74        );
75
76        #[cfg(not(feature = "observability"))]
77        {
78            let _ = (operation, result);
79        }
80    }
81}
82
83impl<S> CacheAside<S>
84where
85    S: CacheStore,
86{
87    /// Deletes a cached value and records delete failures in shared stats.
88    pub async fn delete(&self, key: &CacheKey) -> CacheResult<()> {
89        match self.store.delete(key).await {
90            Ok(()) => {
91                self.record_event("delete", "success");
92                Ok(())
93            }
94            Err(error) => {
95                self.stats.record_delete_error();
96                self.record_event("delete", "error");
97                Err(error)
98            }
99        }
100    }
101
102    /// Reads JSON from cache or loads it once for concurrent misses.
103    pub async fn get_or_load_json<T, F, Fut>(
104        &self,
105        key: &CacheKey,
106        loader: F,
107    ) -> CacheResult<Option<T>>
108    where
109        T: DeserializeOwned + Serialize + Send + Sync,
110        F: FnOnce() -> Fut + Send,
111        Fut: Future<Output = CacheResult<Option<T>>> + Send,
112    {
113        if let Some(value) = self.read_cached_json(key).await? {
114            return Ok(value);
115        }
116
117        self.stats.record_miss();
118        self.record_event("get", "miss");
119        let rendered = key.render();
120        let lock = self.key_lock(&rendered).await;
121        let guard = lock.lock().await;
122
123        if let Some(value) = self.read_cached_json(key).await? {
124            drop(guard);
125            self.release_key_lock(&rendered, &lock).await;
126            return Ok(value);
127        }
128
129        let loaded = loader().await.inspect_err(|_| {
130            self.stats.record_loader_error();
131            self.record_event("load", "error");
132        })?;
133        match loaded.as_ref() {
134            Some(value) => self.write_json(key, value).await?,
135            None => self.write_not_found(key).await?,
136        }
137
138        drop(guard);
139        self.release_key_lock(&rendered, &lock).await;
140        Ok(loaded)
141    }
142
143    async fn read_cached_json<T>(&self, key: &CacheKey) -> CacheResult<Option<Option<T>>>
144    where
145        T: DeserializeOwned + Send,
146    {
147        let Some(bytes) = self.store.get_raw(key).await? else {
148            return Ok(None);
149        };
150
151        if bytes == NOT_FOUND_PLACEHOLDER {
152            self.stats.record_negative_hit();
153            self.record_event("get", "negative_hit");
154            return Ok(Some(None));
155        }
156
157        match serde_json::from_slice(&bytes) {
158            Ok(value) => {
159                self.stats.record_hit();
160                self.record_event("get", "hit");
161                Ok(Some(Some(value)))
162            }
163            Err(_) => {
164                self.record_event("get", "corrupt");
165                if self.store.delete(key).await.is_err() {
166                    self.stats.record_delete_error();
167                    self.record_event("delete", "corrupt_error");
168                } else {
169                    self.record_event("delete", "corrupt");
170                }
171                Ok(None)
172            }
173        }
174    }
175
176    async fn write_json<T>(&self, key: &CacheKey, value: &T) -> CacheResult<()>
177    where
178        T: Serialize + Sync,
179    {
180        let ttl = jitter_ttl(
181            self.config.value_ttl,
182            self.config.ttl_jitter_ratio,
183            key.render(),
184        );
185        let bytes = serde_json::to_vec(value)?;
186        match self.store.set_raw(key, bytes, Some(ttl)).await {
187            Ok(()) => {
188                self.record_event("set", "success");
189                Ok(())
190            }
191            Err(error) => {
192                self.stats.record_set_error();
193                self.record_event("set", "error");
194                Err(error)
195            }
196        }
197    }
198
199    async fn write_not_found(&self, key: &CacheKey) -> CacheResult<()> {
200        let ttl = jitter_ttl(
201            self.config.not_found_ttl,
202            self.config.ttl_jitter_ratio,
203            key.render(),
204        );
205        match self
206            .store
207            .set_raw(key, NOT_FOUND_PLACEHOLDER.to_vec(), Some(ttl))
208            .await
209        {
210            Ok(()) => {
211                self.record_event("set", "negative");
212                Ok(())
213            }
214            Err(error) => {
215                self.stats.record_set_error();
216                self.record_event("set", "error");
217                Err(error)
218            }
219        }
220    }
221
222    async fn key_lock(&self, rendered: &str) -> Arc<Mutex<()>> {
223        let mut locks = self.locks.lock().await;
224        locks
225            .entry(rendered.to_string())
226            .or_insert_with(|| Arc::new(Mutex::new(())))
227            .clone()
228    }
229
230    async fn release_key_lock(&self, rendered: &str, lock: &Arc<Mutex<()>>) {
231        let mut locks = self.locks.lock().await;
232        if locks
233            .get(rendered)
234            .is_some_and(|current| Arc::ptr_eq(current, lock) && Arc::strong_count(lock) == 2)
235        {
236            locks.remove(rendered);
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use std::{
244        sync::{
245            Arc,
246            atomic::{AtomicUsize, Ordering},
247        },
248        time::Duration,
249    };
250
251    use crate::cache::{CacheAside, CacheAsideConfig, CacheKey, CacheStore, MemoryCacheStore};
252
253    #[tokio::test]
254    async fn cache_aside_merges_concurrent_misses() {
255        let client = CacheAside::new(
256            MemoryCacheStore::new(),
257            CacheAsideConfig {
258                value_ttl: Duration::from_secs(60),
259                ..CacheAsideConfig::default()
260            },
261        );
262        let key = CacheKey::new("app", ["user", "42"]);
263        let calls = Arc::new(AtomicUsize::new(0));
264
265        let mut handles = Vec::new();
266        for _ in 0..8 {
267            let client = client.clone();
268            let key = key.clone();
269            let calls = calls.clone();
270            handles.push(tokio::spawn(async move {
271                client
272                    .get_or_load_json(&key, || async move {
273                        calls.fetch_add(1, Ordering::SeqCst);
274                        tokio::time::sleep(Duration::from_millis(20)).await;
275                        Ok(Some(serde_json::json!({"id":42})))
276                    })
277                    .await
278                    .expect("load")
279            }));
280        }
281
282        for handle in handles {
283            assert_eq!(handle.await.expect("join").expect("value")["id"], 42);
284        }
285        assert_eq!(calls.load(Ordering::SeqCst), 1);
286    }
287
288    #[tokio::test]
289    async fn cache_aside_uses_negative_cache() {
290        let client = CacheAside::new(MemoryCacheStore::new(), CacheAsideConfig::default());
291        let key = CacheKey::new("app", ["missing"]);
292        let calls = Arc::new(AtomicUsize::new(0));
293
294        for _ in 0..2 {
295            let calls = calls.clone();
296            let value: Option<serde_json::Value> = client
297                .get_or_load_json(&key, || async move {
298                    calls.fetch_add(1, Ordering::SeqCst);
299                    Ok(None)
300                })
301                .await
302                .expect("load");
303            assert!(value.is_none());
304        }
305
306        assert_eq!(calls.load(Ordering::SeqCst), 1);
307        assert_eq!(client.stats().snapshot().negative_hits, 1);
308    }
309
310    #[tokio::test]
311    async fn cache_aside_deletes_corrupt_value_and_reloads() {
312        let store = MemoryCacheStore::new();
313        let client = CacheAside::new(store.clone(), CacheAsideConfig::default());
314        let key = CacheKey::new("app", ["corrupt"]);
315        let calls = Arc::new(AtomicUsize::new(0));
316
317        store
318            .set_raw(&key, b"{not-json".to_vec(), None)
319            .await
320            .expect("set corrupt");
321
322        let value: Option<serde_json::Value> = client
323            .get_or_load_json(&key, || {
324                let calls = calls.clone();
325                async move {
326                    calls.fetch_add(1, Ordering::SeqCst);
327                    Ok(Some(serde_json::json!({"fresh": true})))
328                }
329            })
330            .await
331            .expect("reload");
332
333        assert_eq!(value.expect("value")["fresh"], true);
334        assert_eq!(calls.load(Ordering::SeqCst), 1);
335        let cached: serde_json::Value = store.get_json(&key).await.expect("cache").expect("value");
336        assert_eq!(cached["fresh"], true);
337    }
338}