use super::{
base_cache::{BaseCache, HouseKeeperArc},
CacheBuilder, ConcurrentCacheExt, EntryRef, Iter,
};
use crate::{
common::concurrent::{
constants::{MAX_SYNC_REPEATS, WRITE_RETRY_INTERVAL_MICROS},
housekeeper::InnerSync,
Weigher, WriteOp,
},
Policy,
};
use crossbeam_channel::{Sender, TrySendError};
use std::{
borrow::Borrow,
collections::hash_map::RandomState,
fmt,
hash::{BuildHasher, Hash},
sync::Arc,
time::Duration,
};
pub struct Cache<K, V, S = RandomState> {
base: BaseCache<K, V, S>,
}
#[allow(clippy::non_send_fields_in_send_ty)]
unsafe impl<K, V, S> Send for Cache<K, V, S>
where
K: Send + Sync,
V: Send + Sync,
S: Send,
{
}
unsafe impl<K, V, S> Sync for Cache<K, V, S>
where
K: Send + Sync,
V: Send + Sync,
S: Sync,
{
}
impl<K, V, S> Clone for Cache<K, V, S> {
fn clone(&self) -> Self {
Self {
base: self.base.clone(),
}
}
}
impl<K, V, S> fmt::Debug for Cache<K, V, S>
where
K: Eq + Hash + fmt::Debug,
V: fmt::Debug,
S: BuildHasher + Clone,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut d_map = f.debug_map();
for r in self.iter() {
let (k, v) = r.pair();
d_map.entry(k, v);
}
d_map.finish()
}
}
impl<K, V> Cache<K, V, RandomState>
where
K: Hash + Eq + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
pub fn new(max_capacity: u64) -> Self {
let build_hasher = RandomState::default();
Self::with_everything(Some(max_capacity), None, build_hasher, None, None, None)
}
pub fn builder() -> CacheBuilder<K, V, Cache<K, V, RandomState>> {
CacheBuilder::default()
}
}
impl<K, V, S> Cache<K, V, S> {
pub fn policy(&self) -> Policy {
self.base.policy()
}
pub fn entry_count(&self) -> u64 {
self.base.entry_count()
}
pub fn weighted_size(&self) -> u64 {
self.base.weighted_size()
}
}
impl<K, V, S> Cache<K, V, S>
where
K: Hash + Eq + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
S: BuildHasher + Clone + Send + Sync + 'static,
{
pub(crate) fn with_everything(
max_capacity: Option<u64>,
initial_capacity: Option<usize>,
build_hasher: S,
weigher: Option<Weigher<K, V>>,
time_to_live: Option<Duration>,
time_to_idle: Option<Duration>,
) -> Self {
Self {
base: BaseCache::new(
max_capacity,
initial_capacity,
build_hasher,
weigher,
time_to_live,
time_to_idle,
),
}
}
pub fn contains_key<Q>(&self, key: &Q) -> bool
where
Arc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.base.contains_key(key)
}
pub fn get<Q>(&self, key: &Q) -> Option<V>
where
Arc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.base.get_with_hash(key, self.base.hash(key))
}
#[doc(hidden)]
#[deprecated(since = "0.8.0", note = "Replaced with `get`")]
pub fn get_if_present<Q>(&self, key: &Q) -> Option<V>
where
Arc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.get(key)
}
pub fn insert(&self, key: K, value: V) {
let hash = self.base.hash(&key);
let key = Arc::new(key);
self.insert_with_hash(key, hash, value)
}
pub(crate) fn insert_with_hash(&self, key: Arc<K>, hash: u64, value: V) {
let op = self.base.do_insert_with_hash(key, hash, value);
let hk = self.base.housekeeper.as_ref();
Self::schedule_write_op(self.base.inner.as_ref(), &self.base.write_op_ch, op, hk)
.expect("Failed to insert");
}
pub fn invalidate<Q>(&self, key: &Q)
where
Arc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
if let Some(kv) = self.base.remove_entry(key) {
let op = WriteOp::Remove(kv);
let hk = self.base.housekeeper.as_ref();
Self::schedule_write_op(self.base.inner.as_ref(), &self.base.write_op_ch, op, hk)
.expect("Failed to remove");
}
}
pub fn invalidate_all(&self) {
self.base.invalidate_all();
}
}
impl<'a, K, V, S> Cache<K, V, S>
where
K: 'a + Eq + Hash,
V: 'a,
S: BuildHasher + Clone,
{
pub fn iter(&self) -> Iter<'_, K, V, S> {
self.base.iter()
}
}
impl<K, V, S> ConcurrentCacheExt<K, V> for Cache<K, V, S>
where
K: Hash + Eq + Send + Sync + 'static,
V: Send + Sync + 'static,
S: BuildHasher + Clone + Send + Sync + 'static,
{
fn sync(&self) {
self.base.inner.sync(MAX_SYNC_REPEATS);
}
}
impl<'a, K, V, S> IntoIterator for &'a Cache<K, V, S>
where
K: 'a + Eq + Hash,
V: 'a,
S: BuildHasher + Clone,
{
type Item = EntryRef<'a, K, V, S>;
type IntoIter = Iter<'a, K, V, S>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<K, V, S> Cache<K, V, S>
where
K: Hash + Eq + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
S: BuildHasher + Clone + Send + Sync + 'static,
{
#[inline]
fn schedule_write_op(
inner: &impl InnerSync,
ch: &Sender<WriteOp<K, V>>,
op: WriteOp<K, V>,
housekeeper: Option<&HouseKeeperArc<K, V, S>>,
) -> Result<(), TrySendError<WriteOp<K, V>>> {
let mut op = op;
loop {
BaseCache::apply_reads_writes_if_needed(inner, ch, housekeeper);
match ch.try_send(op) {
Ok(()) => break,
Err(TrySendError::Full(op1)) => {
op = op1;
std::thread::sleep(Duration::from_micros(WRITE_RETRY_INTERVAL_MICROS));
}
Err(e @ TrySendError::Disconnected(_)) => return Err(e),
}
}
Ok(())
}
}
#[cfg(test)]
impl<K, V, S> Cache<K, V, S>
where
K: Hash + Eq + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
S: BuildHasher + Clone + Send + Sync + 'static,
{
pub(crate) fn is_table_empty(&self) -> bool {
self.entry_count() == 0
}
pub(crate) fn reconfigure_for_testing(&mut self) {
self.base.reconfigure_for_testing();
}
pub(crate) fn set_expiration_clock(&self, clock: Option<crate::common::time::Clock>) {
self.base.set_expiration_clock(clock);
}
}
#[cfg(test)]
mod tests {
use super::{Cache, ConcurrentCacheExt};
use crate::common::time::Clock;
use std::{sync::Arc, time::Duration};
#[test]
fn basic_single_thread() {
let mut cache = Cache::new(3);
cache.reconfigure_for_testing();
let cache = cache;
cache.insert("a", "alice");
cache.insert("b", "bob");
assert_eq!(cache.get(&"a"), Some("alice"));
assert!(cache.contains_key(&"a"));
assert!(cache.contains_key(&"b"));
assert_eq!(cache.get(&"b"), Some("bob"));
cache.sync();
cache.insert("c", "cindy");
assert_eq!(cache.get(&"c"), Some("cindy"));
assert!(cache.contains_key(&"c"));
cache.sync();
assert!(cache.contains_key(&"a"));
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
assert!(cache.contains_key(&"b"));
cache.sync();
cache.insert("d", "david"); cache.sync();
assert_eq!(cache.get(&"d"), None); assert!(!cache.contains_key(&"d"));
cache.insert("d", "david");
cache.sync();
assert!(!cache.contains_key(&"d"));
assert_eq!(cache.get(&"d"), None);
cache.insert("d", "dennis");
cache.sync();
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
assert_eq!(cache.get(&"c"), None);
assert_eq!(cache.get(&"d"), Some("dennis"));
assert!(cache.contains_key(&"a"));
assert!(cache.contains_key(&"b"));
assert!(!cache.contains_key(&"c"));
assert!(cache.contains_key(&"d"));
cache.invalidate(&"b");
assert_eq!(cache.get(&"b"), None);
assert!(!cache.contains_key(&"b"));
}
#[test]
fn size_aware_eviction() {
let weigher = |_k: &&str, v: &(&str, u32)| v.1;
let alice = ("alice", 10);
let bob = ("bob", 15);
let bill = ("bill", 20);
let cindy = ("cindy", 5);
let david = ("david", 15);
let dennis = ("dennis", 15);
let mut cache = Cache::builder().max_capacity(31).weigher(weigher).build();
cache.reconfigure_for_testing();
let cache = cache;
cache.insert("a", alice);
cache.insert("b", bob);
assert_eq!(cache.get(&"a"), Some(alice));
assert!(cache.contains_key(&"a"));
assert!(cache.contains_key(&"b"));
assert_eq!(cache.get(&"b"), Some(bob));
cache.sync();
cache.insert("c", cindy);
assert_eq!(cache.get(&"c"), Some(cindy));
assert!(cache.contains_key(&"c"));
cache.sync();
assert!(cache.contains_key(&"a"));
assert_eq!(cache.get(&"a"), Some(alice));
assert_eq!(cache.get(&"b"), Some(bob));
assert!(cache.contains_key(&"b"));
cache.sync();
cache.insert("d", david); cache.sync();
assert_eq!(cache.get(&"d"), None); assert!(!cache.contains_key(&"d"));
cache.insert("d", david);
cache.sync();
assert!(!cache.contains_key(&"d"));
assert_eq!(cache.get(&"d"), None);
cache.insert("d", david);
cache.sync();
assert_eq!(cache.get(&"d"), None); assert!(!cache.contains_key(&"d"));
cache.insert("d", david);
cache.sync();
assert!(!cache.contains_key(&"d"));
assert_eq!(cache.get(&"d"), None);
cache.insert("d", dennis);
cache.sync();
assert_eq!(cache.get(&"a"), None);
assert_eq!(cache.get(&"b"), Some(bob));
assert_eq!(cache.get(&"c"), None);
assert_eq!(cache.get(&"d"), Some(dennis));
assert!(!cache.contains_key(&"a"));
assert!(cache.contains_key(&"b"));
assert!(!cache.contains_key(&"c"));
assert!(cache.contains_key(&"d"));
cache.insert("b", bill);
cache.sync();
assert_eq!(cache.get(&"b"), Some(bill));
assert_eq!(cache.get(&"d"), None);
assert!(cache.contains_key(&"b"));
assert!(!cache.contains_key(&"d"));
cache.insert("a", alice);
cache.insert("b", bob);
cache.sync();
assert_eq!(cache.get(&"a"), Some(alice));
assert_eq!(cache.get(&"b"), Some(bob));
assert_eq!(cache.get(&"d"), None);
assert!(cache.contains_key(&"a"));
assert!(cache.contains_key(&"b"));
assert!(!cache.contains_key(&"d"));
assert_eq!(cache.entry_count(), 2);
assert_eq!(cache.weighted_size(), 25);
}
#[test]
fn basic_multi_threads() {
let num_threads = 4;
let cache = Cache::new(100);
#[allow(clippy::needless_collect)]
let handles = (0..num_threads)
.map(|id| {
let cache = cache.clone();
std::thread::spawn(move || {
cache.insert(10, format!("{}-100", id));
cache.get(&10);
cache.insert(20, format!("{}-200", id));
cache.invalidate(&10);
})
})
.collect::<Vec<_>>();
handles.into_iter().for_each(|h| h.join().expect("Failed"));
assert!(cache.get(&10).is_none());
assert!(cache.get(&20).is_some());
assert!(!cache.contains_key(&10));
assert!(cache.contains_key(&20));
}
#[test]
fn invalidate_all() {
let mut cache = Cache::new(100);
cache.reconfigure_for_testing();
let cache = cache;
cache.insert("a", "alice");
cache.insert("b", "bob");
cache.insert("c", "cindy");
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
assert_eq!(cache.get(&"c"), Some("cindy"));
assert!(cache.contains_key(&"a"));
assert!(cache.contains_key(&"b"));
assert!(cache.contains_key(&"c"));
cache.invalidate_all();
cache.sync();
cache.insert("d", "david");
cache.sync();
assert!(cache.get(&"a").is_none());
assert!(cache.get(&"b").is_none());
assert!(cache.get(&"c").is_none());
assert_eq!(cache.get(&"d"), Some("david"));
assert!(!cache.contains_key(&"a"));
assert!(!cache.contains_key(&"b"));
assert!(!cache.contains_key(&"c"));
assert!(cache.contains_key(&"d"));
}
#[test]
fn time_to_live() {
let mut cache = Cache::builder()
.max_capacity(100)
.time_to_live(Duration::from_secs(10))
.build();
cache.reconfigure_for_testing();
let (clock, mock) = Clock::mock();
cache.set_expiration_clock(Some(clock));
let cache = cache;
cache.insert("a", "alice");
cache.sync();
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"a"), Some("alice"));
assert!(cache.contains_key(&"a"));
mock.increment(Duration::from_secs(5)); assert_eq!(cache.get(&"a"), None);
assert!(!cache.contains_key(&"a"));
assert_eq!(cache.iter().count(), 0);
cache.sync();
assert!(cache.is_table_empty());
cache.insert("b", "bob");
cache.sync();
assert_eq!(cache.entry_count(), 1);
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"b"), Some("bob"));
assert!(cache.contains_key(&"b"));
assert_eq!(cache.entry_count(), 1);
cache.insert("b", "bill");
cache.sync();
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"b"), Some("bill"));
assert!(cache.contains_key(&"b"));
assert_eq!(cache.entry_count(), 1);
mock.increment(Duration::from_secs(5)); assert_eq!(cache.get(&"a"), None);
assert_eq!(cache.get(&"b"), None);
assert!(!cache.contains_key(&"a"));
assert!(!cache.contains_key(&"b"));
assert_eq!(cache.iter().count(), 0);
cache.sync();
assert!(cache.is_table_empty());
}
#[test]
fn time_to_idle() {
let mut cache = Cache::builder()
.max_capacity(100)
.time_to_idle(Duration::from_secs(10))
.build();
cache.reconfigure_for_testing();
let (clock, mock) = Clock::mock();
cache.set_expiration_clock(Some(clock));
let cache = cache;
cache.insert("a", "alice");
cache.sync();
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"a"), Some("alice"));
mock.increment(Duration::from_secs(5)); cache.sync();
cache.insert("b", "bob");
cache.sync();
assert_eq!(cache.entry_count(), 2);
mock.increment(Duration::from_secs(2)); cache.sync();
assert!(cache.contains_key(&"a"));
assert!(cache.contains_key(&"b"));
cache.sync();
assert_eq!(cache.entry_count(), 2);
mock.increment(Duration::from_secs(3)); assert_eq!(cache.get(&"a"), None);
assert_eq!(cache.get(&"b"), Some("bob"));
assert!(!cache.contains_key(&"a"));
assert!(cache.contains_key(&"b"));
assert_eq!(cache.iter().count(), 1);
cache.sync();
assert_eq!(cache.entry_count(), 1);
mock.increment(Duration::from_secs(10)); assert_eq!(cache.get(&"a"), None);
assert_eq!(cache.get(&"b"), None);
assert!(!cache.contains_key(&"a"));
assert!(!cache.contains_key(&"b"));
assert_eq!(cache.iter().count(), 0);
cache.sync();
assert!(cache.is_table_empty());
}
#[test]
fn test_iter() {
const NUM_KEYS: usize = 50;
fn make_value(key: usize) -> String {
format!("val: {}", key)
}
let cache = Cache::builder()
.max_capacity(100)
.time_to_idle(Duration::from_secs(10))
.build();
for key in 0..NUM_KEYS {
cache.insert(key, make_value(key));
}
let mut key_set = std::collections::HashSet::new();
for entry in &cache {
let (key, value) = entry.pair();
assert_eq!(value, &make_value(*key));
key_set.insert(*key);
}
assert_eq!(key_set.len(), NUM_KEYS);
}
#[test]
fn test_iter_multi_threads() {
use std::collections::HashSet;
const NUM_KEYS: usize = 1024;
const NUM_THREADS: usize = 16;
fn make_value(key: usize) -> String {
format!("val: {}", key)
}
let cache = Cache::builder()
.max_capacity(2048)
.time_to_idle(Duration::from_secs(10))
.build();
for key in 0..NUM_KEYS {
cache.insert(key, make_value(key));
}
let rw_lock = Arc::new(std::sync::RwLock::<()>::default());
let write_lock = rw_lock.write().unwrap();
#[allow(clippy::needless_collect)]
let handles = (0..NUM_THREADS)
.map(|n| {
let cache = cache.clone();
let rw_lock = Arc::clone(&rw_lock);
if n % 2 == 0 {
std::thread::spawn(move || {
let read_lock = rw_lock.read().unwrap();
for key in 0..NUM_KEYS {
cache.insert(key, make_value(key));
}
std::mem::drop(read_lock);
})
} else {
std::thread::spawn(move || {
let read_lock = rw_lock.read().unwrap();
let mut key_set = HashSet::new();
for entry in &cache {
let (key, value) = entry.pair();
assert_eq!(value, &make_value(*key));
key_set.insert(*key);
}
assert_eq!(key_set.len(), NUM_KEYS);
std::mem::drop(read_lock);
})
}
})
.collect::<Vec<_>>();
std::mem::drop(write_lock);
handles.into_iter().for_each(|h| h.join().expect("Failed"));
let key_set = cache.iter().map(|ent| *ent.key()).collect::<HashSet<_>>();
assert_eq!(key_set.len(), NUM_KEYS);
}
#[test]
fn test_debug_format() {
let cache = Cache::new(10);
cache.insert('a', "alice");
cache.insert('b', "bob");
cache.insert('c', "cindy");
let debug_str = format!("{:?}", cache);
assert!(debug_str.starts_with('{'));
assert!(debug_str.contains(r#"'a': "alice""#));
assert!(debug_str.contains(r#"'b': "bob""#));
assert!(debug_str.contains(r#"'c': "cindy""#));
assert!(debug_str.ends_with('}'));
}
}