ref_stable_lru/
lib.rs

1use std::collections::HashMap;
2use std::hash::{Hash, Hasher};
3use std::marker::PhantomData;
4use std::mem;
5use std::num::NonZeroUsize;
6use std::ptr;
7use std::ptr::NonNull;
8
9type InvariantLifetime<'brand> = PhantomData<fn(&'brand ()) -> &'brand ()>;
10
11pub struct CacheHandle<'cache, 'brand, K, V> {
12    _lifetime: InvariantLifetime<'brand>,
13    cache: &'cache mut LruCache<K, V>,
14}
15
16pub struct ValuePerm<'brand> {
17    _lifetime: InvariantLifetime<'brand>,
18}
19
20// Struct used to hold a reference to a key
21struct KeyRef<K> {
22    k: *const K,
23}
24
25impl<K: Hash> Hash for KeyRef<K> {
26    fn hash<H: Hasher>(&self, state: &mut H) {
27        unsafe { (*self.k).hash(state) }
28    }
29}
30
31impl<K: PartialEq> PartialEq for KeyRef<K> {
32    fn eq(&self, other: &KeyRef<K>) -> bool {
33        unsafe { (*self.k).eq(&*other.k) }
34    }
35}
36
37impl<K: Eq> Eq for KeyRef<K> {}
38
39// Struct used to hold a key value pair. Also contains references to previous and next entries
40// so we can maintain the entries in a linked list ordered by their use.
41struct LruEntry<K, V> {
42    key: mem::MaybeUninit<K>,
43    val: mem::MaybeUninit<V>,
44    prev: *mut LruEntry<K, V>,
45    next: *mut LruEntry<K, V>,
46}
47
48impl<K, V> LruEntry<K, V> {
49    fn new(key: K, val: V) -> Self {
50        LruEntry {
51            key: mem::MaybeUninit::new(key),
52            val: mem::MaybeUninit::new(val),
53            prev: ptr::null_mut(),
54            next: ptr::null_mut(),
55        }
56    }
57
58    fn new_sigil() -> Self {
59        LruEntry {
60            key: mem::MaybeUninit::uninit(),
61            val: mem::MaybeUninit::uninit(),
62            prev: ptr::null_mut(),
63            next: ptr::null_mut(),
64        }
65    }
66}
67
68pub struct LruCache<K, V> {
69    map: HashMap<KeyRef<K>, NonNull<LruEntry<K, V>>>,
70    cap: NonZeroUsize,
71
72    // head and tail are sigil nodes to facilitate inserting entries
73    head: *mut LruEntry<K, V>,
74    tail: *mut LruEntry<K, V>,
75}
76
77impl<K: Eq + Hash, V> LruCache<K, V> {
78    pub fn new(cap: NonZeroUsize) -> Self {
79        let cache = LruCache::<K, V> {
80            map: HashMap::with_capacity(cap.get()),
81            cap,
82            head: Box::into_raw(Box::new(LruEntry::new_sigil())),
83            tail: Box::into_raw(Box::new(LruEntry::new_sigil())),
84        };
85
86        unsafe {
87            (*cache.head).next = cache.tail;
88            (*cache.tail).prev = cache.head;
89        };
90
91        cache
92    }
93
94    pub fn scope<'cache, F, R>(&'cache mut self, fun: F) -> R
95    where
96        for<'brand> F: FnOnce(CacheHandle<'cache, 'brand, K, V>, ValuePerm<'brand>) -> R,
97    {
98        let handle = CacheHandle {
99            _lifetime: Default::default(),
100            cache: self.into(),
101        };
102        let perm = ValuePerm {
103            _lifetime: InvariantLifetime::default(),
104        };
105        fun(handle, perm)
106    }
107
108    fn len(&self) -> usize {
109        self.map.len()
110    }
111
112    fn cap(&self) -> NonZeroUsize {
113        self.cap
114    }
115
116    fn detach(&mut self, node: *mut LruEntry<K, V>) {
117        unsafe {
118            (*(*node).prev).next = (*node).next;
119            (*(*node).next).prev = (*node).prev;
120        }
121    }
122
123    // Attaches `node` after the sigil `self.head` node.
124    fn attach(&mut self, node: *mut LruEntry<K, V>) {
125        unsafe {
126            (*node).next = (*self.head).next;
127            (*node).prev = self.head;
128            (*self.head).next = node;
129            (*(*node).next).prev = node;
130        }
131    }
132
133    fn replace_or_create_node(&mut self, k: K, v: V) -> (Option<(K, V)>, NonNull<LruEntry<K, V>>) {
134        if self.len() == self.cap().get() {
135            // if the cache is full, remove the last entry so we can use it for the new key
136            let old_key = KeyRef {
137                k: unsafe { &(*(*(*self.tail).prev).key.as_ptr()) },
138            };
139            let old_node = self.map.remove(&old_key).unwrap();
140            let node_ptr: *mut LruEntry<K, V> = old_node.as_ptr();
141
142            // read out the node's old key and value and then replace it
143            let replaced = unsafe {
144                (
145                    mem::replace(&mut (*node_ptr).key, mem::MaybeUninit::new(k)).assume_init(),
146                    mem::replace(&mut (*node_ptr).val, mem::MaybeUninit::new(v)).assume_init(),
147                )
148            };
149
150            self.detach(node_ptr);
151
152            (Some(replaced), old_node)
153        } else {
154            // if the cache is not full allocate a new LruEntry
155            // Safety: We allocate, turn into raw, and get NonNull all in one step.
156            (None, unsafe {
157                NonNull::new_unchecked(Box::into_raw(Box::new(LruEntry::new(k, v))))
158            })
159        }
160    }
161
162    pub fn put(&mut self, k: K, v: V) -> Option<V> {
163        self.scope(|mut cache, mut perm| cache.put(k, v, &mut perm))
164    }
165
166    pub fn get<'cache>(&'cache mut self, k: &K) -> Option<&'cache V> {
167        self.scope(|mut cache, perm| unsafe {
168            std::mem::transmute::<_, Option<&'cache V>>(cache.get(k, &perm))
169        })
170    }
171
172    pub fn peek_mut<'cache>(&'cache mut self, k: &K) -> Option<&'cache mut V> {
173        self.scope(|cache, mut perm| unsafe {
174            std::mem::transmute::<_, Option<&'cache mut V>>(cache.peek_mut(k, &mut perm))
175        })
176    }
177}
178
179impl<K, V> Drop for LruCache<K, V> {
180    fn drop(&mut self) {
181        self.map.drain().for_each(|(_, node)| unsafe {
182            let mut node = *Box::from_raw(node.as_ptr());
183            ptr::drop_in_place((node).key.as_mut_ptr());
184            ptr::drop_in_place((node).val.as_mut_ptr());
185        });
186        // We rebox the head/tail, and because these are maybe-uninit
187        // they do not have the absent k/v dropped.
188
189        let _head = unsafe { *Box::from_raw(self.head) };
190        let _tail = unsafe { *Box::from_raw(self.tail) };
191    }
192}
193
194impl<'cache, 'brand, K: Hash + Eq, V> CacheHandle<'cache, 'brand, K, V> {
195    pub fn len<'handle, 'perm>(&'handle self) -> usize {
196        self.cache.len()
197    }
198
199    pub fn is_empty<'sperm>(&self) -> bool {
200        self.len() == 0
201    }
202
203    pub fn cap<'sperm>(&self) -> NonZeroUsize {
204        self.cache.cap()
205    }
206
207    pub fn put<'handle, 'perm>(
208        &'handle mut self,
209        k: K,
210        mut v: V,
211        _perm: &'perm mut ValuePerm<'brand>,
212    ) -> Option<V> {
213        let cache = &mut self.cache;
214        let node_ref = cache.map.get_mut(&KeyRef { k: &k });
215
216        match node_ref {
217            Some(node_ref) => {
218                // if the key is already in the cache just update its value and move it to the
219                // front of the list
220                let node_ptr: *mut LruEntry<K, V> = node_ref.as_ptr();
221                let node_ref = unsafe { &mut (*(*node_ptr).val.as_mut_ptr()) };
222                mem::swap(&mut v, node_ref);
223                let _ = node_ref;
224                cache.detach(node_ptr);
225                cache.attach(node_ptr);
226                Some(v)
227            }
228            None => {
229                let (replaced, node) = cache.replace_or_create_node(k, v);
230                let node_ptr: *mut LruEntry<K, V> = node.as_ptr();
231
232                cache.attach(node_ptr);
233
234                let keyref = unsafe { (*node_ptr).key.as_ptr() };
235                cache.map.insert(KeyRef { k: keyref }, node);
236
237                replaced.map(|(_k, v)| v)
238            }
239        }
240    }
241
242    pub fn get<'handle, 'perm>(
243        &mut self,
244        k: &K,
245        _perm: &'perm ValuePerm<'brand>,
246    ) -> Option<&'perm V> {
247        let cache = &mut self.cache;
248        if let Some(node) = cache.map.get_mut(&KeyRef { k }) {
249            let node_ptr: *mut LruEntry<K, V> = node.as_ptr();
250
251            cache.detach(node_ptr);
252            cache.attach(node_ptr);
253
254            Some(unsafe { &*(*node_ptr).val.as_ptr() })
255        } else {
256            None
257        }
258    }
259
260    // get the mutable reference of an entry, but not adjust its position.
261    pub fn peek_mut<'handle, 'key, 'perm>(
262        &'handle self,
263        k: &'key K,
264        _perm: &'perm mut ValuePerm<'brand>,
265    ) -> Option<&'perm mut V> {
266        let cache = &self.cache;
267        match cache.map.get(&KeyRef { k }) {
268            None => None,
269            Some(node) => Some(unsafe { &mut *(*node.as_ptr()).val.as_mut_ptr() }),
270        }
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use std::fmt::Debug;
277
278    use super::*;
279
280    fn assert_opt_eq<V: PartialEq + Debug>(opt: Option<&V>, v: V) {
281        assert!(opt.is_some());
282        assert_eq!(opt.unwrap(), &v);
283    }
284
285    fn assert_opt_eq_mut<V: PartialEq + Debug>(opt: Option<&mut V>, v: V) {
286        assert!(opt.is_some());
287        assert_eq!(opt.unwrap(), &v);
288    }
289
290    #[test]
291    fn test_put_and_get() {
292        let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap());
293        cache.scope(|mut cache, mut perm| {
294            assert_eq!(cache.put("apple", "red", &mut perm), None);
295            assert_eq!(cache.put("banana", "yellow", &mut perm), None);
296
297            assert_eq!(cache.cap().get(), 2);
298            assert_eq!(cache.len(), 2);
299            assert!(!cache.is_empty());
300            assert_opt_eq(cache.get(&"apple", &perm), "red");
301            assert_opt_eq(cache.get(&"banana", &perm), "yellow");
302        });
303    }
304
305    #[test]
306    fn test_multi_get() {
307        let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap());
308
309        cache.scope(|mut cache, mut perm| {
310            assert_eq!(cache.put("apple", "red", &mut perm), None);
311            assert_eq!(cache.put("banana", "yellow", &mut perm), None);
312            assert_eq!(cache.put("lemon", "yellow", &mut perm), Some("red"));
313
314            let colors: Vec<_> = ["apple", "banana", "lemon", "watermelon"]
315                .iter()
316                .map(|k| cache.get(k, &perm))
317                .collect();
318            assert!(colors[0].is_none());
319            assert_opt_eq(colors[1], "yellow");
320            assert_opt_eq(colors[2], "yellow");
321            assert!(colors[3].is_none());
322        });
323    }
324
325    #[test]
326    fn test_peek_mut() {
327        let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap());
328
329        cache.scope(|mut cache, mut perm| {
330            cache.put("apple", "red", &mut perm);
331            cache.put("banana", "yellow", &mut perm);
332
333            assert_opt_eq_mut(cache.peek_mut(&"banana", &mut perm), "yellow");
334            assert_opt_eq_mut(cache.peek_mut(&"apple", &mut perm), "red");
335            assert!(cache.peek_mut(&"pear", &mut perm).is_none());
336
337            cache.put("pear", "green", &mut perm);
338
339            assert!(cache.peek_mut(&"apple", &mut perm).is_none());
340            assert_opt_eq_mut(cache.peek_mut(&"banana", &mut perm), "yellow");
341            assert_opt_eq_mut(cache.peek_mut(&"pear", &mut perm), "green");
342
343            {
344                let v = cache.peek_mut(&"banana", &mut perm).unwrap();
345                *v = "green";
346            }
347
348            assert_opt_eq_mut(cache.peek_mut(&"banana", &mut perm), "green");
349        });
350    }
351}