use crate::axync::{
bounded, select, stop_channel, unbounded, Receiver, RecvError, Sender, WaitGroup,
};
use crate::cache::builder::CacheBuilderCore;
use crate::policy::AsyncLFUPolicy;
use crate::ring::AsyncRingStripe;
use crate::store::ShardedMap;
use crate::ttl::{ExpirationMap, Time};
use crate::{
metrics::MetricType, CacheCallback, CacheError, Coster, DefaultCacheCallback, DefaultCoster,
DefaultKeyBuilder, DefaultUpdateValidator, KeyBuilder, Metrics, UpdateValidator,
};
use async_io::Timer;
use futures::{
future::{BoxFuture, FutureExt},
stream::StreamExt,
};
use std::collections::hash_map::RandomState;
use std::collections::HashMap;
use std::hash::{BuildHasher, Hash};
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
pub struct AsyncCacheBuilder<
K,
V,
KH = DefaultKeyBuilder<K>,
C = DefaultCoster<V>,
U = DefaultUpdateValidator<V>,
CB = DefaultCacheCallback<V>,
S = RandomState,
> {
inner: CacheBuilderCore<K, V, KH, C, U, CB, S>,
}
impl<K: Hash + Eq, V: Send + Sync + 'static> AsyncCacheBuilder<K, V> {
#[inline]
pub fn new(num_counters: usize, max_cost: i64) -> Self {
Self {
inner: CacheBuilderCore::new(num_counters, max_cost),
}
}
}
impl<K: Hash + Eq, V: Send + Sync + 'static, KH: KeyBuilder<Key = K>> AsyncCacheBuilder<K, V, KH> {
#[inline]
pub fn new_with_key_builder(num_counters: usize, max_cost: i64, kh: KH) -> Self {
Self {
inner: CacheBuilderCore::new_with_key_builder(num_counters, max_cost, kh),
}
}
}
impl<K, V, KH, C, U, CB, S> AsyncCacheBuilder<K, V, KH, C, U, CB, S>
where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>,
C: Coster<Value = V>,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static + Send + Sync,
{
#[inline]
pub fn finalize<SP, R>(
self,
spawner: SP,
) -> Result<AsyncCache<K, V, KH, C, U, CB, S>, CacheError>
where
SP: Fn(BoxFuture<'static, ()>) -> R + Send + Sync + 'static + Copy,
{
let num_counters = self.inner.num_counters;
if num_counters == 0 {
return Err(CacheError::InvalidNumCounters);
}
let max_cost = self.inner.max_cost;
if max_cost == 0 {
return Err(CacheError::InvalidMaxCost);
}
let insert_buffer_size = self.inner.insert_buffer_size;
if insert_buffer_size == 0 {
return Err(CacheError::InvalidBufferSize);
}
let (buf_tx, buf_rx) = bounded(insert_buffer_size);
let (stop_tx, stop_rx) = stop_channel();
let (clear_tx, clear_rx) = unbounded();
let hasher = self.inner.hasher.unwrap();
let expiration_map = ExpirationMap::with_hasher(hasher.clone());
let store = Arc::new(ShardedMap::with_validator_and_hasher(
expiration_map,
self.inner.update_validator.unwrap(),
hasher.clone(),
));
let mut policy = AsyncLFUPolicy::with_hasher(num_counters, max_cost, hasher, spawner)?;
let coster = Arc::new(self.inner.coster.unwrap());
let callback = Arc::new(self.inner.callback.unwrap());
let metrics = if self.inner.metrics {
let m = Arc::new(Metrics::new_op());
policy.collect_metrics(m.clone());
m
} else {
Arc::new(Metrics::new())
};
let policy = Arc::new(policy);
CacheProcessor::new(
100000,
self.inner.ignore_internal_cost,
self.inner.cleanup_duration,
store.clone(),
policy.clone(),
buf_rx,
stop_rx,
clear_rx,
metrics.clone(),
callback.clone(),
)
.spawn(Box::new(move |fut| {
spawner(fut);
}));
let buffer_items = self.inner.buffer_items;
let get_buf = AsyncRingStripe::new(policy.clone(), buffer_items);
let this = AsyncCache {
store,
policy,
get_buf: Arc::new(get_buf),
insert_buf_tx: buf_tx,
callback,
key_to_hash: Arc::new(self.inner.key_to_hash),
stop_tx,
clear_tx,
is_closed: Arc::new(AtomicBool::new(false)),
coster,
metrics,
_marker: Default::default(),
};
Ok(this)
}
}
pub(crate) struct CacheProcessor<V, U, CB, S> {
insert_buf_rx: Receiver<Item<V>>,
stop_rx: Receiver<()>,
clear_rx: Receiver<()>,
metrics: Arc<Metrics>,
store: Arc<ShardedMap<V, U, S, S>>,
policy: Arc<AsyncLFUPolicy<S>>,
start_ts: HashMap<u64, Time, S>,
num_to_keep: usize,
callback: Arc<CB>,
ignore_internal_cost: bool,
item_size: usize,
cleanup_duration: Duration,
}
pub(crate) struct CacheCleaner<'a, V, U, CB, S> {
pub(crate) processor: &'a mut CacheProcessor<V, U, CB, S>,
}
pub(crate) enum Item<V> {
New {
key: u64,
conflict: u64,
cost: i64,
value: V,
expiration: Time,
},
Update {
key: u64,
cost: i64,
external_cost: i64,
},
Delete {
key: u64,
conflict: u64,
},
Wait(WaitGroup),
}
impl<V> Item<V> {
#[inline]
fn new(key: u64, conflict: u64, cost: i64, val: V, exp: Time) -> Self {
Self::New {
key,
conflict,
cost,
value: val,
expiration: exp,
}
}
#[inline]
pub(crate) fn update(key: u64, cost: i64, external_cost: i64) -> Self {
Self::Update {
key,
cost,
external_cost,
}
}
#[inline]
fn delete(key: u64, conflict: u64) -> Self {
Self::Delete { key, conflict }
}
#[inline]
fn is_update(&self) -> bool {
matches!(self, Item::Update { .. })
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
pub struct AsyncCache<
K,
V,
KH = DefaultKeyBuilder<K>,
C = DefaultCoster<V>,
U = DefaultUpdateValidator<V>,
CB = DefaultCacheCallback<V>,
S = RandomState,
> where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>,
{
pub(crate) store: Arc<ShardedMap<V, U, S, S>>,
pub(crate) policy: Arc<AsyncLFUPolicy<S>>,
pub(crate) insert_buf_tx: Sender<Item<V>>,
pub(crate) get_buf: Arc<AsyncRingStripe<S>>,
pub(crate) stop_tx: Sender<()>,
pub(crate) clear_tx: Sender<()>,
pub(crate) callback: Arc<CB>,
pub(crate) key_to_hash: Arc<KH>,
pub(crate) is_closed: Arc<AtomicBool>,
pub(crate) coster: Arc<C>,
pub metrics: Arc<Metrics>,
pub(crate) _marker: PhantomData<fn(K)>,
}
impl<K: Hash + Eq, V: Send + Sync + 'static> AsyncCache<K, V> {
#[inline]
pub fn new<SP, R>(num_counters: usize, max_cost: i64, spawner: SP) -> Result<Self, CacheError>
where
SP: Fn(BoxFuture<'static, ()>) -> R + Send + Sync + 'static + Copy,
{
AsyncCacheBuilder::new(num_counters, max_cost).finalize(spawner)
}
#[inline]
pub fn builder(
num_counters: usize,
max_cost: i64,
) -> AsyncCacheBuilder<
K,
V,
DefaultKeyBuilder<K>,
DefaultCoster<V>,
DefaultUpdateValidator<V>,
DefaultCacheCallback<V>,
RandomState,
> {
AsyncCacheBuilder::new(num_counters, max_cost)
}
}
impl<K: Hash + Eq, V: Send + Sync + 'static, KH: KeyBuilder<Key = K>> AsyncCache<K, V, KH> {
#[inline]
pub fn new_with_key_builder<SP, R>(
num_counters: usize,
max_cost: i64,
index: KH,
spawner: SP,
) -> Result<Self, CacheError>
where
SP: Fn(BoxFuture<'static, ()>) -> R + Send + Sync + 'static + Copy,
{
AsyncCacheBuilder::new_with_key_builder(num_counters, max_cost, index).finalize(spawner)
}
}
impl<K, V, KH, C, U, CB, S> AsyncCache<K, V, KH, C, U, CB, S>
where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>,
C: Coster<Value = V>,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static + Send,
{
#[inline]
pub async fn clear(&self) -> Result<(), CacheError> {
if self.is_closed.load(Ordering::SeqCst) {
return Ok(());
}
self.clear_tx.send(()).await.map_err(|e| {
CacheError::SendError(format!("fail to send clear signal to working thread {}", e))
})?;
self.policy.clear();
self.store.clear();
self.metrics.clear();
Ok(())
}
pub async fn insert(&self, key: K, val: V, cost: i64) -> bool {
self.insert_with_ttl(key, val, cost, Duration::ZERO).await
}
pub async fn try_insert(&self, key: K, val: V, cost: i64) -> Result<bool, CacheError> {
self.try_insert_with_ttl(key, val, cost, Duration::ZERO)
.await
}
pub async fn insert_with_ttl(&self, key: K, val: V, cost: i64, ttl: Duration) -> bool {
self.try_insert_in(key, val, cost, ttl, false)
.await
.unwrap()
}
pub async fn try_insert_with_ttl(
&self,
key: K,
val: V,
cost: i64,
ttl: Duration,
) -> Result<bool, CacheError> {
self.try_insert_in(key, val, cost, ttl, false).await
}
pub async fn insert_if_present(&self, key: K, val: V, cost: i64) -> bool {
self.try_insert_in(key, val, cost, Duration::ZERO, true)
.await
.unwrap()
}
pub async fn try_insert_if_present(
&self,
key: K,
val: V,
cost: i64,
) -> Result<bool, CacheError> {
self.try_insert_in(key, val, cost, Duration::ZERO, true)
.await
}
pub async fn wait(&self) -> Result<(), CacheError> {
if self.is_closed.load(Ordering::SeqCst) {
return Ok(());
}
let wg = WaitGroup::new();
let wait_item = Item::Wait(wg.add(1));
match self.insert_buf_tx.try_send(wait_item) {
Ok(_) => {
wg.wait().await;
Ok(())
}
Err(e) => Err(CacheError::SendError(format!(
"cache set buf sender: {}",
e
))),
}
}
pub async fn remove(&self, k: &K) {
self.try_remove(k).await.unwrap()
}
pub async fn try_remove(&self, k: &K) -> Result<(), CacheError> {
if self.is_closed.load(Ordering::SeqCst) {
return Ok(());
}
let (index, conflict) = self.key_to_hash.build_key(k);
let prev = self.store.try_remove(&index, conflict)?;
if let Some(prev) = prev {
self.callback.on_exit(Some(prev.value.into_inner()));
}
let _ = self.insert_buf_tx.send(Item::delete(index, conflict)).await;
Ok(())
}
#[inline]
pub async fn close(&self) -> Result<(), CacheError> {
if self.is_closed.load(Ordering::SeqCst) {
return Ok(());
}
self.clear().await?;
self.stop_tx.send(()).await.map_err(|e| {
CacheError::SendError(format!("fail to send stop signal to working thread, {}", e))
})?;
self.policy.close().await?;
self.is_closed.store(true, Ordering::SeqCst);
Ok(())
}
#[inline]
async fn try_insert_in(
&self,
key: K,
val: V,
cost: i64,
ttl: Duration,
only_update: bool,
) -> Result<bool, CacheError> {
if self.is_closed.load(Ordering::SeqCst) {
return Ok(false);
}
if let Some((index, item)) = self.try_update(key, val, cost, ttl, only_update)? {
let is_update = item.is_update();
select! {
res = self.insert_buf_tx.send(item).fuse() => res.map_or_else(|_| {
if is_update {
Ok(true)
} else {
self.metrics.add(MetricType::DropSets, index, 1);
Ok(false)
}
}, |_| Ok(true)),
default => {
if is_update {
Ok(true)
} else {
self.metrics.add(MetricType::DropSets, index, 1);
Ok(false)
}
}
}
} else {
Ok(false)
}
}
}
impl<V, U, CB, S> CacheProcessor<V, U, CB, S>
where
V: Send + Sync + 'static,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static + Send + Sync,
{
pub(crate) fn new(
num_to_keep: usize,
ignore_internal_cost: bool,
cleanup_duration: Duration,
store: Arc<ShardedMap<V, U, S, S>>,
policy: Arc<AsyncLFUPolicy<S>>,
insert_buf_rx: Receiver<Item<V>>,
stop_rx: Receiver<()>,
clear_rx: Receiver<()>,
metrics: Arc<Metrics>,
callback: Arc<CB>,
) -> Self {
let item_size = store.item_size();
let hasher = store.hasher();
Self {
insert_buf_rx,
stop_rx,
clear_rx,
metrics,
store,
policy,
start_ts: HashMap::with_hasher(hasher),
num_to_keep,
callback,
ignore_internal_cost,
item_size,
cleanup_duration,
}
}
#[inline]
pub(crate) fn spawn(mut self, spawner: Box<dyn Fn(BoxFuture<'static, ()>) + Send + Sync>) {
(spawner)(Box::pin(async move {
let mut cleanup_timer = Timer::interval(self.cleanup_duration);
loop {
select! {
item = self.insert_buf_rx.recv().fuse() => {
if let Err(e) = self.handle_insert_event(item) {
tracing::error!("fail to handle insert event, error: {}", e);
}
}
_ = cleanup_timer.next().fuse() => {
if let Err(e) = self.handle_cleanup_event() {
tracing::error!("fail to handle cleanup event, error: {}", e);
}
},
_ = self.clear_rx.recv().fuse() => {
if let Err(e) = CacheCleaner::new(&mut self).clean().await {
tracing::error!("fail to handle clear event, error: {}", e);
}
},
_ = self.stop_rx.recv().fuse() => {
_ = self.handle_close_event();
return;
},
}
}
}))
}
#[inline]
pub(crate) fn handle_close_event(&mut self) -> Result<(), CacheError> {
self.insert_buf_rx.close();
self.clear_rx.close();
self.stop_rx.close();
Ok(())
}
#[inline]
pub(crate) fn handle_insert_event(
&mut self,
res: Result<Item<V>, RecvError>,
) -> Result<(), CacheError> {
res.map_err(|_| CacheError::RecvError("fail to receive msg from insert buffer".to_string()))
.and_then(|item| self.handle_item(item))
}
#[inline]
pub(crate) fn handle_cleanup_event(&mut self) -> Result<(), CacheError> {
self.store
.try_cleanup_async(self.policy.clone())?
.into_iter()
.for_each(|victim| {
self.prepare_evict(&victim);
self.callback.on_evict(victim);
});
Ok(())
}
}
impl<'a, V, U, CB, S> CacheCleaner<'a, V, U, CB, S>
where
V: Send + Sync + 'static,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static + Send + Sync,
{
#[inline]
pub(crate) async fn clean(mut self) -> Result<(), CacheError> {
loop {
select! {
item = self.processor.insert_buf_rx.recv().fuse() => {
match item {
Ok(item) => {
self.handle_item(item);
},
Err(_) => return Ok(()),
}
},
default => return Ok(()),
}
}
}
}
impl_builder!(AsyncCacheBuilder);
impl_async_cache!(AsyncCache, AsyncCacheBuilder, Item);
impl_cache_processor!(CacheProcessor, Item);
impl_cache_cleaner!(CacheCleaner, CacheProcessor, Item);