chio_guards/external/
cache.rs1use std::collections::HashMap;
13use std::hash::Hash;
14use std::num::NonZeroUsize;
15use std::sync::Arc;
16use std::sync::Mutex;
17use std::time::Duration;
18
19use tokio::time::Instant;
20
21pub trait Clock: Send + Sync + 'static {
27 fn now(&self) -> Instant;
29}
30
31#[derive(Debug, Clone, Copy, Default)]
33pub struct TokioClock;
34
35impl Clock for TokioClock {
36 fn now(&self) -> Instant {
37 Instant::now()
38 }
39}
40
41#[derive(Debug, Clone)]
43struct Entry<V> {
44 value: V,
45 expires_at: Instant,
46 recency: u64,
48}
49
50pub struct TtlCache<K, V> {
57 inner: Mutex<CacheInner<K, V>>,
58 capacity: NonZeroUsize,
59 clock: Arc<dyn Clock>,
60}
61
62struct CacheInner<K, V> {
63 entries: HashMap<K, Entry<V>>,
64 counter: u64,
66}
67
68impl<K, V> TtlCache<K, V>
69where
70 K: Eq + Hash + Clone,
71 V: Clone,
72{
73 pub fn new(capacity: NonZeroUsize) -> Self {
78 Self::with_clock(capacity, Arc::new(TokioClock))
79 }
80
81 pub fn with_clock(capacity: NonZeroUsize, clock: Arc<dyn Clock>) -> Self {
83 Self {
84 inner: Mutex::new(CacheInner {
85 entries: HashMap::with_capacity(capacity.get()),
86 counter: 0,
87 }),
88 capacity,
89 clock,
90 }
91 }
92
93 pub fn capacity(&self) -> usize {
95 self.capacity.get()
96 }
97
98 pub fn len(&self) -> usize {
101 let Ok(inner) = self.inner.lock() else {
102 return 0;
103 };
104 inner.entries.len()
105 }
106
107 pub fn is_empty(&self) -> bool {
109 self.len() == 0
110 }
111
112 pub fn get(&self, key: &K) -> Option<V> {
116 let now = self.clock.now();
117 let Ok(mut inner) = self.inner.lock() else {
118 return None;
119 };
120 let expired = inner
121 .entries
122 .get(key)
123 .map(|entry| entry.expires_at <= now)
124 .unwrap_or(false);
125 if expired {
126 inner.entries.remove(key);
127 return None;
128 }
129 inner.counter = inner.counter.saturating_add(1);
130 let counter = inner.counter;
131 let entry = inner.entries.get_mut(key)?;
132 entry.recency = counter;
133 Some(entry.value.clone())
134 }
135
136 pub fn insert(&self, key: K, value: V, ttl: Duration) {
139 let now = self.clock.now();
140 let expires_at = now.checked_add(ttl).unwrap_or(now);
141 let Ok(mut inner) = self.inner.lock() else {
142 return;
143 };
144
145 inner.counter = inner.counter.saturating_add(1);
146 let recency = inner.counter;
147
148 if let Some(entry) = inner.entries.get_mut(&key) {
150 entry.value = value;
151 entry.expires_at = expires_at;
152 entry.recency = recency;
153 return;
154 }
155
156 if inner.entries.len() >= self.capacity.get() {
158 evict_expired(&mut inner.entries, now);
159 }
160 if inner.entries.len() >= self.capacity.get() {
161 evict_lru(&mut inner.entries);
162 }
163
164 inner.entries.insert(
165 key,
166 Entry {
167 value,
168 expires_at,
169 recency,
170 },
171 );
172 }
173
174 pub fn prune(&self) -> usize {
177 let now = self.clock.now();
178 let Ok(mut inner) = self.inner.lock() else {
179 return 0;
180 };
181 evict_expired(&mut inner.entries, now)
182 }
183
184 pub fn clear(&self) {
186 if let Ok(mut inner) = self.inner.lock() {
187 inner.entries.clear();
188 }
189 }
190}
191
192fn evict_expired<K, V>(entries: &mut HashMap<K, Entry<V>>, now: Instant) -> usize
193where
194 K: Eq + Hash + Clone,
195{
196 let expired: Vec<K> = entries
197 .iter()
198 .filter_map(|(k, entry)| {
199 if entry.expires_at <= now {
200 Some(k.clone())
201 } else {
202 None
203 }
204 })
205 .collect();
206 let removed = expired.len();
207 for k in expired {
208 entries.remove(&k);
209 }
210 removed
211}
212
213fn evict_lru<K, V>(entries: &mut HashMap<K, Entry<V>>)
214where
215 K: Eq + Hash + Clone,
216{
217 let victim = entries
218 .iter()
219 .min_by_key(|(_, entry)| entry.recency)
220 .map(|(k, _)| k.clone());
221 if let Some(key) = victim {
222 entries.remove(&key);
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 fn nz(n: usize) -> NonZeroUsize {
231 NonZeroUsize::new(n).expect("non-zero capacity")
232 }
233
234 #[tokio::test(flavor = "current_thread", start_paused = true)]
235 async fn insert_and_get_returns_value() {
236 let cache: TtlCache<&'static str, u32> = TtlCache::new(nz(4));
237 cache.insert("k", 42, Duration::from_secs(30));
238 assert_eq!(cache.get(&"k"), Some(42));
239 }
240
241 #[tokio::test(flavor = "current_thread", start_paused = true)]
242 async fn expired_entry_returns_none() {
243 let cache: TtlCache<&'static str, u32> = TtlCache::new(nz(4));
244 cache.insert("k", 42, Duration::from_secs(1));
245 tokio::time::advance(Duration::from_secs(2)).await;
246 assert_eq!(cache.get(&"k"), None);
247 assert!(cache.is_empty());
248 }
249
250 #[tokio::test(flavor = "current_thread", start_paused = true)]
251 async fn lru_eviction_when_capacity_exceeded() {
252 let cache: TtlCache<&'static str, u32> = TtlCache::new(nz(2));
253 cache.insert("a", 1, Duration::from_secs(60));
254 cache.insert("b", 2, Duration::from_secs(60));
255 let _ = cache.get(&"a");
257 cache.insert("c", 3, Duration::from_secs(60));
258 assert_eq!(cache.get(&"a"), Some(1));
259 assert_eq!(cache.get(&"b"), None);
260 assert_eq!(cache.get(&"c"), Some(3));
261 }
262
263 #[tokio::test(flavor = "current_thread", start_paused = true)]
264 async fn prune_removes_only_expired() {
265 let cache: TtlCache<&'static str, u32> = TtlCache::new(nz(4));
266 cache.insert("short", 1, Duration::from_secs(1));
267 cache.insert("long", 2, Duration::from_secs(60));
268 tokio::time::advance(Duration::from_secs(2)).await;
269 let removed = cache.prune();
270 assert_eq!(removed, 1);
271 assert_eq!(cache.get(&"short"), None);
272 assert_eq!(cache.get(&"long"), Some(2));
273 }
274
275 #[tokio::test(flavor = "current_thread", start_paused = true)]
276 async fn overwrite_updates_value_and_ttl() {
277 let cache: TtlCache<&'static str, u32> = TtlCache::new(nz(2));
278 cache.insert("k", 1, Duration::from_secs(1));
279 cache.insert("k", 2, Duration::from_secs(30));
280 tokio::time::advance(Duration::from_secs(2)).await;
281 assert_eq!(cache.get(&"k"), Some(2));
282 }
283}