1use std::sync::Mutex;
2use std::time::{Duration, Instant};
3
4const DEFAULT_TTL_SECS: u64 = 5;
5
6pub struct Cache<T> {
7 entries: Mutex<std::collections::HashMap<String, CacheEntry<T>>>,
8 ttl: Duration,
9}
10
11struct CacheEntry<T> {
12 value: T,
13 expires_at: Instant,
14}
15
16impl<T> Cache<T> {
17 pub fn new(ttl_secs: u64) -> Self {
18 Self {
19 entries: Mutex::new(std::collections::HashMap::new()),
20 ttl: Duration::from_secs(ttl_secs),
21 }
22 }
23
24 pub fn default_ttl() -> Self {
25 Self::new(DEFAULT_TTL_SECS)
26 }
27
28 pub fn get_or_compute<F>(&self, key: &str, compute: F) -> T
29 where
30 F: FnOnce() -> T,
31 T: Clone,
32 {
33 let now = Instant::now();
34
35 {
37 let entries = &mut *self.entries.lock().unwrap_or_else(|e| e.into_inner());
38
39 if let Some(entry) = entries.get(key) {
40 if entry.expires_at > now {
41 return entry.value.clone();
42 }
43 entries.remove(key);
45 }
46 }
47
48 let value = compute();
50
51 {
53 let entries = &mut *self.entries.lock().unwrap_or_else(|e| e.into_inner());
54 let expires_at = now + self.ttl;
55 entries.insert(
56 key.to_string(),
57 CacheEntry {
58 value: value.clone(),
59 expires_at,
60 },
61 );
62 }
63
64 value
65 }
66
67 pub fn invalidate(&self, key: &str) {
68 self.entries.lock().unwrap_or_else(|e| e.into_inner()).remove(key);
69 }
70
71 pub fn clear(&self) {
72 self.entries.lock().unwrap_or_else(|e| e.into_inner()).clear();
73 }
74}
75
76impl<T> Default for Cache<T> {
77 fn default() -> Self {
78 Self::default_ttl()
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use std::thread;
86 use std::time::Duration as StdDuration;
87
88 #[test]
89 fn test_cache_basic() {
90 let cache = Cache::new(1);
91 let call_count = std::cell::RefCell::new(0);
92
93 let get_value = || {
94 *call_count.borrow_mut() += 1;
95 42
96 };
97
98 let v1 = cache.get_or_compute("key", get_value);
100 assert_eq!(v1, 42);
101 assert_eq!(*call_count.borrow(), 1);
102
103 let v2 = cache.get_or_compute("key", get_value);
105 assert_eq!(v2, 42);
106 assert_eq!(*call_count.borrow(), 1);
107 }
108
109 #[test]
110 fn test_cache_expiry() {
111 let cache = Cache::new(1); let call_count = std::cell::RefCell::new(0);
113
114 let get_value = || {
115 *call_count.borrow_mut() += 1;
116 42
117 };
118
119 let v1 = cache.get_or_compute("key", get_value);
121 assert_eq!(v1, 42);
122 assert_eq!(*call_count.borrow(), 1);
123
124 thread::sleep(StdDuration::from_millis(1100));
126
127 let v2 = cache.get_or_compute("key", get_value);
129 assert_eq!(v2, 42);
130 assert_eq!(*call_count.borrow(), 2);
131 }
132}