1use super::{
2 deques::Deques, CacheBuilder, Iter, KeyHashDate, ValueEntry,
3};
4use crate::{
5 common::{
6 self,
7 deque::{DeqNode},
8 frequency_sketch::FrequencySketch,
9 CacheRegion,
10 },
11 Policy,
12};
13
14use smallvec::SmallVec;
15use std::{
16 borrow::Borrow,
17 collections::{hash_map::RandomState, HashMap},
18 fmt,
19 hash::{BuildHasher, Hash},
20 ptr::NonNull,
21 rc::Rc,
22};
23
24const EVICTION_BATCH_SIZE: usize = 100;
25
26type CacheStore<K, V, S> = std::collections::HashMap<Rc<K>, ValueEntry<K, V>, S>;
27
28pub struct Cache<K, V, S = RandomState> {
102 max_capacity: Option<u64>,
103 entry_count: u64,
104 cache: CacheStore<K, V, S>,
105 build_hasher: S,
106 deques: Deques<K>,
107 frequency_sketch: FrequencySketch,
108 frequency_sketch_enabled: bool,
109}
110
111impl<K, V, S> fmt::Debug for Cache<K, V, S>
112where
113 K: fmt::Debug + Eq + Hash,
114 V: fmt::Debug,
115 S: BuildHasher + Clone,
117{
118 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119 let mut d_map = f.debug_map();
120
121 for (k, v) in self.iter() {
122 d_map.entry(&k, &v);
123 }
124
125 d_map.finish()
126 }
127}
128
129impl<K, V> Cache<K, V, RandomState>
130where
131 K: Hash + Eq,
132{
133 pub fn new(max_capacity: u64) -> Self {
140 let build_hasher = RandomState::default();
141 Self::with_everything(Some(max_capacity), None, build_hasher)
142 }
143
144 pub fn builder() -> CacheBuilder<K, V, Cache<K, V, RandomState>> {
149 CacheBuilder::default()
150 }
151}
152
153impl<K, V, S> Cache<K, V, S> {
157 pub fn policy(&self) -> Policy {
162 Policy::new(self.max_capacity)
163 }
164
165 pub fn entry_count(&self) -> u64 {
185 self.entry_count
186 }
187
188 pub fn weighted_size(&self) -> u64 {
192 self.entry_count
193 }
194}
195
196impl<K, V, S> Cache<K, V, S>
197where
198 K: Hash + Eq,
199 S: BuildHasher + Clone,
200{
201 pub(crate) fn with_everything(
202 max_capacity: Option<u64>,
203 initial_capacity: Option<usize>,
204 build_hasher: S,
205 ) -> Self {
206 let cache = HashMap::with_capacity_and_hasher(
207 initial_capacity.unwrap_or_default(),
208 build_hasher.clone(),
209 );
210
211 Self {
212 max_capacity,
213 entry_count: 0,
214 cache,
215 build_hasher,
216 deques: Default::default(),
217 frequency_sketch: Default::default(),
218 frequency_sketch_enabled: false,
219 }
220 }
221
222 pub fn contains_key<Q>(&mut self, key: &Q) -> bool
230 where
231 Rc<K>: Borrow<Q>,
232 Q: Hash + Eq + ?Sized,
233 {
234 self.evict_lru_entries();
235 self.cache.contains_key(key)
236 }
237
238 pub fn get<Q>(&mut self, key: &Q) -> Option<&V>
243 where
244 Rc<K>: Borrow<Q>,
245 Q: Hash + Eq + ?Sized,
246 {
247 self.evict_lru_entries();
248 self.frequency_sketch.increment(self.hash(key));
249
250 if let Some(entry) = self.cache.get_mut(key) {
251 Self::record_hit(&mut self.deques, entry);
252 Some(&entry.value)
253 } else {
254 None
255 }
256 }
257
258 pub fn insert(&mut self, key: K, value: V) {
262 self.evict_lru_entries();
263 let policy_weight = 1;
264 let key = Rc::new(key);
265 let entry = ValueEntry::new(value);
266
267 if let Some(old_entry) = self.cache.insert(Rc::clone(&key), entry) {
268 self.handle_update(key, policy_weight, old_entry);
269 } else {
270 let hash = self.hash(&key);
271 self.handle_insert(key, hash, policy_weight);
272 }
273 }
274
275 pub fn invalidate<Q>(&mut self, key: &Q)
280 where
281 Rc<K>: Borrow<Q>,
282 Q: Hash + Eq + ?Sized,
283 {
284 self.evict_lru_entries();
285
286 if let Some(mut entry) = self.cache.remove(key) {
287 self.deques.unlink_ao(&mut entry);
288 }
289 }
290
291 pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
296 where
297 Rc<K>: Borrow<Q>,
298 Q: Hash + Eq + ?Sized,
299 {
300 self.evict_lru_entries();
301
302 if let Some(mut entry) = self.cache.remove(key) {
303 self.deques.unlink_ao(&mut entry);
304 Some(entry.value)
305 } else {
306 None
307 }
308 }
309
310 pub fn invalidate_all(&mut self) {
316 self.cache.clear();
317 self.deques.clear();
318 self.entry_count = 0;
319 }
320
321 #[allow(clippy::needless_collect)]
336 pub fn invalidate_entries_if(&mut self, mut predicate: impl FnMut(&K, &V) -> bool) {
337 let Self { cache, deques, .. } = self;
338
339 let keys_to_invalidate = cache
345 .iter()
346 .filter(|(key, entry)| (predicate)(key, &entry.value))
347 .map(|(key, _)| Rc::clone(key))
348 .collect::<Vec<_>>();
349
350 let mut invalidated = 0u64;
351
352 keys_to_invalidate.into_iter().for_each(|k| {
353 if let Some(mut entry) = cache.remove(&k) {
354 let _weight = entry.policy_weight();
355 deques.unlink_ao(&mut entry);
356 invalidated += 1;
357 }
358 });
359 self.entry_count -= invalidated;
360 }
361
362 pub fn iter(&self) -> Iter<'_, K, V> {
385 Iter::new(self, self.cache.iter())
386 }
387}
388
389impl<K, V, S> Cache<K, V, S>
393where
394 K: Hash + Eq,
395 S: BuildHasher + Clone,
396{
397 #[inline]
398 fn hash<Q>(&self, key: &Q) -> u64
399 where
400 Rc<K>: Borrow<Q>,
401 Q: Hash + Eq + ?Sized,
402 {
403 self.build_hasher.hash_one(key)
404 }
405
406 fn record_hit(deques: &mut Deques<K>, entry: &mut ValueEntry<K, V>) {
407 deques.move_to_back_ao(entry)
408 }
409
410 fn has_enough_capacity(&self, candidate_weight: u32, ws: u64) -> bool {
411 self.max_capacity
412 .map(|limit| ws + candidate_weight as u64 <= limit)
413 .unwrap_or(true)
414 }
415
416 fn weights_to_evict(&self) -> u64 {
417 self.max_capacity
418 .map(|limit| self.entry_count.saturating_sub(limit))
419 .unwrap_or_default()
420 }
421
422 #[inline]
423 fn should_enable_frequency_sketch(&self) -> bool {
424 if self.frequency_sketch_enabled {
425 false
426 } else if let Some(max_cap) = self.max_capacity {
427 self.entry_count >= max_cap / 2
428 } else {
429 false
430 }
431 }
432
433 #[inline]
434 fn enable_frequency_sketch(&mut self) {
435 if let Some(max_cap) = self.max_capacity {
436 self.do_enable_frequency_sketch(max_cap);
437 }
438 }
439
440 #[cfg(test)]
441 fn enable_frequency_sketch_for_testing(&mut self) {
442 if let Some(max_cap) = self.max_capacity {
443 self.do_enable_frequency_sketch(max_cap);
444 }
445 }
446
447 #[inline]
448 fn do_enable_frequency_sketch(&mut self, cache_capacity: u64) {
449 let skt_capacity = common::sketch_capacity(cache_capacity);
450 self.frequency_sketch.ensure_capacity(skt_capacity);
451 self.frequency_sketch_enabled = true;
452 }
453
454 #[inline]
455 fn handle_insert(
456 &mut self,
457 key: Rc<K>,
458 hash: u64,
459 policy_weight: u32,
460 ) {
461 let has_free_space = self.has_enough_capacity(policy_weight, self.entry_count);
462 let (cache, deqs, freq) = (&mut self.cache, &mut self.deques, &self.frequency_sketch);
463
464 if has_free_space {
465 let key = Rc::clone(&key);
467 let entry = cache.get_mut(&key).unwrap();
468 deqs.push_back_ao(
469 CacheRegion::MainProbation,
470 KeyHashDate::new(Rc::clone(&key), hash),
471 entry,
472 );
473 self.entry_count += 1;
474 if self.should_enable_frequency_sketch() {
477 self.enable_frequency_sketch();
478 }
479
480 return;
481 }
482
483 if let Some(max) = self.max_capacity {
484 if policy_weight as u64 > max {
485 cache.remove(&Rc::clone(&key));
487 return;
488 }
489 }
490
491 let mut candidate = EntrySizeAndFrequency::new(policy_weight as u64);
492 candidate.add_frequency(freq, hash);
493
494 match Self::admit(&candidate, cache, deqs, freq) {
495 AdmissionResult::Admitted {
496 victim_nodes,
497 } => {
498 for victim in victim_nodes {
500 let mut vic_entry = cache
502 .remove(unsafe { &victim.as_ref().element.key })
503 .expect("Cannot remove a victim from the hash map");
504 deqs.unlink_ao(&mut vic_entry);
506 self.entry_count -= 1;
508 }
509
510 let entry = cache.get_mut(&key).unwrap();
512 let key = Rc::clone(&key);
513 deqs.push_back_ao(
514 CacheRegion::MainProbation,
515 KeyHashDate::new(Rc::clone(&key), hash),
516 entry,
517 );
518
519 self.entry_count += 1;
520 if self.should_enable_frequency_sketch() {
524 self.enable_frequency_sketch();
525 }
526 }
527 AdmissionResult::Rejected => {
528 cache.remove(&key);
530 }
531 }
532 }
533
534 #[inline]
552 fn admit(
553 candidate: &EntrySizeAndFrequency,
554 _cache: &CacheStore<K, V, S>,
555 deqs: &Deques<K>,
556 freq: &FrequencySketch,
557 ) -> AdmissionResult<K> {
558 let mut victims = EntrySizeAndFrequency::default();
559 let mut victim_nodes = SmallVec::default();
560
561 let mut next_victim = deqs.probation.peek_front_ptr();
563
564 while victims.weight < candidate.weight {
566 if candidate.freq < victims.freq {
567 break;
568 }
569 if let Some(victim) = next_victim.take() {
570 next_victim = DeqNode::next_node_ptr(victim);
571 let vic_elem = &unsafe { victim.as_ref() }.element;
572
573 victims.add_policy_weight();
577 victims.add_frequency(freq, vic_elem.hash);
578 victim_nodes.push(victim);
579 } else {
580 break;
582 }
583 }
584
585 if victims.weight >= candidate.weight && candidate.freq > victims.freq {
591 AdmissionResult::Admitted {
592 victim_nodes,
593 }
594 } else {
595 AdmissionResult::Rejected
596 }
597 }
598
599 fn handle_update(
600 &mut self,
601 key: Rc<K>,
602 policy_weight: u32,
603 old_entry: ValueEntry<K, V>,
604 ) {
605 let entry = self.cache.get_mut(&key).unwrap();
606 entry.replace_deq_nodes_with(old_entry);
607 entry.set_policy_weight(policy_weight);
608
609 let deqs = &mut self.deques;
610 deqs.move_to_back_ao(entry);
611
612 }
615
616 #[inline]
617 fn evict_lru_entries(&mut self) {
618 const DEQ_NAME: &str = "probation";
619
620 let weights_to_evict = self.weights_to_evict();
621 let mut evicted_count = 0u64;
622 let mut evicted_policy_weight = 0u64;
623
624 {
625 let deqs = &mut self.deques;
626 let (probation, cache) =
627 (&mut deqs.probation, &mut self.cache);
628
629 for _ in 0..EVICTION_BATCH_SIZE {
630 if evicted_policy_weight >= weights_to_evict {
631 break;
632 }
633
634 #[allow(clippy::map_clone)]
637 let key = probation
638 .peek_front()
639 .map(|node| Rc::clone(&node.element.key));
640
641 if key.is_none() {
642 break;
643 }
644 let key = key.unwrap();
645
646 if let Some(mut entry) = cache.remove(&key) {
647 let weight = entry.policy_weight();
648 Deques::unlink_ao_from_deque(DEQ_NAME, probation, &mut entry);
649 evicted_count += 1;
650 evicted_policy_weight = evicted_policy_weight.saturating_add(weight as u64);
651 } else {
652 probation.pop_front();
653 }
654 }
655 }
656
657 self.entry_count -= evicted_count;
658 }
660}
661
662#[cfg(test)]
666impl<K, V, S> Cache<K, V, S>
667where
668 K: Hash + Eq,
669 S: BuildHasher + Clone,
670{
671}
672
673#[derive(Default)]
674struct EntrySizeAndFrequency {
675 weight: u64,
676 freq: u32,
677}
678
679impl EntrySizeAndFrequency {
680 fn new(policy_weight: u64) -> Self {
681 Self {
682 weight: policy_weight,
683 ..Default::default()
684 }
685 }
686
687 fn add_policy_weight(&mut self) {
688 self.weight += 1;
689 }
690
691 fn add_frequency(&mut self, freq: &FrequencySketch, hash: u64) {
692 self.freq += freq.frequency(hash) as u32;
693 }
694}
695
696type AoqNode<K> = NonNull<DeqNode<KeyHashDate<K>>>;
698
699enum AdmissionResult<K> {
700 Admitted {
701 victim_nodes: SmallVec<[AoqNode<K>; 8]>,
702 },
703 Rejected,
704}
705
706#[cfg(test)]
712mod tests {
713 use super::Cache;
714
715 #[test]
716 fn basic_single_thread() {
717 let mut cache = Cache::new(3);
718 cache.enable_frequency_sketch_for_testing();
719
720 cache.insert("a", "alice");
721 cache.insert("b", "bob");
722 assert_eq!(cache.get(&"a"), Some(&"alice"));
723 assert!(cache.contains_key(&"a"));
724 assert!(cache.contains_key(&"b"));
725 assert_eq!(cache.get(&"b"), Some(&"bob"));
726 cache.insert("c", "cindy");
729 assert_eq!(cache.get(&"c"), Some(&"cindy"));
730 assert!(cache.contains_key(&"c"));
731 assert!(cache.contains_key(&"a"));
734 assert_eq!(cache.get(&"a"), Some(&"alice"));
735 assert_eq!(cache.get(&"b"), Some(&"bob"));
736 assert!(cache.contains_key(&"b"));
737 cache.insert("d", "david"); assert_eq!(cache.get(&"d"), None); assert!(!cache.contains_key(&"d"));
743
744 cache.insert("d", "david");
745 assert!(!cache.contains_key(&"d"));
746 assert_eq!(cache.get(&"d"), None); cache.insert("d", "dennis");
751 assert_eq!(cache.get(&"a"), Some(&"alice"));
752 assert_eq!(cache.get(&"b"), Some(&"bob"));
753 assert_eq!(cache.get(&"c"), None);
754 assert_eq!(cache.get(&"d"), Some(&"dennis"));
755 assert!(cache.contains_key(&"a"));
756 assert!(cache.contains_key(&"b"));
757 assert!(!cache.contains_key(&"c"));
758 assert!(cache.contains_key(&"d"));
759
760 cache.invalidate(&"b");
761 assert_eq!(cache.get(&"b"), None);
762 assert!(!cache.contains_key(&"b"));
763 }
764
765 #[test]
766 fn invalidate_all() {
767 let mut cache = Cache::new(100);
768 cache.enable_frequency_sketch_for_testing();
769
770 cache.insert("a", "alice");
771 cache.insert("b", "bob");
772 cache.insert("c", "cindy");
773 assert_eq!(cache.get(&"a"), Some(&"alice"));
774 assert_eq!(cache.get(&"b"), Some(&"bob"));
775 assert_eq!(cache.get(&"c"), Some(&"cindy"));
776 assert!(cache.contains_key(&"a"));
777 assert!(cache.contains_key(&"b"));
778 assert!(cache.contains_key(&"c"));
779
780 cache.invalidate_all();
781
782 cache.insert("d", "david");
783
784 assert!(cache.get(&"a").is_none());
785 assert!(cache.get(&"b").is_none());
786 assert!(cache.get(&"c").is_none());
787 assert_eq!(cache.get(&"d"), Some(&"david"));
788 assert!(!cache.contains_key(&"a"));
789 assert!(!cache.contains_key(&"b"));
790 assert!(!cache.contains_key(&"c"));
791 assert!(cache.contains_key(&"d"));
792 }
793
794 #[test]
795 fn invalidate_entries_if() {
796 use std::collections::HashSet;
797
798 let mut cache = Cache::new(100);
799 cache.enable_frequency_sketch_for_testing();
800
801 cache.insert(0, "alice");
802 cache.insert(1, "bob");
803 cache.insert(2, "alex");
804
805 assert_eq!(cache.get(&0), Some(&"alice"));
806 assert_eq!(cache.get(&1), Some(&"bob"));
807 assert_eq!(cache.get(&2), Some(&"alex"));
808 assert!(cache.contains_key(&0));
809 assert!(cache.contains_key(&1));
810 assert!(cache.contains_key(&2));
811
812 let names = ["alice", "alex"].iter().cloned().collect::<HashSet<_>>();
813 cache.invalidate_entries_if(move |_k, &v| names.contains(v));
814
815 cache.insert(3, "alice");
816
817 assert!(cache.get(&0).is_none());
818 assert!(cache.get(&2).is_none());
819 assert_eq!(cache.get(&1), Some(&"bob"));
820 assert_eq!(cache.get(&3), Some(&"alice"));
822
823 assert!(!cache.contains_key(&0));
824 assert!(cache.contains_key(&1));
825 assert!(!cache.contains_key(&2));
826 assert!(cache.contains_key(&3));
827
828 assert_eq!(cache.cache.len(), 2);
829
830 cache.invalidate_entries_if(|_k, &v| v == "alice");
831 cache.invalidate_entries_if(|_k, &v| v == "bob");
832
833 assert!(cache.get(&1).is_none());
834 assert!(cache.get(&3).is_none());
835
836 assert!(!cache.contains_key(&1));
837 assert!(!cache.contains_key(&3));
838
839 assert_eq!(cache.cache.len(), 0);
840 }
841
842 #[cfg_attr(target_pointer_width = "16", ignore)]
843 #[test]
844 fn test_skt_capacity_will_not_overflow() {
845 let pot = |exp| 2u64.pow(exp);
847
848 let ensure_sketch_len = |max_capacity, len, name| {
849 let mut cache = Cache::<u8, u8>::new(max_capacity);
850 cache.enable_frequency_sketch_for_testing();
851 assert_eq!(cache.frequency_sketch.table_len(), len as usize, "{}", name);
852 };
853
854 if cfg!(target_pointer_width = "32") {
855 let pot24 = pot(24);
856 let pot16 = pot(16);
857 ensure_sketch_len(0, 128, "0");
858 ensure_sketch_len(128, 128, "128");
859 ensure_sketch_len(pot16, pot16, "pot16");
860 ensure_sketch_len(pot16 + 1, pot(17), "pot16 + 1");
862 ensure_sketch_len(pot24 - 1, pot24, "pot24 - 1");
864 ensure_sketch_len(pot24, pot24, "pot24");
865 ensure_sketch_len(pot(27), pot24, "pot(27)");
866 ensure_sketch_len(u32::MAX as u64, pot24, "u32::MAX");
867 } else {
868 let pot30 = pot(30);
870 let pot16 = pot(16);
871 ensure_sketch_len(0, 128, "0");
872 ensure_sketch_len(128, 128, "128");
873 ensure_sketch_len(pot16, pot16, "pot16");
874 ensure_sketch_len(pot16 + 1, pot(17), "pot16 + 1");
876
877 if !cfg!(circleci) {
880 ensure_sketch_len(pot30 - 1, pot30, "pot30- 1");
882 ensure_sketch_len(pot30, pot30, "pot30");
883 ensure_sketch_len(u64::MAX, pot30, "u64::MAX");
884 }
885 };
886 }
887
888 #[test]
889 fn test_debug_format() {
890 let mut cache = Cache::new(10);
891 cache.insert('a', "alice");
892 cache.insert('b', "bob");
893 cache.insert('c', "cindy");
894
895 let debug_str = format!("{:?}", cache);
896 assert!(debug_str.starts_with('{'));
897 assert!(debug_str.contains(r#"'a': "alice""#));
898 assert!(debug_str.contains(r#"'b': "bob""#));
899 assert!(debug_str.contains(r#"'c': "cindy""#));
900 assert!(debug_str.ends_with('}'));
901 }
902}