1use std::collections::HashMap;
2use std::hash::{Hash, Hasher};
3use std::marker::PhantomData;
4use std::mem;
5use std::num::NonZeroUsize;
6use std::ptr;
7use std::ptr::NonNull;
8
9type InvariantLifetime<'brand> = PhantomData<fn(&'brand ()) -> &'brand ()>;
10
11pub struct CacheHandle<'cache, 'brand, K, V> {
12 _lifetime: InvariantLifetime<'brand>,
13 cache: &'cache mut LruCache<K, V>,
14}
15
16pub struct ValuePerm<'brand> {
17 _lifetime: InvariantLifetime<'brand>,
18}
19
20struct KeyRef<K> {
22 k: *const K,
23}
24
25impl<K: Hash> Hash for KeyRef<K> {
26 fn hash<H: Hasher>(&self, state: &mut H) {
27 unsafe { (*self.k).hash(state) }
28 }
29}
30
31impl<K: PartialEq> PartialEq for KeyRef<K> {
32 fn eq(&self, other: &KeyRef<K>) -> bool {
33 unsafe { (*self.k).eq(&*other.k) }
34 }
35}
36
37impl<K: Eq> Eq for KeyRef<K> {}
38
39struct LruEntry<K, V> {
42 key: mem::MaybeUninit<K>,
43 val: mem::MaybeUninit<V>,
44 prev: *mut LruEntry<K, V>,
45 next: *mut LruEntry<K, V>,
46}
47
48impl<K, V> LruEntry<K, V> {
49 fn new(key: K, val: V) -> Self {
50 LruEntry {
51 key: mem::MaybeUninit::new(key),
52 val: mem::MaybeUninit::new(val),
53 prev: ptr::null_mut(),
54 next: ptr::null_mut(),
55 }
56 }
57
58 fn new_sigil() -> Self {
59 LruEntry {
60 key: mem::MaybeUninit::uninit(),
61 val: mem::MaybeUninit::uninit(),
62 prev: ptr::null_mut(),
63 next: ptr::null_mut(),
64 }
65 }
66}
67
68pub struct LruCache<K, V> {
69 map: HashMap<KeyRef<K>, NonNull<LruEntry<K, V>>>,
70 cap: NonZeroUsize,
71
72 head: *mut LruEntry<K, V>,
74 tail: *mut LruEntry<K, V>,
75}
76
77impl<K: Eq + Hash, V> LruCache<K, V> {
78 pub fn new(cap: NonZeroUsize) -> Self {
79 let cache = LruCache::<K, V> {
80 map: HashMap::with_capacity(cap.get()),
81 cap,
82 head: Box::into_raw(Box::new(LruEntry::new_sigil())),
83 tail: Box::into_raw(Box::new(LruEntry::new_sigil())),
84 };
85
86 unsafe {
87 (*cache.head).next = cache.tail;
88 (*cache.tail).prev = cache.head;
89 };
90
91 cache
92 }
93
94 pub fn scope<'cache, F, R>(&'cache mut self, fun: F) -> R
95 where
96 for<'brand> F: FnOnce(CacheHandle<'cache, 'brand, K, V>, ValuePerm<'brand>) -> R,
97 {
98 let handle = CacheHandle {
99 _lifetime: Default::default(),
100 cache: self.into(),
101 };
102 let perm = ValuePerm {
103 _lifetime: InvariantLifetime::default(),
104 };
105 fun(handle, perm)
106 }
107
108 fn len(&self) -> usize {
109 self.map.len()
110 }
111
112 fn cap(&self) -> NonZeroUsize {
113 self.cap
114 }
115
116 fn detach(&mut self, node: *mut LruEntry<K, V>) {
117 unsafe {
118 (*(*node).prev).next = (*node).next;
119 (*(*node).next).prev = (*node).prev;
120 }
121 }
122
123 fn attach(&mut self, node: *mut LruEntry<K, V>) {
125 unsafe {
126 (*node).next = (*self.head).next;
127 (*node).prev = self.head;
128 (*self.head).next = node;
129 (*(*node).next).prev = node;
130 }
131 }
132
133 fn replace_or_create_node(&mut self, k: K, v: V) -> (Option<(K, V)>, NonNull<LruEntry<K, V>>) {
134 if self.len() == self.cap().get() {
135 let old_key = KeyRef {
137 k: unsafe { &(*(*(*self.tail).prev).key.as_ptr()) },
138 };
139 let old_node = self.map.remove(&old_key).unwrap();
140 let node_ptr: *mut LruEntry<K, V> = old_node.as_ptr();
141
142 let replaced = unsafe {
144 (
145 mem::replace(&mut (*node_ptr).key, mem::MaybeUninit::new(k)).assume_init(),
146 mem::replace(&mut (*node_ptr).val, mem::MaybeUninit::new(v)).assume_init(),
147 )
148 };
149
150 self.detach(node_ptr);
151
152 (Some(replaced), old_node)
153 } else {
154 (None, unsafe {
157 NonNull::new_unchecked(Box::into_raw(Box::new(LruEntry::new(k, v))))
158 })
159 }
160 }
161
162 pub fn put(&mut self, k: K, v: V) -> Option<V> {
163 self.scope(|mut cache, mut perm| cache.put(k, v, &mut perm))
164 }
165
166 pub fn get<'cache>(&'cache mut self, k: &K) -> Option<&'cache V> {
167 self.scope(|mut cache, perm| unsafe {
168 std::mem::transmute::<_, Option<&'cache V>>(cache.get(k, &perm))
169 })
170 }
171
172 pub fn peek_mut<'cache>(&'cache mut self, k: &K) -> Option<&'cache mut V> {
173 self.scope(|cache, mut perm| unsafe {
174 std::mem::transmute::<_, Option<&'cache mut V>>(cache.peek_mut(k, &mut perm))
175 })
176 }
177}
178
179impl<K, V> Drop for LruCache<K, V> {
180 fn drop(&mut self) {
181 self.map.drain().for_each(|(_, node)| unsafe {
182 let mut node = *Box::from_raw(node.as_ptr());
183 ptr::drop_in_place((node).key.as_mut_ptr());
184 ptr::drop_in_place((node).val.as_mut_ptr());
185 });
186 let _head = unsafe { *Box::from_raw(self.head) };
190 let _tail = unsafe { *Box::from_raw(self.tail) };
191 }
192}
193
194impl<'cache, 'brand, K: Hash + Eq, V> CacheHandle<'cache, 'brand, K, V> {
195 pub fn len<'handle, 'perm>(&'handle self) -> usize {
196 self.cache.len()
197 }
198
199 pub fn is_empty<'sperm>(&self) -> bool {
200 self.len() == 0
201 }
202
203 pub fn cap<'sperm>(&self) -> NonZeroUsize {
204 self.cache.cap()
205 }
206
207 pub fn put<'handle, 'perm>(
208 &'handle mut self,
209 k: K,
210 mut v: V,
211 _perm: &'perm mut ValuePerm<'brand>,
212 ) -> Option<V> {
213 let cache = &mut self.cache;
214 let node_ref = cache.map.get_mut(&KeyRef { k: &k });
215
216 match node_ref {
217 Some(node_ref) => {
218 let node_ptr: *mut LruEntry<K, V> = node_ref.as_ptr();
221 let node_ref = unsafe { &mut (*(*node_ptr).val.as_mut_ptr()) };
222 mem::swap(&mut v, node_ref);
223 let _ = node_ref;
224 cache.detach(node_ptr);
225 cache.attach(node_ptr);
226 Some(v)
227 }
228 None => {
229 let (replaced, node) = cache.replace_or_create_node(k, v);
230 let node_ptr: *mut LruEntry<K, V> = node.as_ptr();
231
232 cache.attach(node_ptr);
233
234 let keyref = unsafe { (*node_ptr).key.as_ptr() };
235 cache.map.insert(KeyRef { k: keyref }, node);
236
237 replaced.map(|(_k, v)| v)
238 }
239 }
240 }
241
242 pub fn get<'handle, 'perm>(
243 &mut self,
244 k: &K,
245 _perm: &'perm ValuePerm<'brand>,
246 ) -> Option<&'perm V> {
247 let cache = &mut self.cache;
248 if let Some(node) = cache.map.get_mut(&KeyRef { k }) {
249 let node_ptr: *mut LruEntry<K, V> = node.as_ptr();
250
251 cache.detach(node_ptr);
252 cache.attach(node_ptr);
253
254 Some(unsafe { &*(*node_ptr).val.as_ptr() })
255 } else {
256 None
257 }
258 }
259
260 pub fn peek_mut<'handle, 'key, 'perm>(
262 &'handle self,
263 k: &'key K,
264 _perm: &'perm mut ValuePerm<'brand>,
265 ) -> Option<&'perm mut V> {
266 let cache = &self.cache;
267 match cache.map.get(&KeyRef { k }) {
268 None => None,
269 Some(node) => Some(unsafe { &mut *(*node.as_ptr()).val.as_mut_ptr() }),
270 }
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use std::fmt::Debug;
277
278 use super::*;
279
280 fn assert_opt_eq<V: PartialEq + Debug>(opt: Option<&V>, v: V) {
281 assert!(opt.is_some());
282 assert_eq!(opt.unwrap(), &v);
283 }
284
285 fn assert_opt_eq_mut<V: PartialEq + Debug>(opt: Option<&mut V>, v: V) {
286 assert!(opt.is_some());
287 assert_eq!(opt.unwrap(), &v);
288 }
289
290 #[test]
291 fn test_put_and_get() {
292 let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap());
293 cache.scope(|mut cache, mut perm| {
294 assert_eq!(cache.put("apple", "red", &mut perm), None);
295 assert_eq!(cache.put("banana", "yellow", &mut perm), None);
296
297 assert_eq!(cache.cap().get(), 2);
298 assert_eq!(cache.len(), 2);
299 assert!(!cache.is_empty());
300 assert_opt_eq(cache.get(&"apple", &perm), "red");
301 assert_opt_eq(cache.get(&"banana", &perm), "yellow");
302 });
303 }
304
305 #[test]
306 fn test_multi_get() {
307 let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap());
308
309 cache.scope(|mut cache, mut perm| {
310 assert_eq!(cache.put("apple", "red", &mut perm), None);
311 assert_eq!(cache.put("banana", "yellow", &mut perm), None);
312 assert_eq!(cache.put("lemon", "yellow", &mut perm), Some("red"));
313
314 let colors: Vec<_> = ["apple", "banana", "lemon", "watermelon"]
315 .iter()
316 .map(|k| cache.get(k, &perm))
317 .collect();
318 assert!(colors[0].is_none());
319 assert_opt_eq(colors[1], "yellow");
320 assert_opt_eq(colors[2], "yellow");
321 assert!(colors[3].is_none());
322 });
323 }
324
325 #[test]
326 fn test_peek_mut() {
327 let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap());
328
329 cache.scope(|mut cache, mut perm| {
330 cache.put("apple", "red", &mut perm);
331 cache.put("banana", "yellow", &mut perm);
332
333 assert_opt_eq_mut(cache.peek_mut(&"banana", &mut perm), "yellow");
334 assert_opt_eq_mut(cache.peek_mut(&"apple", &mut perm), "red");
335 assert!(cache.peek_mut(&"pear", &mut perm).is_none());
336
337 cache.put("pear", "green", &mut perm);
338
339 assert!(cache.peek_mut(&"apple", &mut perm).is_none());
340 assert_opt_eq_mut(cache.peek_mut(&"banana", &mut perm), "yellow");
341 assert_opt_eq_mut(cache.peek_mut(&"pear", &mut perm), "green");
342
343 {
344 let v = cache.peek_mut(&"banana", &mut perm).unwrap();
345 *v = "green";
346 }
347
348 assert_opt_eq_mut(cache.peek_mut(&"banana", &mut perm), "green");
349 });
350 }
351}