1use crate::error::CheckpointError;
6use lru::LruCache;
7use std::num::NonZeroUsize;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::RwLock;
11
12struct CacheEntry {
14 data: Vec<u8>,
16
17 expires_at: Option<std::time::Instant>,
19}
20
21impl CacheEntry {
22 #[must_use]
24 fn is_expired(&self) -> bool {
25 self.expires_at
26 .is_some_and(|expires_at| std::time::Instant::now() >= expires_at)
27 }
28}
29
30#[async_trait::async_trait]
34pub trait BaseCache: Send + Sync + 'static {
35 async fn get(&self, namespace: &str, key: &str) -> Result<Option<Vec<u8>>, CheckpointError>;
41
42 async fn set(
48 &self,
49 namespace: &str,
50 key: &str,
51 value: Vec<u8>,
52 ttl: Option<Duration>,
53 ) -> Result<(), CheckpointError>;
54
55 async fn delete(&self, namespace: &str, key: &str) -> Result<(), CheckpointError>;
61
62 async fn clear(&self, namespace: Option<&str>) -> Result<(), CheckpointError>;
68}
69
70#[derive(Clone, Debug)]
75pub struct MemoryCache {
76 entries: Arc<RwLock<LruCache<String, CacheEntry>>>,
78
79 default_ttl: Option<Duration>,
81}
82
83impl MemoryCache {
84 #[must_use]
90 pub fn new(capacity: usize) -> Self {
91 Self {
92 entries: Arc::new(RwLock::new(LruCache::new(
93 NonZeroUsize::new(capacity).expect("capacity must be non-zero"),
94 ))),
95 default_ttl: None,
96 }
97 }
98
99 #[must_use]
105 pub fn with_ttl(capacity: usize, default_ttl: Duration) -> Self {
106 Self {
107 entries: Arc::new(RwLock::new(LruCache::new(
108 NonZeroUsize::new(capacity).expect("capacity must be non-zero"),
109 ))),
110 default_ttl: Some(default_ttl),
111 }
112 }
113
114 #[must_use]
116 fn build_key(namespace: &str, key: &str) -> String {
117 format!("{namespace}:{key}")
118 }
119
120 async fn purge_expired(&self) {
125 let mut cache = self.entries.write().await;
126 let expired_keys: Vec<String> = cache
127 .iter()
128 .filter(|(_, entry)| entry.is_expired())
129 .map(|(key, _)| key.clone())
130 .collect();
131
132 for key in expired_keys {
133 cache.pop(&key);
134 }
135 }
136
137 pub async fn stats(&self) -> (usize, usize) {
141 let cache = self.entries.read().await;
142 (cache.len(), cache.cap().get())
143 }
144}
145
146impl Default for MemoryCache {
147 fn default() -> Self {
148 Self::new(1000)
149 }
150}
151
152#[async_trait::async_trait]
153impl BaseCache for MemoryCache {
154 async fn get(&self, namespace: &str, key: &str) -> Result<Option<Vec<u8>>, CheckpointError> {
155 self.purge_expired().await;
157
158 let cache_key = Self::build_key(namespace, key);
159 {
160 let mut cache = self.entries.write().await;
161
162 if let Some(entry) = cache.get_mut(&cache_key) {
163 if entry.is_expired() {
164 cache.pop(&cache_key);
165 drop(cache);
166 return Ok(None);
167 }
168 let result = Ok(Some(entry.data.clone()));
169 drop(cache);
170 return result;
171 }
172 }
173
174 Ok(None)
175 }
176
177 async fn set(
178 &self,
179 namespace: &str,
180 key: &str,
181 value: Vec<u8>,
182 ttl: Option<Duration>,
183 ) -> Result<(), CheckpointError> {
184 let cache_key = Self::build_key(namespace, key);
185 let ttl = ttl.or(self.default_ttl);
186
187 let entry = CacheEntry {
188 data: value,
189 expires_at: ttl.map(|duration| std::time::Instant::now() + duration),
190 };
191
192 self.entries.write().await.put(cache_key, entry);
193
194 Ok(())
195 }
196
197 async fn delete(&self, namespace: &str, key: &str) -> Result<(), CheckpointError> {
198 let cache_key = Self::build_key(namespace, key);
199 self.entries.write().await.pop(&cache_key);
200 Ok(())
201 }
202
203 async fn clear(&self, namespace: Option<&str>) -> Result<(), CheckpointError> {
204 if let Some(ns) = namespace {
205 let prefix = format!("{ns}:");
207 let mut cache = self.entries.write().await;
208 let keys_to_remove: Vec<String> = cache
209 .iter()
210 .filter(|(key, _)| key.starts_with(&prefix))
211 .map(|(key, _)| key.clone())
212 .collect();
213
214 for key in keys_to_remove {
215 cache.pop(&key);
216 }
217 } else {
218 self.entries.write().await.clear();
220 }
221
222 Ok(())
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[tokio::test]
231 async fn test_memory_cache_set_get() {
232 let cache = MemoryCache::new(10);
233
234 cache
235 .set("ns1", "key1", b"hello".to_vec(), None)
236 .await
237 .unwrap();
238
239 let value = cache.get("ns1", "key1").await.unwrap();
240 assert_eq!(value, Some(b"hello".to_vec()));
241 }
242
243 #[tokio::test]
244 async fn test_memory_cache_miss() {
245 let cache = MemoryCache::new(10);
246
247 let value = cache.get("ns1", "nonexistent").await.unwrap();
248 assert!(value.is_none());
249 }
250
251 #[tokio::test]
252 async fn test_memory_cache_delete() {
253 let cache = MemoryCache::new(10);
254
255 cache
256 .set("ns1", "key1", b"hello".to_vec(), None)
257 .await
258 .unwrap();
259
260 cache.delete("ns1", "key1").await.unwrap();
261
262 let value = cache.get("ns1", "key1").await.unwrap();
263 assert!(value.is_none());
264 }
265
266 #[tokio::test]
267 async fn test_memory_cache_ttl() {
268 let cache = MemoryCache::with_ttl(10, Duration::from_millis(100));
269
270 cache
271 .set("ns1", "key1", b"hello".to_vec(), None)
272 .await
273 .unwrap();
274
275 let value = cache.get("ns1", "key1").await.unwrap();
277 assert_eq!(value, Some(b"hello".to_vec()));
278
279 tokio::time::sleep(Duration::from_millis(150)).await;
281
282 let value = cache.get("ns1", "key1").await.unwrap();
284 assert!(value.is_none());
285 }
286
287 #[tokio::test]
288 async fn test_memory_cache_clear_namespace() {
289 let cache = MemoryCache::new(10);
290
291 cache
292 .set("ns1", "key1", b"data1".to_vec(), None)
293 .await
294 .unwrap();
295 cache
296 .set("ns2", "key2", b"data2".to_vec(), None)
297 .await
298 .unwrap();
299
300 cache.clear(Some("ns1")).await.unwrap();
301
302 assert!(cache.get("ns1", "key1").await.unwrap().is_none());
303 assert_eq!(
304 cache.get("ns2", "key2").await.unwrap(),
305 Some(b"data2".to_vec())
306 );
307 }
308
309 #[tokio::test]
310 async fn test_memory_cache_clear_all() {
311 let cache = MemoryCache::new(10);
312
313 cache
314 .set("ns1", "key1", b"data1".to_vec(), None)
315 .await
316 .unwrap();
317 cache
318 .set("ns2", "key2", b"data2".to_vec(), None)
319 .await
320 .unwrap();
321
322 cache.clear(None).await.unwrap();
323
324 assert!(cache.get("ns1", "key1").await.unwrap().is_none());
325 assert!(cache.get("ns2", "key2").await.unwrap().is_none());
326 }
327
328 #[tokio::test]
329 async fn test_memory_cache_lru_eviction() {
330 let cache = MemoryCache::new(2);
331
332 cache
333 .set("ns1", "key1", b"data1".to_vec(), None)
334 .await
335 .unwrap();
336 cache
337 .set("ns1", "key2", b"data2".to_vec(), None)
338 .await
339 .unwrap();
340
341 cache.get("ns1", "key1").await.unwrap();
343
344 cache
346 .set("ns1", "key3", b"data3".to_vec(), None)
347 .await
348 .unwrap();
349
350 assert_eq!(
351 cache.get("ns1", "key1").await.unwrap(),
352 Some(b"data1".to_vec())
353 );
354 assert!(cache.get("ns1", "key2").await.unwrap().is_none());
355 assert_eq!(
356 cache.get("ns1", "key3").await.unwrap(),
357 Some(b"data3".to_vec())
358 );
359 }
360
361 #[tokio::test]
362 async fn test_memory_cache_stats() {
363 let cache = MemoryCache::new(100);
364
365 cache
366 .set("ns1", "key1", b"data1".to_vec(), None)
367 .await
368 .unwrap();
369 cache
370 .set("ns1", "key2", b"data2".to_vec(), None)
371 .await
372 .unwrap();
373
374 let (size, capacity) = cache.stats().await;
375 assert_eq!(size, 2);
376 assert_eq!(capacity, 100);
377 }
378}
379
380