1use std::sync::Mutex;
2use std::time::{Duration, Instant};
3
4const DEFAULT_TTL_SECS: u64 = 5;
5
6pub struct Cache<T> {
7 entries: Mutex<std::collections::HashMap<String, CacheEntry<T>>>,
8 ttl: Duration,
9}
10
11struct CacheEntry<T> {
12 value: T,
13 expires_at: Instant,
14}
15
16impl<T> Cache<T> {
17 pub fn new(ttl_secs: u64) -> Self {
18 Self {
19 entries: Mutex::new(std::collections::HashMap::new()),
20 ttl: Duration::from_secs(ttl_secs),
21 }
22 }
23
24 pub fn default_ttl() -> Self {
25 Self::new(DEFAULT_TTL_SECS)
26 }
27
28 pub fn get_or_compute<F>(&self, key: &str, compute: F) -> T
29 where
30 F: FnOnce() -> T,
31 T: Clone,
32 {
33 let now = Instant::now();
34
35 {
37 let entries = &mut *self.entries.lock().unwrap_or_else(|e| e.into_inner());
38
39 if let Some(entry) = entries.get(key) {
40 if entry.expires_at > now {
41 return entry.value.clone();
42 }
43 entries.remove(key);
45 }
46 }
47
48 let value = compute();
50
51 {
53 let entries = &mut *self.entries.lock().unwrap_or_else(|e| e.into_inner());
54 let expires_at = now + self.ttl;
55 entries.insert(
56 key.to_string(),
57 CacheEntry {
58 value: value.clone(),
59 expires_at,
60 },
61 );
62 }
63
64 value
65 }
66
67 pub fn invalidate(&self, key: &str) {
68 self.entries
69 .lock()
70 .unwrap_or_else(|e| e.into_inner())
71 .remove(key);
72 }
73
74 pub fn clear(&self) {
75 self.entries
76 .lock()
77 .unwrap_or_else(|e| e.into_inner())
78 .clear();
79 }
80}
81
82impl<T> Default for Cache<T> {
83 fn default() -> Self {
84 Self::default_ttl()
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use std::thread;
92 use std::time::Duration as StdDuration;
93
94 #[test]
95 fn test_cache_basic() {
96 let cache = Cache::new(1);
97 let call_count = std::cell::RefCell::new(0);
98
99 let get_value = || {
100 *call_count.borrow_mut() += 1;
101 42
102 };
103
104 let v1 = cache.get_or_compute("key", get_value);
106 assert_eq!(v1, 42);
107 assert_eq!(*call_count.borrow(), 1);
108
109 let v2 = cache.get_or_compute("key", get_value);
111 assert_eq!(v2, 42);
112 assert_eq!(*call_count.borrow(), 1);
113 }
114
115 #[test]
116 fn test_cache_expiry() {
117 let cache = Cache::new(1); let call_count = std::cell::RefCell::new(0);
119
120 let get_value = || {
121 *call_count.borrow_mut() += 1;
122 42
123 };
124
125 let v1 = cache.get_or_compute("key", get_value);
127 assert_eq!(v1, 42);
128 assert_eq!(*call_count.borrow(), 1);
129
130 thread::sleep(StdDuration::from_millis(1100));
132
133 let v2 = cache.get_or_compute("key", get_value);
135 assert_eq!(v2, 42);
136 assert_eq!(*call_count.borrow(), 2);
137 }
138}