Skip to main content

miden_node_utils/
lru_cache.rs

1use std::hash::Hash;
2use std::num::NonZeroUsize;
3use std::sync::{Arc, Mutex, MutexGuard};
4
5use lru::LruCache as InnerCache;
6
7/// A newtype wrapper around an LRU cache. Ensures that the cache lock is not held across await
8/// points.
9#[derive(Clone)]
10pub struct LruCache<K, V>(Arc<Mutex<InnerCache<K, V>>>);
11
12impl<K, V> LruCache<K, V>
13where
14    K: Hash + Eq,
15    V: Clone,
16{
17    /// Creates a new cache with the given capacity.
18    pub fn new(capacity: NonZeroUsize) -> Self {
19        Self(Arc::new(Mutex::new(InnerCache::new(capacity))))
20    }
21
22    /// Retrieves a value from the cache.
23    pub fn get(&self, key: &K) -> Option<V> {
24        self.lock().get(key).cloned()
25    }
26
27    /// Puts a value into the cache.
28    pub fn put(&self, key: K, value: V) {
29        self.lock().put(key, value);
30    }
31
32    /// Retrieves multiple values from the cache while holding the cache lock once.
33    pub fn get_many<'a>(&self, keys: impl IntoIterator<Item = &'a K>) -> Vec<Option<V>>
34    where
35        K: 'a,
36    {
37        let mut cache = self.lock();
38        keys.into_iter().map(|key| cache.get(key).cloned()).collect()
39    }
40
41    /// Puts multiple values into the cache while holding the cache lock once.
42    pub fn put_many(&self, entries: impl IntoIterator<Item = (K, V)>) {
43        let mut cache = self.lock();
44        for (key, value) in entries {
45            cache.put(key, value);
46        }
47    }
48
49    /// Clears all entries from the cache.
50    pub fn clear(&self) {
51        self.lock().clear();
52    }
53
54    #[crate::tracing::miden_instrument(name = "lru.lock", skip_all)]
55    fn lock(&self) -> MutexGuard<'_, InnerCache<K, V>> {
56        // SAFETY: The mutex is only held for the duration of the get/put operation where panics are
57        // possible only if we're running out of memory, in which case the entire process is likely
58        // to be unstable anyway.
59        self.0.lock().expect("LRU cache mutex poisoned")
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use std::num::NonZeroUsize;
66
67    use super::LruCache;
68
69    fn cache(cap: usize) -> LruCache<u32, &'static str> {
70        LruCache::new(NonZeroUsize::new(cap).unwrap())
71    }
72
73    #[tokio::test]
74    async fn get_returns_none_on_empty_cache() {
75        let c = cache(4);
76        assert_eq!(c.get(&1), None);
77    }
78
79    #[tokio::test]
80    async fn get_returns_inserted_value() {
81        let c = cache(4);
82        c.put(1, "a");
83        assert_eq!(c.get(&1), Some("a"));
84    }
85
86    #[tokio::test]
87    async fn evicts_least_recently_used_when_full() {
88        let c = cache(2);
89        c.put(1, "a");
90        c.put(2, "b");
91        c.get(&1); // 1 is now most recently used
92        c.put(3, "c"); // evicts 2 (least recently used)
93        assert_eq!(c.get(&1), Some("a"));
94        assert_eq!(c.get(&2), None);
95        assert_eq!(c.get(&3), Some("c"));
96    }
97
98    #[tokio::test]
99    async fn put_overwrites_existing_value() {
100        let c = cache(4);
101        c.put(1, "a");
102        c.put(1, "b");
103        assert_eq!(c.get(&1), Some("b"));
104    }
105
106    #[tokio::test]
107    async fn clone_shares_state() {
108        let c1 = cache(4);
109        let c2 = c1.clone();
110        c1.put(1, "a");
111        assert_eq!(c2.get(&1), Some("a"));
112    }
113}