use super::ConcurrentCacheExt;
use crate::sync::{
base_cache::{BaseCache, HouseKeeperArc, MAX_SYNC_REPEATS, WRITE_RETRY_INTERVAL_MICROS},
housekeeper::InnerSync,
WriteOp,
};
use crossbeam_channel::{Sender, TrySendError};
use std::{
borrow::Borrow,
collections::hash_map::RandomState,
hash::{BuildHasher, Hash},
sync::Arc,
time::Duration,
};
#[derive(Clone)]
pub struct Cache<K, V, S = RandomState> {
base: BaseCache<K, V, S>,
}
unsafe impl<K, V, S> Send for Cache<K, V, S>
where
K: Send + Sync,
V: Send + Sync,
S: Send,
{
}
unsafe impl<K, V, S> Sync for Cache<K, V, S>
where
K: Send + Sync,
V: Send + Sync,
S: Sync,
{
}
impl<K, V> Cache<K, V, RandomState>
where
K: Hash + Eq,
V: Clone,
{
pub fn new(max_capacity: usize) -> Self {
let build_hasher = RandomState::default();
Self::with_everything(max_capacity, None, build_hasher, None, None)
}
}
impl<K, V, S> Cache<K, V, S>
where
K: Hash + Eq,
V: Clone,
S: BuildHasher + Clone,
{
pub(crate) fn with_everything(
max_capacity: usize,
initial_capacity: Option<usize>,
build_hasher: S,
time_to_live: Option<Duration>,
time_to_idle: Option<Duration>,
) -> Self {
Self {
base: BaseCache::new(
max_capacity,
initial_capacity,
build_hasher,
time_to_live,
time_to_idle,
),
}
}
pub fn get<Q>(&self, key: &Q) -> Option<V>
where
Arc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.base.get_with_hash(key, self.base.hash(key))
}
pub async fn insert(&self, key: K, value: V) {
let hash = self.base.hash(&key);
self.insert_with_hash(key, hash, value).await
}
pub fn blocking_insert(&self, key: K, value: V) {
let hash = self.base.hash(&key);
let op = self.base.do_insert_with_hash(key, hash, value);
let hk = self.base.housekeeper.as_ref();
if Self::blocking_schedule_write_op(&self.base.write_op_ch, op, hk).is_err() {
panic!("Failed to insert");
}
}
pub async fn invalidate<Q>(&self, key: &Q)
where
Arc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
if let Some(entry) = self.base.remove(key) {
let op = WriteOp::Remove(entry);
let hk = self.base.housekeeper.as_ref();
if Self::schedule_write_op(&self.base.write_op_ch, op, hk)
.await
.is_err()
{
panic!("Failed to remove");
}
}
}
pub fn blocking_invalidate<Q>(&self, key: &Q)
where
Arc<K>: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
if let Some(entry) = self.base.remove(key) {
let op = WriteOp::Remove(entry);
let hk = self.base.housekeeper.as_ref();
if Self::blocking_schedule_write_op(&self.base.write_op_ch, op, hk).is_err() {
panic!("Failed to remove");
}
}
}
pub fn invalidate_all(&self) {
self.base.invalidate_all();
}
pub fn max_capacity(&self) -> usize {
self.base.max_capacity()
}
pub fn time_to_live(&self) -> Option<Duration> {
self.base.time_to_live()
}
pub fn time_to_idle(&self) -> Option<Duration> {
self.base.time_to_idle()
}
pub fn num_segments(&self) -> usize {
1
}
}
impl<K, V, S> ConcurrentCacheExt<K, V> for Cache<K, V, S>
where
K: Hash + Eq,
S: BuildHasher + Clone,
{
fn sync(&self) {
self.base.inner.sync(MAX_SYNC_REPEATS);
}
}
impl<K, V, S> Cache<K, V, S>
where
K: Hash + Eq,
V: Clone,
S: BuildHasher + Clone,
{
async fn insert_with_hash(&self, key: K, hash: u64, value: V) {
let op = self.base.do_insert_with_hash(key, hash, value);
let hk = self.base.housekeeper.as_ref();
if Self::schedule_write_op(&self.base.write_op_ch, op, hk)
.await
.is_err()
{
panic!("Failed to insert");
}
}
#[inline]
async fn schedule_write_op(
ch: &Sender<WriteOp<K, V>>,
op: WriteOp<K, V>,
housekeeper: Option<&HouseKeeperArc<K, V, S>>,
) -> Result<(), TrySendError<WriteOp<K, V>>> {
let mut op = op;
loop {
BaseCache::apply_reads_writes_if_needed(ch, housekeeper);
match ch.try_send(op) {
Ok(()) => break,
Err(TrySendError::Full(op1)) => {
op = op1;
async_io::Timer::after(Duration::from_micros(WRITE_RETRY_INTERVAL_MICROS))
.await;
}
Err(e @ TrySendError::Disconnected(_)) => return Err(e),
}
}
Ok(())
}
#[inline]
fn blocking_schedule_write_op(
ch: &Sender<WriteOp<K, V>>,
op: WriteOp<K, V>,
housekeeper: Option<&HouseKeeperArc<K, V, S>>,
) -> Result<(), TrySendError<WriteOp<K, V>>> {
let mut op = op;
loop {
BaseCache::apply_reads_writes_if_needed(ch, housekeeper);
match ch.try_send(op) {
Ok(()) => break,
Err(TrySendError::Full(op1)) => {
op = op1;
std::thread::sleep(Duration::from_micros(WRITE_RETRY_INTERVAL_MICROS));
}
Err(e @ TrySendError::Disconnected(_)) => return Err(e),
}
}
Ok(())
}
}
#[cfg(test)]
impl<K, V, S> Cache<K, V, S>
where
K: Hash + Eq,
S: BuildHasher + Clone,
{
fn reconfigure_for_testing(&mut self) {
self.base.reconfigure_for_testing();
}
fn set_expiration_clock(&self, clock: Option<quanta::Clock>) {
self.base.set_expiration_clock(clock);
}
}
#[cfg(test)]
mod tests {
use super::{Cache, ConcurrentCacheExt};
use crate::future::CacheBuilder;
use quanta::Clock;
use std::time::Duration;
#[tokio::test]
async fn basic_single_async_task() {
let mut cache = Cache::new(3);
cache.reconfigure_for_testing();
let cache = cache;
cache.insert("a", "alice").await;
cache.insert("b", "bob").await;
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
cache.sync();
cache.insert("c", "cindy").await;
assert_eq!(cache.get(&"c"), Some("cindy"));
cache.sync();
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
cache.sync();
cache.insert("d", "david").await; cache.sync();
assert_eq!(cache.get(&"d"), None);
cache.insert("d", "david").await;
cache.sync();
assert_eq!(cache.get(&"d"), None);
cache.insert("d", "dennis").await;
cache.sync();
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
assert_eq!(cache.get(&"c"), None);
assert_eq!(cache.get(&"d"), Some("dennis"));
cache.invalidate(&"b").await;
assert_eq!(cache.get(&"b"), None);
}
#[test]
fn basic_single_blocking_api() {
let mut cache = Cache::new(3);
cache.reconfigure_for_testing();
let cache = cache;
cache.blocking_insert("a", "alice");
cache.blocking_insert("b", "bob");
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
cache.sync();
cache.blocking_insert("c", "cindy");
assert_eq!(cache.get(&"c"), Some("cindy"));
cache.sync();
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
cache.sync();
cache.blocking_insert("d", "david"); cache.sync();
assert_eq!(cache.get(&"d"), None);
cache.blocking_insert("d", "david");
cache.sync();
assert_eq!(cache.get(&"d"), None);
cache.blocking_insert("d", "dennis");
cache.sync();
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
assert_eq!(cache.get(&"c"), None);
assert_eq!(cache.get(&"d"), Some("dennis"));
cache.blocking_invalidate(&"b");
assert_eq!(cache.get(&"b"), None);
}
#[tokio::test]
async fn basic_multi_async_tasks() {
let num_threads = 4;
let cache = Cache::new(100);
let tasks = (0..num_threads)
.map(|id| {
let cache = cache.clone();
if id == 0 {
tokio::spawn(async move {
cache.blocking_insert(10, format!("{}-100", id));
cache.get(&10);
cache.blocking_insert(20, format!("{}-200", id));
cache.blocking_invalidate(&10);
})
} else {
tokio::spawn(async move {
cache.insert(10, format!("{}-100", id)).await;
cache.get(&10);
cache.insert(20, format!("{}-200", id)).await;
cache.invalidate(&10).await;
})
}
})
.collect::<Vec<_>>();
let _ = futures::future::join_all(tasks).await;
assert!(cache.get(&10).is_none());
assert!(cache.get(&20).is_some());
}
#[tokio::test]
async fn invalidate_all() {
let mut cache = Cache::new(100);
cache.reconfigure_for_testing();
let cache = cache;
cache.insert("a", "alice").await;
cache.insert("b", "bob").await;
cache.insert("c", "cindy").await;
assert_eq!(cache.get(&"a"), Some("alice"));
assert_eq!(cache.get(&"b"), Some("bob"));
assert_eq!(cache.get(&"c"), Some("cindy"));
cache.sync();
cache.invalidate_all();
cache.sync();
cache.insert("d", "david").await;
cache.sync();
assert!(cache.get(&"a").is_none());
assert!(cache.get(&"b").is_none());
assert!(cache.get(&"c").is_none());
assert_eq!(cache.get(&"d"), Some("david"));
}
#[tokio::test]
async fn time_to_live() {
let mut cache = CacheBuilder::new(100)
.time_to_live(Duration::from_secs(10))
.build();
cache.reconfigure_for_testing();
let (clock, mock) = Clock::mock();
cache.set_expiration_clock(Some(clock));
let cache = cache;
cache.insert("a", "alice").await;
cache.sync();
mock.increment(Duration::from_secs(5)); cache.sync();
cache.get(&"a");
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"a"), None);
assert!(cache.base.is_empty());
cache.insert("b", "bob").await;
cache.sync();
assert_eq!(cache.base.len(), 1);
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"b"), Some("bob"));
assert_eq!(cache.base.len(), 1);
cache.insert("b", "bill").await;
cache.sync();
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"b"), Some("bill"));
assert_eq!(cache.base.len(), 1);
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"a"), None);
assert_eq!(cache.get(&"b"), None);
assert!(cache.base.is_empty());
}
#[tokio::test]
async fn time_to_idle() {
let mut cache = CacheBuilder::new(100)
.time_to_idle(Duration::from_secs(10))
.build();
cache.reconfigure_for_testing();
let (clock, mock) = Clock::mock();
cache.set_expiration_clock(Some(clock));
let cache = cache;
cache.insert("a", "alice").await;
cache.sync();
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"a"), Some("alice"));
mock.increment(Duration::from_secs(5)); cache.sync();
cache.insert("b", "bob").await;
cache.sync();
assert_eq!(cache.base.len(), 2);
mock.increment(Duration::from_secs(5)); cache.sync();
assert_eq!(cache.get(&"a"), None);
assert_eq!(cache.get(&"b"), Some("bob"));
assert_eq!(cache.base.len(), 1);
mock.increment(Duration::from_secs(10)); cache.sync();
assert_eq!(cache.get(&"a"), None);
assert_eq!(cache.get(&"b"), None);
assert!(cache.base.is_empty());
}
}