1use foldhash::fast::RandomState;
2use hashbrown::HashTable;
3use hashbrown::hash_table::Entry;
4use slotmap::{Key, SlotMap, new_key_type};
5use std::borrow::Borrow;
6use std::hash::{BuildHasher, Hash};
7
8new_key_type! {
9 struct LruKey;
10}
11
12#[derive(Copy, Clone, Default)]
13struct LruListNode {
14 more_recent: LruKey,
15 less_recent: LruKey,
16}
17
18struct LruEntry<K, V> {
19 key: K,
20 value: V,
21 list: LruListNode,
22}
23
24pub struct LruCache<K, V, S = RandomState> {
25 table: HashTable<LruKey>,
26 elements: SlotMap<LruKey, LruEntry<K, V>>,
27 max_capacity: usize,
28 most_recent: LruKey,
29 least_recent: LruKey,
30 build_hasher: S,
31}
32
33impl<K, V> LruCache<K, V> {
34 pub fn with_capacity(capacity: usize) -> Self {
35 Self::with_capacity_and_hasher(capacity, RandomState::default())
36 }
37}
38
39impl<K, V, S> LruCache<K, V, S> {
40 pub fn with_capacity_and_hasher(max_capacity: usize, build_hasher: S) -> Self {
41 assert!(max_capacity > 0);
42 Self {
43 table: HashTable::with_capacity(max_capacity + 1),
46 elements: SlotMap::with_capacity_and_key(max_capacity + 1),
47 max_capacity,
48 most_recent: LruKey::null(),
49 least_recent: LruKey::null(),
50 build_hasher,
51 }
52 }
53}
54
55impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
56 fn lru_list_unlink(&mut self, lru_key: LruKey) {
57 let list = self.elements[lru_key].list;
58 if let Some(more_recent) = self.elements.get_mut(list.more_recent) {
59 more_recent.list.less_recent = list.less_recent;
60 } else {
61 self.most_recent = list.less_recent;
62 }
63
64 if let Some(less_recent) = self.elements.get_mut(list.less_recent) {
65 less_recent.list.more_recent = list.more_recent;
66 } else {
67 self.least_recent = list.more_recent;
68 }
69 }
70
71 fn lru_list_insert_mru(&mut self, lru_key: LruKey) {
72 let prev_most_recent_key = self.most_recent;
73 self.most_recent = lru_key;
74 if let Some(prev_most_recent) = self.elements.get_mut(prev_most_recent_key) {
75 prev_most_recent.list.more_recent = lru_key;
76 } else {
77 self.least_recent = lru_key;
78 }
79 let list = &mut self.elements[lru_key].list;
80 list.more_recent = LruKey::null();
81 list.less_recent = prev_most_recent_key;
82 }
83
84 pub fn pop_lru(&mut self) -> Option<(K, V)> {
85 if self.elements.is_empty() {
86 return None;
87 }
88 let lru_key = self.least_recent;
89 let hash = self.build_hasher.hash_one(&self.elements[lru_key].key);
90 self.lru_list_unlink(lru_key);
91 let lru_entry = self.elements.remove(lru_key).unwrap();
92 self.table
93 .find_entry(hash, |k| *k == lru_key)
94 .unwrap()
95 .remove();
96 Some((lru_entry.key, lru_entry.value))
97 }
98
99 pub fn get<Q>(&mut self, key: &Q) -> Option<&V>
100 where
101 K: Borrow<Q>,
102 Q: Hash + Eq + ?Sized,
103 {
104 let hash = self.build_hasher.hash_one(key);
105 let lru_key = *self
106 .table
107 .find(hash, |lru_key| self.elements[*lru_key].key.borrow() == key)?;
108 self.lru_list_unlink(lru_key);
109 self.lru_list_insert_mru(lru_key);
110 let lru_node = self.elements.get(lru_key).unwrap();
111 Some(&lru_node.value)
112 }
113
114 pub fn insert(&mut self, key: K, value: V) -> Option<V> {
116 let hash = self.build_hasher.hash_one(&key);
117 match self.table.entry(
118 hash,
119 |lru_key| self.elements[*lru_key].key == key,
120 |lru_key| self.build_hasher.hash_one(&self.elements[*lru_key].key),
121 ) {
122 Entry::Occupied(o) => {
123 let lru_key = *o.get();
124 self.lru_list_unlink(lru_key);
125 self.lru_list_insert_mru(lru_key);
126 Some(core::mem::replace(&mut self.elements[lru_key].value, value))
127 }
128
129 Entry::Vacant(v) => {
130 let lru_entry = LruEntry {
131 key,
132 value,
133 list: LruListNode::default(),
134 };
135 let lru_key = self.elements.insert(lru_entry);
136 v.insert(lru_key);
137 self.lru_list_insert_mru(lru_key);
138 if self.elements.len() > self.max_capacity {
139 self.pop_lru();
140 }
141 None
142 }
143 }
144 }
145
146 pub fn get_or_insert_with<Q, F>(&mut self, key: &Q, f: F) -> &mut V
147 where
148 F: FnOnce(&Q) -> V,
149 K: Borrow<Q>,
150 Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
151 {
152 enum Never {}
153 let Ok(ret) = self.try_get_or_insert_with::<Q, Never, _>(key, |k| Ok(f(k)));
154 ret
155 }
156
157 pub fn try_get_or_insert_with<Q, E, F: FnOnce(&Q) -> Result<V, E>>(
158 &mut self,
159 key: &Q,
160 f: F,
161 ) -> Result<&mut V, E>
162 where
163 K: Borrow<Q>,
164 Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
165 {
166 let hash = self.build_hasher.hash_one(key);
167 match self.table.entry(
168 hash,
169 |lru_key| self.elements[*lru_key].key.borrow() == key,
170 |lru_key| self.build_hasher.hash_one(&self.elements[*lru_key].key),
171 ) {
172 Entry::Occupied(o) => {
173 let lru_key = *o.get();
174 if lru_key != self.most_recent {
175 self.lru_list_unlink(lru_key);
176 self.lru_list_insert_mru(lru_key);
177 }
178 Ok(&mut self.elements[lru_key].value)
179 }
180
181 Entry::Vacant(v) => {
182 let lru_entry = LruEntry {
183 value: f(key)?,
184 key: key.to_owned(),
185 list: LruListNode::default(),
186 };
187 let lru_key = self.elements.insert(lru_entry);
188 v.insert(lru_key);
189 self.lru_list_insert_mru(lru_key);
190 if self.elements.len() > self.max_capacity {
191 self.pop_lru();
192 }
193 Ok(&mut self.elements[lru_key].value)
194 }
195 }
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::LruCache;
202 #[test]
203 fn test_lru_cache_basic() {
204 let mut lru: LruCache<u32, &str> = LruCache::with_capacity(2);
205
206 assert_eq!(lru.insert(1, "one"), None);
207 assert_eq!(lru.insert(2, "two"), None);
208 assert_eq!(lru.get(&1), Some(&"one"));
209 assert_eq!(lru.insert(3, "three"), None); assert_eq!(lru.get(&2), None);
211 assert_eq!(lru.get(&1), Some(&"one"));
212 assert_eq!(lru.get(&3), Some(&"three"));
213 assert_eq!(lru.insert(1, "uno"), Some("one")); assert_eq!(lru.get(&1), Some(&"uno"));
215 }
216
217 #[test]
218 fn test_lru_cache_pop() {
219 let mut lru: LruCache<u32, &str> = LruCache::with_capacity(2);
220
221 assert_eq!(lru.insert(1, "one"), None);
222 assert_eq!(lru.insert(2, "two"), None);
223 assert_eq!(lru.pop_lru(), Some((1, "one"))); assert_eq!(lru.get(&1), None); assert_eq!(lru.get(&2), Some(&"two")); }
227}