Skip to main content

coralstack_cmd_ipc/
ttl_map.rs

1//! A map whose entries expire after a configurable time-to-live.
2//!
3//! Used by the registry to track pending replies, in-flight routed
4//! requests, and recently-seen event IDs (for mesh deduplication).
5//!
6//! Expiry is **lazy**: the map has no background sweep task, so it
7//! introduces no runtime dependency. Entries are removed when `get`,
8//! `has`, or `insert` notices they are stale. An optional
9//! `on_expire(key, value)` callback fires at that moment.
10
11use std::collections::HashMap;
12use std::hash::Hash;
13use std::time::{Duration, Instant};
14
15use parking_lot::Mutex;
16
17type OnExpire<K, V> = Box<dyn Fn(&K, V) + Send + Sync>;
18
19/// A `HashMap` whose entries expire after `ttl`.
20pub struct TtlMap<K, V>
21where
22    K: Eq + Hash,
23{
24    ttl: Duration,
25    inner: Mutex<HashMap<K, (V, Instant)>>,
26    on_expire: Option<OnExpire<K, V>>,
27}
28
29impl<K, V> TtlMap<K, V>
30where
31    K: Eq + Hash + Clone,
32{
33    /// Creates a new map with the given TTL.
34    ///
35    /// A `ttl` of zero disables expiry entirely — entries stay until
36    /// explicitly removed.
37    pub fn new(ttl: Duration) -> Self {
38        Self {
39            ttl,
40            inner: Mutex::new(HashMap::new()),
41            on_expire: None,
42        }
43    }
44
45    /// Sets a callback invoked whenever an entry is removed due to TTL
46    /// expiry.
47    ///
48    /// The registry uses this to reject pending request promises with a
49    /// timeout error.
50    pub fn with_on_expire<F>(mut self, cb: F) -> Self
51    where
52        F: Fn(&K, V) + Send + Sync + 'static,
53    {
54        self.on_expire = Some(Box::new(cb));
55        self
56    }
57
58    fn is_expired(&self, inserted_at: Instant) -> bool {
59        !self.ttl.is_zero() && inserted_at.elapsed() > self.ttl
60    }
61
62    /// Inserts a value, returning the previous entry if any.
63    pub fn insert(&self, key: K, value: V) -> Option<V> {
64        self.inner
65            .lock()
66            .insert(key, (value, Instant::now()))
67            .map(|(v, _)| v)
68    }
69
70    /// Removes and returns the value for `key`, bypassing expiry.
71    pub fn remove(&self, key: &K) -> Option<V> {
72        self.inner.lock().remove(key).map(|(v, _)| v)
73    }
74
75    /// Returns whether `key` is present and unexpired.
76    ///
77    /// Triggers `on_expire` as a side effect if the entry is stale.
78    pub fn contains_key(&self, key: &K) -> bool {
79        self.take_if_expired(key);
80        self.inner.lock().contains_key(key)
81    }
82
83    /// Returns a clone of the value for `key` if present and unexpired.
84    pub fn get_cloned(&self, key: &K) -> Option<V>
85    where
86        V: Clone,
87    {
88        self.take_if_expired(key);
89        self.inner.lock().get(key).map(|(v, _)| v.clone())
90    }
91
92    /// Returns the current size of the map (including any stale entries
93    /// that have not yet been touched).
94    pub fn len(&self) -> usize {
95        self.inner.lock().len()
96    }
97
98    pub fn is_empty(&self) -> bool {
99        self.len() == 0
100    }
101
102    /// Drops all entries without firing `on_expire`.
103    pub fn clear(&self) {
104        self.inner.lock().clear();
105    }
106
107    /// Removes every entry that has exceeded the TTL, firing
108    /// `on_expire` for each. Callers may invoke this periodically to
109    /// bound memory in long-running processes.
110    pub fn sweep_expired(&self) {
111        if self.ttl.is_zero() {
112            return;
113        }
114        let expired: Vec<(K, V)> = {
115            let mut inner = self.inner.lock();
116            let keys: Vec<K> = inner
117                .iter()
118                .filter(|(_, (_, t))| self.is_expired(*t))
119                .map(|(k, _)| k.clone())
120                .collect();
121            keys.into_iter()
122                .filter_map(|k| inner.remove(&k).map(|(v, _)| (k, v)))
123                .collect()
124        };
125        if let Some(cb) = &self.on_expire {
126            for (k, v) in expired {
127                cb(&k, v);
128            }
129        }
130    }
131
132    /// Returns every key whose (unexpired) value satisfies `pred`.
133    ///
134    /// The registry uses this during channel-close cleanup to find all
135    /// pending replies and routes associated with the dead channel.
136    pub fn snapshot_keys_where<F>(&self, pred: F) -> Vec<K>
137    where
138        F: Fn(&V) -> bool,
139    {
140        let inner = self.inner.lock();
141        inner
142            .iter()
143            .filter(|(_, (v, t))| !self.is_expired(*t) && pred(v))
144            .map(|(k, _)| k.clone())
145            .collect()
146    }
147
148    /// If `key` is present but stale, drop it and invoke `on_expire`.
149    fn take_if_expired(&self, key: &K) {
150        if self.ttl.is_zero() {
151            return;
152        }
153        let expired = {
154            let mut inner = self.inner.lock();
155            match inner.get(key) {
156                Some((_, t)) if self.is_expired(*t) => inner.remove(key).map(|(v, _)| v),
157                _ => None,
158            }
159        };
160        if let (Some(v), Some(cb)) = (expired, &self.on_expire) {
161            cb(key, v);
162        }
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use std::sync::atomic::{AtomicUsize, Ordering};
170    use std::sync::Arc;
171    use std::thread::sleep;
172
173    #[test]
174    fn insert_and_get() {
175        let m: TtlMap<&'static str, i32> = TtlMap::new(Duration::from_secs(60));
176        m.insert("a", 1);
177        assert_eq!(m.get_cloned(&"a"), Some(1));
178        assert!(m.contains_key(&"a"));
179    }
180
181    #[test]
182    fn remove_returns_value() {
183        let m: TtlMap<&'static str, i32> = TtlMap::new(Duration::from_secs(60));
184        m.insert("a", 1);
185        assert_eq!(m.remove(&"a"), Some(1));
186        assert!(!m.contains_key(&"a"));
187    }
188
189    #[test]
190    fn zero_ttl_disables_expiry() {
191        let m: TtlMap<&'static str, i32> = TtlMap::new(Duration::ZERO);
192        m.insert("a", 1);
193        sleep(Duration::from_millis(20));
194        assert_eq!(m.get_cloned(&"a"), Some(1));
195    }
196
197    #[test]
198    fn lazy_expiry_drops_stale_entries_on_access() {
199        let fired = Arc::new(AtomicUsize::new(0));
200        let f = fired.clone();
201        let m: TtlMap<&'static str, i32> =
202            TtlMap::new(Duration::from_millis(10)).with_on_expire(move |_, _| {
203                f.fetch_add(1, Ordering::SeqCst);
204            });
205        m.insert("a", 1);
206        sleep(Duration::from_millis(25));
207        assert_eq!(m.get_cloned(&"a"), None);
208        assert_eq!(fired.load(Ordering::SeqCst), 1);
209    }
210
211    #[test]
212    fn sweep_removes_all_stale() {
213        let fired = Arc::new(AtomicUsize::new(0));
214        let f = fired.clone();
215        let m: TtlMap<i32, i32> =
216            TtlMap::new(Duration::from_millis(10)).with_on_expire(move |_, _| {
217                f.fetch_add(1, Ordering::SeqCst);
218            });
219        for i in 0..5 {
220            m.insert(i, i * 10);
221        }
222        sleep(Duration::from_millis(25));
223        m.sweep_expired();
224        assert!(m.is_empty());
225        assert_eq!(fired.load(Ordering::SeqCst), 5);
226    }
227}