quick_hash_cache/lru/
mod.rs

1use std::borrow::Borrow;
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::sync::{
4    atomic::{AtomicU64, AtomicUsize, Ordering},
5    Arc,
6};
7
8use tokio::sync::{OwnedRwLockWriteGuard, RwLock};
9
10use hashbrown::hash_map::DefaultHashBuilder;
11
12use rand::Rng;
13
14use crate::{Erased, ReadHandle, WriteHandle};
15
16mod shard;
17
18use shard::IndexedShard;
19
20pub trait AtomicTimestamp {
21    /// Create a new timestamp at the given time
22    fn now() -> Self;
23    /// Update the timestamp to `now` in-place
24    fn update(&self);
25    fn is_before(&self, other: &Self) -> bool;
26}
27
28#[derive(Debug)]
29pub struct AtomicInstant(AtomicU64);
30
31impl AtomicTimestamp for AtomicInstant {
32    #[inline]
33    fn now() -> Self {
34        AtomicInstant(AtomicU64::new(quanta::Instant::now().as_u64()))
35    }
36
37    #[inline]
38    fn update(&self) {
39        self.0.store(quanta::Instant::now().as_u64(), Ordering::SeqCst);
40    }
41
42    #[inline]
43    fn is_before(&self, other: &Self) -> bool {
44        self.0.load(Ordering::SeqCst) < other.0.load(Ordering::SeqCst)
45    }
46}
47
48#[derive(Debug)]
49struct TimestampedValue<V, T> {
50    value: V,
51    timestamp: T,
52}
53
54impl<V, T> Clone for TimestampedValue<V, T>
55where
56    V: Clone,
57    T: AtomicTimestamp,
58{
59    fn clone(&self) -> Self {
60        TimestampedValue {
61            value: self.value.clone(),
62            timestamp: T::now(),
63        }
64    }
65}
66
67type Shard<K, T> = Arc<RwLock<IndexedShard<K, T>>>;
68
69#[derive(Debug)]
70pub struct LruCache<K, V, T = AtomicInstant, S = DefaultHashBuilder> {
71    hash_builder: S,
72    shards: Vec<(Shard<K, TimestampedValue<V, T>>, AtomicUsize)>,
73    size: AtomicUsize,
74}
75
76impl<K, V, T> LruCache<K, V, T, DefaultHashBuilder> {
77    pub fn new(num_shards: usize) -> Self {
78        Self::with_hasher(num_shards, DefaultHashBuilder::default())
79    }
80}
81
82impl<K, V> Default for LruCache<K, V, AtomicInstant, DefaultHashBuilder> {
83    fn default() -> Self {
84        Self::new(num_cpus::get())
85    }
86}
87
88impl<K, V, T, S> LruCache<K, V, T, S> {
89    pub fn with_hasher(num_shards: usize, hash_builder: S) -> Self {
90        LruCache {
91            shards: (0..num_shards)
92                .into_iter()
93                .map(|_| (Arc::new(RwLock::new(IndexedShard::new())), AtomicUsize::new(0)))
94                .collect(),
95            hash_builder,
96            size: AtomicUsize::new(0),
97        }
98    }
99}
100
101impl<K, V, T, S> LruCache<K, V, T, S>
102where
103    S: Clone,
104    K: Clone,
105    V: Clone,
106    T: AtomicTimestamp,
107{
108    /// Attempts to duplicate/clone the LruCache. An LruCache cannot be cloned regularly due to internal asynchronous locking.
109    pub async fn duplicate(&self) -> Self {
110        let mut shards = Vec::with_capacity(self.shards.len());
111        let mut size = 0;
112
113        for shard in &self.shards {
114            let shard = shard.0.read().await.clone();
115
116            let shard_len = shard.len();
117            size += shard_len;
118            shards.push((Arc::new(RwLock::new(shard)), AtomicUsize::new(shard_len)));
119        }
120
121        LruCache {
122            shards,
123            hash_builder: self.hash_builder.clone(),
124            size: AtomicUsize::new(size),
125        }
126    }
127}
128
129impl<K, V, T, S> LruCache<K, V, T, S>
130where
131    K: Hash + Eq,
132    S: BuildHasher,
133    T: AtomicTimestamp,
134{
135    #[inline]
136    pub fn size(&self) -> usize {
137        self.size.load(Ordering::SeqCst)
138    }
139
140    #[cfg(test)]
141    pub async fn test_size(&self) -> usize {
142        let mut size = 0;
143        for shard in &self.shards {
144            size += shard.0.read().await.len();
145        }
146
147        size
148    }
149
150    #[inline]
151    pub fn hash_builder(&self) -> &S {
152        &self.hash_builder
153    }
154
155    #[inline]
156    pub fn num_shards(&self) -> usize {
157        self.shards.len()
158    }
159
160    pub async fn retain<F>(&self, f: F)
161    where
162        F: Fn(&K, &mut V) -> bool,
163    {
164        for (shard, _) in &self.shards {
165            let mut shard = shard.write().await;
166
167            let len = shard.len();
168            shard.retain(|k, tv| f(k, &mut tv.value));
169
170            self.size.fetch_sub(len - shard.len(), Ordering::SeqCst);
171        }
172    }
173
174    pub async fn clear(&self) {
175        for (shard, _) in &self.shards {
176            let mut shard = shard.write().await;
177            let len = shard.len();
178            shard.clear();
179
180            self.size.fetch_sub(len, Ordering::SeqCst);
181        }
182    }
183
184    #[inline]
185    fn hash_and_shard<Q: ?Sized>(&self, key: &Q) -> (u64, usize)
186    where
187        Q: Hash + Eq,
188    {
189        let mut hasher = self.hash_builder.build_hasher();
190        key.hash(&mut hasher);
191        let hash = hasher.finish();
192        (hash, hash as usize % self.shards.len())
193    }
194
195    async fn get_mut_raw<Q: ?Sized>(
196        &self,
197        key: &Q,
198    ) -> Option<WriteHandle<impl Erased, TimestampedValue<V, T>>>
199    where
200        K: Borrow<Q>,
201        Q: Hash + Eq,
202    {
203        let (hash, shard_idx) = self.hash_and_shard(key);
204        let shard = unsafe { self.shards.get_unchecked(shard_idx).0.clone().write_owned().await };
205
206        OwnedRwLockWriteGuard::try_map(shard, |shard| shard.get_mut(hash, key)).ok()
207    }
208
209    async fn get_raw<Q: ?Sized>(&self, key: &Q) -> Option<ReadHandle<impl Erased, TimestampedValue<V, T>>>
210    where
211        K: Borrow<Q>,
212        Q: Hash + Eq,
213    {
214        let (hash, shard_idx) = self.hash_and_shard(key);
215        let shard = unsafe { self.shards.get_unchecked(shard_idx).0.clone().read_owned().await };
216
217        ReadHandle::try_map(shard, |shard| shard.get(hash, key)).ok()
218    }
219
220    pub async fn peek<Q: ?Sized>(&self, key: &Q) -> Option<ReadHandle<impl Erased, V>>
221    where
222        K: Borrow<Q>,
223        Q: Hash + Eq,
224    {
225        self.get_raw(key)
226            .await
227            .map(|tv| ReadHandle::map(tv, |tv| &tv.value))
228    }
229
230    pub async fn peek_mut<Q: ?Sized>(&self, key: &Q) -> Option<WriteHandle<impl Erased, V>>
231    where
232        K: Borrow<Q>,
233        Q: Hash + Eq,
234    {
235        self.get_mut_raw(key)
236            .await
237            .map(|tv| WriteHandle::map(tv, |tv| &mut tv.value))
238    }
239
240    pub async fn get<Q: ?Sized>(&self, key: &Q) -> Option<ReadHandle<impl Erased, V>>
241    where
242        K: Borrow<Q>,
243        Q: Hash + Eq,
244    {
245        let tv = self.get_raw(key).await;
246
247        if let Some(ref tv) = tv {
248            tv.timestamp.update();
249        }
250
251        tv.map(|tv| ReadHandle::map(tv, |tv| &tv.value))
252    }
253
254    pub async fn get_mut<Q: ?Sized>(&self, key: &Q) -> Option<WriteHandle<impl Erased, V>>
255    where
256        K: Borrow<Q>,
257        Q: Hash + Eq,
258    {
259        let mut tv = self.get_mut_raw(key).await;
260
261        // owned ref, don't bother with atomic overhead
262        if let Some(ref mut tv) = tv {
263            tv.timestamp = T::now();
264        }
265
266        tv.map(|tv| WriteHandle::map(tv, |tv| &mut tv.value))
267    }
268
269    pub async fn insert(&self, key: K, value: V) -> Option<V> {
270        let (hash, shard_idx) = self.hash_and_shard(&key);
271        let (locked_shard, shard_size) = unsafe { self.shards.get_unchecked(shard_idx) };
272
273        let mut shard = locked_shard.write().await;
274
275        let value = TimestampedValue {
276            value,
277            timestamp: T::now(),
278        };
279
280        shard
281            .insert_full(hash, key, value, || {
282                self.size.fetch_add(1, Ordering::SeqCst);
283                shard_size.fetch_add(1, Ordering::SeqCst);
284            })
285            .1
286            .map(|tv| tv.value)
287    }
288
289    pub async fn remove<Q: ?Sized>(&self, key: &Q) -> Option<V>
290    where
291        K: Borrow<Q>,
292        Q: Hash + Eq,
293    {
294        let (hash, shard_idx) = self.hash_and_shard(&key);
295        let (locked_shard, shard_size) = unsafe { self.shards.get_unchecked(shard_idx) };
296
297        let mut shard = locked_shard.write().await;
298
299        match shard.swap_remove_full(hash, key) {
300            Some((_, tv)) => {
301                self.size.fetch_sub(1, Ordering::SeqCst);
302                // know the real size, so just store it
303                shard_size.store(shard.len(), Ordering::SeqCst);
304
305                Some(tv.value)
306            }
307            None => None,
308        }
309    }
310
311    fn non_empty_shards(&self) -> impl Iterator<Item = &Shard<K, TimestampedValue<V, T>>> {
312        self.shards
313            .iter()
314            .filter_map(|(shard, shard_size)| match shard_size.load(Ordering::SeqCst) {
315                0 => None,
316                _ => Some(shard),
317            })
318    }
319
320    /// Fair element eviction based on 2-random sampling of two shards at once, and performs a random walk through
321    /// all shards as necessary to remain unbiased.
322    ///
323    /// NOTE: This method acquires one write lock per element, and can be inefficient for many evictions.
324    ///
325    /// If you want fair eviction of a handful of items, this is the method to use. For less-predictable bulk-eviction look at `evict_many_fast`
326    pub async fn evict<F>(&self, mut rng: impl Rng, mut predicate: F) -> Vec<(K, V)>
327    where
328        F: FnMut(&K, &mut V) -> Evict,
329    {
330        use rand::seq::SliceRandom;
331
332        /* Algorithm:
333
334            Overall: Evict and collect items until the predicate returns false
335            The predicate will test the oldest of the two selected items at each iteration
336
337            To start with, collect all non-empty shards, then shuffle them.
338
339            Take one of them (pop) and lock it.
340
341            Then, pick another random shard (pop), and begin selecting two random elements from between those,
342            pass the oldest to the predicate, and if the predicate returns true then evict it.
343
344            Swap shard_a and shard_b, then continue. This forms a random-walk of sorts between non-empty shards,
345            where it goes from A->B, B->C, C->D, etc.
346
347            Doing a random walk avoids having to reacquire the locks on each shard each iteration.
348
349            When `non_empty` runs empty, refill it with the same method and shuffle it again
350
351        */
352
353        let mut evicted = Vec::new();
354
355        let mut non_empty = Vec::with_capacity(self.shards.len());
356
357        macro_rules! pop_shard {
358            () => {
359                loop {
360                    match non_empty.pop() {
361                        Some(shard) => {
362                            let shard = shard.write().await;
363                            // once locked, check if the shard is actually non-empty
364                            if shard.len() > 0 {
365                                break Some(shard);
366                            }
367                        }
368                        None => break None,
369                    }
370                }
371            };
372        }
373
374        'evict: while self.size() > 0 {
375            non_empty.extend(self.non_empty_shards());
376            non_empty.shuffle(&mut rng);
377
378            let mut shard_a = match pop_shard!() {
379                Some(shard) => shard,
380                // if we couldn't find an actual non-empty shard, go back to `while size > 0`, and if there is still one, sample it.
381                None => continue 'evict,
382            };
383
384            'walk: loop {
385                match pop_shard!() {
386                    None => {
387                        // single-shard case
388                        let res = match shard_a.len() {
389                            1 => unsafe {
390                                let shard::Bucket {
391                                    ref key,
392                                    ref mut value,
393                                    ..
394                                } = shard_a.entries.get_unchecked_mut(0);
395
396                                let res = predicate(key, &mut value.value);
397
398                                if matches!(res, Evict::Continue | Evict::Once) {
399                                    shard_a.indices.clear();
400                                    let shard::Bucket { key, value, .. } = shard_a.entries.pop().unwrap();
401                                    self.size.fetch_sub(1, Ordering::SeqCst);
402                                    evicted.push((key, value.value));
403                                }
404
405                                res
406                            },
407                            len @ _ => unsafe {
408                                let (elem_a_idx, elem_b_idx) = pick_indices(len, &mut rng);
409
410                                let ts_a = &shard_a.entries.get_unchecked(elem_a_idx).value.timestamp;
411                                let ts_b = &shard_a.entries.get_unchecked(elem_b_idx).value.timestamp;
412                                let idx = if ts_a.is_before(ts_b) {
413                                    elem_a_idx
414                                } else {
415                                    elem_b_idx
416                                };
417
418                                let shard::Bucket {
419                                    ref key,
420                                    ref mut value,
421                                    ..
422                                } = shard_a.entries.get_unchecked_mut(idx);
423
424                                let res = predicate(key, &mut value.value);
425
426                                if matches!(res, Evict::Continue | Evict::Once) {
427                                    let (key, value) = shard_a.swap_remove_index_raw(idx);
428                                    self.size.fetch_sub(1, Ordering::SeqCst);
429                                    evicted.push((key, value.value));
430                                }
431
432                                res
433                            },
434                        };
435
436                        if matches!(res, Evict::Once | Evict::None) {
437                            break 'evict;
438                        }
439
440                        // since pop_shard!() returned None, there is no point in looping again,
441                        // so try to refresh the non_empty shard list
442                        continue 'evict;
443                    }
444                    Some(mut shard_b) => unsafe {
445                        // two-shard case
446
447                        let shard_a_len = shard_a.len();
448                        let shard_b_len = shard_b.len();
449
450                        debug_assert!(shard_a_len > 0);
451                        debug_assert!(shard_b_len > 0);
452
453                        let sample_range = shard_a_len + shard_b_len;
454
455                        let (elem_a_range_idx, elem_b_range_idx) = pick_indices(sample_range, &mut rng);
456
457                        let ts_a = if elem_a_range_idx < shard_a_len {
458                            &shard_a.entries.get_unchecked(elem_a_range_idx).value.timestamp
459                        } else {
460                            &shard_b
461                                .entries
462                                .get_unchecked(elem_a_range_idx - shard_a_len)
463                                .value
464                                .timestamp
465                        };
466
467                        let ts_b = if elem_b_range_idx < shard_a_len {
468                            &shard_a.entries.get_unchecked(elem_b_range_idx).value.timestamp
469                        } else {
470                            &shard_b
471                                .entries
472                                .get_unchecked(elem_b_range_idx - shard_a_len)
473                                .value
474                                .timestamp
475                        };
476
477                        let elem_range_idx = if ts_a.is_before(ts_b) {
478                            elem_a_range_idx
479                        } else {
480                            elem_b_range_idx
481                        };
482
483                        let (shard, idx) = if elem_range_idx < shard_a_len {
484                            (&mut shard_a, elem_range_idx)
485                        } else {
486                            (&mut shard_b, elem_range_idx - shard_a_len)
487                        };
488
489                        let shard::Bucket {
490                            ref key,
491                            ref mut value,
492                            ..
493                        } = shard.entries.get_unchecked_mut(idx);
494
495                        let res = predicate(key, &mut value.value);
496
497                        if matches!(res, Evict::Continue | Evict::Once) {
498                            let (key, value) = shard.swap_remove_index_raw(idx);
499                            self.size.fetch_sub(1, Ordering::SeqCst);
500                            evicted.push((key, value.value));
501                        }
502
503                        if matches!(res, Evict::None | Evict::Once) {
504                            break 'evict;
505                        }
506
507                        shard_a = shard_b; // do random walk A->B, B->C, etc.
508                    },
509                }
510
511                // if the former shard_b was emptied by the eviction, then try to find a new one before continuing
512                if shard_a.len() == 0 {
513                    shard_a = match pop_shard!() {
514                        Some(shard) => shard,
515                        None => break 'walk,
516                    };
517                }
518            }
519        }
520
521        evicted
522    }
523
524    /// Fairly evict many elements, based on 2-random sampling of two shards at once, and performs a random walk through
525    /// all shards as necessary to remain unbiased.
526    ///
527    /// NOTE: This method acquires one write lock per element, and can be inefficient for many evictions.
528    ///
529    /// If you want fair eviction of a handful of items, this is the method to use. For less-predictable bulk-eviction look at `evict_many_fast`
530    pub async fn evict_many(&self, mut count: usize, rng: impl Rng) -> Vec<(K, V)> {
531        count = count.min(self.size());
532
533        if count == 0 {
534            return Vec::new();
535        }
536
537        let mut cur = count;
538
539        self.evict(rng, |_, _| {
540            cur -= 1;
541
542            match cur {
543                0 => Evict::Once,
544                _ => Evict::Continue,
545            }
546        })
547        .await
548    }
549
550    // Fairly evict one element
551    pub async fn evict_one(&self, rng: impl Rng) -> Option<(K, V)> {
552        self.evict(rng, |_, _| Evict::Once).await.pop()
553    }
554
555    /// Less-fair and less-predictable algorithm that only acquires shard locks once at most,
556    /// but may not evict the exact number of requested elements (a couple more or less)
557    ///
558    /// Compare to `evict` or `evict_many` that acquires a shard lock *per-item evicted*,
559    /// but is more fair and unbiased in doing so.
560    pub async fn evict_many_fast(&self, mut count: usize, mut rng: impl Rng) -> Vec<(K, V)> {
561        use rand::prelude::SliceRandom;
562
563        count = count.min(self.size());
564
565        let mut evicted = Vec::new();
566
567        if count == 0 {
568            return evicted;
569        }
570
571        let mut non_empty = Vec::with_capacity(self.shards.len());
572        non_empty.extend(self.non_empty_shards());
573        non_empty.shuffle(&mut rng);
574
575        fn proportion_of(size: usize, len: usize, count: usize) -> usize {
576            // `len / size` is the fraction this shard holds of the entire structure, between 0 and 1
577            // so `count * fraction` is the number of elements to be taken from this shard
578            // reorganize to avoid floating point, at the cost of 128-bit ints
579            ((count as u128 * len as u128) / size as u128) as usize + 1
580        }
581
582        let size = self.size();
583
584        let mut sum = 0;
585        for shard in non_empty {
586            let mut shard = shard.write().await;
587
588            if shard.len() == 0 {
589                continue;
590            }
591
592            let mut sub_count = proportion_of(size, shard.len(), count);
593            sum += sub_count;
594
595            if sum > count {
596                sub_count = sum - count - 1;
597            }
598
599            if sub_count == shard.len() {
600                // fast path for evicting all of this shard
601                evicted.extend(
602                    shard
603                        .entries
604                        .drain(..)
605                        .map(|bucket| (bucket.key, bucket.value.value)),
606                );
607
608                shard.indices.clear();
609                self.size.fetch_sub(sub_count, Ordering::SeqCst); // sub_count == shard.len() here
610            } else {
611                for _ in 0..sub_count {
612                    let (elem_a_idx, elem_b_idx) = pick_indices(shard.len(), &mut rng);
613
614                    unsafe {
615                        let ts_a = &shard.entries.get_unchecked(elem_a_idx).value.timestamp;
616                        let ts_b = &shard.entries.get_unchecked(elem_b_idx).value.timestamp;
617
618                        let idx = if ts_a.is_before(ts_b) {
619                            elem_a_idx
620                        } else {
621                            elem_b_idx
622                        };
623
624                        evicted.push({
625                            let (key, value) = shard.swap_remove_index_raw(idx);
626                            self.size.fetch_sub(1, Ordering::SeqCst);
627                            (key, value.value)
628                        });
629                    }
630                }
631            }
632
633            if sum > count {
634                break;
635            }
636        }
637
638        evicted
639    }
640}
641
642#[derive(Debug, Clone, Copy, PartialEq, Eq)]
643pub enum Evict {
644    /// Continue to evict after this item
645    Continue,
646    /// Evict only this item and then no more
647    Once,
648    /// Do not evict this item nor any more others
649    None,
650}
651
652fn pick_indices(len: usize, mut rng: impl Rng) -> (usize, usize) {
653    match len {
654        0 => panic!("Invalid length"),
655        1 => (0, 0),
656        2 => (0, 1),
657        _ => {
658            let idx_a = rng.gen_range(0..len);
659
660            loop {
661                let idx_b = rng.gen_range(0..len);
662
663                if idx_b != idx_a {
664                    return (idx_a, idx_b);
665                }
666            }
667        }
668    }
669}