use super::{deques::Deques, CacheBuilder, Iter, KeyHashDate, ValueEntry};
use crate::{
common::{self, deque::DeqNode, frequency_sketch::FrequencySketch, CacheRegion},
Policy,
};
use std::{
borrow::Borrow,
collections::{hash_map::RandomState, HashMap},
fmt,
hash::{BuildHasher, Hash},
ptr::NonNull,
rc::Rc,
};
const EVICTION_BATCH_SIZE: usize = 100;
type CacheStore<K, V, S> = std::collections::HashMap<Rc<K>, ValueEntry<K, V>, S>;
pub struct Cache<K, V, S = RandomState> {
max_capacity: Option<u64>,
entry_count: u64,
cache: CacheStore<K, V, S>,
build_hasher: S,
deques: Deques<K>,
frequency_sketch: FrequencySketch,
frequency_sketch_enabled: bool,
}
impl<K, V, S> fmt::Debug for Cache<K, V, S>
where
K: fmt::Debug + Eq + Hash,
V: fmt::Debug,
S: BuildHasher + Clone,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut d_map = f.debug_map();
for (k, v) in self.iter() {
d_map.entry(&k, &v);
}
d_map.finish()
}
}
impl<K, V> Cache<K, V, RandomState>
where
K: Hash + Eq,
{
pub fn new(max_capacity: u64) -> Self {
let build_hasher = RandomState::default();
Self::with_everything(Some(max_capacity), None, build_hasher)
}
pub fn builder() -> CacheBuilder<K, V, Cache<K, V, RandomState>> {
CacheBuilder::default()
}
}
impl<K, V, S> Cache<K, V, S> {
pub fn policy(&self) -> Policy {
Policy::new(self.max_capacity)
}
pub fn entry_count(&self) -> u64 {
self.entry_count
}
pub fn weighted_size(&self) -> u64 {
self.entry_count
}
}
impl<K, V, S> Cache<K, V, S>
where
K: Hash + Eq,
S: BuildHasher + Clone,
{
pub(crate) fn with_everything(
max_capacity: Option<u64>,
initial_capacity: Option<usize>,
build_hasher: S,
) -> Self {
let cache = HashMap::with_capacity_and_hasher(
initial_capacity.unwrap_or_default(),
build_hasher.clone(),
);
Self {
max_capacity,
entry_count: 0,
cache,
build_hasher,
deques: Default::default(),
frequency_sketch: Default::default(),
frequency_sketch_enabled: false,
}
}
pub fn contains_key<Q>(&mut self, key: &Q) -> bool
where
Rc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.cache.contains_key(key)
}
pub fn get<Q>(&mut self, key: &Q) -> Option<&V>
where
Rc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
let hash = self.hash(key);
if let Some(entry) = self.cache.get_mut(key) {
self.frequency_sketch.increment(hash);
Self::record_hit(&mut self.deques, entry);
Some(&entry.value)
} else {
None
}
}
pub fn insert(&mut self, key: K, value: V) {
self.evict_lru_entries();
let policy_weight = 1;
let key = Rc::new(key);
let entry = ValueEntry::new(value);
if let Some(old_entry) = self.cache.insert(Rc::clone(&key), entry) {
self.handle_update(key, policy_weight, old_entry);
} else {
let hash = self.hash(&key);
self.handle_insert(key, hash, policy_weight);
}
}
pub fn invalidate<Q>(&mut self, key: &Q)
where
Rc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.evict_lru_entries();
if let Some(mut entry) = self.cache.remove(key) {
self.deques.unlink_ao(&mut entry);
self.entry_count -= 1;
}
}
pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
where
Rc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.evict_lru_entries();
if let Some(mut entry) = self.cache.remove(key) {
self.deques.unlink_ao(&mut entry);
self.entry_count -= 1;
Some(entry.value)
} else {
None
}
}
pub fn invalidate_all(&mut self) {
let old_capacity = self.cache.capacity();
let old_cache = std::mem::replace(
&mut self.cache,
HashMap::with_hasher(self.build_hasher.clone()),
);
self.deques.clear();
self.entry_count = 0;
drop(old_cache);
let _ = self.cache.try_reserve(old_capacity);
}
#[allow(clippy::needless_collect)]
pub fn invalidate_entries_if(&mut self, mut predicate: impl FnMut(&K, &V) -> bool) {
let Self { cache, deques, .. } = self;
let keys_to_invalidate = cache
.iter()
.filter(|(key, entry)| (predicate)(key, &entry.value))
.map(|(key, _)| Rc::clone(key))
.collect::<Vec<_>>();
let mut invalidated = 0u64;
keys_to_invalidate.into_iter().for_each(|k| {
if let Some(mut entry) = cache.remove(&k) {
let _weight = entry.policy_weight();
deques.unlink_ao(&mut entry);
invalidated += 1;
}
});
self.entry_count -= invalidated;
}
pub fn iter(&self) -> Iter<'_, K, V> {
Iter::new(self, self.cache.iter())
}
}
impl<K, V, S> Cache<K, V, S>
where
K: Hash + Eq,
S: BuildHasher + Clone,
{
#[inline]
fn hash<Q>(&self, key: &Q) -> u64
where
Rc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.build_hasher.hash_one(key)
}
fn record_hit(deques: &mut Deques<K>, entry: &mut ValueEntry<K, V>) {
deques.move_to_back_ao(entry)
}
fn has_enough_capacity(&self, candidate_weight: u32, ws: u64) -> bool {
self.max_capacity
.map(|limit| ws + candidate_weight as u64 <= limit)
.unwrap_or(true)
}
fn weights_to_evict(&self) -> u64 {
self.max_capacity
.map(|limit| self.entry_count.saturating_sub(limit))
.unwrap_or_default()
}
#[inline]
fn should_enable_frequency_sketch(&self) -> bool {
if self.frequency_sketch_enabled {
false
} else if let Some(max_cap) = self.max_capacity {
self.entry_count >= max_cap / 2
} else {
false
}
}
#[inline]
fn enable_frequency_sketch(&mut self) {
if let Some(max_cap) = self.max_capacity {
self.do_enable_frequency_sketch(max_cap);
}
}
#[cfg(test)]
fn enable_frequency_sketch_for_testing(&mut self) {
if let Some(max_cap) = self.max_capacity {
self.do_enable_frequency_sketch(max_cap);
}
}
#[inline]
fn do_enable_frequency_sketch(&mut self, cache_capacity: u64) {
let skt_capacity = common::sketch_capacity(cache_capacity);
self.frequency_sketch.ensure_capacity(skt_capacity);
self.frequency_sketch_enabled = true;
}
#[inline]
fn handle_insert(&mut self, key: Rc<K>, hash: u64, policy_weight: u32) {
debug_assert_eq!(policy_weight, 1);
let has_free_space = self.has_enough_capacity(policy_weight, self.entry_count);
let (cache, deqs, freq) = (&mut self.cache, &mut self.deques, &self.frequency_sketch);
if has_free_space {
let key = Rc::clone(&key);
let entry = cache.get_mut(&key).unwrap();
deqs.push_back_ao(
CacheRegion::MainProbation,
KeyHashDate::new(Rc::clone(&key), hash),
entry,
);
self.entry_count += 1;
if self.should_enable_frequency_sketch() {
self.enable_frequency_sketch();
}
return;
}
if let Some(max) = self.max_capacity {
if policy_weight as u64 > max {
cache.remove(&Rc::clone(&key));
return;
}
}
let candidate_freq = freq.frequency(hash);
match Self::admit(candidate_freq, deqs, freq) {
AdmissionResult::Admitted { victim_node } => {
let mut vic_entry = cache
.remove(unsafe { &victim_node.as_ref().element.key })
.expect("Cannot remove a victim from the hash map");
deqs.unlink_ao(&mut vic_entry);
self.entry_count -= 1;
let entry = cache.get_mut(&key).unwrap();
let key = Rc::clone(&key);
deqs.push_back_ao(
CacheRegion::MainProbation,
KeyHashDate::new(Rc::clone(&key), hash),
entry,
);
self.entry_count += 1;
if self.should_enable_frequency_sketch() {
self.enable_frequency_sketch();
}
}
AdmissionResult::Rejected => {
cache.remove(&key);
}
}
}
#[inline]
fn admit(candidate_freq: u8, deqs: &Deques<K>, freq: &FrequencySketch) -> AdmissionResult<K> {
let Some(victim_node) = deqs.probation.peek_front_ptr() else {
return AdmissionResult::Rejected;
};
let victim_hash = unsafe { victim_node.as_ref() }.element.hash;
let victim_freq = freq.frequency(victim_hash);
if candidate_freq > victim_freq {
AdmissionResult::Admitted { victim_node }
} else {
AdmissionResult::Rejected
}
}
fn handle_update(&mut self, key: Rc<K>, policy_weight: u32, old_entry: ValueEntry<K, V>) {
let entry = self.cache.get_mut(&key).unwrap();
entry.replace_deq_nodes_with(old_entry);
entry.set_policy_weight(policy_weight);
let deqs = &mut self.deques;
deqs.move_to_back_ao(entry);
}
#[inline]
fn evict_lru_entries(&mut self) {
const DEQ_NAME: &str = "probation";
let weights_to_evict = self.weights_to_evict();
let mut evicted_count = 0u64;
let mut evicted_policy_weight = 0u64;
{
let deqs = &mut self.deques;
let (probation, cache) = (&mut deqs.probation, &mut self.cache);
for _ in 0..EVICTION_BATCH_SIZE {
if evicted_policy_weight >= weights_to_evict {
break;
}
#[allow(clippy::map_clone)]
let key = probation
.peek_front()
.map(|node| Rc::clone(&node.element.key));
if key.is_none() {
break;
}
let key = key.unwrap();
if let Some(mut entry) = cache.remove(&key) {
let weight = entry.policy_weight();
Deques::unlink_ao_from_deque(DEQ_NAME, probation, &mut entry);
evicted_count += 1;
evicted_policy_weight = evicted_policy_weight.saturating_add(weight as u64);
} else {
probation.pop_front();
}
}
}
self.entry_count -= evicted_count;
}
}
#[cfg(test)]
impl<K, V, S> Cache<K, V, S>
where
K: Hash + Eq,
S: BuildHasher + Clone,
{
}
type AoqNode<K> = NonNull<DeqNode<KeyHashDate<K>>>;
enum AdmissionResult<K> {
Admitted { victim_node: AoqNode<K> },
Rejected,
}
#[cfg(test)]
mod tests {
use super::Cache;
#[test]
fn basic_single_thread() {
let mut cache = Cache::new(3);
cache.enable_frequency_sketch_for_testing();
cache.insert("a", "alice");
cache.insert("b", "bob");
assert_eq!(cache.get(&"a"), Some(&"alice"));
assert!(cache.contains_key(&"a"));
assert!(cache.contains_key(&"b"));
assert_eq!(cache.get(&"b"), Some(&"bob"));
cache.insert("c", "cindy");
assert_eq!(cache.get(&"c"), Some(&"cindy"));
assert!(cache.contains_key(&"c"));
assert!(cache.contains_key(&"a"));
assert_eq!(cache.get(&"a"), Some(&"alice"));
assert_eq!(cache.get(&"b"), Some(&"bob"));
assert!(cache.contains_key(&"b"));
cache.insert("d", "david");
assert_eq!(cache.get(&"d"), None);
assert!(!cache.contains_key(&"d"));
cache.insert("d", "david");
assert!(!cache.contains_key(&"d"));
assert_eq!(cache.get(&"d"), None);
cache.invalidate(&"c");
cache.insert("d", "david");
assert_eq!(cache.get(&"d"), Some(&"david")); assert_eq!(cache.get(&"d"), Some(&"david")); cache.invalidate(&"d");
cache.insert("c", "cindy");
assert_eq!(cache.get(&"a"), Some(&"alice"));
assert_eq!(cache.get(&"b"), Some(&"bob"));
cache.insert("d", "dennis");
assert_eq!(cache.get(&"a"), Some(&"alice"));
assert_eq!(cache.get(&"b"), Some(&"bob"));
assert_eq!(cache.get(&"c"), None);
assert_eq!(cache.get(&"d"), Some(&"dennis"));
assert!(cache.contains_key(&"a"));
assert!(cache.contains_key(&"b"));
assert!(!cache.contains_key(&"c"));
assert!(cache.contains_key(&"d"));
cache.invalidate(&"b");
assert_eq!(cache.get(&"b"), None);
assert!(!cache.contains_key(&"b"));
}
#[test]
fn invalidate_all() {
let mut cache = Cache::new(100);
cache.enable_frequency_sketch_for_testing();
cache.insert("a", "alice");
cache.insert("b", "bob");
cache.insert("c", "cindy");
assert_eq!(cache.get(&"a"), Some(&"alice"));
assert_eq!(cache.get(&"b"), Some(&"bob"));
assert_eq!(cache.get(&"c"), Some(&"cindy"));
assert!(cache.contains_key(&"a"));
assert!(cache.contains_key(&"b"));
assert!(cache.contains_key(&"c"));
cache.invalidate_all();
cache.insert("d", "david");
assert!(cache.get(&"a").is_none());
assert!(cache.get(&"b").is_none());
assert!(cache.get(&"c").is_none());
assert_eq!(cache.get(&"d"), Some(&"david"));
assert!(!cache.contains_key(&"a"));
assert!(!cache.contains_key(&"b"));
assert!(!cache.contains_key(&"c"));
assert!(cache.contains_key(&"d"));
}
#[test]
fn invalidate_entries_if() {
use std::collections::HashSet;
let mut cache = Cache::new(100);
cache.enable_frequency_sketch_for_testing();
cache.insert(0, "alice");
cache.insert(1, "bob");
cache.insert(2, "alex");
assert_eq!(cache.get(&0), Some(&"alice"));
assert_eq!(cache.get(&1), Some(&"bob"));
assert_eq!(cache.get(&2), Some(&"alex"));
assert!(cache.contains_key(&0));
assert!(cache.contains_key(&1));
assert!(cache.contains_key(&2));
let names = ["alice", "alex"].iter().cloned().collect::<HashSet<_>>();
cache.invalidate_entries_if(move |_k, &v| names.contains(v));
cache.insert(3, "alice");
assert!(cache.get(&0).is_none());
assert!(cache.get(&2).is_none());
assert_eq!(cache.get(&1), Some(&"bob"));
assert_eq!(cache.get(&3), Some(&"alice"));
assert!(!cache.contains_key(&0));
assert!(cache.contains_key(&1));
assert!(!cache.contains_key(&2));
assert!(cache.contains_key(&3));
assert_eq!(cache.cache.len(), 2);
cache.invalidate_entries_if(|_k, &v| v == "alice");
cache.invalidate_entries_if(|_k, &v| v == "bob");
assert!(cache.get(&1).is_none());
assert!(cache.get(&3).is_none());
assert!(!cache.contains_key(&1));
assert!(!cache.contains_key(&3));
assert_eq!(cache.cache.len(), 0);
}
#[cfg_attr(target_pointer_width = "16", ignore)]
#[test]
fn test_skt_capacity_will_not_overflow() {
let pot = |exp| 2u64.pow(exp);
let ensure_sketch_len = |max_capacity, len, name| {
let mut cache = Cache::<u8, u8>::new(max_capacity);
cache.enable_frequency_sketch_for_testing();
assert_eq!(cache.frequency_sketch.table_len(), len as usize, "{}", name);
};
if cfg!(target_pointer_width = "32") {
let pot24 = pot(24);
let pot16 = pot(16);
ensure_sketch_len(0, 128, "0");
ensure_sketch_len(128, 128, "128");
ensure_sketch_len(pot16, pot16, "pot16");
ensure_sketch_len(pot16 + 1, pot(17), "pot16 + 1");
ensure_sketch_len(pot24 - 1, pot24, "pot24 - 1");
ensure_sketch_len(pot24, pot24, "pot24");
ensure_sketch_len(pot(27), pot24, "pot(27)");
ensure_sketch_len(u32::MAX as u64, pot24, "u32::MAX");
} else {
let pot30 = pot(30);
let pot16 = pot(16);
ensure_sketch_len(0, 128, "0");
ensure_sketch_len(128, 128, "128");
ensure_sketch_len(pot16, pot16, "pot16");
ensure_sketch_len(pot16 + 1, pot(17), "pot16 + 1");
if !cfg!(circleci) {
ensure_sketch_len(pot30 - 1, pot30, "pot30- 1");
ensure_sketch_len(pot30, pot30, "pot30");
ensure_sketch_len(u64::MAX, pot30, "u64::MAX");
}
};
}
#[test]
fn remove_decrements_entry_count() {
let mut cache = Cache::new(3);
cache.insert("a", "alice");
cache.insert("b", "bob");
assert_eq!(cache.entry_count(), 2);
let removed = cache.remove(&"a");
assert_eq!(removed, Some("alice"));
assert_eq!(cache.entry_count(), 1);
cache.remove(&"nonexistent");
assert_eq!(cache.entry_count(), 1);
cache.remove(&"b");
assert_eq!(cache.entry_count(), 0);
}
#[test]
fn invalidate_decrements_entry_count() {
let mut cache = Cache::new(3);
cache.insert("a", "alice");
cache.insert("b", "bob");
assert_eq!(cache.entry_count(), 2);
cache.invalidate(&"a");
assert_eq!(cache.entry_count(), 1);
cache.invalidate(&"nonexistent");
assert_eq!(cache.entry_count(), 1);
cache.invalidate(&"b");
assert_eq!(cache.entry_count(), 0);
}
#[test]
fn insert_after_remove_on_full_cache() {
let mut cache = Cache::new(2);
cache.insert("a", "alice");
cache.insert("b", "bob");
assert_eq!(cache.entry_count(), 2);
cache.remove(&"a");
assert_eq!(cache.entry_count(), 1);
cache.insert("c", "cindy");
assert_eq!(cache.entry_count(), 2);
assert_eq!(cache.get(&"c"), Some(&"cindy"));
assert_eq!(cache.get(&"b"), Some(&"bob"));
assert_eq!(cache.get(&"a"), None);
}
#[test]
fn insert_after_invalidate_on_full_cache() {
let mut cache = Cache::new(2);
cache.insert("a", "alice");
cache.insert("b", "bob");
assert_eq!(cache.entry_count(), 2);
cache.invalidate(&"a");
assert_eq!(cache.entry_count(), 1);
cache.insert("c", "cindy");
assert_eq!(cache.entry_count(), 2);
assert_eq!(cache.get(&"c"), Some(&"cindy"));
assert_eq!(cache.get(&"b"), Some(&"bob"));
assert_eq!(cache.get(&"a"), None);
}
#[test]
fn invalidate_all_panic_safety() {
use std::panic::catch_unwind;
use std::panic::AssertUnwindSafe;
use std::sync::atomic::{AtomicU32, Ordering};
static DROP_COUNT: AtomicU32 = AtomicU32::new(0);
struct PanicOnDrop {
id: u32,
should_panic: bool,
}
impl Drop for PanicOnDrop {
fn drop(&mut self) {
DROP_COUNT.fetch_add(1, Ordering::Relaxed);
if self.should_panic {
panic!("intentional panic in drop for id={}", self.id);
}
}
}
DROP_COUNT.store(0, Ordering::Relaxed);
let mut cache = Cache::new(10);
cache.insert(
1,
PanicOnDrop {
id: 1,
should_panic: false,
},
);
cache.insert(
2,
PanicOnDrop {
id: 2,
should_panic: true,
},
);
cache.insert(
3,
PanicOnDrop {
id: 3,
should_panic: false,
},
);
assert_eq!(cache.entry_count(), 3);
let result = catch_unwind(AssertUnwindSafe(|| {
cache.invalidate_all();
}));
assert!(result.is_err());
assert_eq!(cache.entry_count(), 0);
assert_eq!(cache.cache.len(), 0);
cache.insert(
4,
PanicOnDrop {
id: 4,
should_panic: false,
},
);
assert_eq!(cache.entry_count(), 1);
assert!(cache.contains_key(&4));
}
#[test]
fn frequency_sketch_only_increments_on_hit() {
let mut cache = Cache::new(100);
cache.enable_frequency_sketch_for_testing();
let miss_key = "missing";
for _ in 0..10 {
assert_eq!(cache.get(&miss_key), None);
}
let miss_hash = cache.hash(&miss_key);
assert_eq!(cache.frequency_sketch.frequency(miss_hash), 0);
let hit_key = "present";
cache.insert(hit_key, "value");
for _ in 0..5 {
assert_eq!(cache.get(&hit_key), Some(&"value"));
}
let hit_hash = cache.hash(&hit_key);
assert_eq!(cache.frequency_sketch.frequency(hit_hash), 5);
}
#[test]
fn test_debug_format() {
let mut cache = Cache::new(10);
cache.insert('a', "alice");
cache.insert('b', "bob");
cache.insert('c', "cindy");
let debug_str = format!("{:?}", cache);
assert!(debug_str.starts_with('{'));
assert!(debug_str.contains(r#"'a': "alice""#));
assert!(debug_str.contains(r#"'b': "bob""#));
assert!(debug_str.contains(r#"'c': "cindy""#));
assert!(debug_str.ends_with('}'));
}
}