1use std::{
52 borrow::Borrow,
53 cell::UnsafeCell,
54 fmt,
55 hash::{Hash, Hasher},
56 ops::{Deref, DerefMut},
57 sync::{
58 atomic::{AtomicBool, Ordering},
59 Arc,
60 },
61};
62
63use crossbeam_queue::SegQueue;
64
65mod dll;
66use crate::dll::{DoublyLinkedList, Node};
69
70const MAX_QUEUE_ITEMS: usize = 32;
71const RESIZE_CUTOFF: usize = 63;
75const RESIZE_CUTOFF_U8: u8 = RESIZE_CUTOFF as u8;
76const N_SHARDS: usize = 256;
77
78struct TryMutex<T> {
80 inner: UnsafeCell<T>,
81 mu: AtomicBool,
82}
83
84impl<T> TryMutex<T> {
85 fn new(inner: T) -> TryMutex<T> {
86 TryMutex {
87 inner: inner.into(),
88 mu: false.into(),
89 }
90 }
91
92 #[inline]
93 fn try_lock(&self) -> Option<TryMutexGuard<'_, T>> {
94 if self.mu.swap(true, Ordering::Acquire) {
95 None
97 } else {
98 Some(TryMutexGuard { tm: self })
99 }
100 }
101}
102
103struct TryMutexGuard<'a, T> {
104 tm: &'a TryMutex<T>,
105}
106
107unsafe impl<T: Send> Send for TryMutex<T> {}
108
109unsafe impl<T: Send> Sync for TryMutex<T> {}
110
111impl<'a, T> Drop for TryMutexGuard<'a, T> {
112 #[inline]
113 fn drop(&mut self) {
114 assert!(self.tm.mu.swap(false, Ordering::Release));
115 }
116}
117
118impl<'a, T> Deref for TryMutexGuard<'a, T> {
119 type Target = T;
120
121 fn deref(&self) -> &T {
122 unsafe { &*self.tm.inner.get() }
123 }
124}
125
126impl<'a, T> DerefMut for TryMutexGuard<'a, T> {
127 #[inline]
128 fn deref_mut(&mut self) -> &mut T {
129 unsafe { &mut *self.tm.inner.get() }
130 }
131}
132
133#[derive(Clone, Default)]
134struct Resizer {
135 actual: u128,
136 decompressed: u128,
137}
138
139impl Resizer {
140 fn compress(&mut self, raw_input: usize) -> u8 {
143 if raw_input <= RESIZE_CUTOFF {
144 return u8::try_from(raw_input).unwrap();
145 }
146
147 let upgraded_input = u128::try_from(raw_input).unwrap();
148 let po2 = upgraded_input.next_power_of_two();
149 let compressed = po2.trailing_zeros() as u8;
150 let decompressed = decompress(compressed + RESIZE_CUTOFF_U8) as u128;
151 self.actual += raw_input as u128;
152
153 let ret = if self.decompressed + decompressed > self.actual {
154 compressed - 1
155 } else {
156 compressed
157 };
158
159 self.decompressed += decompress(ret + RESIZE_CUTOFF_U8) as u128;
160
161 let sz = ret + RESIZE_CUTOFF_U8;
162
163 assert!(sz < 128);
164
165 sz
166 }
167}
168
169#[inline]
170const fn decompress(input: u8) -> usize {
171 let masked = input & 127;
173 match masked {
174 0..=RESIZE_CUTOFF_U8 => masked as usize,
175 _ => {
176 if let Some(o) = 1_usize.checked_shl((masked - RESIZE_CUTOFF_U8) as u32) {
177 o
178 } else {
179 usize::MAX
180 }
181 }
182 }
183}
184
185struct Fnv(u64);
186
187impl Default for Fnv {
188 #[inline]
189 fn default() -> Fnv {
190 Fnv(0xcbf29ce484222325)
191 }
192}
193
194impl std::hash::Hasher for Fnv {
195 #[inline]
196 fn finish(&self) -> u64 {
197 self.0
198 }
199
200 #[inline]
201 fn write(&mut self, bytes: &[u8]) {
202 let Fnv(mut hash) = *self;
203
204 for byte in bytes.iter() {
205 hash ^= *byte as u64;
206 hash = hash.wrapping_mul(0x100000001b3);
207 }
208
209 *self = Fnv(hash);
210 }
211}
212
213pub(crate) type FnvSet8<V> = std::collections::HashSet<V, std::hash::BuildHasherDefault<Fnv>>;
214
215type PageId = u64;
216
217fn _sz_test() {
218 let _: [u8; 8] = [0; std::mem::size_of::<CacheAccess>()];
219 let _: [u8; 1] = [0; std::mem::align_of::<CacheAccess>()];
220}
221
222#[derive(Debug, Clone, Copy, PartialEq, Eq)]
223pub(crate) struct CacheAccess {
224 size: u8,
225 pid_bytes: [u8; 7],
226}
227
228impl CacheAccess {
229 fn was_promoted(&self) -> bool {
230 self.size & 128 != 0
231 }
232
233 fn size(&self) -> usize {
234 decompress((self.size) as u8)
235 }
236
237 fn pid(&self, shard: u8) -> PageId {
238 let mut pid_bytes = [0; 8];
239 pid_bytes[1..8].copy_from_slice(&self.pid_bytes);
240 pid_bytes[0] = shard;
241 PageId::from_le_bytes(pid_bytes)
242 }
243
244 fn new(pid: PageId, sz: usize, resizer: &mut Resizer) -> CacheAccess {
245 let size = resizer.compress(sz);
246
247 let mut pid_bytes = [0; 7];
248 pid_bytes.copy_from_slice(&pid.to_le_bytes()[1..8]);
249
250 CacheAccess { size, pid_bytes }
251 }
252}
253
254pub struct CacheAdvisor {
305 shards: Arc<[TryMutex<Shard>]>,
306 access_queues: Arc<[SegQueue<CacheAccess>]>,
307 local_queue: Vec<(u64, usize)>,
308 resizer: Resizer,
309 access_buffer: Vec<(u64, usize)>,
310}
311
312impl Clone for CacheAdvisor {
313 fn clone(&self) -> CacheAdvisor {
314 CacheAdvisor {
315 shards: self.shards.clone(),
316 access_queues: self.access_queues.clone(),
317 local_queue: vec![],
318 resizer: self.resizer.clone(),
319 access_buffer: vec![],
320 }
321 }
322}
323
324impl fmt::Debug for CacheAdvisor {
325 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326 f.debug_struct("CacheAdvisor").finish()
327 }
328}
329
330impl Default for CacheAdvisor {
331 fn default() -> CacheAdvisor {
333 CacheAdvisor::new(1024 * 1024, 20)
334 }
335}
336
337const fn _send_sync_ca() {
338 const fn send_sync<T: Send + Sync>() {}
339 send_sync::<CacheAdvisor>();
340}
341
342impl CacheAdvisor {
343 pub fn new(capacity: usize, entry_percent: u8) -> Self {
364 assert!(
365 capacity >= N_SHARDS,
366 "Please configure the cache \
367 capacity to be at least 256"
368 );
369 let shard_capacity = capacity / N_SHARDS;
370
371 let mut shards = Vec::with_capacity(N_SHARDS);
372 for _ in 0..N_SHARDS {
373 shards.push(TryMutex::new(Shard::new(shard_capacity, entry_percent)))
374 }
375
376 let mut access_queues = Vec::with_capacity(N_SHARDS);
377 for _ in 0..N_SHARDS {
378 access_queues.push(SegQueue::default());
379 }
380
381 Self {
382 shards: shards.into(),
383 access_queues: access_queues.into(),
384 local_queue: Vec::with_capacity(MAX_QUEUE_ITEMS),
385 resizer: Resizer::default(),
386 access_buffer: vec![],
387 }
388 }
389
390 pub fn accessed(&mut self, id: u64, cost: usize) -> Vec<(u64, usize)> {
394 let mut ret = vec![];
395 self.accessed_inner(id, cost, &mut ret);
396 ret
397 }
398
399 pub fn accessed_reuse_buffer(&mut self, id: u64, cost: usize) -> &[(u64, usize)] {
406 let mut swapped = std::mem::take(&mut self.access_buffer);
407 swapped.clear();
408 self.accessed_inner(id, cost, &mut swapped);
409 self.access_buffer = swapped;
410 &self.access_buffer
411 }
412
413 pub fn reset_internal_access_buffer(&mut self) {
421 self.access_buffer = vec![]
422 }
423
424 fn accessed_inner(&mut self, id: u64, cost: usize, ret: &mut Vec<(u64, usize)>) {
425 self.local_queue.push((id, cost));
426
427 if self.local_queue.len() < MAX_QUEUE_ITEMS {
428 return;
429 }
430
431 while let Some((id, cost)) = self.local_queue.pop() {
432 let shard_idx = (id.to_le_bytes()[0] as u64 % N_SHARDS as u64) as usize;
433 let shard_mu = &self.shards[shard_idx];
434 let access_queue = &self.access_queues[shard_idx];
435 let cache_access = CacheAccess::new(id, cost, &mut self.resizer);
436
437 if let Some(mut shard) = shard_mu.try_lock() {
439 for _ in 0..access_queue.len() {
444 if let Some(queued_cache_access) = access_queue.pop() {
445 shard.accessed(queued_cache_access, shard_idx, ret);
446 }
447 }
448
449 shard.accessed(cache_access, shard_idx, ret);
450 } else {
451 access_queue.push(cache_access);
452 }
453 }
454 }
455}
456
457#[derive(Eq)]
458struct Entry(*mut Node);
459
460unsafe impl Send for Entry {}
461
462impl Ord for Entry {
463 fn cmp(&self, other: &Entry) -> std::cmp::Ordering {
464 let left_pid: &[u8; 7] = self.borrow();
465 let right_pid: &[u8; 7] = other.borrow();
466 left_pid.cmp(&right_pid)
467 }
468}
469
470impl PartialOrd<Entry> for Entry {
471 fn partial_cmp(&self, other: &Entry) -> Option<std::cmp::Ordering> {
472 Some(self.cmp(other))
473 }
474}
475
476impl PartialEq for Entry {
477 fn eq(&self, other: &Entry) -> bool {
478 unsafe { (*self.0).pid_bytes == (*other.0).pid_bytes }
479 }
480}
481
482impl Borrow<[u8; 7]> for Entry {
483 fn borrow(&self) -> &[u8; 7] {
484 unsafe { &(*self.0).pid_bytes }
485 }
486}
487
488impl Hash for Entry {
491 fn hash<H: Hasher>(&self, hasher: &mut H) {
492 unsafe { (*self.0).pid_bytes.hash(hasher) }
493 }
494}
495
496struct Shard {
497 entry_cache: DoublyLinkedList,
498 main_cache: DoublyLinkedList,
499 entries: FnvSet8<Entry>,
500 entry_capacity: usize,
501 entry_size: usize,
502 main_capacity: usize,
503 main_size: usize,
504 ever_evicted_main: bool,
505}
506
507impl Shard {
508 fn new(capacity: usize, entry_pct: u8) -> Self {
509 assert!(
510 entry_pct <= 100,
511 "entry cache percent must be less than or equal to 100"
512 );
513 assert!(capacity > 0, "shard capacity must be non-zero");
514
515 let entry_capacity = (capacity * entry_pct as usize) / 100;
516 let main_capacity = capacity - entry_capacity;
517
518 Self {
519 entry_cache: DoublyLinkedList::default(),
520 main_cache: DoublyLinkedList::default(),
521 entries: FnvSet8::default(),
522 entry_capacity,
523 main_capacity,
524 entry_size: 0,
525 main_size: 0,
526 ever_evicted_main: false,
527 }
528 }
529
530 fn accessed(
531 &mut self,
532 cache_access: CacheAccess,
533 shard_idx: usize,
534 ret: &mut Vec<(u64, usize)>,
535 ) {
536 let new_size = cache_access.size();
537
538 if let Some(entry) = self.entries.get(&cache_access.pid_bytes) {
539 let (old_size, was_promoted) = unsafe {
540 let old_size = (*entry.0).size();
541 let was_promoted = (*entry.0).was_promoted();
542
543 (*entry.0).inner.get_mut().size = 128 | cache_access.size;
552
553 (old_size, was_promoted)
554 };
555
556 if was_promoted {
557 self.main_size -= old_size;
560
561 self.main_cache.unwire(entry.0);
562 self.main_cache.install(entry.0);
563 } else {
564 self.entry_size -= old_size;
567
568 self.entry_cache.unwire(entry.0);
569 self.main_cache.install(entry.0);
570 }
571
572 self.main_size += new_size;
573 } else if !self.ever_evicted_main {
574 let mut cache_access = cache_access;
580 cache_access.size |= 128;
581 let ptr = self.main_cache.push_head(cache_access);
582 self.entries.insert(Entry(ptr));
583 self.main_size += new_size;
584 } else {
585 let ptr = self.entry_cache.push_head(cache_access);
586 self.entries.insert(Entry(ptr));
587 self.entry_size += new_size;
588 };
589
590 while self.entry_size > self.entry_capacity && self.entry_cache.len() > 1 {
591 let node: *mut Node = self.entry_cache.pop_tail().unwrap();
592
593 let popped_entry: CacheAccess = unsafe { *(*node).inner.get() };
594 let node_size = popped_entry.size();
595 let item = popped_entry.pid(u8::try_from(shard_idx).unwrap());
596
597 self.entry_size -= node_size;
598
599 assert!(
600 !popped_entry.was_promoted(),
601 "somehow, promoted item was still in entry cache"
602 );
603
604 let pid_bytes = popped_entry.pid_bytes;
605 assert!(self.entries.remove(&pid_bytes));
606
607 ret.push((item, node_size));
608 let node_box: Box<Node> = unsafe { Box::from_raw(node) };
609
610 drop(node_box);
617 }
618
619 while self.main_size > self.main_capacity && self.main_cache.len() > 1 {
620 self.ever_evicted_main = true;
621
622 let node: *mut Node = self.main_cache.pop_tail().unwrap();
623
624 let popped_main: CacheAccess = unsafe { *(*node).inner.get() };
625 let node_size = popped_main.size();
626 let item = popped_main.pid(u8::try_from(shard_idx).unwrap());
627
628 self.main_size -= node_size;
629
630 let pid_bytes = popped_main.pid_bytes;
631 assert!(self.entries.remove(&pid_bytes));
632
633 ret.push((item, node_size));
634
635 let node_box: Box<Node> = unsafe { Box::from_raw(node) };
636
637 drop(node_box);
644 }
645 }
646}
647
648#[test]
649fn lru_smoke_test() {
650 let mut lru = CacheAdvisor::new(256, 50);
651 let mut evicted = 0;
652 for i in 0..10_000 {
653 evicted += lru.accessed(i, 16).len();
654 }
655 assert!(evicted > 9700, "only evicted {} items", evicted);
656}
657
658#[test]
659fn probabilistic_sum() {
660 let mut resizer = Resizer::default();
661 let mut resized = 0;
662 let mut actual = 0;
663 for i in 0..1000 {
664 let compressed = resizer.compress(i);
665 let decompressed = decompress(compressed);
666 resized += decompressed;
667 actual += i;
668 }
669
670 let abs_delta = ((resized as f64 / actual as f64) - 1.).abs();
671
672 assert!(abs_delta < 0.005, "delta is actually {}", abs_delta);
673}
674
675#[test]
676fn probabilistic_ev() {
677 let mut resizer = Resizer::default();
678
679 fn assert_rt(i: usize, resizer: &mut Resizer) {
680 let mut resized = 0_u128;
681 let mut actual = 0_u128;
682 for _ in 1..10_000 {
683 let compressed = resizer.compress(i);
684 let decompressed = decompress(compressed);
685 resized += decompressed as u128;
686 actual += i as u128;
687 }
688
689 if i == 0 {
690 assert_eq!(actual, 0);
691 assert_eq!(resized, 0);
692 } else {
693 let abs_delta = ((resized as f64 / actual as f64) - 1.).abs();
694 assert!(
695 abs_delta < 0.0001,
696 "delta is actually {} for inputs of size {}. actual: {} round-trip: {}",
697 abs_delta,
698 i,
699 actual,
700 resized
701 );
702 }
703 }
704
705 for i in 0..1024 {
706 assert_rt(i, &mut resizer)
707 }
708
709 assert_rt(usize::MAX, &mut resizer)
710}
711
712#[test]
713fn probabilistic_n() {
714 const N: usize = 9;
715
716 let mut resizer = Resizer::default();
717 let mut resized = 0;
718 let mut actual = 0;
719
720 for _ in 0..1000 {
721 let compressed = resizer.compress(N);
722 let decompressed = decompress(compressed);
723 resized += decompressed;
724 actual += N;
725 }
726
727 let abs_delta = ((resized as f64 / actual as f64) - 1.).abs();
728
729 assert!(abs_delta < 0.005, "delta is actually {}", abs_delta);
730}
731
732#[test]
733fn scan_resistance() {
734 let mut ca = CacheAdvisor::new(256 * 10, 10);
736
737 ca.accessed(0, 1);
739
740 ca.accessed(0, 1);
742
743 for i in 1..5000 {
745 let id = i * 256;
746 let evicted = ca.accessed(id, 1);
747
748 assert!(!evicted.contains(&(0, 1)));
750 }
751
752 let mut zero_evicted = false;
753
754 for i in 1..5000 {
757 let id = i * 256;
758 zero_evicted |= ca.accessed(id, 1).contains(&(0, 1));
759 zero_evicted |= ca.accessed(id, 1).contains(&(0, 1));
760 zero_evicted |= ca.accessed(id, 1).contains(&(0, 1));
761 }
762
763 assert!(zero_evicted);
764}