1use super::{deques::Deques, CacheBuilder, Iter, KeyHashDate, ValueEntry};
2use crate::{
3 common::{self, deque::DeqNode, frequency_sketch::FrequencySketch, CacheRegion},
4 Policy,
5};
6
7use std::{
8 borrow::Borrow,
9 collections::{hash_map::RandomState, HashMap},
10 fmt,
11 hash::{BuildHasher, Hash},
12 ptr::NonNull,
13 rc::Rc,
14};
15
16const EVICTION_BATCH_SIZE: usize = 100;
17
18type CacheStore<K, V, S> = std::collections::HashMap<Rc<K>, ValueEntry<K, V>, S>;
19
20pub struct Cache<K, V, S = RandomState> {
94 max_capacity: Option<u64>,
95 entry_count: u64,
96 cache: CacheStore<K, V, S>,
97 build_hasher: S,
98 deques: Deques<K>,
99 frequency_sketch: FrequencySketch,
100 frequency_sketch_enabled: bool,
101}
102
103impl<K, V, S> fmt::Debug for Cache<K, V, S>
104where
105 K: fmt::Debug + Eq + Hash,
106 V: fmt::Debug,
107 S: BuildHasher + Clone,
109{
110 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111 let mut d_map = f.debug_map();
112
113 for (k, v) in self.iter() {
114 d_map.entry(&k, &v);
115 }
116
117 d_map.finish()
118 }
119}
120
121impl<K, V> Cache<K, V, RandomState>
122where
123 K: Hash + Eq,
124{
125 pub fn new(max_capacity: u64) -> Self {
132 let build_hasher = RandomState::default();
133 Self::with_everything(Some(max_capacity), None, build_hasher)
134 }
135
136 pub fn builder() -> CacheBuilder<K, V, Cache<K, V, RandomState>> {
141 CacheBuilder::default()
142 }
143}
144
145impl<K, V, S> Cache<K, V, S> {
149 pub fn policy(&self) -> Policy {
154 Policy::new(self.max_capacity)
155 }
156
157 pub fn entry_count(&self) -> u64 {
177 self.entry_count
178 }
179
180 pub fn weighted_size(&self) -> u64 {
184 self.entry_count
185 }
186}
187
188impl<K, V, S> Cache<K, V, S>
189where
190 K: Hash + Eq,
191 S: BuildHasher + Clone,
192{
193 pub(crate) fn with_everything(
194 max_capacity: Option<u64>,
195 initial_capacity: Option<usize>,
196 build_hasher: S,
197 ) -> Self {
198 let cache = HashMap::with_capacity_and_hasher(
199 initial_capacity.unwrap_or_default(),
200 build_hasher.clone(),
201 );
202
203 Self {
204 max_capacity,
205 entry_count: 0,
206 cache,
207 build_hasher,
208 deques: Default::default(),
209 frequency_sketch: Default::default(),
210 frequency_sketch_enabled: false,
211 }
212 }
213
214 pub fn contains_key<Q>(&mut self, key: &Q) -> bool
222 where
223 Rc<K>: Borrow<Q>,
224 Q: Hash + Eq + ?Sized,
225 {
226 self.cache.contains_key(key)
227 }
228
229 pub fn get<Q>(&mut self, key: &Q) -> Option<&V>
234 where
235 Rc<K>: Borrow<Q>,
236 Q: Hash + Eq + ?Sized,
237 {
238 self.frequency_sketch.increment(self.hash(key));
239
240 if let Some(entry) = self.cache.get_mut(key) {
241 Self::record_hit(&mut self.deques, entry);
242 Some(&entry.value)
243 } else {
244 None
245 }
246 }
247
248 pub fn insert(&mut self, key: K, value: V) {
252 self.evict_lru_entries();
253 let policy_weight = 1;
254 let key = Rc::new(key);
255 let entry = ValueEntry::new(value);
256
257 if let Some(old_entry) = self.cache.insert(Rc::clone(&key), entry) {
258 self.handle_update(key, policy_weight, old_entry);
259 } else {
260 let hash = self.hash(&key);
261 self.handle_insert(key, hash, policy_weight);
262 }
263 }
264
265 pub fn invalidate<Q>(&mut self, key: &Q)
270 where
271 Rc<K>: Borrow<Q>,
272 Q: Hash + Eq + ?Sized,
273 {
274 self.evict_lru_entries();
275
276 if let Some(mut entry) = self.cache.remove(key) {
277 self.deques.unlink_ao(&mut entry);
278 self.entry_count -= 1;
279 }
280 }
281
282 pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
287 where
288 Rc<K>: Borrow<Q>,
289 Q: Hash + Eq + ?Sized,
290 {
291 self.evict_lru_entries();
292
293 if let Some(mut entry) = self.cache.remove(key) {
294 self.deques.unlink_ao(&mut entry);
295 self.entry_count -= 1;
296 Some(entry.value)
297 } else {
298 None
299 }
300 }
301
302 pub fn invalidate_all(&mut self) {
308 let old_capacity = self.cache.capacity();
311 let old_cache = std::mem::replace(
312 &mut self.cache,
313 HashMap::with_hasher(self.build_hasher.clone()),
314 );
315 self.deques.clear();
316 self.entry_count = 0;
317
318 drop(old_cache);
320
321 let _ = self.cache.try_reserve(old_capacity);
323 }
324
325 #[allow(clippy::needless_collect)]
340 pub fn invalidate_entries_if(&mut self, mut predicate: impl FnMut(&K, &V) -> bool) {
341 let Self { cache, deques, .. } = self;
342
343 let keys_to_invalidate = cache
349 .iter()
350 .filter(|(key, entry)| (predicate)(key, &entry.value))
351 .map(|(key, _)| Rc::clone(key))
352 .collect::<Vec<_>>();
353
354 let mut invalidated = 0u64;
355
356 keys_to_invalidate.into_iter().for_each(|k| {
357 if let Some(mut entry) = cache.remove(&k) {
358 let _weight = entry.policy_weight();
359 deques.unlink_ao(&mut entry);
360 invalidated += 1;
361 }
362 });
363 self.entry_count -= invalidated;
364 }
365
366 pub fn iter(&self) -> Iter<'_, K, V> {
389 Iter::new(self, self.cache.iter())
390 }
391}
392
393impl<K, V, S> Cache<K, V, S>
397where
398 K: Hash + Eq,
399 S: BuildHasher + Clone,
400{
401 #[inline]
402 fn hash<Q>(&self, key: &Q) -> u64
403 where
404 Rc<K>: Borrow<Q>,
405 Q: Hash + Eq + ?Sized,
406 {
407 self.build_hasher.hash_one(key)
408 }
409
410 fn record_hit(deques: &mut Deques<K>, entry: &mut ValueEntry<K, V>) {
411 deques.move_to_back_ao(entry)
412 }
413
414 fn has_enough_capacity(&self, candidate_weight: u32, ws: u64) -> bool {
415 self.max_capacity
416 .map(|limit| ws + candidate_weight as u64 <= limit)
417 .unwrap_or(true)
418 }
419
420 fn weights_to_evict(&self) -> u64 {
421 self.max_capacity
422 .map(|limit| self.entry_count.saturating_sub(limit))
423 .unwrap_or_default()
424 }
425
426 #[inline]
427 fn should_enable_frequency_sketch(&self) -> bool {
428 if self.frequency_sketch_enabled {
429 false
430 } else if let Some(max_cap) = self.max_capacity {
431 self.entry_count >= max_cap / 2
432 } else {
433 false
434 }
435 }
436
437 #[inline]
438 fn enable_frequency_sketch(&mut self) {
439 if let Some(max_cap) = self.max_capacity {
440 self.do_enable_frequency_sketch(max_cap);
441 }
442 }
443
444 #[cfg(test)]
445 fn enable_frequency_sketch_for_testing(&mut self) {
446 if let Some(max_cap) = self.max_capacity {
447 self.do_enable_frequency_sketch(max_cap);
448 }
449 }
450
451 #[inline]
452 fn do_enable_frequency_sketch(&mut self, cache_capacity: u64) {
453 let skt_capacity = common::sketch_capacity(cache_capacity);
454 self.frequency_sketch.ensure_capacity(skt_capacity);
455 self.frequency_sketch_enabled = true;
456 }
457
458 #[inline]
459 fn handle_insert(&mut self, key: Rc<K>, hash: u64, policy_weight: u32) {
460 debug_assert_eq!(policy_weight, 1);
461 let has_free_space = self.has_enough_capacity(policy_weight, self.entry_count);
462 let (cache, deqs, freq) = (&mut self.cache, &mut self.deques, &self.frequency_sketch);
463
464 if has_free_space {
465 let key = Rc::clone(&key);
467 let entry = cache.get_mut(&key).unwrap();
468 deqs.push_back_ao(
469 CacheRegion::MainProbation,
470 KeyHashDate::new(Rc::clone(&key), hash),
471 entry,
472 );
473 self.entry_count += 1;
474 if self.should_enable_frequency_sketch() {
477 self.enable_frequency_sketch();
478 }
479
480 return;
481 }
482
483 if let Some(max) = self.max_capacity {
484 if policy_weight as u64 > max {
485 cache.remove(&Rc::clone(&key));
487 return;
488 }
489 }
490
491 let candidate_freq = freq.frequency(hash);
492
493 match Self::admit(candidate_freq, deqs, freq) {
494 AdmissionResult::Admitted { victim_node } => {
495 let mut vic_entry = cache
497 .remove(unsafe { &victim_node.as_ref().element.key })
498 .expect("Cannot remove a victim from the hash map");
499 deqs.unlink_ao(&mut vic_entry);
500 self.entry_count -= 1;
501
502 let entry = cache.get_mut(&key).unwrap();
504 let key = Rc::clone(&key);
505 deqs.push_back_ao(
506 CacheRegion::MainProbation,
507 KeyHashDate::new(Rc::clone(&key), hash),
508 entry,
509 );
510
511 self.entry_count += 1;
512 if self.should_enable_frequency_sketch() {
516 self.enable_frequency_sketch();
517 }
518 }
519 AdmissionResult::Rejected => {
520 cache.remove(&key);
522 }
523 }
524 }
525
526 #[inline]
536 fn admit(candidate_freq: u8, deqs: &Deques<K>, freq: &FrequencySketch) -> AdmissionResult<K> {
537 let Some(victim_node) = deqs.probation.peek_front_ptr() else {
538 return AdmissionResult::Rejected;
539 };
540 let victim_hash = unsafe { victim_node.as_ref() }.element.hash;
541 let victim_freq = freq.frequency(victim_hash);
542
543 if candidate_freq > victim_freq {
547 AdmissionResult::Admitted { victim_node }
548 } else {
549 AdmissionResult::Rejected
550 }
551 }
552
553 fn handle_update(&mut self, key: Rc<K>, policy_weight: u32, old_entry: ValueEntry<K, V>) {
554 let entry = self.cache.get_mut(&key).unwrap();
555 entry.replace_deq_nodes_with(old_entry);
556 entry.set_policy_weight(policy_weight);
557
558 let deqs = &mut self.deques;
559 deqs.move_to_back_ao(entry);
560
561 }
564
565 #[inline]
566 fn evict_lru_entries(&mut self) {
567 const DEQ_NAME: &str = "probation";
568
569 let weights_to_evict = self.weights_to_evict();
570 let mut evicted_count = 0u64;
571 let mut evicted_policy_weight = 0u64;
572
573 {
574 let deqs = &mut self.deques;
575 let (probation, cache) = (&mut deqs.probation, &mut self.cache);
576
577 for _ in 0..EVICTION_BATCH_SIZE {
578 if evicted_policy_weight >= weights_to_evict {
579 break;
580 }
581
582 #[allow(clippy::map_clone)]
585 let key = probation
586 .peek_front()
587 .map(|node| Rc::clone(&node.element.key));
588
589 if key.is_none() {
590 break;
591 }
592 let key = key.unwrap();
593
594 if let Some(mut entry) = cache.remove(&key) {
595 let weight = entry.policy_weight();
596 Deques::unlink_ao_from_deque(DEQ_NAME, probation, &mut entry);
597 evicted_count += 1;
598 evicted_policy_weight = evicted_policy_weight.saturating_add(weight as u64);
599 } else {
600 probation.pop_front();
601 }
602 }
603 }
604
605 self.entry_count -= evicted_count;
606 }
608}
609
610#[cfg(test)]
614impl<K, V, S> Cache<K, V, S>
615where
616 K: Hash + Eq,
617 S: BuildHasher + Clone,
618{
619}
620
621type AoqNode<K> = NonNull<DeqNode<KeyHashDate<K>>>;
623
624enum AdmissionResult<K> {
625 Admitted { victim_node: AoqNode<K> },
626 Rejected,
627}
628
629#[cfg(test)]
635mod tests {
636 use super::Cache;
637
638 #[test]
639 fn basic_single_thread() {
640 let mut cache = Cache::new(3);
641 cache.enable_frequency_sketch_for_testing();
642
643 cache.insert("a", "alice");
644 cache.insert("b", "bob");
645 assert_eq!(cache.get(&"a"), Some(&"alice"));
646 assert!(cache.contains_key(&"a"));
647 assert!(cache.contains_key(&"b"));
648 assert_eq!(cache.get(&"b"), Some(&"bob"));
649 cache.insert("c", "cindy");
652 assert_eq!(cache.get(&"c"), Some(&"cindy"));
653 assert!(cache.contains_key(&"c"));
654 assert!(cache.contains_key(&"a"));
657 assert_eq!(cache.get(&"a"), Some(&"alice"));
658 assert_eq!(cache.get(&"b"), Some(&"bob"));
659 assert!(cache.contains_key(&"b"));
660 cache.insert("d", "david"); assert_eq!(cache.get(&"d"), None); assert!(!cache.contains_key(&"d"));
666
667 cache.insert("d", "david");
668 assert!(!cache.contains_key(&"d"));
669 assert_eq!(cache.get(&"d"), None); cache.insert("d", "dennis");
674 assert_eq!(cache.get(&"a"), Some(&"alice"));
675 assert_eq!(cache.get(&"b"), Some(&"bob"));
676 assert_eq!(cache.get(&"c"), None);
677 assert_eq!(cache.get(&"d"), Some(&"dennis"));
678 assert!(cache.contains_key(&"a"));
679 assert!(cache.contains_key(&"b"));
680 assert!(!cache.contains_key(&"c"));
681 assert!(cache.contains_key(&"d"));
682
683 cache.invalidate(&"b");
684 assert_eq!(cache.get(&"b"), None);
685 assert!(!cache.contains_key(&"b"));
686 }
687
688 #[test]
689 fn invalidate_all() {
690 let mut cache = Cache::new(100);
691 cache.enable_frequency_sketch_for_testing();
692
693 cache.insert("a", "alice");
694 cache.insert("b", "bob");
695 cache.insert("c", "cindy");
696 assert_eq!(cache.get(&"a"), Some(&"alice"));
697 assert_eq!(cache.get(&"b"), Some(&"bob"));
698 assert_eq!(cache.get(&"c"), Some(&"cindy"));
699 assert!(cache.contains_key(&"a"));
700 assert!(cache.contains_key(&"b"));
701 assert!(cache.contains_key(&"c"));
702
703 cache.invalidate_all();
704
705 cache.insert("d", "david");
706
707 assert!(cache.get(&"a").is_none());
708 assert!(cache.get(&"b").is_none());
709 assert!(cache.get(&"c").is_none());
710 assert_eq!(cache.get(&"d"), Some(&"david"));
711 assert!(!cache.contains_key(&"a"));
712 assert!(!cache.contains_key(&"b"));
713 assert!(!cache.contains_key(&"c"));
714 assert!(cache.contains_key(&"d"));
715 }
716
717 #[test]
718 fn invalidate_entries_if() {
719 use std::collections::HashSet;
720
721 let mut cache = Cache::new(100);
722 cache.enable_frequency_sketch_for_testing();
723
724 cache.insert(0, "alice");
725 cache.insert(1, "bob");
726 cache.insert(2, "alex");
727
728 assert_eq!(cache.get(&0), Some(&"alice"));
729 assert_eq!(cache.get(&1), Some(&"bob"));
730 assert_eq!(cache.get(&2), Some(&"alex"));
731 assert!(cache.contains_key(&0));
732 assert!(cache.contains_key(&1));
733 assert!(cache.contains_key(&2));
734
735 let names = ["alice", "alex"].iter().cloned().collect::<HashSet<_>>();
736 cache.invalidate_entries_if(move |_k, &v| names.contains(v));
737
738 cache.insert(3, "alice");
739
740 assert!(cache.get(&0).is_none());
741 assert!(cache.get(&2).is_none());
742 assert_eq!(cache.get(&1), Some(&"bob"));
743 assert_eq!(cache.get(&3), Some(&"alice"));
745
746 assert!(!cache.contains_key(&0));
747 assert!(cache.contains_key(&1));
748 assert!(!cache.contains_key(&2));
749 assert!(cache.contains_key(&3));
750
751 assert_eq!(cache.cache.len(), 2);
752
753 cache.invalidate_entries_if(|_k, &v| v == "alice");
754 cache.invalidate_entries_if(|_k, &v| v == "bob");
755
756 assert!(cache.get(&1).is_none());
757 assert!(cache.get(&3).is_none());
758
759 assert!(!cache.contains_key(&1));
760 assert!(!cache.contains_key(&3));
761
762 assert_eq!(cache.cache.len(), 0);
763 }
764
765 #[cfg_attr(target_pointer_width = "16", ignore)]
766 #[test]
767 fn test_skt_capacity_will_not_overflow() {
768 let pot = |exp| 2u64.pow(exp);
770
771 let ensure_sketch_len = |max_capacity, len, name| {
772 let mut cache = Cache::<u8, u8>::new(max_capacity);
773 cache.enable_frequency_sketch_for_testing();
774 assert_eq!(cache.frequency_sketch.table_len(), len as usize, "{}", name);
775 };
776
777 if cfg!(target_pointer_width = "32") {
778 let pot24 = pot(24);
779 let pot16 = pot(16);
780 ensure_sketch_len(0, 128, "0");
781 ensure_sketch_len(128, 128, "128");
782 ensure_sketch_len(pot16, pot16, "pot16");
783 ensure_sketch_len(pot16 + 1, pot(17), "pot16 + 1");
785 ensure_sketch_len(pot24 - 1, pot24, "pot24 - 1");
787 ensure_sketch_len(pot24, pot24, "pot24");
788 ensure_sketch_len(pot(27), pot24, "pot(27)");
789 ensure_sketch_len(u32::MAX as u64, pot24, "u32::MAX");
790 } else {
791 let pot30 = pot(30);
793 let pot16 = pot(16);
794 ensure_sketch_len(0, 128, "0");
795 ensure_sketch_len(128, 128, "128");
796 ensure_sketch_len(pot16, pot16, "pot16");
797 ensure_sketch_len(pot16 + 1, pot(17), "pot16 + 1");
799
800 if !cfg!(circleci) {
803 ensure_sketch_len(pot30 - 1, pot30, "pot30- 1");
805 ensure_sketch_len(pot30, pot30, "pot30");
806 ensure_sketch_len(u64::MAX, pot30, "u64::MAX");
807 }
808 };
809 }
810
811 #[test]
812 fn remove_decrements_entry_count() {
813 let mut cache = Cache::new(3);
814 cache.insert("a", "alice");
815 cache.insert("b", "bob");
816 assert_eq!(cache.entry_count(), 2);
817
818 let removed = cache.remove(&"a");
819 assert_eq!(removed, Some("alice"));
820 assert_eq!(cache.entry_count(), 1);
821
822 cache.remove(&"nonexistent");
823 assert_eq!(cache.entry_count(), 1);
824
825 cache.remove(&"b");
826 assert_eq!(cache.entry_count(), 0);
827 }
828
829 #[test]
830 fn invalidate_decrements_entry_count() {
831 let mut cache = Cache::new(3);
832 cache.insert("a", "alice");
833 cache.insert("b", "bob");
834 assert_eq!(cache.entry_count(), 2);
835
836 cache.invalidate(&"a");
837 assert_eq!(cache.entry_count(), 1);
838
839 cache.invalidate(&"nonexistent");
840 assert_eq!(cache.entry_count(), 1);
841
842 cache.invalidate(&"b");
843 assert_eq!(cache.entry_count(), 0);
844 }
845
846 #[test]
847 fn insert_after_remove_on_full_cache() {
848 let mut cache = Cache::new(2);
849 cache.insert("a", "alice");
850 cache.insert("b", "bob");
851 assert_eq!(cache.entry_count(), 2);
852
853 cache.remove(&"a");
854 assert_eq!(cache.entry_count(), 1);
855
856 cache.insert("c", "cindy");
857 assert_eq!(cache.entry_count(), 2);
858 assert_eq!(cache.get(&"c"), Some(&"cindy"));
859 assert_eq!(cache.get(&"b"), Some(&"bob"));
860 assert_eq!(cache.get(&"a"), None);
861 }
862
863 #[test]
864 fn insert_after_invalidate_on_full_cache() {
865 let mut cache = Cache::new(2);
866 cache.insert("a", "alice");
867 cache.insert("b", "bob");
868 assert_eq!(cache.entry_count(), 2);
869
870 cache.invalidate(&"a");
871 assert_eq!(cache.entry_count(), 1);
872
873 cache.insert("c", "cindy");
874 assert_eq!(cache.entry_count(), 2);
875 assert_eq!(cache.get(&"c"), Some(&"cindy"));
876 assert_eq!(cache.get(&"b"), Some(&"bob"));
877 assert_eq!(cache.get(&"a"), None);
878 }
879
880 #[test]
881 fn invalidate_all_panic_safety() {
882 use std::panic::catch_unwind;
883 use std::panic::AssertUnwindSafe;
884 use std::sync::atomic::{AtomicU32, Ordering};
885
886 static DROP_COUNT: AtomicU32 = AtomicU32::new(0);
887
888 struct PanicOnDrop {
889 id: u32,
890 should_panic: bool,
891 }
892
893 impl Drop for PanicOnDrop {
894 fn drop(&mut self) {
895 DROP_COUNT.fetch_add(1, Ordering::Relaxed);
896 if self.should_panic {
897 panic!("intentional panic in drop for id={}", self.id);
898 }
899 }
900 }
901
902 DROP_COUNT.store(0, Ordering::Relaxed);
903 let mut cache = Cache::new(10);
904 cache.insert(
905 1,
906 PanicOnDrop {
907 id: 1,
908 should_panic: false,
909 },
910 );
911 cache.insert(
912 2,
913 PanicOnDrop {
914 id: 2,
915 should_panic: true,
916 },
917 );
918 cache.insert(
919 3,
920 PanicOnDrop {
921 id: 3,
922 should_panic: false,
923 },
924 );
925 assert_eq!(cache.entry_count(), 3);
926
927 let result = catch_unwind(AssertUnwindSafe(|| {
928 cache.invalidate_all();
929 }));
930 assert!(result.is_err());
931
932 assert_eq!(cache.entry_count(), 0);
933 assert_eq!(cache.cache.len(), 0);
934
935 cache.insert(
936 4,
937 PanicOnDrop {
938 id: 4,
939 should_panic: false,
940 },
941 );
942 assert_eq!(cache.entry_count(), 1);
943 assert!(cache.contains_key(&4));
944 }
945
946 #[test]
947 fn test_debug_format() {
948 let mut cache = Cache::new(10);
949 cache.insert('a', "alice");
950 cache.insert('b', "bob");
951 cache.insert('c', "cindy");
952
953 let debug_str = format!("{:?}", cache);
954 assert!(debug_str.starts_with('{'));
955 assert!(debug_str.contains(r#"'a': "alice""#));
956 assert!(debug_str.contains(r#"'b': "bob""#));
957 assert!(debug_str.contains(r#"'c': "cindy""#));
958 assert!(debug_str.ends_with('}'));
959 }
960}