use super::{
value_initializer::{InitResult, ValueInitializer},
ConcurrentCacheExt,
};
use crate::{
sync::{
base_cache::{BaseCache, HouseKeeperArc, MAX_SYNC_REPEATS, WRITE_RETRY_INTERVAL_MICROS},
housekeeper::InnerSync,
PredicateId, WriteOp,
},
PredicateError,
};
use crossbeam_channel::{Sender, TrySendError};
use std::{
any::TypeId,
borrow::Borrow,
collections::hash_map::RandomState,
future::Future,
hash::{BuildHasher, Hash},
sync::Arc,
time::Duration,
};
#[derive(Clone)]
pub struct Cache<K, V, S = RandomState> {
base: BaseCache<K, V, S>,
value_initializer: Arc<ValueInitializer<K, V, S>>,
}
#[allow(clippy::non_send_fields_in_send_ty)]
unsafe impl<K, V, S> Send for Cache<K, V, S>
where
K: Send + Sync,
V: Send + Sync,
S: Send,
{
}
unsafe impl<K, V, S> Sync for Cache<K, V, S>
where
K: Send + Sync,
V: Send + Sync,
S: Sync,
{
}
impl<K, V> Cache<K, V, RandomState>
where
K: Hash + Eq + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
pub fn new(max_capacity: usize) -> Self {
let build_hasher = RandomState::default();
Self::with_everything(max_capacity, None, build_hasher, None, None, false)
}
}
impl<K, V, S> Cache<K, V, S>
where
K: Hash + Eq + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
S: BuildHasher + Clone + Send + Sync + 'static,
{
pub(crate) fn with_everything(
max_capacity: usize,
initial_capacity: Option<usize>,
build_hasher: S,
time_to_live: Option<Duration>,
time_to_idle: Option<Duration>,
invalidator_enabled: bool,
) -> Self {
Self {
base: BaseCache::new(
max_capacity,
initial_capacity,
build_hasher.clone(),
time_to_live,
time_to_idle,
invalidator_enabled,
),
value_initializer: Arc::new(ValueInitializer::with_hasher(build_hasher)),
}
}
pub fn get<Q>(&self, key: &Q) -> Option<V>
where
Arc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.base.get_with_hash(key, self.base.hash(key))
}
pub async fn get_or_insert_with<F>(&self, key: K, init: F) -> V
where
F: Future<Output = V>,
{
let hash = self.base.hash(&key);
let key = Arc::new(key);
self.get_or_insert_with_hash_and_fun(key, hash, init).await
}
pub async fn get_or_try_insert_with<F, E>(&self, key: K, init: F) -> Result<V, Arc<E>>
where
F: Future<Output = Result<V, E>>,
E: Send + Sync + 'static,
{
let hash = self.base.hash(&key);
let key = Arc::new(key);
self.get_or_try_insert_with_hash_and_fun(key, hash, init)
.await
}
pub async fn insert(&self, key: K, value: V) {
let hash = self.base.hash(&key);
let key = Arc::new(key);
self.insert_with_hash(key, hash, value).await
}
pub fn blocking_insert(&self, key: K, value: V) {
let hash = self.base.hash(&key);
let key = Arc::new(key);
let op = self.base.do_insert_with_hash(key, hash, value);
let hk = self.base.housekeeper.as_ref();
Self::blocking_schedule_write_op(&self.base.write_op_ch, op, hk).expect("Failed to insert");
}
pub async fn invalidate<Q>(&self, key: &Q)
where
Arc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
if let Some(entry) = self.base.remove(key) {
let op = WriteOp::Remove(entry);
let hk = self.base.housekeeper.as_ref();
Self::schedule_write_op(&self.base.write_op_ch, op, hk)
.await
.expect("Failed to remove");
}
}
pub fn blocking_invalidate<Q>(&self, key: &Q)
where
Arc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
if let Some(entry) = self.base.remove(key) {
let op = WriteOp::Remove(entry);
let hk = self.base.housekeeper.as_ref();
Self::blocking_schedule_write_op(&self.base.write_op_ch, op, hk)
.expect("Failed to remove");
}
}
pub fn invalidate_all(&self) {
self.base.invalidate_all();
}
pub fn invalidate_entries_if<F>(&self, predicate: F) -> Result<PredicateId, PredicateError>
where
F: Fn(&K, &V) -> bool + Send + Sync + 'static,
{
self.base.invalidate_entries_if(Arc::new(predicate))
}
pub fn max_capacity(&self) -> usize {
self.base.max_capacity()
}
pub fn time_to_live(&self) -> Option<Duration> {
self.base.time_to_live()
}
pub fn time_to_idle(&self) -> Option<Duration> {
self.base.time_to_idle()
}
pub fn num_segments(&self) -> usize {
1
}
}
impl<K, V, S> ConcurrentCacheExt<K, V> for Cache<K, V, S>
where
K: Hash + Eq + Send + Sync + 'static,
V: Send + Sync + 'static,
S: BuildHasher + Clone + Send + Sync + 'static,
{
fn sync(&self) {
self.base.inner.sync(MAX_SYNC_REPEATS);
}
}
impl<K, V, S> Cache<K, V, S>
where
K: Hash + Eq + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
S: BuildHasher + Clone + Send + Sync + 'static,
{
async fn get_or_insert_with_hash_and_fun(
&self,
key: Arc<K>,
hash: u64,
init: impl Future<Output = V>,
) -> V {
if let Some(v) = self.base.get_with_hash(&key, hash) {
return v;
}
match self
.value_initializer
.init_or_read(Arc::clone(&key), init)
.await
{
InitResult::Initialized(v) => {
self.insert_with_hash(Arc::clone(&key), hash, v.clone())
.await;
self.value_initializer
.remove_waiter(&key, TypeId::of::<()>());
v
}
InitResult::ReadExisting(v) => v,
InitResult::InitErr(_) => unreachable!(),
}
}
async fn get_or_try_insert_with_hash_and_fun<F, E>(
&self,
key: Arc<K>,
hash: u64,
init: F,
) -> Result<V, Arc<E>>
where
F: Future<Output = Result<V, E>>,
E: Send + Sync + 'static,
{
if let Some(v) = self.base.get_with_hash(&key, hash) {
return Ok(v);
}
match self
.value_initializer
.try_init_or_read(Arc::clone(&key), init)
.await
{
InitResult::Initialized(v) => {
let hash = self.base.hash(&key);
self.insert_with_hash(Arc::clone(&key), hash, v.clone())
.await;
self.value_initializer
.remove_waiter(&key, TypeId::of::<E>());
Ok(v)
}
InitResult::ReadExisting(v) => Ok(v),
InitResult::InitErr(e) => Err(e),
}
}
async fn insert_with_hash(&self, key: Arc<K>, hash: u64, value: V) {
let op = self.base.do_insert_with_hash(key, hash, value);
let hk = self.base.housekeeper.as_ref();
Self::schedule_write_op(&self.base.write_op_ch, op, hk)
.await
.expect("Failed to insert");
}
#[inline]
async fn schedule_write_op(
ch: &Sender<WriteOp<K, V>>,
op: WriteOp<K, V>,
housekeeper: Option<&HouseKeeperArc<K, V, S>>,
) -> Result<(), TrySendError<WriteOp<K, V>>> {
let mut op = op;
loop {
BaseCache::apply_reads_writes_if_needed(ch, housekeeper);
match ch.try_send(op) {
Ok(()) => break,
Err(TrySendError::Full(op1)) => {
op = op1;
async_io::Timer::after(Duration::from_micros(WRITE_RETRY_INTERVAL_MICROS))
.await;
}
Err(e @ TrySendError::Disconnected(_)) => return Err(e),
}
}
Ok(())
}
#[inline]
fn blocking_schedule_write_op(
ch: &Sender<WriteOp<K, V>>,
op: WriteOp<K, V>,
housekeeper: Option<&HouseKeeperArc<K, V, S>>,
) -> Result<(), TrySendError<WriteOp<K, V>>> {
let mut op = op;
loop {
BaseCache::apply_reads_writes_if_needed(ch, housekeeper);
match ch.try_send(op) {
Ok(()) => break,
Err(TrySendError::Full(op1)) => {
op = op1;
std::thread::sleep(Duration::from_micros(WRITE_RETRY_INTERVAL_MICROS));
}
Err(e @ TrySendError::Disconnected(_)) => return Err(e),
}
}
Ok(())
}
}
#[cfg(test)]
impl<K, V, S> Cache<K, V, S>
where
K: Hash + Eq + Send + Sync + 'static,
V: Send + Sync + 'static,
S: BuildHasher + Clone + Send + Sync + 'static,
{
fn is_table_empty(&self) -> bool {
self.table_size() == 0
}
fn table_size(&self) -> usize {
self.base.table_size()
}
fn invalidation_predicate_count(&self) -> usize {
self.base.invalidation_predicate_count()
}
fn reconfigure_for_testing(&mut self) {
self.base.reconfigure_for_testing();
}
fn set_expiration_clock(&self, clock: Option<crate::common::time::Clock>) {
self.base.set_expiration_clock(clock);
}
}
#[cfg(test)]
mod tests {
use super::{Cache, ConcurrentCacheExt};
use crate::{common::time::Clock, future::CacheBuilder};
use async_io::Timer;
use std::{convert::Infallible, sync::Arc, time::Duration};
#[tokio::test]
async fn basic_single_async_task() {
let mut cache = Cache::new(3);
cache.reconfigure_for_testing();
let cache = cache;
cache.insert("a", "alice").await;
cache.insert("b", "bob").await;
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
cache.sync();
cache.insert("c", "cindy").await;
assert_eq!(cache.get(&"c"), Some("cindy"));
cache.sync();
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
cache.sync();
cache.insert("d", "david").await; cache.sync();
assert_eq!(cache.get(&"d"), None);
cache.insert("d", "david").await;
cache.sync();
assert_eq!(cache.get(&"d"), None);
cache.insert("d", "dennis").await;
cache.sync();
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
assert_eq!(cache.get(&"c"), None);
assert_eq!(cache.get(&"d"), Some("dennis"));
cache.invalidate(&"b").await;
assert_eq!(cache.get(&"b"), None);
}
#[test]
fn basic_single_blocking_api() {
let mut cache = Cache::new(3);
cache.reconfigure_for_testing();
let cache = cache;
cache.blocking_insert("a", "alice");
cache.blocking_insert("b", "bob");
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
cache.sync();
cache.blocking_insert("c", "cindy");
assert_eq!(cache.get(&"c"), Some("cindy"));
cache.sync();
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
cache.sync();
cache.blocking_insert("d", "david"); cache.sync();
assert_eq!(cache.get(&"d"), None);
cache.blocking_insert("d", "david");
cache.sync();
assert_eq!(cache.get(&"d"), None);
cache.blocking_insert("d", "dennis");
cache.sync();
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
assert_eq!(cache.get(&"c"), None);
assert_eq!(cache.get(&"d"), Some("dennis"));
cache.blocking_invalidate(&"b");
assert_eq!(cache.get(&"b"), None);
}
#[tokio::test]
async fn basic_multi_async_tasks() {
let num_threads = 4;
let cache = Cache::new(100);
let tasks = (0..num_threads)
.map(|id| {
let cache = cache.clone();
if id == 0 {
tokio::spawn(async move {
cache.blocking_insert(10, format!("{}-100", id));
cache.get(&10);
cache.blocking_insert(20, format!("{}-200", id));
cache.blocking_invalidate(&10);
})
} else {
tokio::spawn(async move {
cache.insert(10, format!("{}-100", id)).await;
cache.get(&10);
cache.insert(20, format!("{}-200", id)).await;
cache.invalidate(&10).await;
})
}
})
.collect::<Vec<_>>();
let _ = futures_util::future::join_all(tasks).await;
assert!(cache.get(&10).is_none());
assert!(cache.get(&20).is_some());
}
#[tokio::test]
async fn invalidate_all() {
let mut cache = Cache::new(100);
cache.reconfigure_for_testing();
let cache = cache;
cache.insert("a", "alice").await;
cache.insert("b", "bob").await;
cache.insert("c", "cindy").await;
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
assert_eq!(cache.get(&"c"), Some("cindy"));
cache.sync();
cache.invalidate_all();
cache.sync();
cache.insert("d", "david").await;
cache.sync();
assert!(cache.get(&"a").is_none());
assert!(cache.get(&"b").is_none());
assert!(cache.get(&"c").is_none());
assert_eq!(cache.get(&"d"), Some("david"));
}
#[tokio::test]
async fn invalidate_entries_if() -> Result<(), Box<dyn std::error::Error>> {
use std::collections::HashSet;
let mut cache = CacheBuilder::new(100)
.support_invalidation_closures()
.build();
cache.reconfigure_for_testing();
let (clock, mock) = Clock::mock();
cache.set_expiration_clock(Some(clock));
let cache = cache;
cache.insert(0, "alice").await;
cache.insert(1, "bob").await;
cache.insert(2, "alex").await;
cache.sync();
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&0), Some("alice"));
assert_eq!(cache.get(&1), Some("bob"));
assert_eq!(cache.get(&2), Some("alex"));
let names = ["alice", "alex"].iter().cloned().collect::<HashSet<_>>();
cache.invalidate_entries_if(move |_k, &v| names.contains(v))?;
assert_eq!(cache.invalidation_predicate_count(), 1);
mock.increment(Duration::from_secs(5));
cache.insert(3, "alice").await;
cache.sync(); std::thread::sleep(Duration::from_millis(200));
cache.sync(); std::thread::sleep(Duration::from_millis(200));
assert!(cache.get(&0).is_none());
assert!(cache.get(&2).is_none());
assert_eq!(cache.get(&1), Some("bob"));
assert_eq!(cache.get(&3), Some("alice"));
assert_eq!(cache.table_size(), 2);
assert_eq!(cache.invalidation_predicate_count(), 0);
mock.increment(Duration::from_secs(5));
cache.invalidate_entries_if(|_k, &v| v == "alice")?;
cache.invalidate_entries_if(|_k, &v| v == "bob")?;
assert_eq!(cache.invalidation_predicate_count(), 2);
cache.sync(); std::thread::sleep(Duration::from_millis(200));
cache.sync(); std::thread::sleep(Duration::from_millis(200));
assert!(cache.get(&1).is_none());
assert!(cache.get(&3).is_none());
assert_eq!(cache.table_size(), 0);
assert_eq!(cache.invalidation_predicate_count(), 0);
Ok(())
}
#[tokio::test]
async fn time_to_live() {
let mut cache = CacheBuilder::new(100)
.time_to_live(Duration::from_secs(10))
.build();
cache.reconfigure_for_testing();
let (clock, mock) = Clock::mock();
cache.set_expiration_clock(Some(clock));
let cache = cache;
cache.insert("a", "alice").await;
cache.sync();
mock.increment(Duration::from_secs(5)); cache.sync();
cache.get(&"a");
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"a"), None);
assert!(cache.is_table_empty());
cache.insert("b", "bob").await;
cache.sync();
assert_eq!(cache.table_size(), 1);
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"b"), Some("bob"));
assert_eq!(cache.table_size(), 1);
cache.insert("b", "bill").await;
cache.sync();
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"b"), Some("bill"));
assert_eq!(cache.table_size(), 1);
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"a"), None);
assert_eq!(cache.get(&"b"), None);
assert!(cache.is_table_empty());
}
#[tokio::test]
async fn time_to_idle() {
let mut cache = CacheBuilder::new(100)
.time_to_idle(Duration::from_secs(10))
.build();
cache.reconfigure_for_testing();
let (clock, mock) = Clock::mock();
cache.set_expiration_clock(Some(clock));
let cache = cache;
cache.insert("a", "alice").await;
cache.sync();
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"a"), Some("alice"));
mock.increment(Duration::from_secs(5)); cache.sync();
cache.insert("b", "bob").await;
cache.sync();
assert_eq!(cache.table_size(), 2);
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"a"), None);
assert_eq!(cache.get(&"b"), Some("bob"));
assert_eq!(cache.table_size(), 1);
mock.increment(Duration::from_secs(10)); cache.sync();
assert_eq!(cache.get(&"a"), None);
assert_eq!(cache.get(&"b"), None);
assert!(cache.is_table_empty());
}
#[tokio::test]
async fn get_or_insert_with() {
let cache = Cache::new(100);
const KEY: u32 = 0;
let task1 = {
let cache1 = cache.clone();
async move {
let v = cache1
.get_or_insert_with(KEY, async {
Timer::after(Duration::from_millis(300)).await;
"task1"
})
.await;
assert_eq!(v, "task1");
}
};
let task2 = {
let cache2 = cache.clone();
async move {
Timer::after(Duration::from_millis(100)).await;
let v = cache2
.get_or_insert_with(KEY, async { unreachable!() })
.await;
assert_eq!(v, "task1");
}
};
let task3 = {
let cache3 = cache.clone();
async move {
Timer::after(Duration::from_millis(400)).await;
let v = cache3
.get_or_insert_with(KEY, async { unreachable!() })
.await;
assert_eq!(v, "task1");
}
};
let task4 = {
let cache4 = cache.clone();
async move {
Timer::after(Duration::from_millis(200)).await;
let maybe_v = cache4.get(&KEY);
assert!(maybe_v.is_none());
}
};
let task5 = {
let cache5 = cache.clone();
async move {
Timer::after(Duration::from_millis(400)).await;
let maybe_v = cache5.get(&KEY);
assert_eq!(maybe_v, Some("task1"));
}
};
futures_util::join!(task1, task2, task3, task4, task5);
}
#[tokio::test]
async fn get_or_try_insert_with() {
use std::sync::Arc;
#[derive(Debug)]
pub struct MyError(String);
type MyResult<T> = Result<T, Arc<MyError>>;
let cache = Cache::new(100);
const KEY: u32 = 0;
let task1 = {
let cache1 = cache.clone();
async move {
let v = cache1
.get_or_try_insert_with(KEY, async {
Timer::after(Duration::from_millis(300)).await;
Err(MyError("task1 error".into()))
})
.await;
assert!(v.is_err());
}
};
let task2 = {
let cache2 = cache.clone();
async move {
Timer::after(Duration::from_millis(100)).await;
let v: MyResult<_> = cache2
.get_or_try_insert_with(KEY, async { unreachable!() })
.await;
assert!(v.is_err());
}
};
let task3 = {
let cache3 = cache.clone();
async move {
Timer::after(Duration::from_millis(400)).await;
let v: MyResult<_> = cache3
.get_or_try_insert_with(KEY, async {
Timer::after(Duration::from_millis(300)).await;
Ok("task3")
})
.await;
assert_eq!(v.unwrap(), "task3");
}
};
let task4 = {
let cache4 = cache.clone();
async move {
Timer::after(Duration::from_millis(500)).await;
let v: MyResult<_> = cache4
.get_or_try_insert_with(KEY, async { unreachable!() })
.await;
assert_eq!(v.unwrap(), "task3");
}
};
let task5 = {
let cache5 = cache.clone();
async move {
Timer::after(Duration::from_millis(800)).await;
let v: MyResult<_> = cache5
.get_or_try_insert_with(KEY, async { unreachable!() })
.await;
assert_eq!(v.unwrap(), "task3");
}
};
let task6 = {
let cache6 = cache.clone();
async move {
Timer::after(Duration::from_millis(200)).await;
let maybe_v = cache6.get(&KEY);
assert!(maybe_v.is_none());
}
};
let task7 = {
let cache7 = cache.clone();
async move {
Timer::after(Duration::from_millis(400)).await;
let maybe_v = cache7.get(&KEY);
assert!(maybe_v.is_none());
}
};
let task8 = {
let cache8 = cache.clone();
async move {
Timer::after(Duration::from_millis(800)).await;
let maybe_v = cache8.get(&KEY);
assert_eq!(maybe_v, Some("task3"));
}
};
futures_util::join!(task1, task2, task3, task4, task5, task6, task7, task8);
}
#[tokio::test]
async fn handle_panic_in_get_or_insert_with() {
use tokio::time::{sleep, Duration};
let cache = Cache::new(16);
let semaphore = Arc::new(tokio::sync::Semaphore::new(0));
{
let cache_ref = cache.clone();
let semaphore_ref = semaphore.clone();
tokio::task::spawn(async move {
let _ = cache_ref
.get_or_insert_with(1, async move {
semaphore_ref.add_permits(1);
sleep(Duration::from_millis(50)).await;
panic!("Panic during get_or_try_insert_with");
})
.await;
});
}
let _ = semaphore.acquire().await.expect("semaphore acquire failed");
assert_eq!(cache.get_or_insert_with(1, async { 5 }).await, 5);
}
#[tokio::test]
async fn handle_panic_in_get_or_try_insert_with() {
use tokio::time::{sleep, Duration};
let cache = Cache::new(16);
let semaphore = Arc::new(tokio::sync::Semaphore::new(0));
{
let cache_ref = cache.clone();
let semaphore_ref = semaphore.clone();
tokio::task::spawn(async move {
let _ = cache_ref
.get_or_try_insert_with(1, async move {
semaphore_ref.add_permits(1);
sleep(Duration::from_millis(50)).await;
panic!("Panic during get_or_try_insert_with");
})
.await as Result<_, Arc<Infallible>>;
});
}
let _ = semaphore.acquire().await.expect("semaphore acquire failed");
assert_eq!(
cache.get_or_try_insert_with(1, async { Ok(5) }).await as Result<_, Arc<Infallible>>,
Ok(5)
);
}
}