1use std::hash::{Hash, Hasher};
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18
19use bytes::Bytes;
20use moka::Expiry;
21use moka::sync::Cache;
22
23use crate::Cache as CacheTrait;
24
25#[derive(Clone, Debug, Eq)]
29struct CacheKey(Bytes);
30
31impl CacheKey {
32 #[inline]
33 fn new(key: &[u8]) -> Self {
34 Self(Bytes::copy_from_slice(key))
35 }
36
37 #[inline]
38 fn as_slice(&self) -> &[u8] {
39 &self.0
40 }
41}
42
43impl PartialEq for CacheKey {
44 #[inline]
45 fn eq(&self, other: &Self) -> bool {
46 self.0 == other.0
47 }
48}
49
50impl Hash for CacheKey {
51 #[inline]
52 fn hash<H: Hasher>(&self, state: &mut H) {
53 self.0.hash(state);
54 }
55}
56
57#[derive(PartialEq, Eq)]
62struct BorrowedKey<'a>(&'a [u8]);
63
64impl Hash for BorrowedKey<'_> {
65 #[inline]
66 fn hash<H: Hasher>(&self, state: &mut H) {
67 self.0.hash(state);
68 }
69}
70
71impl equivalent::Equivalent<CacheKey> for BorrowedKey<'_> {
72 #[inline]
73 fn equivalent(&self, key: &CacheKey) -> bool {
74 self.0 == key.as_slice()
75 }
76}
77
78#[derive(Clone, Debug)]
82struct CacheEntry {
83 value: Bytes,
84 expires_at: Instant,
85}
86
87struct PerEntryExpiry;
92
93impl Expiry<CacheKey, CacheEntry> for PerEntryExpiry {
94 fn expire_after_create(
96 &self,
97 _key: &CacheKey,
98 value: &CacheEntry,
99 _current_time: Instant,
100 ) -> Option<Duration> {
101 let now = Instant::now();
102 if value.expires_at > now {
103 Some(value.expires_at.duration_since(now))
104 } else {
105 Some(Duration::ZERO)
107 }
108 }
109
110 fn expire_after_read(
112 &self,
113 _key: &CacheKey,
114 _value: &CacheEntry,
115 _current_time: Instant,
116 current_duration: Option<Duration>,
117 _last_modified_at: Instant,
118 ) -> Option<Duration> {
119 current_duration
121 }
122
123 fn expire_after_update(
125 &self,
126 _key: &CacheKey,
127 value: &CacheEntry,
128 _current_time: Instant,
129 _current_duration: Option<Duration>,
130 ) -> Option<Duration> {
131 let now = Instant::now();
132 if value.expires_at > now {
133 Some(value.expires_at.duration_since(now))
134 } else {
135 Some(Duration::ZERO)
136 }
137 }
138}
139
140#[derive(Clone, Debug)]
168pub struct LocalCache {
169 inner: Arc<Cache<CacheKey, CacheEntry>>,
170}
171
172impl LocalCache {
173 pub fn new(capacity: u64, default_ttl: Duration) -> Self {
189 let cache = Cache::builder()
190 .max_capacity(capacity.max(1))
191 .time_to_live(default_ttl)
193 .expire_after(PerEntryExpiry)
195 .build();
196 Self {
197 inner: Arc::new(cache),
198 }
199 }
200
201 #[inline]
211 pub fn contains_sync(&self, key: &[u8]) -> bool {
212 self.inner.contains_key(&BorrowedKey(key))
214 }
215
216 #[inline]
225 pub fn get_sync(&self, key: &[u8]) -> Option<Bytes> {
226 self.inner.get(&BorrowedKey(key)).map(|entry| entry.value)
227 }
228}
229
230impl CacheTrait for LocalCache {
231 fn set_nx_px(
244 &self,
245 key: &[u8],
246 value: &[u8],
247 ttl: Duration,
248 ) -> impl Future<Output = anyhow::Result<bool>> + Send {
249 let cache_key = CacheKey::new(key);
251 let entry = CacheEntry {
252 value: Bytes::copy_from_slice(value),
253 expires_at: Instant::now() + ttl,
254 };
255 let inner = Arc::clone(&self.inner);
256
257 async move {
258 let result = inner.entry(cache_key).or_insert(entry);
261
262 Ok(result.is_fresh())
265 }
266 }
267
268 fn set(
276 &self,
277 key: &[u8],
278 value: &[u8],
279 ttl: Duration,
280 ) -> impl Future<Output = anyhow::Result<()>> + Send {
281 let cache_key = CacheKey::new(key);
283 let entry = CacheEntry {
284 value: Bytes::copy_from_slice(value),
285 expires_at: Instant::now() + ttl,
286 };
287 let inner = Arc::clone(&self.inner);
288
289 async move {
290 inner.insert(cache_key, entry);
291 Ok(())
292 }
293 }
294
295 fn get(&self, key: &[u8]) -> impl Future<Output = anyhow::Result<Option<Vec<u8>>>> + Send {
302 let result = self.inner.get(&BorrowedKey(key)).map(|entry| entry.value);
304 async move { Ok(result.map(|bytes| bytes.to_vec())) }
305 }
306
307 fn del(&self, key: &[u8]) -> impl Future<Output = anyhow::Result<()>> + Send {
309 self.inner.invalidate(&BorrowedKey(key));
311 async move { Ok(()) }
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[tokio::test]
320 async fn test_set_and_get() {
321 let cache = LocalCache::new(100, Duration::from_secs(60));
322
323 cache
324 .set(b"key1", b"value1", Duration::from_secs(60))
325 .await
326 .unwrap();
327
328 let result = cache.get(b"key1").await.unwrap();
329 assert_eq!(result, Some(b"value1".to_vec()));
330 }
331
332 #[tokio::test]
333 async fn test_get_nonexistent() {
334 let cache = LocalCache::new(100, Duration::from_secs(60));
335
336 let result = cache.get(b"nonexistent").await.unwrap();
337 assert_eq!(result, None);
338 }
339
340 #[tokio::test]
341 async fn test_set_nx_px_new_key() {
342 let cache = LocalCache::new(100, Duration::from_secs(60));
343
344 let was_set = cache
345 .set_nx_px(b"key1", b"value1", Duration::from_secs(60))
346 .await
347 .unwrap();
348 assert!(was_set, "Expected key to be set (new key)");
349
350 let result = cache.get(b"key1").await.unwrap();
351 assert_eq!(result, Some(b"value1".to_vec()));
352 }
353
354 #[tokio::test]
355 async fn test_set_nx_px_existing_key() {
356 let cache = LocalCache::new(100, Duration::from_secs(60));
357
358 let was_set1 = cache
360 .set_nx_px(b"key1", b"value1", Duration::from_secs(60))
361 .await
362 .unwrap();
363 assert!(was_set1);
364
365 let was_set2 = cache
367 .set_nx_px(b"key1", b"value2", Duration::from_secs(60))
368 .await
369 .unwrap();
370 assert!(!was_set2, "Expected key NOT to be set (key exists)");
371
372 let result = cache.get(b"key1").await.unwrap();
374 assert_eq!(result, Some(b"value1".to_vec()));
375 }
376
377 #[tokio::test]
378 async fn test_del() {
379 let cache = LocalCache::new(100, Duration::from_secs(60));
380
381 cache
382 .set(b"key1", b"value1", Duration::from_secs(60))
383 .await
384 .unwrap();
385 cache.del(b"key1").await.unwrap();
386
387 let result = cache.get(b"key1").await.unwrap();
388 assert_eq!(result, None);
389 }
390
391 #[tokio::test]
392 async fn test_contains_sync() {
393 let cache = LocalCache::new(100, Duration::from_secs(60));
394
395 assert!(!cache.contains_sync(b"key1"));
396
397 cache
398 .set(b"key1", b"value1", Duration::from_secs(60))
399 .await
400 .unwrap();
401
402 assert!(cache.contains_sync(b"key1"));
403 }
404
405 #[tokio::test]
406 async fn test_get_sync() {
407 let cache = LocalCache::new(100, Duration::from_secs(60));
408
409 assert!(cache.get_sync(b"key1").is_none());
410
411 cache
412 .set(b"key1", b"value1", Duration::from_secs(60))
413 .await
414 .unwrap();
415
416 let result = cache.get_sync(b"key1");
417 assert_eq!(result, Some(Bytes::from_static(b"value1")));
418 }
419
420 #[tokio::test]
421 async fn test_per_entry_ttl_respected() {
422 let cache = LocalCache::new(100, Duration::from_secs(60));
423
424 cache
426 .set(b"short_ttl", b"value", Duration::from_millis(50))
427 .await
428 .unwrap();
429
430 let result = cache.get(b"short_ttl").await.unwrap();
432 assert_eq!(result, Some(b"value".to_vec()));
433
434 tokio::time::sleep(Duration::from_millis(100)).await;
436
437 let result = cache.get(b"short_ttl").await.unwrap();
439 assert_eq!(result, None, "Entry should have expired after TTL");
440 }
441
442 #[tokio::test]
443 async fn test_different_ttls_for_different_keys() {
444 let cache = LocalCache::new(100, Duration::from_secs(60));
445
446 cache
448 .set(b"short", b"value1", Duration::from_millis(50))
449 .await
450 .unwrap();
451 cache
452 .set(b"long", b"value2", Duration::from_secs(10))
453 .await
454 .unwrap();
455
456 assert!(cache.get(b"short").await.unwrap().is_some());
458 assert!(cache.get(b"long").await.unwrap().is_some());
459
460 tokio::time::sleep(Duration::from_millis(100)).await;
462
463 assert!(
465 cache.get(b"short").await.unwrap().is_none(),
466 "Short TTL entry should have expired"
467 );
468 assert!(
469 cache.get(b"long").await.unwrap().is_some(),
470 "Long TTL entry should still exist"
471 );
472 }
473
474 #[tokio::test]
475 async fn test_set_nx_px_ttl_respected() {
476 let cache = LocalCache::new(100, Duration::from_secs(60));
477
478 let was_set = cache
480 .set_nx_px(b"key", b"value", Duration::from_millis(50))
481 .await
482 .unwrap();
483 assert!(was_set);
484
485 tokio::time::sleep(Duration::from_millis(100)).await;
487
488 let result = cache.get(b"key").await.unwrap();
490 assert_eq!(result, None, "Entry should have expired after TTL");
491
492 let was_set_again = cache
494 .set_nx_px(b"key", b"new_value", Duration::from_secs(60))
495 .await
496 .unwrap();
497 assert!(
498 was_set_again,
499 "Should be able to set after previous entry expired"
500 );
501
502 let result = cache.get(b"key").await.unwrap();
503 assert_eq!(result, Some(b"new_value".to_vec()));
504 }
505}