1use std::collections::{HashMap, HashSet, VecDeque};
14use std::fmt;
15use std::hash::Hash;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::{Arc, RwLock};
18use std::time::{Duration, Instant};
19
20#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct CacheStats {
23 pub hits: u64,
25 pub misses: u64,
27 pub evictions: u64,
29}
30
31struct Entry<V> {
32 value: V,
33 expires_at: Option<Instant>,
34 tags: HashSet<String>,
35}
36
37pub struct Cache<K, V>
39where
40 K: Eq + Hash + Clone,
41{
42 inner: Arc<RwLock<CacheInner<K, V>>>,
43 hits: Arc<AtomicU64>,
44 misses: Arc<AtomicU64>,
45 evictions: Arc<AtomicU64>,
46}
47
48struct CacheInner<K, V>
49where
50 K: Eq + Hash + Clone,
51{
52 items: HashMap<K, Entry<V>>,
53 order: VecDeque<K>,
54 max_size: usize,
55 default_ttl: Option<Duration>,
56}
57
58impl<K, V> Default for Cache<K, V>
59where
60 K: Eq + Hash + Clone,
61{
62 fn default() -> Self {
63 Self::new(100, None)
64 }
65}
66
67impl<K, V> fmt::Debug for Cache<K, V>
68where
69 K: Eq + Hash + Clone,
70{
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 let inner = self.inner.read().unwrap();
73 f.debug_struct("Cache")
74 .field("size", &inner.items.len())
75 .field("max_size", &inner.max_size)
76 .field("default_ttl", &inner.default_ttl)
77 .finish()
78 }
79}
80
81impl<K, V> Cache<K, V>
82where
83 K: Eq + Hash + Clone,
84{
85 pub fn new(max_size: usize, default_ttl: Option<Duration>) -> Self {
87 Self {
88 inner: Arc::new(RwLock::new(CacheInner {
89 items: HashMap::with_capacity(max_size),
90 order: VecDeque::with_capacity(max_size),
91 max_size,
92 default_ttl,
93 })),
94 hits: Arc::new(AtomicU64::new(0)),
95 misses: Arc::new(AtomicU64::new(0)),
96 evictions: Arc::new(AtomicU64::new(0)),
97 }
98 }
99
100 pub fn set(&self, key: K, value: V) {
102 self.set_with(key, value, None, &[]);
103 }
104
105 pub fn set_with(&self, key: K, value: V, ttl: Option<Duration>, tags: &[&str]) {
107 let mut inner = self.inner.write().unwrap();
108 let effective_ttl = ttl.or(inner.default_ttl);
109 let expires_at = effective_ttl.map(|d| Instant::now() + d);
110 let tag_set: HashSet<String> = tags.iter().map(|s| s.to_string()).collect();
111
112 if inner.items.contains_key(&key) {
113 inner.order.retain(|k| k != &key);
114 } else if inner.items.len() >= inner.max_size {
115 let mut evicted = false;
117 let now = Instant::now();
118 let expired_key = inner
119 .order
120 .iter()
121 .find(|k| {
122 inner
123 .items
124 .get(*k)
125 .is_some_and(|e| e.expires_at.is_some_and(|t| now > t))
126 })
127 .cloned();
128
129 if let Some(ek) = expired_key {
130 inner.items.remove(&ek);
131 inner.order.retain(|k| k != &ek);
132 self.evictions.fetch_add(1, Ordering::Relaxed);
133 evicted = true;
134 }
135
136 if !evicted {
137 if let Some(lru_key) = inner.order.pop_back() {
138 inner.items.remove(&lru_key);
139 self.evictions.fetch_add(1, Ordering::Relaxed);
140 }
141 }
142 }
143
144 inner.items.insert(
145 key.clone(),
146 Entry {
147 value,
148 expires_at,
149 tags: tag_set,
150 },
151 );
152 inner.order.push_front(key);
153 }
154
155 pub fn get(&self, key: &K) -> Option<V>
159 where
160 V: Clone,
161 {
162 let mut inner = self.inner.write().unwrap();
163 let entry = match inner.items.get(key) {
164 Some(e) => e,
165 None => {
166 self.misses.fetch_add(1, Ordering::Relaxed);
167 return None;
168 }
169 };
170
171 if let Some(expires_at) = entry.expires_at {
172 if Instant::now() > expires_at {
173 inner.items.remove(key);
174 inner.order.retain(|k| k != key);
175 self.misses.fetch_add(1, Ordering::Relaxed);
176 return None;
177 }
178 }
179
180 let value = entry.value.clone();
181 inner.order.retain(|k| k != key);
182 inner.order.push_front(key.clone());
183 self.hits.fetch_add(1, Ordering::Relaxed);
184 Some(value)
185 }
186
187 pub fn has(&self, key: &K) -> bool {
189 let mut inner = self.inner.write().unwrap();
190 let entry = match inner.items.get(key) {
191 Some(e) => e,
192 None => return false,
193 };
194
195 if let Some(expires_at) = entry.expires_at {
196 if Instant::now() > expires_at {
197 inner.items.remove(key);
198 inner.order.retain(|k| k != key);
199 return false;
200 }
201 }
202
203 true
204 }
205
206 pub fn delete(&self, key: &K) -> bool {
208 let mut inner = self.inner.write().unwrap();
209 if inner.items.remove(key).is_some() {
210 inner.order.retain(|k| k != key);
211 true
212 } else {
213 false
214 }
215 }
216
217 pub fn invalidate_by_tag(&self, tag: &str) -> usize {
219 let mut inner = self.inner.write().unwrap();
220 let keys: Vec<K> = inner
221 .items
222 .iter()
223 .filter(|(_, v)| v.tags.contains(tag))
224 .map(|(k, _)| k.clone())
225 .collect();
226 let count = keys.len();
227 for key in &keys {
228 inner.items.remove(key);
229 }
230 inner.order.retain(|k| !keys.contains(k));
231 count
232 }
233
234 pub fn clear(&self) {
236 let mut inner = self.inner.write().unwrap();
237 inner.items.clear();
238 inner.order.clear();
239 }
240
241 pub fn size(&self) -> usize {
243 self.inner.read().unwrap().items.len()
244 }
245
246 pub fn is_empty(&self) -> bool {
248 self.inner.read().unwrap().items.is_empty()
249 }
250
251 pub fn max_size(&self) -> usize {
253 self.inner.read().unwrap().max_size
254 }
255
256 pub fn keys(&self) -> Vec<K> {
258 let inner = self.inner.read().unwrap();
259 let now = Instant::now();
260 inner
261 .items
262 .iter()
263 .filter(|(_, entry)| {
264 entry.expires_at.map_or(true, |t| now <= t)
265 })
266 .map(|(k, _)| k.clone())
267 .collect()
268 }
269
270 pub fn remove_expired(&self) -> usize {
272 let mut inner = self.inner.write().unwrap();
273 let now = Instant::now();
274 let expired_keys: Vec<K> = inner
275 .items
276 .iter()
277 .filter(|(_, entry)| entry.expires_at.is_some_and(|t| now > t))
278 .map(|(k, _)| k.clone())
279 .collect();
280 let count = expired_keys.len();
281 for key in &expired_keys {
282 inner.items.remove(key);
283 }
284 inner.order.retain(|k| !expired_keys.contains(k));
285 count
286 }
287
288 pub fn get_or_insert_with<F>(&self, key: K, f: F) -> V
290 where
291 V: Clone,
292 F: FnOnce() -> V,
293 {
294 if let Some(val) = self.get(&key) {
296 return val;
297 }
298 let value = f();
300 self.set(key, value.clone());
301 value
302 }
303
304 pub fn stats(&self) -> CacheStats {
306 CacheStats {
307 hits: self.hits.load(Ordering::Relaxed),
308 misses: self.misses.load(Ordering::Relaxed),
309 evictions: self.evictions.load(Ordering::Relaxed),
310 }
311 }
312
313 pub fn get_many(&self, keys: &[K]) -> HashMap<K, V>
316 where
317 V: Clone,
318 {
319 let mut result = HashMap::with_capacity(keys.len());
320 for key in keys {
321 if let Some(val) = self.get(key) {
322 result.insert(key.clone(), val);
323 }
324 }
325 result
326 }
327
328 pub fn delete_where<F>(&self, predicate: F) -> usize
331 where
332 F: Fn(&K, &V) -> bool,
333 {
334 let mut inner = self.inner.write().unwrap();
335 let keys_to_remove: Vec<K> = inner
336 .items
337 .iter()
338 .filter(|(k, entry)| predicate(k, &entry.value))
339 .map(|(k, _)| k.clone())
340 .collect();
341 let count = keys_to_remove.len();
342 for key in &keys_to_remove {
343 inner.items.remove(key);
344 }
345 inner.order.retain(|k| !keys_to_remove.contains(k));
346 count
347 }
348
349 pub fn len(&self) -> usize {
351 self.size()
352 }
353}
354
355impl<K, V> Clone for Cache<K, V>
356where
357 K: Eq + Hash + Clone,
358{
359 fn clone(&self) -> Self {
360 Self {
361 inner: Arc::clone(&self.inner),
362 hits: Arc::clone(&self.hits),
363 misses: Arc::clone(&self.misses),
364 evictions: Arc::clone(&self.evictions),
365 }
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn test_set_and_get() {
375 let cache = Cache::new(10, None);
376 cache.set("key", "value");
377 assert_eq!(cache.get(&"key"), Some("value"));
378 }
379
380 #[test]
381 fn test_get_missing_key() {
382 let cache: Cache<&str, &str> = Cache::new(10, None);
383 assert_eq!(cache.get(&"missing"), None);
384 }
385
386 #[test]
387 fn test_overwrite_value() {
388 let cache = Cache::new(10, None);
389 cache.set("key", "v1");
390 cache.set("key", "v2");
391 assert_eq!(cache.get(&"key"), Some("v2"));
392 assert_eq!(cache.size(), 1);
393 }
394
395 #[test]
396 fn test_delete() {
397 let cache = Cache::new(10, None);
398 cache.set("key", "value");
399 assert!(cache.delete(&"key"));
400 assert_eq!(cache.get(&"key"), None);
401 assert!(!cache.delete(&"key"));
402 }
403
404 #[test]
405 fn test_has() {
406 let cache = Cache::new(10, None);
407 cache.set("key", "value");
408 assert!(cache.has(&"key"));
409 assert!(!cache.has(&"missing"));
410 }
411
412 #[test]
413 fn test_clear() {
414 let cache = Cache::new(10, None);
415 cache.set("a", 1);
416 cache.set("b", 2);
417 assert_eq!(cache.size(), 2);
418 cache.clear();
419 assert_eq!(cache.size(), 0);
420 }
421
422 #[test]
423 fn test_lru_eviction() {
424 let cache = Cache::new(3, None);
425 cache.set("a", 1);
426 cache.set("b", 2);
427 cache.set("c", 3);
428 cache.set("d", 4);
430 assert_eq!(cache.get(&"a"), None);
431 assert_eq!(cache.size(), 3);
432 }
433
434 #[test]
435 fn test_lru_access_updates_order() {
436 let cache = Cache::new(3, None);
437 cache.set("a", 1);
438 cache.set("b", 2);
439 cache.set("c", 3);
440 cache.get(&"a");
442 cache.set("d", 4);
444 assert_eq!(cache.get(&"a"), Some(1));
445 assert_eq!(cache.get(&"b"), None);
446 }
447
448 #[test]
449 fn test_ttl_expiration() {
450 let cache = Cache::new(10, None);
451 cache.set_with("key", "value", Some(Duration::from_millis(1)), &[]);
452 std::thread::sleep(Duration::from_millis(10));
453 assert_eq!(cache.get(&"key"), None);
454 }
455
456 #[test]
457 fn test_has_with_expired_ttl() {
458 let cache = Cache::new(10, None);
459 cache.set_with("key", "value", Some(Duration::from_millis(1)), &[]);
460 std::thread::sleep(Duration::from_millis(10));
461 assert!(!cache.has(&"key"));
462 }
463
464 #[test]
465 fn test_default_ttl() {
466 let cache = Cache::new(10, Some(Duration::from_millis(1)));
467 cache.set("key", "value");
468 std::thread::sleep(Duration::from_millis(10));
469 assert_eq!(cache.get(&"key"), None);
470 }
471
472 #[test]
473 fn test_tag_invalidation() {
474 let cache = Cache::new(10, None);
475 cache.set_with("a", 1, None, &["group1"]);
476 cache.set_with("b", 2, None, &["group1", "group2"]);
477 cache.set_with("c", 3, None, &["group2"]);
478 let removed = cache.invalidate_by_tag("group1");
479 assert_eq!(removed, 2);
480 assert_eq!(cache.get(&"a"), None);
481 assert_eq!(cache.get(&"b"), None);
482 assert_eq!(cache.get(&"c"), Some(3));
483 }
484
485 #[test]
486 fn test_clone_shares_state() {
487 let cache = Cache::new(10, None);
488 let cache2 = cache.clone();
489 cache.set("key", "value");
490 assert_eq!(cache2.get(&"key"), Some("value"));
491 }
492
493 #[test]
494 fn test_debug_impl() {
495 let cache: Cache<&str, &str> = Cache::new(10, None);
496 let debug = format!("{:?}", cache);
497 assert!(debug.contains("Cache"));
498 assert!(debug.contains("max_size"));
499 }
500
501 #[test]
502 fn test_default_impl() {
503 let cache: Cache<String, String> = Cache::default();
504 assert_eq!(cache.size(), 0);
505 }
506
507 #[test]
508 fn test_thread_safety() {
509 let cache = Cache::new(100, None);
510 let mut handles = vec![];
511
512 for i in 0..10 {
513 let c = cache.clone();
514 handles.push(std::thread::spawn(move || {
515 c.set(i, i * 10);
516 }));
517 }
518
519 for h in handles {
520 h.join().unwrap();
521 }
522
523 assert_eq!(cache.size(), 10);
524 }
525
526 #[test]
527 fn test_is_empty() {
528 let cache: Cache<&str, &str> = Cache::new(10, None);
529 assert!(cache.is_empty());
530 cache.set("key", "value");
531 assert!(!cache.is_empty());
532 }
533
534 #[test]
535 fn test_max_size() {
536 let cache: Cache<&str, &str> = Cache::new(42, None);
537 assert_eq!(cache.max_size(), 42);
538 }
539
540 #[test]
541 fn test_keys() {
542 let cache = Cache::new(10, None);
543 cache.set("a", 1);
544 cache.set("b", 2);
545 let mut keys = cache.keys();
546 keys.sort();
547 assert_eq!(keys, vec!["a", "b"]);
548 }
549
550 #[test]
551 fn test_keys_excludes_expired() {
552 let cache = Cache::new(10, None);
553 cache.set_with("fresh", 1, None, &[]);
554 cache.set_with("expired", 2, Some(Duration::from_millis(1)), &[]);
555 std::thread::sleep(Duration::from_millis(10));
556 let keys = cache.keys();
557 assert_eq!(keys, vec!["fresh"]);
558 }
559
560 #[test]
561 fn test_remove_expired() {
562 let cache = Cache::new(10, None);
563 cache.set_with("fresh", 1, None, &[]);
564 cache.set_with("stale1", 2, Some(Duration::from_millis(1)), &[]);
565 cache.set_with("stale2", 3, Some(Duration::from_millis(1)), &[]);
566 std::thread::sleep(Duration::from_millis(10));
567 let removed = cache.remove_expired();
568 assert_eq!(removed, 2);
569 assert_eq!(cache.size(), 1);
570 assert!(cache.has(&"fresh"));
571 }
572
573 #[test]
574 fn test_get_or_insert_with_existing() {
575 let cache = Cache::new(10, None);
576 cache.set("key", 42);
577 let val = cache.get_or_insert_with("key", || 99);
578 assert_eq!(val, 42);
579 }
580
581 #[test]
582 fn test_get_or_insert_with_missing() {
583 let cache = Cache::new(10, None);
584 let val = cache.get_or_insert_with("key", || 99);
585 assert_eq!(val, 99);
586 assert_eq!(cache.get(&"key"), Some(99));
587 }
588
589 #[test]
590 fn test_get_or_insert_with_expired() {
591 let cache = Cache::new(10, None);
592 cache.set_with("key", 42, Some(Duration::from_millis(1)), &[]);
593 std::thread::sleep(Duration::from_millis(10));
594 let val = cache.get_or_insert_with("key", || 99);
595 assert_eq!(val, 99);
596 }
597
598 #[test]
599 fn test_stats_hits_and_misses() {
600 let cache = Cache::new(10, None);
601 cache.set("a", 1);
602 cache.set("b", 2);
603
604 assert_eq!(cache.get(&"a"), Some(1));
606 assert_eq!(cache.get(&"b"), Some(2));
607
608 assert_eq!(cache.get(&"c"), None);
610 assert_eq!(cache.get(&"d"), None);
611
612 let s = cache.stats();
613 assert_eq!(s.hits, 2);
614 assert_eq!(s.misses, 2);
615 assert_eq!(s.evictions, 0);
616 }
617
618 #[test]
619 fn test_stats_evictions() {
620 let cache = Cache::new(2, None);
621 cache.set("a", 1);
622 cache.set("b", 2);
623 cache.set("c", 3);
625
626 let s = cache.stats();
627 assert_eq!(s.evictions, 1);
628 }
629
630 #[test]
631 fn test_stats_miss_on_expired() {
632 let cache = Cache::new(10, None);
633 cache.set_with("key", 1, Some(Duration::from_millis(1)), &[]);
634 std::thread::sleep(Duration::from_millis(10));
635 assert_eq!(cache.get(&"key"), None);
636
637 let s = cache.stats();
638 assert_eq!(s.misses, 1);
639 assert_eq!(s.hits, 0);
640 }
641
642 #[test]
643 fn test_get_many() {
644 let cache = Cache::new(10, None);
645 cache.set("a", 1);
646 cache.set("b", 2);
647 cache.set("c", 3);
648
649 let result = cache.get_many(&["a", "c", "missing"]);
650 assert_eq!(result.len(), 2);
651 assert_eq!(result[&"a"], 1);
652 assert_eq!(result[&"c"], 3);
653 }
654
655 #[test]
656 fn test_get_many_empty() {
657 let cache: Cache<&str, i32> = Cache::new(10, None);
658 let result = cache.get_many(&["a", "b"]);
659 assert!(result.is_empty());
660 }
661
662 #[test]
663 fn test_delete_where() {
664 let cache = Cache::new(10, None);
665 cache.set("a", 1);
666 cache.set("b", 20);
667 cache.set("c", 3);
668 cache.set("d", 40);
669
670 let removed = cache.delete_where(|_k, v| *v >= 10);
671 assert_eq!(removed, 2);
672 assert_eq!(cache.size(), 2);
673 assert_eq!(cache.get(&"a"), Some(1));
674 assert_eq!(cache.get(&"c"), Some(3));
675 assert_eq!(cache.get(&"b"), None);
676 assert_eq!(cache.get(&"d"), None);
677 }
678
679 #[test]
680 fn test_delete_where_none_match() {
681 let cache = Cache::new(10, None);
682 cache.set("a", 1);
683 cache.set("b", 2);
684
685 let removed = cache.delete_where(|_k, v| *v > 100);
686 assert_eq!(removed, 0);
687 assert_eq!(cache.size(), 2);
688 }
689
690 #[test]
691 fn test_len_alias() {
692 let cache = Cache::new(10, None);
693 assert_eq!(cache.len(), 0);
694 cache.set("a", 1);
695 cache.set("b", 2);
696 assert_eq!(cache.len(), 2);
697 assert_eq!(cache.len(), cache.size());
698 }
699}