Skip to main content

chio_guards/external/
cache.rs

1//! TTL cache with LRU eviction for external guard verdicts.
2//!
3//! [`TtlCache`] is a thread-safe bounded cache that stores values with a
4//! per-entry monotonic deadline. Entries evict either when their TTL expires
5//! (on access or on `prune`) or when the cache reaches its capacity (least
6//! recently used eviction).
7//!
8//! The cache uses a [`Clock`] abstraction for the "now" timestamp so that
9//! tests can drive time via [`tokio::time::pause`] + `advance` without any
10//! wall-clock sleep.
11
12use std::collections::HashMap;
13use std::hash::Hash;
14use std::num::NonZeroUsize;
15use std::sync::Arc;
16use std::sync::Mutex;
17use std::time::Duration;
18
19use tokio::time::Instant;
20
21/// Clock abstraction used by the cache and other resilience primitives.
22///
23/// The default implementation reads from [`tokio::time::Instant::now`], which
24/// honors [`tokio::time::pause`]/`advance` in tests. Callers with a custom
25/// time source may provide their own implementation.
26pub trait Clock: Send + Sync + 'static {
27    /// Return the current monotonic instant.
28    fn now(&self) -> Instant;
29}
30
31/// Default [`Clock`] implementation backed by Tokio's pausable timer.
32#[derive(Debug, Clone, Copy, Default)]
33pub struct TokioClock;
34
35impl Clock for TokioClock {
36    fn now(&self) -> Instant {
37        Instant::now()
38    }
39}
40
41/// Entry stored in the TTL cache.
42#[derive(Debug, Clone)]
43struct Entry<V> {
44    value: V,
45    expires_at: Instant,
46    /// Monotonically increasing recency counter. Higher = more recent.
47    recency: u64,
48}
49
50/// Thread-safe TTL cache with LRU eviction.
51///
52/// The cache is keyed by any `Eq + Hash + Clone` type and stores any
53/// `Clone` value. Each `insert` takes a per-entry TTL. Entries expire on
54/// the first `get`/`insert` that observes the expired deadline; an explicit
55/// [`TtlCache::prune`] is also provided for bulk collection.
56pub struct TtlCache<K, V> {
57    inner: Mutex<CacheInner<K, V>>,
58    capacity: NonZeroUsize,
59    clock: Arc<dyn Clock>,
60}
61
62struct CacheInner<K, V> {
63    entries: HashMap<K, Entry<V>>,
64    /// Monotonic counter used to stamp entry recency.
65    counter: u64,
66}
67
68impl<K, V> TtlCache<K, V>
69where
70    K: Eq + Hash + Clone,
71    V: Clone,
72{
73    /// Create a new cache with the given capacity, backed by [`TokioClock`].
74    ///
75    /// Capacity is a non-zero usize because a zero-capacity cache is
76    /// degenerate (every insert would immediately evict itself).
77    pub fn new(capacity: NonZeroUsize) -> Self {
78        Self::with_clock(capacity, Arc::new(TokioClock))
79    }
80
81    /// Create a cache backed by a custom [`Clock`] implementation.
82    pub fn with_clock(capacity: NonZeroUsize, clock: Arc<dyn Clock>) -> Self {
83        Self {
84            inner: Mutex::new(CacheInner {
85                entries: HashMap::with_capacity(capacity.get()),
86                counter: 0,
87            }),
88            capacity,
89            clock,
90        }
91    }
92
93    /// Configured maximum number of live entries.
94    pub fn capacity(&self) -> usize {
95        self.capacity.get()
96    }
97
98    /// Current number of entries in the cache (may include not-yet-pruned
99    /// expired entries).
100    pub fn len(&self) -> usize {
101        let Ok(inner) = self.inner.lock() else {
102            return 0;
103        };
104        inner.entries.len()
105    }
106
107    /// Returns true when the cache holds no entries.
108    pub fn is_empty(&self) -> bool {
109        self.len() == 0
110    }
111
112    /// Look up `key`. Returns `Some(value)` on cache hit (and bumps the
113    /// entry's recency); `None` on miss or expired entry. Expired entries
114    /// are removed on observation.
115    pub fn get(&self, key: &K) -> Option<V> {
116        let now = self.clock.now();
117        let Ok(mut inner) = self.inner.lock() else {
118            return None;
119        };
120        let expired = inner
121            .entries
122            .get(key)
123            .map(|entry| entry.expires_at <= now)
124            .unwrap_or(false);
125        if expired {
126            inner.entries.remove(key);
127            return None;
128        }
129        inner.counter = inner.counter.saturating_add(1);
130        let counter = inner.counter;
131        let entry = inner.entries.get_mut(key)?;
132        entry.recency = counter;
133        Some(entry.value.clone())
134    }
135
136    /// Insert `value` under `key` with the given `ttl`. If the cache is at
137    /// capacity, evicts the least recently used live entry first.
138    pub fn insert(&self, key: K, value: V, ttl: Duration) {
139        let now = self.clock.now();
140        let expires_at = now.checked_add(ttl).unwrap_or(now);
141        let Ok(mut inner) = self.inner.lock() else {
142            return;
143        };
144
145        inner.counter = inner.counter.saturating_add(1);
146        let recency = inner.counter;
147
148        // Replace existing entry directly.
149        if let Some(entry) = inner.entries.get_mut(&key) {
150            entry.value = value;
151            entry.expires_at = expires_at;
152            entry.recency = recency;
153            return;
154        }
155
156        // Evict expired entries first, then LRU if still at capacity.
157        if inner.entries.len() >= self.capacity.get() {
158            evict_expired(&mut inner.entries, now);
159        }
160        if inner.entries.len() >= self.capacity.get() {
161            evict_lru(&mut inner.entries);
162        }
163
164        inner.entries.insert(
165            key,
166            Entry {
167                value,
168                expires_at,
169                recency,
170            },
171        );
172    }
173
174    /// Remove every entry whose TTL has expired relative to the clock's
175    /// current "now". Returns the number of entries removed.
176    pub fn prune(&self) -> usize {
177        let now = self.clock.now();
178        let Ok(mut inner) = self.inner.lock() else {
179            return 0;
180        };
181        evict_expired(&mut inner.entries, now)
182    }
183
184    /// Remove all entries.
185    pub fn clear(&self) {
186        if let Ok(mut inner) = self.inner.lock() {
187            inner.entries.clear();
188        }
189    }
190}
191
192fn evict_expired<K, V>(entries: &mut HashMap<K, Entry<V>>, now: Instant) -> usize
193where
194    K: Eq + Hash + Clone,
195{
196    let expired: Vec<K> = entries
197        .iter()
198        .filter_map(|(k, entry)| {
199            if entry.expires_at <= now {
200                Some(k.clone())
201            } else {
202                None
203            }
204        })
205        .collect();
206    let removed = expired.len();
207    for k in expired {
208        entries.remove(&k);
209    }
210    removed
211}
212
213fn evict_lru<K, V>(entries: &mut HashMap<K, Entry<V>>)
214where
215    K: Eq + Hash + Clone,
216{
217    let victim = entries
218        .iter()
219        .min_by_key(|(_, entry)| entry.recency)
220        .map(|(k, _)| k.clone());
221    if let Some(key) = victim {
222        entries.remove(&key);
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    fn nz(n: usize) -> NonZeroUsize {
231        NonZeroUsize::new(n).expect("non-zero capacity")
232    }
233
234    #[tokio::test(flavor = "current_thread", start_paused = true)]
235    async fn insert_and_get_returns_value() {
236        let cache: TtlCache<&'static str, u32> = TtlCache::new(nz(4));
237        cache.insert("k", 42, Duration::from_secs(30));
238        assert_eq!(cache.get(&"k"), Some(42));
239    }
240
241    #[tokio::test(flavor = "current_thread", start_paused = true)]
242    async fn expired_entry_returns_none() {
243        let cache: TtlCache<&'static str, u32> = TtlCache::new(nz(4));
244        cache.insert("k", 42, Duration::from_secs(1));
245        tokio::time::advance(Duration::from_secs(2)).await;
246        assert_eq!(cache.get(&"k"), None);
247        assert!(cache.is_empty());
248    }
249
250    #[tokio::test(flavor = "current_thread", start_paused = true)]
251    async fn lru_eviction_when_capacity_exceeded() {
252        let cache: TtlCache<&'static str, u32> = TtlCache::new(nz(2));
253        cache.insert("a", 1, Duration::from_secs(60));
254        cache.insert("b", 2, Duration::from_secs(60));
255        // Touch "a" so "b" becomes LRU.
256        let _ = cache.get(&"a");
257        cache.insert("c", 3, Duration::from_secs(60));
258        assert_eq!(cache.get(&"a"), Some(1));
259        assert_eq!(cache.get(&"b"), None);
260        assert_eq!(cache.get(&"c"), Some(3));
261    }
262
263    #[tokio::test(flavor = "current_thread", start_paused = true)]
264    async fn prune_removes_only_expired() {
265        let cache: TtlCache<&'static str, u32> = TtlCache::new(nz(4));
266        cache.insert("short", 1, Duration::from_secs(1));
267        cache.insert("long", 2, Duration::from_secs(60));
268        tokio::time::advance(Duration::from_secs(2)).await;
269        let removed = cache.prune();
270        assert_eq!(removed, 1);
271        assert_eq!(cache.get(&"short"), None);
272        assert_eq!(cache.get(&"long"), Some(2));
273    }
274
275    #[tokio::test(flavor = "current_thread", start_paused = true)]
276    async fn overwrite_updates_value_and_ttl() {
277        let cache: TtlCache<&'static str, u32> = TtlCache::new(nz(2));
278        cache.insert("k", 1, Duration::from_secs(1));
279        cache.insert("k", 2, Duration::from_secs(30));
280        tokio::time::advance(Duration::from_secs(2)).await;
281        assert_eq!(cache.get(&"k"), Some(2));
282    }
283}