1use crate::{CacheBackend, CacheConfig, CacheResult, CacheStats};
4use async_trait::async_trait;
5use dashmap::DashMap;
6use lru::LruCache;
7use parking_lot::{RwLock, Mutex};
8use std::{
9 num::NonZeroUsize,
10 sync::{
11 atomic::{AtomicU64, Ordering},
12 Arc,
13 },
14 time::{Duration, Instant},
15};
16
17#[derive(Debug)]
19struct CacheEntry {
20 data: Vec<u8>,
21 created_at: Instant,
22 expires_at: Option<Instant>,
23 access_count: AtomicU64,
24 last_accessed: RwLock<Instant>,
25}
26
27impl Clone for CacheEntry {
28 fn clone(&self) -> Self {
29 Self {
30 data: self.data.clone(),
31 created_at: self.created_at,
32 expires_at: self.expires_at,
33 access_count: AtomicU64::new(self.access_count.load(Ordering::Relaxed)),
34 last_accessed: RwLock::new(*self.last_accessed.read()),
35 }
36 }
37}
38
39impl CacheEntry {
40 fn new(data: Vec<u8>, ttl: Option<Duration>) -> Self {
41 let now = Instant::now();
42 Self {
43 data,
44 created_at: now,
45 expires_at: ttl.map(|ttl| now + ttl),
46 access_count: AtomicU64::new(1),
47 last_accessed: RwLock::new(now),
48 }
49 }
50
51 fn is_expired(&self) -> bool {
52 self.expires_at.map_or(false, |exp| Instant::now() > exp)
53 }
54
55 fn access(&self) -> Vec<u8> {
56 self.access_count.fetch_add(1, Ordering::Relaxed);
57 *self.last_accessed.write() = Instant::now();
58 self.data.clone()
59 }
60
61 fn size(&self) -> usize {
62 self.data.len() + std::mem::size_of::<Self>()
63 }
64}
65
66struct LruTracker {
69 cache: RwLock<LruCache<String, ()>>,
70}
71
72impl LruTracker {
73 fn new() -> Self {
74 let capacity = NonZeroUsize::new(1000).expect("1000 is non-zero");
76 Self {
77 cache: RwLock::new(LruCache::new(capacity)),
78 }
79 }
80
81 fn access(&self, key: &str) {
84 let mut cache = self.cache.write();
85
86 if cache.peek(key).is_none() && cache.len() == cache.cap().get() {
88 let new_capacity = NonZeroUsize::new(cache.cap().get() * 2)
89 .expect("doubled capacity should be non-zero");
90 cache.resize(new_capacity);
91 }
92
93 cache.put(key.to_string(), ());
96 }
97
98 fn remove(&self, key: &str) {
101 let mut cache = self.cache.write();
102 cache.pop(key);
103 }
104
105 fn least_recently_used(&self) -> Option<String> {
108 let cache = self.cache.read();
109 cache.iter().next_back().map(|(key, _)| key.clone())
110 }
111
112 fn clear(&self) {
114 let mut cache = self.cache.write();
115 cache.clear();
116 }
117}
118
119pub struct MemoryBackend {
121 entries: DashMap<String, CacheEntry>,
122 lru: LruTracker,
123 config: CacheConfig,
124 stats: Arc<Mutex<CacheStats>>,
125}
126
127impl MemoryBackend {
128 pub fn new(config: CacheConfig) -> Self {
130 Self {
131 entries: DashMap::new(),
132 lru: LruTracker::new(),
133 config,
134 stats: Arc::new(Mutex::new(CacheStats::default())),
135 }
136 }
137
138 fn memory_usage(&self) -> usize {
140 self.entries.iter().map(|entry| entry.value().size()).sum()
141 }
142
143 fn should_evict(&self) -> bool {
145 if let Some(max_entries) = self.config.get_max_entries() {
146 if self.entries.len() >= *max_entries {
147 return true;
148 }
149 }
150
151 if let Some(max_memory) = self.config.get_max_memory() {
152 if self.memory_usage() >= *max_memory {
153 return true;
154 }
155 }
156
157 false
158 }
159
160 async fn evict(&self) -> CacheResult<()> {
162 let expired_keys: Vec<String> = self.entries
164 .iter()
165 .filter_map(|entry| {
166 if entry.value().is_expired() {
167 Some(entry.key().clone())
168 } else {
169 None
170 }
171 })
172 .collect();
173
174 for key in expired_keys {
175 if let Some((_, removed_entry)) = self.entries.remove(&key) {
176 self.lru.remove(&key);
177 let mut stats = self.stats.lock();
178 stats.total_keys = stats.total_keys.saturating_sub(1);
179 stats.memory_usage = stats.memory_usage.saturating_sub(removed_entry.size() as u64);
180 }
181 }
182
183 while self.should_evict() {
185 if let Some(lru_key) = self.lru.least_recently_used() {
186 if let Some((_, removed_entry)) = self.entries.remove(&lru_key) {
187 self.lru.remove(&lru_key);
188 let mut stats = self.stats.lock();
189 stats.total_keys = stats.total_keys.saturating_sub(1);
190 stats.memory_usage = stats.memory_usage.saturating_sub(removed_entry.size() as u64);
191 } else {
192 break;
193 }
194 } else {
195 break;
196 }
197 }
198
199 Ok(())
200 }
201
202 async fn cleanup_expired(&self) {
204 let expired_keys: Vec<String> = self.entries
205 .iter()
206 .filter_map(|entry| {
207 if entry.value().is_expired() {
208 Some(entry.key().clone())
209 } else {
210 None
211 }
212 })
213 .collect();
214
215 for key in expired_keys {
216 if let Some((_, removed_entry)) = self.entries.remove(&key) {
217 self.lru.remove(&key);
218 let mut stats = self.stats.lock();
219 stats.total_keys = stats.total_keys.saturating_sub(1);
220 stats.memory_usage = stats.memory_usage.saturating_sub(removed_entry.size() as u64);
221 }
222 }
223 }
224}
225
226#[async_trait]
227impl CacheBackend for MemoryBackend {
228 async fn get(&self, key: &str) -> CacheResult<Option<Vec<u8>>> {
229 if rand::random::<f64>() < 0.01 { self.cleanup_expired().await;
232 }
233
234 if let Some(entry) = self.entries.get(key) {
235 if entry.is_expired() {
236 let entry_size = entry.size() as u64;
238 drop(entry);
239
240 if self.entries.remove(key).is_some() {
242 self.lru.remove(key);
243
244 let mut stats = self.stats.lock();
246 stats.misses += 1;
247 stats.total_keys = stats.total_keys.saturating_sub(1);
248 stats.memory_usage = stats.memory_usage.saturating_sub(entry_size);
249 }
250
251 return Ok(None);
252 }
253
254 let data = entry.access();
256 self.lru.access(key);
257
258 self.stats.lock().hits += 1;
260
261 Ok(Some(data))
262 } else {
263 self.stats.lock().misses += 1;
265
266 Ok(None)
267 }
268 }
269
270 async fn put(&self, key: &str, value: Vec<u8>, ttl: Option<Duration>) -> CacheResult<()> {
271 if self.should_evict() {
273 self.evict().await?;
274 }
275
276 let entry = CacheEntry::new(value, ttl);
277 let entry_size = entry.size() as u64;
278
279 let old_entry = self.entries.insert(key.to_string(), entry);
281
282 let mut stats = self.stats.lock();
283 if let Some(old_entry) = old_entry {
284 let old_size = old_entry.size() as u64;
286 stats.memory_usage = stats.memory_usage.saturating_sub(old_size) + entry_size;
287 } else {
288 stats.total_keys += 1;
290 stats.memory_usage += entry_size;
291 }
292
293 self.lru.access(key);
295
296 Ok(())
297 }
298
299 async fn forget(&self, key: &str) -> CacheResult<bool> {
300 if let Some((_, removed_entry)) = self.entries.remove(key) {
301 self.lru.remove(key);
302
303 let mut stats = self.stats.lock();
305 stats.total_keys = stats.total_keys.saturating_sub(1);
306 stats.memory_usage = stats.memory_usage.saturating_sub(removed_entry.size() as u64);
307
308 Ok(true)
309 } else {
310 Ok(false)
311 }
312 }
313
314 async fn exists(&self, key: &str) -> CacheResult<bool> {
315 if let Some(entry) = self.entries.get(key) {
316 if entry.is_expired() {
317 let entry_size = entry.size() as u64;
319 drop(entry);
320
321 if self.entries.remove(key).is_some() {
323 self.lru.remove(key);
324
325 let mut stats = self.stats.lock();
326 stats.total_keys = stats.total_keys.saturating_sub(1);
327 stats.memory_usage = stats.memory_usage.saturating_sub(entry_size);
328 }
329
330 return Ok(false);
331 }
332 Ok(true)
333 } else {
334 Ok(false)
335 }
336 }
337
338 async fn flush(&self) -> CacheResult<()> {
339 self.entries.clear();
340
341 self.lru.clear();
343
344 let mut stats = self.stats.lock();
346 stats.total_keys = 0;
347 stats.memory_usage = 0;
348
349 Ok(())
350 }
351
352 async fn get_many(&self, keys: &[&str]) -> CacheResult<Vec<Option<Vec<u8>>>> {
353 let mut results = Vec::with_capacity(keys.len());
354
355 for key in keys {
356 results.push(self.get(key).await?);
357 }
358
359 Ok(results)
360 }
361
362 async fn put_many(&self, entries: &[(&str, Vec<u8>, Option<Duration>)]) -> CacheResult<()> {
363 for (key, value, ttl) in entries {
364 self.put(key, value.clone(), *ttl).await?;
365 }
366
367 Ok(())
368 }
369
370 async fn forget_many(&self, keys: &[&str]) -> CacheResult<usize> {
371 let mut removed_count = 0;
372 let mut total_freed_memory = 0u64;
373
374 for key in keys {
376 if let Some((_, removed_entry)) = self.entries.remove(*key) {
377 self.lru.remove(key);
378 total_freed_memory += removed_entry.size() as u64;
379 removed_count += 1;
380 }
381 }
382
383 if removed_count > 0 {
385 let mut stats = self.stats.lock();
386 stats.total_keys = stats.total_keys.saturating_sub(removed_count as u64);
387 stats.memory_usage = stats.memory_usage.saturating_sub(total_freed_memory);
388 }
389
390 Ok(removed_count)
391 }
392
393 async fn stats(&self) -> CacheResult<CacheStats> {
394 Ok(self.stats.lock().clone())
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use std::time::Duration;
402 use tokio::time::sleep;
403
404 #[tokio::test]
405 async fn test_memory_backend_basic_operations() {
406 let backend = MemoryBackend::new(CacheConfig::default());
407
408 backend.put("test", b"value".to_vec(), Some(Duration::from_secs(60))).await.unwrap();
410 let result = backend.get("test").await.unwrap();
411 assert_eq!(result, Some(b"value".to_vec()));
412
413 assert!(backend.exists("test").await.unwrap());
415 assert!(!backend.exists("nonexistent").await.unwrap());
416
417 assert!(backend.forget("test").await.unwrap());
419 assert!(!backend.exists("test").await.unwrap());
420 }
421
422 #[tokio::test]
423 async fn test_memory_backend_ttl() {
424 let backend = MemoryBackend::new(CacheConfig::default());
425
426 backend.put("ttl_test", b"value".to_vec(), Some(Duration::from_millis(50))).await.unwrap();
428
429 assert!(backend.exists("ttl_test").await.unwrap());
431
432 sleep(Duration::from_millis(100)).await;
434
435 assert!(!backend.exists("ttl_test").await.unwrap());
437 let result = backend.get("ttl_test").await.unwrap();
438 assert_eq!(result, None);
439 }
440
441 #[tokio::test]
442 async fn test_memory_backend_lru_eviction() {
443 let config = CacheConfig::builder()
444 .max_entries_limit(2)
445 .build_config();
446 let backend = MemoryBackend::new(config);
447
448 backend.put("key1", b"value1".to_vec(), None).await.unwrap();
450 backend.put("key2", b"value2".to_vec(), None).await.unwrap();
451
452 backend.get("key1").await.unwrap();
454
455 backend.put("key3", b"value3".to_vec(), None).await.unwrap();
457
458 assert!(backend.exists("key1").await.unwrap());
460 assert!(!backend.exists("key2").await.unwrap());
461 assert!(backend.exists("key3").await.unwrap());
462 }
463
464 #[tokio::test]
465 async fn test_memory_backend_stats() {
466 let backend = MemoryBackend::new(CacheConfig::default());
467
468 let stats = backend.stats().await.unwrap();
470 assert_eq!(stats.hits, 0);
471 assert_eq!(stats.misses, 0);
472 assert_eq!(stats.total_keys, 0);
473
474 backend.put("test1", b"value1".to_vec(), None).await.unwrap();
476 backend.put("test2", b"value2".to_vec(), None).await.unwrap();
477
478 let stats = backend.stats().await.unwrap();
480 assert_eq!(stats.total_keys, 2);
481 assert!(stats.memory_usage > 0);
482
483 backend.get("test1").await.unwrap();
485 let stats = backend.stats().await.unwrap();
486 assert_eq!(stats.hits, 1);
487
488 backend.get("nonexistent").await.unwrap();
490 let stats = backend.stats().await.unwrap();
491 assert_eq!(stats.misses, 1);
492
493 assert_eq!(stats.hit_ratio(), 0.5);
495 }
496
497 #[tokio::test]
498 async fn test_memory_backend_forget_many() {
499 let backend = MemoryBackend::new(CacheConfig::default());
500
501 backend.put("key1", b"value1".to_vec(), None).await.unwrap();
503 backend.put("key2", b"value2".to_vec(), None).await.unwrap();
504 backend.put("key3", b"value3".to_vec(), None).await.unwrap();
505 backend.put("key4", b"value4".to_vec(), None).await.unwrap();
506
507 assert!(backend.exists("key1").await.unwrap());
509 assert!(backend.exists("key2").await.unwrap());
510 assert!(backend.exists("key3").await.unwrap());
511 assert!(backend.exists("key4").await.unwrap());
512
513 let initial_stats = backend.stats().await.unwrap();
515 assert_eq!(initial_stats.total_keys, 4);
516
517 let keys_to_remove = ["key1", "key2", "key3"];
519 let removed_count = backend.forget_many(&keys_to_remove).await.unwrap();
520 assert_eq!(removed_count, 3);
521
522 assert!(!backend.exists("key1").await.unwrap());
524 assert!(!backend.exists("key2").await.unwrap());
525 assert!(!backend.exists("key3").await.unwrap());
526 assert!(backend.exists("key4").await.unwrap());
527
528 let final_stats = backend.stats().await.unwrap();
530 assert_eq!(final_stats.total_keys, 1);
531 assert!(final_stats.memory_usage < initial_stats.memory_usage);
532
533 let nonexistent_keys = ["nonexistent1", "nonexistent2"];
535 let removed_count = backend.forget_many(&nonexistent_keys).await.unwrap();
536 assert_eq!(removed_count, 0);
537
538 let empty_keys: Vec<&str> = vec![];
540 let removed_count = backend.forget_many(&empty_keys).await.unwrap();
541 assert_eq!(removed_count, 0);
542 }
543
544 #[tokio::test]
545 async fn test_memory_backend_flush() {
546 let backend = MemoryBackend::new(CacheConfig::default());
547
548 backend.put("test1", b"value1".to_vec(), None).await.unwrap();
550 backend.put("test2", b"value2".to_vec(), None).await.unwrap();
551
552 assert!(backend.exists("test1").await.unwrap());
554 assert!(backend.exists("test2").await.unwrap());
555
556 backend.flush().await.unwrap();
558
559 assert!(!backend.exists("test1").await.unwrap());
561 assert!(!backend.exists("test2").await.unwrap());
562
563 let stats = backend.stats().await.unwrap();
564 assert_eq!(stats.total_keys, 0);
565 assert_eq!(stats.memory_usage, 0);
566 }
567
568 #[tokio::test]
569 async fn test_lru_tracker_consistency() {
570 let backend = MemoryBackend::new(CacheConfig::default());
572
573 for i in 0..1200 {
576 let key = format!("consistency_test_{}", i);
577 let value = format!("value_{}", i).into_bytes();
578 backend.put(&key, value, None).await.unwrap();
579 }
580
581 for i in 0..1200 {
583 let key = format!("consistency_test_{}", i);
584 assert!(backend.exists(&key).await.unwrap(),
585 "Key {} should exist but was not found", key);
586 }
587
588 for i in (0..100).rev() {
590 let key = format!("consistency_test_{}", i);
591 backend.get(&key).await.unwrap();
592 }
593
594 for i in 0..1200 {
596 let key = format!("consistency_test_{}", i);
597 assert!(backend.exists(&key).await.unwrap(),
598 "Key {} should still exist after LRU access", key);
599 }
600 }
601}