use super::{
base_cache::{BaseCache, HouseKeeperArc, MAX_SYNC_REPEATS, WRITE_RETRY_INTERVAL_MICROS},
housekeeper::InnerSync,
value_initializer::ValueInitializer,
ConcurrentCacheExt, PredicateId, WriteOp,
};
use crate::{sync::value_initializer::InitResult, PredicateError};
use crossbeam_channel::{Sender, TrySendError};
use std::{
any::TypeId,
borrow::Borrow,
collections::hash_map::RandomState,
hash::{BuildHasher, Hash},
sync::Arc,
time::Duration,
};
#[derive(Clone)]
pub struct Cache<K, V, S = RandomState> {
base: BaseCache<K, V, S>,
value_initializer: Arc<ValueInitializer<K, V, S>>,
}
#[allow(clippy::non_send_fields_in_send_ty)]
unsafe impl<K, V, S> Send for Cache<K, V, S>
where
K: Send + Sync,
V: Send + Sync,
S: Send,
{
}
unsafe impl<K, V, S> Sync for Cache<K, V, S>
where
K: Send + Sync,
V: Send + Sync,
S: Sync,
{
}
impl<K, V> Cache<K, V, RandomState>
where
K: Hash + Eq + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
pub fn new(max_capacity: usize) -> Self {
let build_hasher = RandomState::default();
Self::with_everything(max_capacity, None, build_hasher, None, None, false)
}
}
impl<K, V, S> Cache<K, V, S>
where
K: Hash + Eq + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
S: BuildHasher + Clone + Send + Sync + 'static,
{
pub(crate) fn with_everything(
max_capacity: usize,
initial_capacity: Option<usize>,
build_hasher: S,
time_to_live: Option<Duration>,
time_to_idle: Option<Duration>,
invalidator_enabled: bool,
) -> Self {
Self {
base: BaseCache::new(
max_capacity,
initial_capacity,
build_hasher.clone(),
time_to_live,
time_to_idle,
invalidator_enabled,
),
value_initializer: Arc::new(ValueInitializer::with_hasher(build_hasher)),
}
}
pub fn get<Q>(&self, key: &Q) -> Option<V>
where
Arc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.base.get_with_hash(key, self.base.hash(key))
}
pub(crate) fn get_with_hash<Q>(&self, key: &Q, hash: u64) -> Option<V>
where
Arc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.base.get_with_hash(key, hash)
}
pub fn get_or_insert_with(&self, key: K, init: impl FnOnce() -> V) -> V {
let hash = self.base.hash(&key);
let key = Arc::new(key);
self.get_or_insert_with_hash_and_fun(key, hash, init)
}
pub(crate) fn get_or_insert_with_hash_and_fun(
&self,
key: Arc<K>,
hash: u64,
init: impl FnOnce() -> V,
) -> V {
if let Some(v) = self.get_with_hash(&key, hash) {
return v;
}
match self.value_initializer.init_or_read(Arc::clone(&key), init) {
InitResult::Initialized(v) => {
self.insert_with_hash(Arc::clone(&key), hash, v.clone());
self.value_initializer
.remove_waiter(&key, TypeId::of::<()>());
v
}
InitResult::ReadExisting(v) => v,
InitResult::InitErr(_) => unreachable!(),
}
}
pub fn get_or_try_insert_with<F, E>(&self, key: K, init: F) -> Result<V, Arc<E>>
where
F: FnOnce() -> Result<V, E>,
E: Send + Sync + 'static,
{
let hash = self.base.hash(&key);
let key = Arc::new(key);
self.get_or_try_insert_with_hash_and_fun(key, hash, init)
}
pub(crate) fn get_or_try_insert_with_hash_and_fun<F, E>(
&self,
key: Arc<K>,
hash: u64,
init: F,
) -> Result<V, Arc<E>>
where
F: FnOnce() -> Result<V, E>,
E: Send + Sync + 'static,
{
if let Some(v) = self.get_with_hash(&key, hash) {
return Ok(v);
}
match self
.value_initializer
.try_init_or_read(Arc::clone(&key), init)
{
InitResult::Initialized(v) => {
self.insert_with_hash(Arc::clone(&key), hash, v.clone());
self.value_initializer
.remove_waiter(&key, TypeId::of::<E>());
Ok(v)
}
InitResult::ReadExisting(v) => Ok(v),
InitResult::InitErr(e) => Err(e),
}
}
pub fn insert(&self, key: K, value: V) {
let hash = self.base.hash(&key);
let key = Arc::new(key);
self.insert_with_hash(key, hash, value)
}
pub(crate) fn insert_with_hash(&self, key: Arc<K>, hash: u64, value: V) {
let op = self.base.do_insert_with_hash(key, hash, value);
let hk = self.base.housekeeper.as_ref();
Self::schedule_write_op(&self.base.write_op_ch, op, hk).expect("Failed to insert");
}
pub fn invalidate<Q>(&self, key: &Q)
where
Arc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
if let Some(entry) = self.base.remove(key) {
let op = WriteOp::Remove(entry);
let hk = self.base.housekeeper.as_ref();
Self::schedule_write_op(&self.base.write_op_ch, op, hk).expect("Failed to remove");
}
}
pub fn invalidate_all(&self) {
self.base.invalidate_all();
}
pub fn invalidate_entries_if<F>(&self, predicate: F) -> Result<PredicateId, PredicateError>
where
F: Fn(&K, &V) -> bool + Send + Sync + 'static,
{
self.base.invalidate_entries_if(Arc::new(predicate))
}
pub(crate) fn invalidate_entries_with_arc_fun<F>(
&self,
predicate: Arc<F>,
) -> Result<PredicateId, PredicateError>
where
F: Fn(&K, &V) -> bool + Send + Sync + 'static,
{
self.base.invalidate_entries_if(predicate)
}
pub fn max_capacity(&self) -> usize {
self.base.max_capacity()
}
pub fn time_to_live(&self) -> Option<Duration> {
self.base.time_to_live()
}
pub fn time_to_idle(&self) -> Option<Duration> {
self.base.time_to_idle()
}
pub fn num_segments(&self) -> usize {
1
}
}
impl<K, V, S> ConcurrentCacheExt<K, V> for Cache<K, V, S>
where
K: Hash + Eq + Send + Sync + 'static,
V: Send + Sync + 'static,
S: BuildHasher + Clone + Send + Sync + 'static,
{
fn sync(&self) {
self.base.inner.sync(MAX_SYNC_REPEATS);
}
}
impl<K, V, S> Cache<K, V, S>
where
K: Hash + Eq + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
S: BuildHasher + Clone + Send + Sync + 'static,
{
#[inline]
fn schedule_write_op(
ch: &Sender<WriteOp<K, V>>,
op: WriteOp<K, V>,
housekeeper: Option<&HouseKeeperArc<K, V, S>>,
) -> Result<(), TrySendError<WriteOp<K, V>>> {
let mut op = op;
loop {
BaseCache::apply_reads_writes_if_needed(ch, housekeeper);
match ch.try_send(op) {
Ok(()) => break,
Err(TrySendError::Full(op1)) => {
op = op1;
std::thread::sleep(Duration::from_micros(WRITE_RETRY_INTERVAL_MICROS));
}
Err(e @ TrySendError::Disconnected(_)) => return Err(e),
}
}
Ok(())
}
}
#[cfg(test)]
impl<K, V, S> Cache<K, V, S>
where
K: Hash + Eq + Send + Sync + 'static,
V: Send + Sync + 'static,
S: BuildHasher + Clone + Send + Sync + 'static,
{
pub(crate) fn is_table_empty(&self) -> bool {
self.table_size() == 0
}
pub(crate) fn table_size(&self) -> usize {
self.base.table_size()
}
pub(crate) fn invalidation_predicate_count(&self) -> usize {
self.base.invalidation_predicate_count()
}
pub(crate) fn reconfigure_for_testing(&mut self) {
self.base.reconfigure_for_testing();
}
pub(crate) fn set_expiration_clock(&self, clock: Option<crate::common::time::Clock>) {
self.base.set_expiration_clock(clock);
}
}
#[cfg(test)]
mod tests {
use super::{Cache, ConcurrentCacheExt};
use crate::{common::time::Clock, sync::CacheBuilder};
use std::{convert::Infallible, sync::Arc, time::Duration};
#[test]
fn basic_single_thread() {
let mut cache = Cache::new(3);
cache.reconfigure_for_testing();
let cache = cache;
cache.insert("a", "alice");
cache.insert("b", "bob");
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
cache.sync();
cache.insert("c", "cindy");
assert_eq!(cache.get(&"c"), Some("cindy"));
cache.sync();
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
cache.sync();
cache.insert("d", "david"); cache.sync();
assert_eq!(cache.get(&"d"), None);
cache.insert("d", "david");
cache.sync();
assert_eq!(cache.get(&"d"), None);
cache.insert("d", "dennis");
cache.sync();
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
assert_eq!(cache.get(&"c"), None);
assert_eq!(cache.get(&"d"), Some("dennis"));
cache.invalidate(&"b");
assert_eq!(cache.get(&"b"), None);
}
#[test]
fn basic_multi_threads() {
let num_threads = 4;
let cache = Cache::new(100);
let handles = (0..num_threads)
.map(|id| {
let cache = cache.clone();
std::thread::spawn(move || {
cache.insert(10, format!("{}-100", id));
cache.get(&10);
cache.insert(20, format!("{}-200", id));
cache.invalidate(&10);
})
})
.collect::<Vec<_>>();
handles.into_iter().for_each(|h| h.join().expect("Failed"));
assert!(cache.get(&10).is_none());
assert!(cache.get(&20).is_some());
}
#[test]
fn invalidate_all() {
let mut cache = Cache::new(100);
cache.reconfigure_for_testing();
let cache = cache;
cache.insert("a", "alice");
cache.insert("b", "bob");
cache.insert("c", "cindy");
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
assert_eq!(cache.get(&"c"), Some("cindy"));
cache.sync();
cache.invalidate_all();
cache.sync();
cache.insert("d", "david");
cache.sync();
assert!(cache.get(&"a").is_none());
assert!(cache.get(&"b").is_none());
assert!(cache.get(&"c").is_none());
assert_eq!(cache.get(&"d"), Some("david"));
}
#[test]
fn invalidate_entries_if() -> Result<(), Box<dyn std::error::Error>> {
use std::collections::HashSet;
let mut cache = CacheBuilder::new(100)
.support_invalidation_closures()
.build();
cache.reconfigure_for_testing();
let (clock, mock) = Clock::mock();
cache.set_expiration_clock(Some(clock));
let cache = cache;
cache.insert(0, "alice");
cache.insert(1, "bob");
cache.insert(2, "alex");
cache.sync();
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&0), Some("alice"));
assert_eq!(cache.get(&1), Some("bob"));
assert_eq!(cache.get(&2), Some("alex"));
let names = ["alice", "alex"].iter().cloned().collect::<HashSet<_>>();
cache.invalidate_entries_if(move |_k, &v| names.contains(v))?;
assert_eq!(cache.base.invalidation_predicate_count(), 1);
mock.increment(Duration::from_secs(5));
cache.insert(3, "alice");
cache.sync(); std::thread::sleep(Duration::from_millis(200));
cache.sync(); std::thread::sleep(Duration::from_millis(200));
assert!(cache.get(&0).is_none());
assert!(cache.get(&2).is_none());
assert_eq!(cache.get(&1), Some("bob"));
assert_eq!(cache.get(&3), Some("alice"));
assert_eq!(cache.table_size(), 2);
assert_eq!(cache.invalidation_predicate_count(), 0);
mock.increment(Duration::from_secs(5));
cache.invalidate_entries_if(|_k, &v| v == "alice")?;
cache.invalidate_entries_if(|_k, &v| v == "bob")?;
assert_eq!(cache.invalidation_predicate_count(), 2);
cache.sync(); std::thread::sleep(Duration::from_millis(200));
cache.sync(); std::thread::sleep(Duration::from_millis(200));
assert!(cache.get(&1).is_none());
assert!(cache.get(&3).is_none());
assert_eq!(cache.table_size(), 0);
assert_eq!(cache.invalidation_predicate_count(), 0);
Ok(())
}
#[test]
fn time_to_live() {
let mut cache = CacheBuilder::new(100)
.time_to_live(Duration::from_secs(10))
.build();
cache.reconfigure_for_testing();
let (clock, mock) = Clock::mock();
cache.set_expiration_clock(Some(clock));
let cache = cache;
cache.insert("a", "alice");
cache.sync();
mock.increment(Duration::from_secs(5)); cache.sync();
cache.get(&"a");
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"a"), None);
assert!(cache.is_table_empty());
cache.insert("b", "bob");
cache.sync();
assert_eq!(cache.table_size(), 1);
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"b"), Some("bob"));
assert_eq!(cache.table_size(), 1);
cache.insert("b", "bill");
cache.sync();
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"b"), Some("bill"));
assert_eq!(cache.table_size(), 1);
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"a"), None);
assert_eq!(cache.get(&"b"), None);
assert!(cache.is_table_empty());
}
#[test]
fn time_to_idle() {
let mut cache = CacheBuilder::new(100)
.time_to_idle(Duration::from_secs(10))
.build();
cache.reconfigure_for_testing();
let (clock, mock) = Clock::mock();
cache.set_expiration_clock(Some(clock));
let cache = cache;
cache.insert("a", "alice");
cache.sync();
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"a"), Some("alice"));
mock.increment(Duration::from_secs(5)); cache.sync();
cache.insert("b", "bob");
cache.sync();
assert_eq!(cache.table_size(), 2);
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"a"), None);
assert_eq!(cache.get(&"b"), Some("bob"));
assert_eq!(cache.table_size(), 1);
mock.increment(Duration::from_secs(10)); cache.sync();
assert_eq!(cache.get(&"a"), None);
assert_eq!(cache.get(&"b"), None);
assert!(cache.is_table_empty());
}
#[test]
fn get_or_insert_with() {
use std::thread::{sleep, spawn};
let cache = Cache::new(100);
const KEY: u32 = 0;
let thread1 = {
let cache1 = cache.clone();
spawn(move || {
let v = cache1.get_or_insert_with(KEY, || {
sleep(Duration::from_millis(300));
"thread1"
});
assert_eq!(v, "thread1");
})
};
let thread2 = {
let cache2 = cache.clone();
spawn(move || {
sleep(Duration::from_millis(100));
let v = cache2.get_or_insert_with(KEY, || unreachable!());
assert_eq!(v, "thread1");
})
};
let thread3 = {
let cache3 = cache.clone();
spawn(move || {
sleep(Duration::from_millis(400));
let v = cache3.get_or_insert_with(KEY, || unreachable!());
assert_eq!(v, "thread1");
})
};
let thread4 = {
let cache4 = cache.clone();
spawn(move || {
sleep(Duration::from_millis(200));
let maybe_v = cache4.get(&KEY);
assert!(maybe_v.is_none());
})
};
let thread5 = {
let cache5 = cache.clone();
spawn(move || {
sleep(Duration::from_millis(400));
let maybe_v = cache5.get(&KEY);
assert_eq!(maybe_v, Some("thread1"));
})
};
for t in vec![thread1, thread2, thread3, thread4, thread5] {
t.join().expect("Failed to join");
}
}
#[test]
fn get_or_try_insert_with() {
use std::{
sync::Arc,
thread::{sleep, spawn},
};
#[derive(Debug)]
pub struct MyError(String);
type MyResult<T> = Result<T, Arc<MyError>>;
let cache = Cache::new(100);
const KEY: u32 = 0;
let thread1 = {
let cache1 = cache.clone();
spawn(move || {
let v = cache1.get_or_try_insert_with(KEY, || {
sleep(Duration::from_millis(300));
Err(MyError("thread1 error".into()))
});
assert!(v.is_err());
})
};
let thread2 = {
let cache2 = cache.clone();
spawn(move || {
sleep(Duration::from_millis(100));
let v: MyResult<_> = cache2.get_or_try_insert_with(KEY, || unreachable!());
assert!(v.is_err());
})
};
let thread3 = {
let cache3 = cache.clone();
spawn(move || {
sleep(Duration::from_millis(400));
let v: MyResult<_> = cache3.get_or_try_insert_with(KEY, || {
sleep(Duration::from_millis(300));
Ok("thread3")
});
assert_eq!(v.unwrap(), "thread3");
})
};
let thread4 = {
let cache4 = cache.clone();
spawn(move || {
sleep(Duration::from_millis(500));
let v: MyResult<_> = cache4.get_or_try_insert_with(KEY, || unreachable!());
assert_eq!(v.unwrap(), "thread3");
})
};
let thread5 = {
let cache5 = cache.clone();
spawn(move || {
sleep(Duration::from_millis(800));
let v: MyResult<_> = cache5.get_or_try_insert_with(KEY, || unreachable!());
assert_eq!(v.unwrap(), "thread3");
})
};
let thread6 = {
let cache6 = cache.clone();
spawn(move || {
sleep(Duration::from_millis(200));
let maybe_v = cache6.get(&KEY);
assert!(maybe_v.is_none());
})
};
let thread7 = {
let cache7 = cache.clone();
spawn(move || {
sleep(Duration::from_millis(400));
let maybe_v = cache7.get(&KEY);
assert!(maybe_v.is_none());
})
};
let thread8 = {
let cache8 = cache.clone();
spawn(move || {
sleep(Duration::from_millis(800));
let maybe_v = cache8.get(&KEY);
assert_eq!(maybe_v, Some("thread3"));
})
};
for t in vec![
thread1, thread2, thread3, thread4, thread5, thread6, thread7, thread8,
] {
t.join().expect("Failed to join");
}
}
#[test]
fn handle_panic_in_get_or_insert_with() {
use std::{sync::Barrier, thread};
let cache = Cache::new(16);
let barrier = Arc::new(Barrier::new(2));
{
let cache_ref = cache.clone();
let barrier_ref = barrier.clone();
thread::spawn(move || {
let _ = cache_ref.get_or_insert_with(1, || {
barrier_ref.wait();
thread::sleep(Duration::from_millis(50));
panic!("Panic during get_or_try_insert_with");
});
});
}
barrier.wait();
assert_eq!(cache.get_or_insert_with(1, || 5), 5);
}
#[test]
fn handle_panic_in_get_or_try_insert_with() {
use std::{sync::Barrier, thread};
let cache = Cache::new(16);
let barrier = Arc::new(Barrier::new(2));
{
let cache_ref = cache.clone();
let barrier_ref = barrier.clone();
thread::spawn(move || {
let _ = cache_ref.get_or_try_insert_with(1, || {
barrier_ref.wait();
thread::sleep(Duration::from_millis(50));
panic!("Panic during get_or_try_insert_with");
}) as Result<_, Arc<Infallible>>;
});
}
barrier.wait();
assert_eq!(
cache.get_or_try_insert_with(1, || Ok(5)) as Result<_, Arc<Infallible>>,
Ok(5)
);
}
}