Skip to main content

miden_node_utils/
lru_cache.rs

1use std::hash::Hash;
2use std::num::NonZeroUsize;
3use std::sync::{Arc, Mutex, MutexGuard};
4
5use lru::LruCache as InnerCache;
6use tracing::instrument;
7
8/// A newtype wrapper around an LRU cache. Ensures that the cache lock is not held across await
9/// points.
10#[derive(Clone)]
11pub struct LruCache<K, V>(Arc<Mutex<InnerCache<K, V>>>);
12
13impl<K, V> LruCache<K, V>
14where
15    K: Hash + Eq,
16    V: Clone,
17{
18    /// Creates a new cache with the given capacity.
19    pub fn new(capacity: NonZeroUsize) -> Self {
20        Self(Arc::new(Mutex::new(InnerCache::new(capacity))))
21    }
22
23    /// Retrieves a value from the cache.
24    pub fn get(&self, key: &K) -> Option<V> {
25        self.lock().get(key).cloned()
26    }
27
28    /// Puts a value into the cache.
29    pub fn put(&self, key: K, value: V) {
30        self.lock().put(key, value);
31    }
32
33    /// Retrieves multiple values from the cache while holding the cache lock once.
34    pub fn get_many<'a>(&self, keys: impl IntoIterator<Item = &'a K>) -> Vec<Option<V>>
35    where
36        K: 'a,
37    {
38        let mut cache = self.lock();
39        keys.into_iter().map(|key| cache.get(key).cloned()).collect()
40    }
41
42    /// Puts multiple values into the cache while holding the cache lock once.
43    pub fn put_many(&self, entries: impl IntoIterator<Item = (K, V)>) {
44        let mut cache = self.lock();
45        for (key, value) in entries {
46            cache.put(key, value);
47        }
48    }
49
50    /// Clears all entries from the cache.
51    pub fn clear(&self) {
52        self.lock().clear();
53    }
54
55    #[instrument(name = "lru.lock", skip_all)]
56    fn lock(&self) -> MutexGuard<'_, InnerCache<K, V>> {
57        // SAFETY: The mutex is only held for the duration of the get/put operation where panics are
58        // possible only if we're running out of memory, in which case the entire process is likely
59        // to be unstable anyway.
60        self.0.lock().expect("LRU cache mutex poisoned")
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use std::num::NonZeroUsize;
67
68    use super::LruCache;
69
70    fn cache(cap: usize) -> LruCache<u32, &'static str> {
71        LruCache::new(NonZeroUsize::new(cap).unwrap())
72    }
73
74    #[tokio::test]
75    async fn get_returns_none_on_empty_cache() {
76        let c = cache(4);
77        assert_eq!(c.get(&1), None);
78    }
79
80    #[tokio::test]
81    async fn get_returns_inserted_value() {
82        let c = cache(4);
83        c.put(1, "a");
84        assert_eq!(c.get(&1), Some("a"));
85    }
86
87    #[tokio::test]
88    async fn evicts_least_recently_used_when_full() {
89        let c = cache(2);
90        c.put(1, "a");
91        c.put(2, "b");
92        c.get(&1); // 1 is now most recently used
93        c.put(3, "c"); // evicts 2 (least recently used)
94        assert_eq!(c.get(&1), Some("a"));
95        assert_eq!(c.get(&2), None);
96        assert_eq!(c.get(&3), Some("c"));
97    }
98
99    #[tokio::test]
100    async fn put_overwrites_existing_value() {
101        let c = cache(4);
102        c.put(1, "a");
103        c.put(1, "b");
104        assert_eq!(c.get(&1), Some("b"));
105    }
106
107    #[tokio::test]
108    async fn clone_shares_state() {
109        let c1 = cache(4);
110        let c2 = c1.clone();
111        c1.put(1, "a");
112        assert_eq!(c2.get(&1), Some("a"));
113    }
114}