battler/common/
lru.rs

1use std::{
2    borrow::Borrow,
3    hash::{
4        Hash,
5        Hasher,
6    },
7    iter::FusedIterator,
8    marker::PhantomData,
9    mem,
10    ptr::{
11        self,
12        NonNull,
13    },
14};
15
16use ahash::{
17    HashMap,
18    HashMapExt,
19};
20
21/// A reference to a key.
22#[derive(Eq)]
23#[repr(transparent)]
24struct KeyRef<K>(*const K);
25
26impl<K: Hash> Hash for KeyRef<K> {
27    fn hash<H: Hasher>(&self, state: &mut H) {
28        unsafe { &*self.0 }.hash(state)
29    }
30}
31
32impl<K: PartialEq> PartialEq for KeyRef<K> {
33    fn eq(&self, other: &KeyRef<K>) -> bool {
34        unsafe { &*self.0 }.eq(unsafe { &*other.0 })
35    }
36}
37
38#[derive(PartialEq, Eq, Hash)]
39#[repr(transparent)]
40struct KeyValue<K: ?Sized>(K);
41
42impl<K> KeyValue<K>
43where
44    K: ?Sized,
45{
46    fn from_ref(key: &K) -> &Self {
47        // Transparent representation makes this cast valid.
48        unsafe { &*(key as *const K as *const KeyValue<K>) }
49    }
50}
51
52impl<K, L> Borrow<KeyValue<L>> for KeyRef<K>
53where
54    K: Borrow<L>,
55    L: ?Sized,
56{
57    fn borrow(&self) -> &KeyValue<L> {
58        let key = unsafe { &*self.0 }.borrow();
59        KeyValue::from_ref(key)
60    }
61}
62
63/// An entry in an LRU cache.
64///
65/// Holds a key-value pair, and a reference to the previous and next entry for linked list ordering.
66struct LruEntry<K, V> {
67    key: mem::MaybeUninit<K>,
68    value: mem::MaybeUninit<V>,
69    prev: *mut LruEntry<K, V>,
70    next: *mut LruEntry<K, V>,
71}
72
73impl<K, V> LruEntry<K, V> {
74    fn new(key: K, value: V) -> Self {
75        Self {
76            key: mem::MaybeUninit::new(key),
77            value: mem::MaybeUninit::new(value),
78            prev: ptr::null_mut(),
79            next: ptr::null_mut(),
80        }
81    }
82
83    fn new_empty() -> Self {
84        Self {
85            key: mem::MaybeUninit::uninit(),
86            value: mem::MaybeUninit::uninit(),
87            prev: ptr::null_mut(),
88            next: ptr::null_mut(),
89        }
90    }
91}
92
93/// An LRU (least-recently-used) cache.
94///
95/// Implemented by maintaining a doubly-linked list of cache entries. On access, entries are moved
96/// to the head of list. Once the cache reaches capacity, entries at the back of the list will be
97/// evicted first to make room for newer entries.
98pub struct LruCache<K, V> {
99    map: HashMap<KeyRef<K>, NonNull<LruEntry<K, V>>>,
100    capacity: usize,
101    head: *mut LruEntry<K, V>,
102    tail: *mut LruEntry<K, V>,
103}
104
105impl<K, V> Clone for LruCache<K, V>
106where
107    K: PartialEq + Eq + Hash + Clone,
108    V: Clone,
109{
110    fn clone(&self) -> Self {
111        let mut cloned = Self::new(self.capacity());
112        for (key, value) in self.iter().rev() {
113            cloned.push(key.clone(), value.clone());
114        }
115        cloned
116    }
117}
118
119impl<K, V> LruCache<K, V>
120where
121    K: Eq + Hash,
122{
123    /// Creates a new LRU cache with the given capacity.
124    pub fn new(capacity: usize) -> Self {
125        let cache = Self {
126            map: HashMap::with_capacity(capacity),
127            capacity,
128            head: Box::into_raw(Box::new(LruEntry::new_empty())),
129            tail: Box::into_raw(Box::new(LruEntry::new_empty())),
130        };
131
132        unsafe {
133            (*cache.head).next = cache.tail;
134            (*cache.tail).prev = cache.head;
135        }
136        cache
137    }
138
139    /// The capacity of the cache.
140    pub fn capacity(&self) -> usize {
141        self.capacity
142    }
143
144    /// The length of the map.
145    pub fn len(&self) -> usize {
146        self.map.len()
147    }
148
149    /// Puts a key-value pair into the cache.
150    ///
151    /// If the key already exists in the cache, it is updated and the old value is returned.
152    /// Otherwise, `None` is returned.
153    pub fn put(&mut self, key: K, value: V) -> Option<V> {
154        self.capturing_put(key, value, false).map(|(_, v)| v)
155    }
156
157    /// Pushes a key-value pair into the cache.
158    ///
159    /// If the key already exists in the cache or another entry is removed (due to capacity),
160    /// then the old key-value pair is returned. Otherwise, returns `None`.
161    pub fn push(&mut self, key: K, value: V) -> Option<(K, V)> {
162        self.capturing_put(key, value, true)
163    }
164
165    fn capturing_put(&mut self, key: K, mut value: V, capture: bool) -> Option<(K, V)> {
166        let entry = self.map.get_mut(&KeyRef(&key));
167        match entry {
168            Some(entry) => {
169                // The key is already in the cache, so just update it and move it to the front of
170                // the list.
171                let entry_ptr = entry.as_ptr();
172                let stored_value = unsafe { &mut (*(*entry_ptr).value.as_mut_ptr()) };
173                mem::swap(&mut value, stored_value);
174                self.detach(entry_ptr);
175                self.attach(entry_ptr);
176                Some((key, value))
177            }
178            None => {
179                let (replaced, entry) = self.replace_or_create_entry(key, value);
180                let entry_ptr = entry.as_ptr();
181                self.attach(entry_ptr);
182                let key = unsafe { &*entry_ptr }.key.as_ptr();
183                self.map.insert(KeyRef(key), entry);
184                replaced.filter(|_| capture)
185            }
186        }
187    }
188
189    fn replace_or_create_entry(
190        &mut self,
191        key: K,
192        value: V,
193    ) -> (Option<(K, V)>, NonNull<LruEntry<K, V>>) {
194        if self.len() == self.capacity() {
195            // Cache is full, remove the last entry.
196            let old_key = KeyRef(unsafe { &(*(*(*self.tail).prev).key.as_ptr()) });
197            let old_entry = self.map.remove(&old_key).unwrap();
198            let entry_ptr = old_entry.as_ptr();
199            let replaced = unsafe {
200                (
201                    mem::replace(&mut (*entry_ptr).key, mem::MaybeUninit::new(key)).assume_init(),
202                    mem::replace(&mut (*entry_ptr).value, mem::MaybeUninit::new(value))
203                        .assume_init(),
204                )
205            };
206            self.detach(entry_ptr);
207            (Some(replaced), old_entry)
208        } else {
209            (None, unsafe {
210                NonNull::new_unchecked(Box::into_raw(Box::new(LruEntry::new(key, value))))
211            })
212        }
213    }
214
215    fn detach(&mut self, entry: *mut LruEntry<K, V>) {
216        unsafe {
217            (*(*entry).prev).next = (*entry).next;
218            (*(*entry).next).prev = (*entry).prev;
219        }
220    }
221
222    fn attach(&mut self, entry: *mut LruEntry<K, V>) {
223        unsafe {
224            (*entry).next = (*self.head).next;
225            (*entry).prev = self.head;
226            (*self.head).next = entry;
227            (*(*entry).next).prev = entry;
228        }
229    }
230
231    /// Checks if the given key is contained in the cache.
232    pub fn contains_key<'a, L>(&'a self, key: &L) -> bool
233    where
234        K: Borrow<L>,
235        L: Eq + Hash + ?Sized,
236    {
237        self.map.contains_key(KeyValue::from_ref(key))
238    }
239
240    /// Returns a reference to the value associated with the given key.
241    ///
242    /// Moves the key to the head of the LRU list if it exists. Otherwise, returns [`None`].
243    pub fn get<'a, L>(&'a mut self, key: &L) -> Option<&'a V>
244    where
245        K: Borrow<L>,
246        L: Eq + Hash + ?Sized,
247    {
248        if let Some(entry) = self.map.get_mut(KeyValue::from_ref(key)) {
249            let entry_ptr = entry.as_ptr();
250            self.detach(entry_ptr);
251            self.attach(entry_ptr);
252            Some(unsafe { &*(*entry_ptr).value.as_ptr() })
253        } else {
254            None
255        }
256    }
257
258    /// Returns a mutable reference to the value associated with the given key.
259    ///
260    /// Moves the key to the head of the LRU list if it exists. Otherwise, returns [`None`].
261    pub fn get_mut<'a, L>(&'a mut self, key: &L) -> Option<&'a mut V>
262    where
263        K: Borrow<L>,
264        L: Eq + Hash + ?Sized,
265    {
266        if let Some(entry) = self.map.get_mut(KeyValue::from_ref(key)) {
267            let entry_ptr = entry.as_ptr();
268            self.detach(entry_ptr);
269            self.attach(entry_ptr);
270            Some(unsafe { &mut *(*entry_ptr).value.as_mut_ptr() })
271        } else {
272            None
273        }
274    }
275
276    /// Returns an iterator visiting all entries in most-recently used order.
277    pub fn iter(&self) -> Iter<'_, K, V> {
278        Iter {
279            len: self.len(),
280            ptr: unsafe { (*self.head).next },
281            end: unsafe { (*self.tail).prev },
282            phantom: PhantomData,
283        }
284    }
285
286    /// Returns an iterator visiting all entries in most-recently used order, with a mutable
287    /// reference to the value.
288    pub fn iter_mut(&self) -> IterMut<'_, K, V> {
289        IterMut {
290            len: self.len(),
291            ptr: unsafe { (*self.head).next },
292            end: unsafe { (*self.tail).prev },
293            phantom: PhantomData,
294        }
295    }
296}
297
298impl<K, V> Drop for LruCache<K, V> {
299    fn drop(&mut self) {
300        self.map.drain().for_each(|(_, entry)| unsafe {
301            let mut entry = *Box::from_raw(entry.as_ptr());
302            ptr::drop_in_place(entry.key.as_mut_ptr());
303            ptr::drop_in_place(entry.value.as_mut_ptr());
304        });
305        unsafe { drop(Box::from_raw(self.head)) };
306        unsafe { drop(Box::from_raw(self.tail)) };
307    }
308}
309
310impl<'a, K, V> IntoIterator for &'a LruCache<K, V>
311where
312    K: Eq + Hash,
313{
314    type Item = (&'a K, &'a V);
315    type IntoIter = Iter<'a, K, V>;
316
317    fn into_iter(self) -> Self::IntoIter {
318        self.iter()
319    }
320}
321
322impl<'a, K, V> IntoIterator for &'a mut LruCache<K, V>
323where
324    K: Eq + Hash,
325{
326    type Item = (&'a K, &'a mut V);
327    type IntoIter = IterMut<'a, K, V>;
328
329    fn into_iter(self) -> Self::IntoIter {
330        self.iter_mut()
331    }
332}
333
334unsafe impl<K: Send, V: Send> Send for LruCache<K, V> {}
335unsafe impl<K: Sync, V: Sync> Sync for LruCache<K, V> {}
336
337/// An iterator over entries in an [`LruCache`].
338pub struct Iter<'a, K, V>
339where
340    K: 'a,
341    V: 'a,
342{
343    len: usize,
344    ptr: *const LruEntry<K, V>,
345    end: *const LruEntry<K, V>,
346    phantom: PhantomData<&'a K>,
347}
348
349impl<'a, K, V> Iterator for Iter<'a, K, V> {
350    type Item = (&'a K, &'a V);
351
352    fn next(&mut self) -> Option<Self::Item> {
353        if self.len == 0 {
354            return None;
355        }
356
357        let key = unsafe { &(*(*self.ptr).key.as_ptr()) as &K };
358        let value = unsafe { &(*(*self.ptr).value.as_ptr()) as &V };
359        self.len -= 1;
360        self.ptr = unsafe { (*self.ptr).next };
361        Some((key, value))
362    }
363
364    fn size_hint(&self) -> (usize, Option<usize>) {
365        (self.len, Some(self.len))
366    }
367
368    fn count(self) -> usize {
369        self.len
370    }
371}
372
373impl<'a, K, V> DoubleEndedIterator for Iter<'a, K, V> {
374    fn next_back(&mut self) -> Option<Self::Item> {
375        if self.len == 0 {
376            return None;
377        }
378
379        let key = unsafe { &(*(*self.end).key.as_ptr()) };
380        let value = unsafe { &(*(*self.end).value.as_ptr()) };
381        self.len -= 1;
382        self.end = unsafe { (*self.end).prev };
383        Some((key, value))
384    }
385}
386
387impl<'a, K, V> ExactSizeIterator for Iter<'a, K, V> {}
388impl<'a, K, V> FusedIterator for Iter<'a, K, V> {}
389
390unsafe impl<'a, K: Send, V: Send> Send for Iter<'a, K, V> {}
391unsafe impl<'a, K: Sync, V: Sync> Sync for Iter<'a, K, V> {}
392
393/// A mutable iterator over entries in an [`LruCache`].
394pub struct IterMut<'a, K, V>
395where
396    K: 'a,
397    V: 'a,
398{
399    len: usize,
400    ptr: *mut LruEntry<K, V>,
401    end: *mut LruEntry<K, V>,
402    phantom: PhantomData<&'a K>,
403}
404
405impl<'a, K, V> Iterator for IterMut<'a, K, V> {
406    type Item = (&'a K, &'a mut V);
407
408    fn next(&mut self) -> Option<Self::Item> {
409        if self.len == 0 {
410            return None;
411        }
412
413        let key = unsafe { &(*(*self.ptr).key.as_ptr()) };
414        let value = unsafe { &mut (*(*self.ptr).value.as_mut_ptr()) };
415        self.len -= 1;
416        self.ptr = unsafe { (*self.ptr).next };
417        Some((key, value))
418    }
419
420    fn size_hint(&self) -> (usize, Option<usize>) {
421        (self.len, Some(self.len))
422    }
423
424    fn count(self) -> usize {
425        self.len
426    }
427}
428
429impl<'a, K, V> DoubleEndedIterator for IterMut<'a, K, V> {
430    fn next_back(&mut self) -> Option<Self::Item> {
431        if self.len == 0 {
432            return None;
433        }
434
435        let key = unsafe { &(*(*self.end).key.as_ptr()) };
436        let value = unsafe { &mut (*(*self.end).value.as_mut_ptr()) };
437        self.len -= 1;
438        self.end = unsafe { (*self.end).prev };
439        Some((key, value))
440    }
441}
442
443impl<'a, K, V> ExactSizeIterator for IterMut<'a, K, V> {}
444impl<'a, K, V> FusedIterator for IterMut<'a, K, V> {}
445
446unsafe impl<'a, K: Send, V: Send> Send for IterMut<'a, K, V> {}
447unsafe impl<'a, K: Sync, V: Sync> Sync for IterMut<'a, K, V> {}
448
449#[cfg(test)]
450mod lru_cache_test {
451    use crate::common::LruCache;
452
453    #[test]
454    fn removes_least_recently_used_by_capacity() {
455        let mut cache = LruCache::new(2);
456        assert_eq!(cache.capacity(), 2);
457        assert_eq!(cache.len(), 0);
458
459        assert!(!cache.contains_key("a"));
460        assert_eq!(cache.push("a", 1), None);
461        assert!(cache.contains_key("a"));
462        assert_eq!(cache.len(), 1);
463        assert!(!cache.contains_key("b"));
464        assert_eq!(cache.push("b", 2), None);
465        assert!(cache.contains_key("b"));
466        assert_eq!(cache.len(), 2);
467        assert_eq!(cache.get("a"), Some(&1));
468        assert_eq!(cache.get("b"), Some(&2));
469
470        assert_eq!(cache.push("b", 3), Some(("b", 2)));
471        assert_eq!(cache.push("b", 4), Some(("b", 3)));
472        assert_eq!(cache.get("a"), Some(&1));
473        assert_eq!(cache.get("b"), Some(&4));
474        assert_eq!(
475            cache.iter().map(|(k, v)| (*k, *v)).collect::<Vec<_>>(),
476            vec![("b", 4), ("a", 1)]
477        );
478
479        assert_eq!(cache.push("c", 5), Some(("a", 1)));
480        assert_eq!(cache.get("a"), None);
481        assert_eq!(cache.get("b"), Some(&4));
482        assert_eq!(cache.get("c"), Some(&5));
483        assert_eq!(
484            cache.iter().map(|(k, v)| (*k, *v)).collect::<Vec<_>>(),
485            vec![("c", 5), ("b", 4)]
486        );
487    }
488
489    #[test]
490    fn iterates_in_most_recently_used_order() {
491        let mut cache = LruCache::new(5);
492        assert_eq!(cache.put(1, "a"), None);
493        assert_eq!(cache.put(2, "b"), None);
494        assert_eq!(cache.put(3, "c"), None);
495        assert_eq!(cache.put(4, "d"), None);
496        assert_eq!(cache.put(5, "e"), None);
497        assert_eq!(
498            cache.iter().map(|(k, v)| (*k, *v)).collect::<Vec<_>>(),
499            vec![(5, "e"), (4, "d"), (3, "c"), (2, "b"), (1, "a")]
500        );
501
502        assert_eq!(cache.put(3, "f"), Some("c"));
503        assert_eq!(cache.put(6, "g"), None);
504        assert_eq!(
505            cache.iter().map(|(k, v)| (*k, *v)).collect::<Vec<_>>(),
506            vec![(6, "g"), (3, "f"), (5, "e"), (4, "d"), (2, "b"),]
507        );
508    }
509
510    #[test]
511    fn mutably_iterates_in_most_recently_used_order() {
512        let mut cache = LruCache::new(5);
513        assert_eq!(cache.put(1, 1), None);
514        assert_eq!(cache.put(2, 2), None);
515        assert_eq!(cache.put(3, 3), None);
516        assert_eq!(cache.put(4, 4), None);
517        assert_eq!(cache.put(5, 5), None);
518        for (_, v) in cache.iter_mut() {
519            *v *= 2;
520        }
521        assert_eq!(
522            cache.iter_mut().map(|(k, v)| (*k, *v)).collect::<Vec<_>>(),
523            vec![(5, 10), (4, 8), (3, 6), (2, 4), (1, 2)]
524        );
525    }
526}