1use std::borrow::Borrow;
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::sync::{
4 atomic::{AtomicU64, AtomicUsize, Ordering},
5 Arc,
6};
7
8use tokio::sync::{OwnedRwLockWriteGuard, RwLock};
9
10use hashbrown::hash_map::DefaultHashBuilder;
11
12use rand::Rng;
13
14use crate::{Erased, ReadHandle, WriteHandle};
15
16mod shard;
17
18use shard::IndexedShard;
19
20pub trait AtomicTimestamp {
21 fn now() -> Self;
23 fn update(&self);
25 fn is_before(&self, other: &Self) -> bool;
26}
27
28#[derive(Debug)]
29pub struct AtomicInstant(AtomicU64);
30
31impl AtomicTimestamp for AtomicInstant {
32 #[inline]
33 fn now() -> Self {
34 AtomicInstant(AtomicU64::new(quanta::Instant::now().as_u64()))
35 }
36
37 #[inline]
38 fn update(&self) {
39 self.0.store(quanta::Instant::now().as_u64(), Ordering::SeqCst);
40 }
41
42 #[inline]
43 fn is_before(&self, other: &Self) -> bool {
44 self.0.load(Ordering::SeqCst) < other.0.load(Ordering::SeqCst)
45 }
46}
47
48#[derive(Debug)]
49struct TimestampedValue<V, T> {
50 value: V,
51 timestamp: T,
52}
53
54impl<V, T> Clone for TimestampedValue<V, T>
55where
56 V: Clone,
57 T: AtomicTimestamp,
58{
59 fn clone(&self) -> Self {
60 TimestampedValue {
61 value: self.value.clone(),
62 timestamp: T::now(),
63 }
64 }
65}
66
67type Shard<K, T> = Arc<RwLock<IndexedShard<K, T>>>;
68
69#[derive(Debug)]
70pub struct LruCache<K, V, T = AtomicInstant, S = DefaultHashBuilder> {
71 hash_builder: S,
72 shards: Vec<(Shard<K, TimestampedValue<V, T>>, AtomicUsize)>,
73 size: AtomicUsize,
74}
75
76impl<K, V, T> LruCache<K, V, T, DefaultHashBuilder> {
77 pub fn new(num_shards: usize) -> Self {
78 Self::with_hasher(num_shards, DefaultHashBuilder::default())
79 }
80}
81
82impl<K, V> Default for LruCache<K, V, AtomicInstant, DefaultHashBuilder> {
83 fn default() -> Self {
84 Self::new(num_cpus::get())
85 }
86}
87
88impl<K, V, T, S> LruCache<K, V, T, S> {
89 pub fn with_hasher(num_shards: usize, hash_builder: S) -> Self {
90 LruCache {
91 shards: (0..num_shards)
92 .into_iter()
93 .map(|_| (Arc::new(RwLock::new(IndexedShard::new())), AtomicUsize::new(0)))
94 .collect(),
95 hash_builder,
96 size: AtomicUsize::new(0),
97 }
98 }
99}
100
101impl<K, V, T, S> LruCache<K, V, T, S>
102where
103 S: Clone,
104 K: Clone,
105 V: Clone,
106 T: AtomicTimestamp,
107{
108 pub async fn duplicate(&self) -> Self {
110 let mut shards = Vec::with_capacity(self.shards.len());
111 let mut size = 0;
112
113 for shard in &self.shards {
114 let shard = shard.0.read().await.clone();
115
116 let shard_len = shard.len();
117 size += shard_len;
118 shards.push((Arc::new(RwLock::new(shard)), AtomicUsize::new(shard_len)));
119 }
120
121 LruCache {
122 shards,
123 hash_builder: self.hash_builder.clone(),
124 size: AtomicUsize::new(size),
125 }
126 }
127}
128
129impl<K, V, T, S> LruCache<K, V, T, S>
130where
131 K: Hash + Eq,
132 S: BuildHasher,
133 T: AtomicTimestamp,
134{
135 #[inline]
136 pub fn size(&self) -> usize {
137 self.size.load(Ordering::SeqCst)
138 }
139
140 #[cfg(test)]
141 pub async fn test_size(&self) -> usize {
142 let mut size = 0;
143 for shard in &self.shards {
144 size += shard.0.read().await.len();
145 }
146
147 size
148 }
149
150 #[inline]
151 pub fn hash_builder(&self) -> &S {
152 &self.hash_builder
153 }
154
155 #[inline]
156 pub fn num_shards(&self) -> usize {
157 self.shards.len()
158 }
159
160 pub async fn retain<F>(&self, f: F)
161 where
162 F: Fn(&K, &mut V) -> bool,
163 {
164 for (shard, _) in &self.shards {
165 let mut shard = shard.write().await;
166
167 let len = shard.len();
168 shard.retain(|k, tv| f(k, &mut tv.value));
169
170 self.size.fetch_sub(len - shard.len(), Ordering::SeqCst);
171 }
172 }
173
174 pub async fn clear(&self) {
175 for (shard, _) in &self.shards {
176 let mut shard = shard.write().await;
177 let len = shard.len();
178 shard.clear();
179
180 self.size.fetch_sub(len, Ordering::SeqCst);
181 }
182 }
183
184 #[inline]
185 fn hash_and_shard<Q: ?Sized>(&self, key: &Q) -> (u64, usize)
186 where
187 Q: Hash + Eq,
188 {
189 let mut hasher = self.hash_builder.build_hasher();
190 key.hash(&mut hasher);
191 let hash = hasher.finish();
192 (hash, hash as usize % self.shards.len())
193 }
194
195 async fn get_mut_raw<Q: ?Sized>(
196 &self,
197 key: &Q,
198 ) -> Option<WriteHandle<impl Erased, TimestampedValue<V, T>>>
199 where
200 K: Borrow<Q>,
201 Q: Hash + Eq,
202 {
203 let (hash, shard_idx) = self.hash_and_shard(key);
204 let shard = unsafe { self.shards.get_unchecked(shard_idx).0.clone().write_owned().await };
205
206 OwnedRwLockWriteGuard::try_map(shard, |shard| shard.get_mut(hash, key)).ok()
207 }
208
209 async fn get_raw<Q: ?Sized>(&self, key: &Q) -> Option<ReadHandle<impl Erased, TimestampedValue<V, T>>>
210 where
211 K: Borrow<Q>,
212 Q: Hash + Eq,
213 {
214 let (hash, shard_idx) = self.hash_and_shard(key);
215 let shard = unsafe { self.shards.get_unchecked(shard_idx).0.clone().read_owned().await };
216
217 ReadHandle::try_map(shard, |shard| shard.get(hash, key)).ok()
218 }
219
220 pub async fn peek<Q: ?Sized>(&self, key: &Q) -> Option<ReadHandle<impl Erased, V>>
221 where
222 K: Borrow<Q>,
223 Q: Hash + Eq,
224 {
225 self.get_raw(key)
226 .await
227 .map(|tv| ReadHandle::map(tv, |tv| &tv.value))
228 }
229
230 pub async fn peek_mut<Q: ?Sized>(&self, key: &Q) -> Option<WriteHandle<impl Erased, V>>
231 where
232 K: Borrow<Q>,
233 Q: Hash + Eq,
234 {
235 self.get_mut_raw(key)
236 .await
237 .map(|tv| WriteHandle::map(tv, |tv| &mut tv.value))
238 }
239
240 pub async fn get<Q: ?Sized>(&self, key: &Q) -> Option<ReadHandle<impl Erased, V>>
241 where
242 K: Borrow<Q>,
243 Q: Hash + Eq,
244 {
245 let tv = self.get_raw(key).await;
246
247 if let Some(ref tv) = tv {
248 tv.timestamp.update();
249 }
250
251 tv.map(|tv| ReadHandle::map(tv, |tv| &tv.value))
252 }
253
254 pub async fn get_mut<Q: ?Sized>(&self, key: &Q) -> Option<WriteHandle<impl Erased, V>>
255 where
256 K: Borrow<Q>,
257 Q: Hash + Eq,
258 {
259 let mut tv = self.get_mut_raw(key).await;
260
261 if let Some(ref mut tv) = tv {
263 tv.timestamp = T::now();
264 }
265
266 tv.map(|tv| WriteHandle::map(tv, |tv| &mut tv.value))
267 }
268
269 pub async fn insert(&self, key: K, value: V) -> Option<V> {
270 let (hash, shard_idx) = self.hash_and_shard(&key);
271 let (locked_shard, shard_size) = unsafe { self.shards.get_unchecked(shard_idx) };
272
273 let mut shard = locked_shard.write().await;
274
275 let value = TimestampedValue {
276 value,
277 timestamp: T::now(),
278 };
279
280 shard
281 .insert_full(hash, key, value, || {
282 self.size.fetch_add(1, Ordering::SeqCst);
283 shard_size.fetch_add(1, Ordering::SeqCst);
284 })
285 .1
286 .map(|tv| tv.value)
287 }
288
289 pub async fn remove<Q: ?Sized>(&self, key: &Q) -> Option<V>
290 where
291 K: Borrow<Q>,
292 Q: Hash + Eq,
293 {
294 let (hash, shard_idx) = self.hash_and_shard(&key);
295 let (locked_shard, shard_size) = unsafe { self.shards.get_unchecked(shard_idx) };
296
297 let mut shard = locked_shard.write().await;
298
299 match shard.swap_remove_full(hash, key) {
300 Some((_, tv)) => {
301 self.size.fetch_sub(1, Ordering::SeqCst);
302 shard_size.store(shard.len(), Ordering::SeqCst);
304
305 Some(tv.value)
306 }
307 None => None,
308 }
309 }
310
311 fn non_empty_shards(&self) -> impl Iterator<Item = &Shard<K, TimestampedValue<V, T>>> {
312 self.shards
313 .iter()
314 .filter_map(|(shard, shard_size)| match shard_size.load(Ordering::SeqCst) {
315 0 => None,
316 _ => Some(shard),
317 })
318 }
319
320 pub async fn evict<F>(&self, mut rng: impl Rng, mut predicate: F) -> Vec<(K, V)>
327 where
328 F: FnMut(&K, &mut V) -> Evict,
329 {
330 use rand::seq::SliceRandom;
331
332 let mut evicted = Vec::new();
354
355 let mut non_empty = Vec::with_capacity(self.shards.len());
356
357 macro_rules! pop_shard {
358 () => {
359 loop {
360 match non_empty.pop() {
361 Some(shard) => {
362 let shard = shard.write().await;
363 if shard.len() > 0 {
365 break Some(shard);
366 }
367 }
368 None => break None,
369 }
370 }
371 };
372 }
373
374 'evict: while self.size() > 0 {
375 non_empty.extend(self.non_empty_shards());
376 non_empty.shuffle(&mut rng);
377
378 let mut shard_a = match pop_shard!() {
379 Some(shard) => shard,
380 None => continue 'evict,
382 };
383
384 'walk: loop {
385 match pop_shard!() {
386 None => {
387 let res = match shard_a.len() {
389 1 => unsafe {
390 let shard::Bucket {
391 ref key,
392 ref mut value,
393 ..
394 } = shard_a.entries.get_unchecked_mut(0);
395
396 let res = predicate(key, &mut value.value);
397
398 if matches!(res, Evict::Continue | Evict::Once) {
399 shard_a.indices.clear();
400 let shard::Bucket { key, value, .. } = shard_a.entries.pop().unwrap();
401 self.size.fetch_sub(1, Ordering::SeqCst);
402 evicted.push((key, value.value));
403 }
404
405 res
406 },
407 len @ _ => unsafe {
408 let (elem_a_idx, elem_b_idx) = pick_indices(len, &mut rng);
409
410 let ts_a = &shard_a.entries.get_unchecked(elem_a_idx).value.timestamp;
411 let ts_b = &shard_a.entries.get_unchecked(elem_b_idx).value.timestamp;
412 let idx = if ts_a.is_before(ts_b) {
413 elem_a_idx
414 } else {
415 elem_b_idx
416 };
417
418 let shard::Bucket {
419 ref key,
420 ref mut value,
421 ..
422 } = shard_a.entries.get_unchecked_mut(idx);
423
424 let res = predicate(key, &mut value.value);
425
426 if matches!(res, Evict::Continue | Evict::Once) {
427 let (key, value) = shard_a.swap_remove_index_raw(idx);
428 self.size.fetch_sub(1, Ordering::SeqCst);
429 evicted.push((key, value.value));
430 }
431
432 res
433 },
434 };
435
436 if matches!(res, Evict::Once | Evict::None) {
437 break 'evict;
438 }
439
440 continue 'evict;
443 }
444 Some(mut shard_b) => unsafe {
445 let shard_a_len = shard_a.len();
448 let shard_b_len = shard_b.len();
449
450 debug_assert!(shard_a_len > 0);
451 debug_assert!(shard_b_len > 0);
452
453 let sample_range = shard_a_len + shard_b_len;
454
455 let (elem_a_range_idx, elem_b_range_idx) = pick_indices(sample_range, &mut rng);
456
457 let ts_a = if elem_a_range_idx < shard_a_len {
458 &shard_a.entries.get_unchecked(elem_a_range_idx).value.timestamp
459 } else {
460 &shard_b
461 .entries
462 .get_unchecked(elem_a_range_idx - shard_a_len)
463 .value
464 .timestamp
465 };
466
467 let ts_b = if elem_b_range_idx < shard_a_len {
468 &shard_a.entries.get_unchecked(elem_b_range_idx).value.timestamp
469 } else {
470 &shard_b
471 .entries
472 .get_unchecked(elem_b_range_idx - shard_a_len)
473 .value
474 .timestamp
475 };
476
477 let elem_range_idx = if ts_a.is_before(ts_b) {
478 elem_a_range_idx
479 } else {
480 elem_b_range_idx
481 };
482
483 let (shard, idx) = if elem_range_idx < shard_a_len {
484 (&mut shard_a, elem_range_idx)
485 } else {
486 (&mut shard_b, elem_range_idx - shard_a_len)
487 };
488
489 let shard::Bucket {
490 ref key,
491 ref mut value,
492 ..
493 } = shard.entries.get_unchecked_mut(idx);
494
495 let res = predicate(key, &mut value.value);
496
497 if matches!(res, Evict::Continue | Evict::Once) {
498 let (key, value) = shard.swap_remove_index_raw(idx);
499 self.size.fetch_sub(1, Ordering::SeqCst);
500 evicted.push((key, value.value));
501 }
502
503 if matches!(res, Evict::None | Evict::Once) {
504 break 'evict;
505 }
506
507 shard_a = shard_b; },
509 }
510
511 if shard_a.len() == 0 {
513 shard_a = match pop_shard!() {
514 Some(shard) => shard,
515 None => break 'walk,
516 };
517 }
518 }
519 }
520
521 evicted
522 }
523
524 pub async fn evict_many(&self, mut count: usize, rng: impl Rng) -> Vec<(K, V)> {
531 count = count.min(self.size());
532
533 if count == 0 {
534 return Vec::new();
535 }
536
537 let mut cur = count;
538
539 self.evict(rng, |_, _| {
540 cur -= 1;
541
542 match cur {
543 0 => Evict::Once,
544 _ => Evict::Continue,
545 }
546 })
547 .await
548 }
549
550 pub async fn evict_one(&self, rng: impl Rng) -> Option<(K, V)> {
552 self.evict(rng, |_, _| Evict::Once).await.pop()
553 }
554
555 pub async fn evict_many_fast(&self, mut count: usize, mut rng: impl Rng) -> Vec<(K, V)> {
561 use rand::prelude::SliceRandom;
562
563 count = count.min(self.size());
564
565 let mut evicted = Vec::new();
566
567 if count == 0 {
568 return evicted;
569 }
570
571 let mut non_empty = Vec::with_capacity(self.shards.len());
572 non_empty.extend(self.non_empty_shards());
573 non_empty.shuffle(&mut rng);
574
575 fn proportion_of(size: usize, len: usize, count: usize) -> usize {
576 ((count as u128 * len as u128) / size as u128) as usize + 1
580 }
581
582 let size = self.size();
583
584 let mut sum = 0;
585 for shard in non_empty {
586 let mut shard = shard.write().await;
587
588 if shard.len() == 0 {
589 continue;
590 }
591
592 let mut sub_count = proportion_of(size, shard.len(), count);
593 sum += sub_count;
594
595 if sum > count {
596 sub_count = sum - count - 1;
597 }
598
599 if sub_count == shard.len() {
600 evicted.extend(
602 shard
603 .entries
604 .drain(..)
605 .map(|bucket| (bucket.key, bucket.value.value)),
606 );
607
608 shard.indices.clear();
609 self.size.fetch_sub(sub_count, Ordering::SeqCst); } else {
611 for _ in 0..sub_count {
612 let (elem_a_idx, elem_b_idx) = pick_indices(shard.len(), &mut rng);
613
614 unsafe {
615 let ts_a = &shard.entries.get_unchecked(elem_a_idx).value.timestamp;
616 let ts_b = &shard.entries.get_unchecked(elem_b_idx).value.timestamp;
617
618 let idx = if ts_a.is_before(ts_b) {
619 elem_a_idx
620 } else {
621 elem_b_idx
622 };
623
624 evicted.push({
625 let (key, value) = shard.swap_remove_index_raw(idx);
626 self.size.fetch_sub(1, Ordering::SeqCst);
627 (key, value.value)
628 });
629 }
630 }
631 }
632
633 if sum > count {
634 break;
635 }
636 }
637
638 evicted
639 }
640}
641
642#[derive(Debug, Clone, Copy, PartialEq, Eq)]
643pub enum Evict {
644 Continue,
646 Once,
648 None,
650}
651
652fn pick_indices(len: usize, mut rng: impl Rng) -> (usize, usize) {
653 match len {
654 0 => panic!("Invalid length"),
655 1 => (0, 0),
656 2 => (0, 1),
657 _ => {
658 let idx_a = rng.gen_range(0..len);
659
660 loop {
661 let idx_b = rng.gen_range(0..len);
662
663 if idx_b != idx_a {
664 return (idx_a, idx_b);
665 }
666 }
667 }
668 }
669}