Skip to main content

radiate_utils/
lru.rs

1use foldhash::fast::RandomState;
2use hashbrown::HashTable;
3use hashbrown::hash_table::Entry;
4use slotmap::{Key, SlotMap, new_key_type};
5use std::borrow::Borrow;
6use std::hash::{BuildHasher, Hash};
7
8new_key_type! {
9    struct LruKey;
10}
11
12#[derive(Copy, Clone, Default)]
13struct LruListNode {
14    more_recent: LruKey,
15    less_recent: LruKey,
16}
17
18struct LruEntry<K, V> {
19    key: K,
20    value: V,
21    list: LruListNode,
22}
23
24pub struct LruCache<K, V, S = RandomState> {
25    table: HashTable<LruKey>,
26    elements: SlotMap<LruKey, LruEntry<K, V>>,
27    max_capacity: usize,
28    most_recent: LruKey,
29    least_recent: LruKey,
30    build_hasher: S,
31}
32
33impl<K, V> LruCache<K, V> {
34    pub fn with_capacity(capacity: usize) -> Self {
35        Self::with_capacity_and_hasher(capacity, RandomState::default())
36    }
37}
38
39impl<K, V, S> LruCache<K, V, S> {
40    pub fn with_capacity_and_hasher(max_capacity: usize, build_hasher: S) -> Self {
41        assert!(max_capacity > 0);
42        Self {
43            // Allocate one more capacity to prevent double-lookup or realloc
44            // when doing get_or_insert when full.
45            table: HashTable::with_capacity(max_capacity + 1),
46            elements: SlotMap::with_capacity_and_key(max_capacity + 1),
47            max_capacity,
48            most_recent: LruKey::null(),
49            least_recent: LruKey::null(),
50            build_hasher,
51        }
52    }
53}
54
55impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
56    fn lru_list_unlink(&mut self, lru_key: LruKey) {
57        let list = self.elements[lru_key].list;
58        if let Some(more_recent) = self.elements.get_mut(list.more_recent) {
59            more_recent.list.less_recent = list.less_recent;
60        } else {
61            self.most_recent = list.less_recent;
62        }
63
64        if let Some(less_recent) = self.elements.get_mut(list.less_recent) {
65            less_recent.list.more_recent = list.more_recent;
66        } else {
67            self.least_recent = list.more_recent;
68        }
69    }
70
71    fn lru_list_insert_mru(&mut self, lru_key: LruKey) {
72        let prev_most_recent_key = self.most_recent;
73        self.most_recent = lru_key;
74        if let Some(prev_most_recent) = self.elements.get_mut(prev_most_recent_key) {
75            prev_most_recent.list.more_recent = lru_key;
76        } else {
77            self.least_recent = lru_key;
78        }
79        let list = &mut self.elements[lru_key].list;
80        list.more_recent = LruKey::null();
81        list.less_recent = prev_most_recent_key;
82    }
83
84    pub fn pop_lru(&mut self) -> Option<(K, V)> {
85        if self.elements.is_empty() {
86            return None;
87        }
88        let lru_key = self.least_recent;
89        let hash = self.build_hasher.hash_one(&self.elements[lru_key].key);
90        self.lru_list_unlink(lru_key);
91        let lru_entry = self.elements.remove(lru_key).unwrap();
92        self.table
93            .find_entry(hash, |k| *k == lru_key)
94            .unwrap()
95            .remove();
96        Some((lru_entry.key, lru_entry.value))
97    }
98
99    pub fn get<Q>(&mut self, key: &Q) -> Option<&V>
100    where
101        K: Borrow<Q>,
102        Q: Hash + Eq + ?Sized,
103    {
104        let hash = self.build_hasher.hash_one(key);
105        let lru_key = *self
106            .table
107            .find(hash, |lru_key| self.elements[*lru_key].key.borrow() == key)?;
108        self.lru_list_unlink(lru_key);
109        self.lru_list_insert_mru(lru_key);
110        let lru_node = self.elements.get(lru_key).unwrap();
111        Some(&lru_node.value)
112    }
113
114    /// Returns the old value, if any.
115    pub fn insert(&mut self, key: K, value: V) -> Option<V> {
116        let hash = self.build_hasher.hash_one(&key);
117        match self.table.entry(
118            hash,
119            |lru_key| self.elements[*lru_key].key == key,
120            |lru_key| self.build_hasher.hash_one(&self.elements[*lru_key].key),
121        ) {
122            Entry::Occupied(o) => {
123                let lru_key = *o.get();
124                self.lru_list_unlink(lru_key);
125                self.lru_list_insert_mru(lru_key);
126                Some(core::mem::replace(&mut self.elements[lru_key].value, value))
127            }
128
129            Entry::Vacant(v) => {
130                let lru_entry = LruEntry {
131                    key,
132                    value,
133                    list: LruListNode::default(),
134                };
135                let lru_key = self.elements.insert(lru_entry);
136                v.insert(lru_key);
137                self.lru_list_insert_mru(lru_key);
138                if self.elements.len() > self.max_capacity {
139                    self.pop_lru();
140                }
141                None
142            }
143        }
144    }
145
146    pub fn get_or_insert_with<Q, F>(&mut self, key: &Q, f: F) -> &mut V
147    where
148        F: FnOnce(&Q) -> V,
149        K: Borrow<Q>,
150        Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
151    {
152        enum Never {}
153        let Ok(ret) = self.try_get_or_insert_with::<Q, Never, _>(key, |k| Ok(f(k)));
154        ret
155    }
156
157    pub fn try_get_or_insert_with<Q, E, F: FnOnce(&Q) -> Result<V, E>>(
158        &mut self,
159        key: &Q,
160        f: F,
161    ) -> Result<&mut V, E>
162    where
163        K: Borrow<Q>,
164        Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
165    {
166        let hash = self.build_hasher.hash_one(key);
167        match self.table.entry(
168            hash,
169            |lru_key| self.elements[*lru_key].key.borrow() == key,
170            |lru_key| self.build_hasher.hash_one(&self.elements[*lru_key].key),
171        ) {
172            Entry::Occupied(o) => {
173                let lru_key = *o.get();
174                if lru_key != self.most_recent {
175                    self.lru_list_unlink(lru_key);
176                    self.lru_list_insert_mru(lru_key);
177                }
178                Ok(&mut self.elements[lru_key].value)
179            }
180
181            Entry::Vacant(v) => {
182                let lru_entry = LruEntry {
183                    value: f(key)?,
184                    key: key.to_owned(),
185                    list: LruListNode::default(),
186                };
187                let lru_key = self.elements.insert(lru_entry);
188                v.insert(lru_key);
189                self.lru_list_insert_mru(lru_key);
190                if self.elements.len() > self.max_capacity {
191                    self.pop_lru();
192                }
193                Ok(&mut self.elements[lru_key].value)
194            }
195        }
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::LruCache;
202    #[test]
203    fn test_lru_cache_basic() {
204        let mut lru: LruCache<u32, &str> = LruCache::with_capacity(2);
205
206        assert_eq!(lru.insert(1, "one"), None);
207        assert_eq!(lru.insert(2, "two"), None);
208        assert_eq!(lru.get(&1), Some(&"one"));
209        assert_eq!(lru.insert(3, "three"), None); // Evicts key 2
210        assert_eq!(lru.get(&2), None);
211        assert_eq!(lru.get(&1), Some(&"one"));
212        assert_eq!(lru.get(&3), Some(&"three"));
213        assert_eq!(lru.insert(1, "uno"), Some("one")); // Update key 1
214        assert_eq!(lru.get(&1), Some(&"uno"));
215    }
216
217    #[test]
218    fn test_lru_cache_pop() {
219        let mut lru: LruCache<u32, &str> = LruCache::with_capacity(2);
220
221        assert_eq!(lru.insert(1, "one"), None);
222        assert_eq!(lru.insert(2, "two"), None);
223        assert_eq!(lru.pop_lru(), Some((1, "one"))); // Evicts key 1
224        assert_eq!(lru.get(&1), None); // Key 1 should be gone
225        assert_eq!(lru.get(&2), Some(&"two")); // Key 2 should still be present
226    }
227}