1use std::{
2 borrow::Borrow,
3 hash::{
4 Hash,
5 Hasher,
6 },
7 iter::FusedIterator,
8 marker::PhantomData,
9 mem,
10 ptr::{
11 self,
12 NonNull,
13 },
14};
15
16use ahash::{
17 HashMap,
18 HashMapExt,
19};
20
21#[derive(Eq)]
23#[repr(transparent)]
24struct KeyRef<K>(*const K);
25
26impl<K: Hash> Hash for KeyRef<K> {
27 fn hash<H: Hasher>(&self, state: &mut H) {
28 unsafe { &*self.0 }.hash(state)
29 }
30}
31
32impl<K: PartialEq> PartialEq for KeyRef<K> {
33 fn eq(&self, other: &KeyRef<K>) -> bool {
34 unsafe { &*self.0 }.eq(unsafe { &*other.0 })
35 }
36}
37
38#[derive(PartialEq, Eq, Hash)]
39#[repr(transparent)]
40struct KeyValue<K: ?Sized>(K);
41
42impl<K> KeyValue<K>
43where
44 K: ?Sized,
45{
46 fn from_ref(key: &K) -> &Self {
47 unsafe { &*(key as *const K as *const KeyValue<K>) }
49 }
50}
51
52impl<K, L> Borrow<KeyValue<L>> for KeyRef<K>
53where
54 K: Borrow<L>,
55 L: ?Sized,
56{
57 fn borrow(&self) -> &KeyValue<L> {
58 let key = unsafe { &*self.0 }.borrow();
59 KeyValue::from_ref(key)
60 }
61}
62
63struct LruEntry<K, V> {
67 key: mem::MaybeUninit<K>,
68 value: mem::MaybeUninit<V>,
69 prev: *mut LruEntry<K, V>,
70 next: *mut LruEntry<K, V>,
71}
72
73impl<K, V> LruEntry<K, V> {
74 fn new(key: K, value: V) -> Self {
75 Self {
76 key: mem::MaybeUninit::new(key),
77 value: mem::MaybeUninit::new(value),
78 prev: ptr::null_mut(),
79 next: ptr::null_mut(),
80 }
81 }
82
83 fn new_empty() -> Self {
84 Self {
85 key: mem::MaybeUninit::uninit(),
86 value: mem::MaybeUninit::uninit(),
87 prev: ptr::null_mut(),
88 next: ptr::null_mut(),
89 }
90 }
91}
92
93pub struct LruCache<K, V> {
99 map: HashMap<KeyRef<K>, NonNull<LruEntry<K, V>>>,
100 capacity: usize,
101 head: *mut LruEntry<K, V>,
102 tail: *mut LruEntry<K, V>,
103}
104
105impl<K, V> Clone for LruCache<K, V>
106where
107 K: PartialEq + Eq + Hash + Clone,
108 V: Clone,
109{
110 fn clone(&self) -> Self {
111 let mut cloned = Self::new(self.capacity());
112 for (key, value) in self.iter().rev() {
113 cloned.push(key.clone(), value.clone());
114 }
115 cloned
116 }
117}
118
119impl<K, V> LruCache<K, V>
120where
121 K: Eq + Hash,
122{
123 pub fn new(capacity: usize) -> Self {
125 let cache = Self {
126 map: HashMap::with_capacity(capacity),
127 capacity,
128 head: Box::into_raw(Box::new(LruEntry::new_empty())),
129 tail: Box::into_raw(Box::new(LruEntry::new_empty())),
130 };
131
132 unsafe {
133 (*cache.head).next = cache.tail;
134 (*cache.tail).prev = cache.head;
135 }
136 cache
137 }
138
139 pub fn capacity(&self) -> usize {
141 self.capacity
142 }
143
144 pub fn len(&self) -> usize {
146 self.map.len()
147 }
148
149 pub fn put(&mut self, key: K, value: V) -> Option<V> {
154 self.capturing_put(key, value, false).map(|(_, v)| v)
155 }
156
157 pub fn push(&mut self, key: K, value: V) -> Option<(K, V)> {
162 self.capturing_put(key, value, true)
163 }
164
165 fn capturing_put(&mut self, key: K, mut value: V, capture: bool) -> Option<(K, V)> {
166 let entry = self.map.get_mut(&KeyRef(&key));
167 match entry {
168 Some(entry) => {
169 let entry_ptr = entry.as_ptr();
172 let stored_value = unsafe { &mut (*(*entry_ptr).value.as_mut_ptr()) };
173 mem::swap(&mut value, stored_value);
174 self.detach(entry_ptr);
175 self.attach(entry_ptr);
176 Some((key, value))
177 }
178 None => {
179 let (replaced, entry) = self.replace_or_create_entry(key, value);
180 let entry_ptr = entry.as_ptr();
181 self.attach(entry_ptr);
182 let key = unsafe { &*entry_ptr }.key.as_ptr();
183 self.map.insert(KeyRef(key), entry);
184 replaced.filter(|_| capture)
185 }
186 }
187 }
188
189 fn replace_or_create_entry(
190 &mut self,
191 key: K,
192 value: V,
193 ) -> (Option<(K, V)>, NonNull<LruEntry<K, V>>) {
194 if self.len() == self.capacity() {
195 let old_key = KeyRef(unsafe { &(*(*(*self.tail).prev).key.as_ptr()) });
197 let old_entry = self.map.remove(&old_key).unwrap();
198 let entry_ptr = old_entry.as_ptr();
199 let replaced = unsafe {
200 (
201 mem::replace(&mut (*entry_ptr).key, mem::MaybeUninit::new(key)).assume_init(),
202 mem::replace(&mut (*entry_ptr).value, mem::MaybeUninit::new(value))
203 .assume_init(),
204 )
205 };
206 self.detach(entry_ptr);
207 (Some(replaced), old_entry)
208 } else {
209 (None, unsafe {
210 NonNull::new_unchecked(Box::into_raw(Box::new(LruEntry::new(key, value))))
211 })
212 }
213 }
214
215 fn detach(&mut self, entry: *mut LruEntry<K, V>) {
216 unsafe {
217 (*(*entry).prev).next = (*entry).next;
218 (*(*entry).next).prev = (*entry).prev;
219 }
220 }
221
222 fn attach(&mut self, entry: *mut LruEntry<K, V>) {
223 unsafe {
224 (*entry).next = (*self.head).next;
225 (*entry).prev = self.head;
226 (*self.head).next = entry;
227 (*(*entry).next).prev = entry;
228 }
229 }
230
231 pub fn contains_key<'a, L>(&'a self, key: &L) -> bool
233 where
234 K: Borrow<L>,
235 L: Eq + Hash + ?Sized,
236 {
237 self.map.contains_key(KeyValue::from_ref(key))
238 }
239
240 pub fn get<'a, L>(&'a mut self, key: &L) -> Option<&'a V>
244 where
245 K: Borrow<L>,
246 L: Eq + Hash + ?Sized,
247 {
248 if let Some(entry) = self.map.get_mut(KeyValue::from_ref(key)) {
249 let entry_ptr = entry.as_ptr();
250 self.detach(entry_ptr);
251 self.attach(entry_ptr);
252 Some(unsafe { &*(*entry_ptr).value.as_ptr() })
253 } else {
254 None
255 }
256 }
257
258 pub fn get_mut<'a, L>(&'a mut self, key: &L) -> Option<&'a mut V>
262 where
263 K: Borrow<L>,
264 L: Eq + Hash + ?Sized,
265 {
266 if let Some(entry) = self.map.get_mut(KeyValue::from_ref(key)) {
267 let entry_ptr = entry.as_ptr();
268 self.detach(entry_ptr);
269 self.attach(entry_ptr);
270 Some(unsafe { &mut *(*entry_ptr).value.as_mut_ptr() })
271 } else {
272 None
273 }
274 }
275
276 pub fn iter(&self) -> Iter<'_, K, V> {
278 Iter {
279 len: self.len(),
280 ptr: unsafe { (*self.head).next },
281 end: unsafe { (*self.tail).prev },
282 phantom: PhantomData,
283 }
284 }
285
286 pub fn iter_mut(&self) -> IterMut<'_, K, V> {
289 IterMut {
290 len: self.len(),
291 ptr: unsafe { (*self.head).next },
292 end: unsafe { (*self.tail).prev },
293 phantom: PhantomData,
294 }
295 }
296}
297
298impl<K, V> Drop for LruCache<K, V> {
299 fn drop(&mut self) {
300 self.map.drain().for_each(|(_, entry)| unsafe {
301 let mut entry = *Box::from_raw(entry.as_ptr());
302 ptr::drop_in_place(entry.key.as_mut_ptr());
303 ptr::drop_in_place(entry.value.as_mut_ptr());
304 });
305 unsafe { drop(Box::from_raw(self.head)) };
306 unsafe { drop(Box::from_raw(self.tail)) };
307 }
308}
309
310impl<'a, K, V> IntoIterator for &'a LruCache<K, V>
311where
312 K: Eq + Hash,
313{
314 type Item = (&'a K, &'a V);
315 type IntoIter = Iter<'a, K, V>;
316
317 fn into_iter(self) -> Self::IntoIter {
318 self.iter()
319 }
320}
321
322impl<'a, K, V> IntoIterator for &'a mut LruCache<K, V>
323where
324 K: Eq + Hash,
325{
326 type Item = (&'a K, &'a mut V);
327 type IntoIter = IterMut<'a, K, V>;
328
329 fn into_iter(self) -> Self::IntoIter {
330 self.iter_mut()
331 }
332}
333
334unsafe impl<K: Send, V: Send> Send for LruCache<K, V> {}
335unsafe impl<K: Sync, V: Sync> Sync for LruCache<K, V> {}
336
337pub struct Iter<'a, K, V>
339where
340 K: 'a,
341 V: 'a,
342{
343 len: usize,
344 ptr: *const LruEntry<K, V>,
345 end: *const LruEntry<K, V>,
346 phantom: PhantomData<&'a K>,
347}
348
349impl<'a, K, V> Iterator for Iter<'a, K, V> {
350 type Item = (&'a K, &'a V);
351
352 fn next(&mut self) -> Option<Self::Item> {
353 if self.len == 0 {
354 return None;
355 }
356
357 let key = unsafe { &(*(*self.ptr).key.as_ptr()) as &K };
358 let value = unsafe { &(*(*self.ptr).value.as_ptr()) as &V };
359 self.len -= 1;
360 self.ptr = unsafe { (*self.ptr).next };
361 Some((key, value))
362 }
363
364 fn size_hint(&self) -> (usize, Option<usize>) {
365 (self.len, Some(self.len))
366 }
367
368 fn count(self) -> usize {
369 self.len
370 }
371}
372
373impl<'a, K, V> DoubleEndedIterator for Iter<'a, K, V> {
374 fn next_back(&mut self) -> Option<Self::Item> {
375 if self.len == 0 {
376 return None;
377 }
378
379 let key = unsafe { &(*(*self.end).key.as_ptr()) };
380 let value = unsafe { &(*(*self.end).value.as_ptr()) };
381 self.len -= 1;
382 self.end = unsafe { (*self.end).prev };
383 Some((key, value))
384 }
385}
386
387impl<'a, K, V> ExactSizeIterator for Iter<'a, K, V> {}
388impl<'a, K, V> FusedIterator for Iter<'a, K, V> {}
389
390unsafe impl<'a, K: Send, V: Send> Send for Iter<'a, K, V> {}
391unsafe impl<'a, K: Sync, V: Sync> Sync for Iter<'a, K, V> {}
392
393pub struct IterMut<'a, K, V>
395where
396 K: 'a,
397 V: 'a,
398{
399 len: usize,
400 ptr: *mut LruEntry<K, V>,
401 end: *mut LruEntry<K, V>,
402 phantom: PhantomData<&'a K>,
403}
404
405impl<'a, K, V> Iterator for IterMut<'a, K, V> {
406 type Item = (&'a K, &'a mut V);
407
408 fn next(&mut self) -> Option<Self::Item> {
409 if self.len == 0 {
410 return None;
411 }
412
413 let key = unsafe { &(*(*self.ptr).key.as_ptr()) };
414 let value = unsafe { &mut (*(*self.ptr).value.as_mut_ptr()) };
415 self.len -= 1;
416 self.ptr = unsafe { (*self.ptr).next };
417 Some((key, value))
418 }
419
420 fn size_hint(&self) -> (usize, Option<usize>) {
421 (self.len, Some(self.len))
422 }
423
424 fn count(self) -> usize {
425 self.len
426 }
427}
428
429impl<'a, K, V> DoubleEndedIterator for IterMut<'a, K, V> {
430 fn next_back(&mut self) -> Option<Self::Item> {
431 if self.len == 0 {
432 return None;
433 }
434
435 let key = unsafe { &(*(*self.end).key.as_ptr()) };
436 let value = unsafe { &mut (*(*self.end).value.as_mut_ptr()) };
437 self.len -= 1;
438 self.end = unsafe { (*self.end).prev };
439 Some((key, value))
440 }
441}
442
443impl<'a, K, V> ExactSizeIterator for IterMut<'a, K, V> {}
444impl<'a, K, V> FusedIterator for IterMut<'a, K, V> {}
445
446unsafe impl<'a, K: Send, V: Send> Send for IterMut<'a, K, V> {}
447unsafe impl<'a, K: Sync, V: Sync> Sync for IterMut<'a, K, V> {}
448
449#[cfg(test)]
450mod lru_cache_test {
451 use crate::common::LruCache;
452
453 #[test]
454 fn removes_least_recently_used_by_capacity() {
455 let mut cache = LruCache::new(2);
456 assert_eq!(cache.capacity(), 2);
457 assert_eq!(cache.len(), 0);
458
459 assert!(!cache.contains_key("a"));
460 assert_eq!(cache.push("a", 1), None);
461 assert!(cache.contains_key("a"));
462 assert_eq!(cache.len(), 1);
463 assert!(!cache.contains_key("b"));
464 assert_eq!(cache.push("b", 2), None);
465 assert!(cache.contains_key("b"));
466 assert_eq!(cache.len(), 2);
467 assert_eq!(cache.get("a"), Some(&1));
468 assert_eq!(cache.get("b"), Some(&2));
469
470 assert_eq!(cache.push("b", 3), Some(("b", 2)));
471 assert_eq!(cache.push("b", 4), Some(("b", 3)));
472 assert_eq!(cache.get("a"), Some(&1));
473 assert_eq!(cache.get("b"), Some(&4));
474 assert_eq!(
475 cache.iter().map(|(k, v)| (*k, *v)).collect::<Vec<_>>(),
476 vec![("b", 4), ("a", 1)]
477 );
478
479 assert_eq!(cache.push("c", 5), Some(("a", 1)));
480 assert_eq!(cache.get("a"), None);
481 assert_eq!(cache.get("b"), Some(&4));
482 assert_eq!(cache.get("c"), Some(&5));
483 assert_eq!(
484 cache.iter().map(|(k, v)| (*k, *v)).collect::<Vec<_>>(),
485 vec![("c", 5), ("b", 4)]
486 );
487 }
488
489 #[test]
490 fn iterates_in_most_recently_used_order() {
491 let mut cache = LruCache::new(5);
492 assert_eq!(cache.put(1, "a"), None);
493 assert_eq!(cache.put(2, "b"), None);
494 assert_eq!(cache.put(3, "c"), None);
495 assert_eq!(cache.put(4, "d"), None);
496 assert_eq!(cache.put(5, "e"), None);
497 assert_eq!(
498 cache.iter().map(|(k, v)| (*k, *v)).collect::<Vec<_>>(),
499 vec![(5, "e"), (4, "d"), (3, "c"), (2, "b"), (1, "a")]
500 );
501
502 assert_eq!(cache.put(3, "f"), Some("c"));
503 assert_eq!(cache.put(6, "g"), None);
504 assert_eq!(
505 cache.iter().map(|(k, v)| (*k, *v)).collect::<Vec<_>>(),
506 vec![(6, "g"), (3, "f"), (5, "e"), (4, "d"), (2, "b"),]
507 );
508 }
509
510 #[test]
511 fn mutably_iterates_in_most_recently_used_order() {
512 let mut cache = LruCache::new(5);
513 assert_eq!(cache.put(1, 1), None);
514 assert_eq!(cache.put(2, 2), None);
515 assert_eq!(cache.put(3, 3), None);
516 assert_eq!(cache.put(4, 4), None);
517 assert_eq!(cache.put(5, 5), None);
518 for (_, v) in cache.iter_mut() {
519 *v *= 2;
520 }
521 assert_eq!(
522 cache.iter_mut().map(|(k, v)| (*k, *v)).collect::<Vec<_>>(),
523 vec![(5, 10), (4, 8), (3, 6), (2, 4), (1, 2)]
524 );
525 }
526}