use crate::{
CacheCallback, CacheError, Coster, DefaultCacheCallback, DefaultCoster, DefaultKeyBuilder,
DefaultUpdateValidator, Item as CrateItem, KeyBuilder, Metrics, UpdateValidator, ValueRef,
ValueRefMut,
axync::Waiter,
cache::builder::CacheBuilderCore,
metrics::MetricType,
policy::{AddOutcome, AsyncLFUPolicy},
ring::AsyncRingStripe,
store::{ShardedMap, UpdateResult},
ttl::{ExpirationMap, Time},
};
use agnostic_lite::RuntimeLite;
use crossbeam_channel::{Receiver, Sender, bounded as cb_bounded, select as cb_select, tick};
use std::{
collections::{HashMap, hash_map::RandomState},
hash::{BuildHasher, Hash},
marker::PhantomData,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
thread::{JoinHandle, spawn as thread_spawn},
time::Duration,
};
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
pub struct AsyncCacheBuilder<
K,
V,
KH = DefaultKeyBuilder<K>,
C = DefaultCoster<V>,
U = DefaultUpdateValidator<V>,
CB = DefaultCacheCallback<V>,
S = RandomState,
> {
inner: CacheBuilderCore<K, V, KH, C, U, CB, S>,
}
impl<K: Hash + Eq, V: Send + Sync + 'static> AsyncCacheBuilder<K, V> {
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn new(num_counters: usize, max_cost: i64) -> Self {
Self {
inner: CacheBuilderCore::new(num_counters, max_cost),
}
}
}
impl<K: Hash + Eq, V: Send + Sync + 'static, KH: KeyBuilder<Key = K>> AsyncCacheBuilder<K, V, KH> {
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn new_with_key_builder(num_counters: usize, max_cost: i64, kh: KH) -> Self {
Self {
inner: CacheBuilderCore::new_with_key_builder(num_counters, max_cost, kh),
}
}
}
impl<K, V, KH, C, U, CB, S> AsyncCacheBuilder<K, V, KH, C, U, CB, S>
where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>,
C: Coster<Value = V>,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static + Send + Sync,
{
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn build<R: RuntimeLite>(self) -> Result<AsyncCache<K, V, R, KH, C, U, CB, S>, CacheError>
where
<R as RuntimeLite>::Interval: Send,
{
let num_counters = self.inner.num_counters;
if num_counters == 0 {
return Err(CacheError::InvalidNumCounters);
}
let max_cost = self.inner.max_cost;
if max_cost == 0 {
return Err(CacheError::InvalidMaxCost);
}
let (insert_buf_ring, buf_rx) = crate::cache::insert_stripe::InsertStripeRing::<Item<V>>::new(
self.inner.insert_stripe_high_water,
);
let insert_buf_ring = Arc::new(insert_buf_ring);
let (policy_stop_tx, policy_stop_rx) = crate::axync::stop_channel();
let (stop_tx, stop_rx) = cb_bounded::<()>(1);
let hasher = self.inner.hasher.unwrap();
let expiration_map = ExpirationMap::with_hasher(hasher.clone());
let store = Arc::new(ShardedMap::with_validator_and_hasher(
expiration_map,
self.inner.update_validator.unwrap(),
hasher.clone(),
));
let mut policy =
AsyncLFUPolicy::with_hasher::<R>(num_counters, max_cost, hasher, policy_stop_rx)?;
let coster = Arc::new(self.inner.coster.unwrap());
let callback = Arc::new(self.inner.callback.unwrap());
let metrics = if self.inner.metrics {
let m = Arc::new(Metrics::new_op());
policy.collect_metrics(m.clone());
m
} else {
Arc::new(Metrics::new())
};
let policy = Arc::new(policy);
let clear_generation = Arc::new(AtomicU64::new(0));
let processor = CacheProcessor::new(
100000,
self.inner.ignore_internal_cost,
self.inner.cleanup_duration,
self.inner.drain_interval,
store.clone(),
policy.clone(),
buf_rx,
insert_buf_ring.clone(),
stop_rx,
metrics.clone(),
callback.clone(),
clear_generation.clone(),
)
.spawn();
let buffer_items = self.inner.buffer_items;
let get_buf = AsyncRingStripe::new(policy.clone(), buffer_items);
let inner = AsyncCacheInner {
store,
policy,
get_buf: Arc::new(get_buf),
insert_buf_ring,
callback,
key_to_hash: Arc::new(self.inner.key_to_hash),
stop_tx: Some(stop_tx),
policy_stop_tx: Some(policy_stop_tx),
coster,
metrics,
clear_generation,
processor: Some(processor),
_marker: PhantomData,
_runtime: PhantomData,
};
Ok(AsyncCache(Arc::new(inner)))
}
}
pub(crate) struct CacheProcessor<V, U, CB, S> {
insert_buf_rx: Receiver<Vec<Item<V>>>,
insert_stripe: Arc<crate::cache::insert_stripe::InsertStripeRing<Item<V>>>,
stop_rx: Receiver<()>,
metrics: Arc<Metrics>,
store: Arc<ShardedMap<V, U, S, S>>,
policy: Arc<AsyncLFUPolicy<S>>,
start_ts: HashMap<u64, Time, S>,
num_to_keep: usize,
callback: Arc<CB>,
ignore_internal_cost: bool,
item_size: usize,
clear_generation: Arc<AtomicU64>,
cleanup_duration: Duration,
drain_duration: Duration,
}
pub(crate) enum Item<V> {
New {
key: u64,
conflict: u64,
cost: i64,
expiration: Time,
version: u64,
generation: u64,
_marker: std::marker::PhantomData<fn() -> V>,
},
Update {
key: u64,
conflict: u64,
cost: i64,
external_cost: i64,
#[allow(dead_code)]
expiration: Time,
version: u64,
generation: u64,
},
Delete {
key: u64,
conflict: u64,
generation: u64,
version: u64,
},
Wait(Waiter),
Clear(Waiter),
}
impl<V> Item<V> {
#[cfg_attr(not(tarpaulin), inline(always))]
fn new(key: u64, conflict: u64, cost: i64, exp: Time, version: u64, generation: u64) -> Self {
Self::New {
key,
conflict,
cost,
expiration: exp,
version,
generation,
_marker: std::marker::PhantomData,
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn update(
key: u64,
conflict: u64,
cost: i64,
external_cost: i64,
expiration: Time,
version: u64,
generation: u64,
) -> Self {
Self::Update {
key,
conflict,
cost,
external_cost,
expiration,
version,
generation,
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn delete(key: u64, conflict: u64, generation: u64, version: u64) -> Self {
Self::Delete {
key,
conflict,
generation,
version,
}
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
pub struct AsyncCache<
K,
V,
R,
KH = DefaultKeyBuilder<K>,
C = DefaultCoster<V>,
U = DefaultUpdateValidator<V>,
CB = DefaultCacheCallback<V>,
S = RandomState,
>(pub(crate) Arc<AsyncCacheInner<K, V, R, KH, C, U, CB, S>>)
where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>;
impl<K, V, R, KH, C, U, CB, S> Clone for AsyncCache<K, V, R, KH, C, U, CB, S>
where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>,
{
#[cfg_attr(not(tarpaulin), inline(always))]
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
pub(crate) struct AsyncCacheInner<
K,
V,
R,
KH = DefaultKeyBuilder<K>,
C = DefaultCoster<V>,
U = DefaultUpdateValidator<V>,
CB = DefaultCacheCallback<V>,
S = RandomState,
> where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>,
{
pub(crate) store: Arc<ShardedMap<V, U, S, S>>,
pub(crate) policy: Arc<AsyncLFUPolicy<S>>,
pub(crate) insert_buf_ring: Arc<crate::cache::insert_stripe::InsertStripeRing<Item<V>>>,
pub(crate) get_buf: Arc<AsyncRingStripe<S>>,
pub(crate) stop_tx: Option<Sender<()>>,
pub(crate) policy_stop_tx: Option<crate::axync::Sender<()>>,
pub(crate) callback: Arc<CB>,
pub(crate) key_to_hash: Arc<KH>,
pub(crate) coster: Arc<C>,
pub metrics: Arc<Metrics>,
pub(crate) clear_generation: Arc<AtomicU64>,
pub(crate) processor: Option<JoinHandle<Result<(), CacheError>>>,
pub(crate) _marker: PhantomData<fn(K)>,
pub(crate) _runtime: PhantomData<R>,
}
impl<K, V, R, KH, C, U, CB, S> AsyncCache<K, V, R, KH, C, U, CB, S>
where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>,
C: Coster<Value = V>,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static + Send,
R: RuntimeLite,
{
#[cfg_attr(not(tarpaulin), inline(always))]
pub async fn clear(&self) -> Result<(), CacheError> {
self
.0
.insert_buf_ring
.drain_all_stripes_to_channel()
.map_err(|_| CacheError::SendError("fail to drain stripes: channel closed".to_string()))?;
let (waiter, rx) = Waiter::new();
self
.0
.insert_buf_ring
.send_single(Item::Clear(waiter))
.map_err(|_| {
CacheError::SendError("fail to enqueue clear marker: channel closed".to_string())
})?;
let _ = rx.await;
Ok(())
}
pub async fn insert(&self, key: K, val: V, cost: i64) -> bool {
self.insert_with_ttl(key, val, cost, Duration::ZERO).await
}
pub async fn try_insert(&self, key: K, val: V, cost: i64) -> Result<bool, CacheError> {
self
.try_insert_with_ttl(key, val, cost, Duration::ZERO)
.await
}
pub async fn insert_with_ttl(&self, key: K, val: V, cost: i64, ttl: Duration) -> bool {
self
.try_insert_in(key, val, cost, ttl, false)
.await
.unwrap()
}
pub async fn try_insert_with_ttl(
&self,
key: K,
val: V,
cost: i64,
ttl: Duration,
) -> Result<bool, CacheError> {
self.try_insert_in(key, val, cost, ttl, false).await
}
pub async fn insert_if_present(&self, key: K, val: V, cost: i64) -> bool {
self
.try_insert_in(key, val, cost, Duration::ZERO, true)
.await
.unwrap()
}
pub async fn try_insert_if_present(&self, key: K, val: V, cost: i64) -> Result<bool, CacheError> {
self
.try_insert_in(key, val, cost, Duration::ZERO, true)
.await
}
pub async fn wait(&self) -> Result<(), CacheError> {
self
.0
.insert_buf_ring
.drain_all_stripes_to_channel()
.map_err(|_| CacheError::SendError("fail to drain stripes: channel closed".to_string()))?;
let (waiter, rx) = Waiter::new();
self
.0
.insert_buf_ring
.send_single(Item::Wait(waiter))
.map_err(|_| {
CacheError::SendError("fail to enqueue wait marker: channel closed".to_string())
})?;
let _ = rx.await;
Ok(())
}
pub async fn remove(&self, k: &K) {
self.try_remove(k).await.unwrap()
}
pub async fn try_remove(&self, k: &K) -> Result<(), CacheError> {
let (index, conflict) = self.0.key_to_hash.build_key(k);
let captured_gen = self.0.clear_generation.load(Ordering::Acquire);
let prev = self
.0
.store
.try_remove_if_not_stale(&index, conflict, captured_gen)?;
let Some(prev) = prev else {
return Ok(());
};
let prev_version = prev.version;
let send_result =
self
.0
.insert_buf_ring
.send_single(Item::delete(index, conflict, captured_gen, prev_version));
self.0.callback.on_exit(Some(prev.value));
match send_result {
Ok(()) => Ok(()),
Err(()) => Err(CacheError::ChannelError(
"failed to send delete to insert buffer: channel closed".to_string(),
)),
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
async fn try_insert_in(
&self,
key: K,
val: V,
cost: i64,
ttl: Duration,
only_update: bool,
) -> Result<bool, CacheError> {
use crate::cache::insert_stripe::PushOutcome;
let (_index, item, prev_val) = match self.try_update(key, val, cost, ttl, only_update)? {
Some(triple) => triple,
None => return Ok(false),
};
let is_update = matches!(item, Item::Update { .. });
let result = match self.0.insert_buf_ring.push(item) {
PushOutcome::Buffered | PushOutcome::Sent => {
if let Some(v) = prev_val {
self.0.callback.on_exit(Some(v));
}
Ok(true)
}
PushOutcome::Dropped(batch) => {
rollback_batch(
&self.0.store,
&self.0.policy,
&self.0.callback,
&self.0.metrics,
&batch,
);
if let Some(v) = prev_val {
self.0.callback.on_exit(Some(v));
}
Ok(is_update)
}
};
R::yield_now().await;
result
}
}
impl<K, V, R, KH, C, U, CB, S> Drop for AsyncCacheInner<K, V, R, KH, C, U, CB, S>
where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>,
{
fn drop(&mut self) {
let _ = self.stop_tx.take();
let _ = self.policy_stop_tx.take();
if let Some(handle) = self.processor.take() {
if handle.thread().id() == std::thread::current().id() {
return;
}
let _ = handle.join();
}
}
}
pub(crate) fn rollback_batch<V, U, CB, S>(
store: &Arc<ShardedMap<V, U, S, S>>,
policy: &Arc<AsyncLFUPolicy<S>>,
callback: &Arc<CB>,
metrics: &Arc<Metrics>,
batch: &[Item<V>],
) where
V: Send + Sync + 'static,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static,
{
for item in batch {
match item {
Item::New {
key,
conflict,
cost,
expiration,
version,
..
} => {
metrics.add(MetricType::DropSets, *key, 1);
if let Ok(Some(sitem)) = store.try_remove_if_version(key, *conflict, *version) {
if !store.contains_key(key, 0) {
policy.remove(key);
}
callback.on_reject(CrateItem {
val: Some(sitem.value),
index: *key,
conflict: *conflict,
cost: *cost,
exp: *expiration,
});
}
}
Item::Update { key, .. } => {
metrics.add(MetricType::DropSets, *key, 1);
}
Item::Delete { .. } | Item::Wait(_) | Item::Clear(_) => {}
}
}
}
impl<V, U, CB, S> CacheProcessor<V, U, CB, S>
where
V: Send + Sync + 'static,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static + Send + Sync,
{
pub(crate) fn new(
num_to_keep: usize,
ignore_internal_cost: bool,
cleanup_duration: Duration,
drain_duration: Duration,
store: Arc<ShardedMap<V, U, S, S>>,
policy: Arc<AsyncLFUPolicy<S>>,
insert_buf_rx: Receiver<Vec<Item<V>>>,
insert_stripe: Arc<crate::cache::insert_stripe::InsertStripeRing<Item<V>>>,
stop_rx: Receiver<()>,
metrics: Arc<Metrics>,
callback: Arc<CB>,
clear_generation: Arc<AtomicU64>,
) -> Self {
let item_size = store.item_size();
let hasher = store.hasher();
Self {
insert_buf_rx,
insert_stripe,
stop_rx,
metrics,
store,
policy,
start_ts: HashMap::with_hasher(hasher),
num_to_keep,
callback,
ignore_internal_cost,
item_size,
cleanup_duration,
drain_duration,
clear_generation,
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn spawn(mut self) -> JoinHandle<Result<(), CacheError>> {
let drain_ticker = tick(self.drain_duration);
let cleanup_ticker = tick(self.cleanup_duration);
let stripe = self.insert_stripe.clone();
thread_spawn(move || {
loop {
cb_select! {
recv(self.insert_buf_rx) -> res => {
match res {
Ok(batch) => {
if let Err(e) = self.handle_insert_batch(batch) {
tracing::error!("fail to handle insert batch: {}", e);
}
}
Err(e) => {
stripe.drain_all_stripes_inline(|batch| {
for item in batch {
let _ = self.handle_item(item);
}
});
tracing::debug!("insert receiver disconnected: {}", e);
return Ok(());
}
}
},
recv(drain_ticker) -> _ => {
stripe.drain_all_stripes_inline(|batch| {
for item in batch {
let _ = self.handle_item(item);
}
});
},
recv(cleanup_ticker) -> _ => {
if let Err(e) = self.handle_cleanup_event() {
tracing::error!("fail to handle cleanup event: {}", e);
}
},
recv(self.stop_rx) -> _ => {
stripe.drain_all_stripes_inline(|batch| {
for item in batch {
let _ = self.handle_item(item);
}
});
while let Ok(batch) = self.insert_buf_rx.try_recv() {
for item in batch {
match item {
Item::Wait(wg) | Item::Clear(wg) => wg.done(),
Item::New { .. } | Item::Update { .. } | Item::Delete { .. } => {
let _ = self.handle_item(item);
}
}
}
}
return Ok(());
},
}
}
})
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn handle_cleanup_event(&mut self) -> Result<(), CacheError> {
self
.store
.try_cleanup_async(self.policy.clone())?
.into_iter()
.for_each(|victim| {
self.prepare_evict(&victim);
self.callback.on_evict(victim);
});
Ok(())
}
}
impl_builder!(AsyncCacheBuilder);
impl_cache_processor!(CacheProcessor, Item);
impl<K, V, R, KH, C, U, CB, S> AsyncCache<K, V, R, KH, C, U, CB, S>
where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>,
C: Coster<Value = V>,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static + Send,
{
pub async fn get<Q>(&self, key: &Q) -> Option<ValueRef<'_, V>>
where
K: core::borrow::Borrow<Q>,
Q: core::hash::Hash + Eq + ?Sized,
{
let (index, conflict) = self.0.key_to_hash.build_key(key);
self.0.get_buf.push(index);
match self.0.store.get(&index, conflict) {
None => {
self.0.metrics.add(MetricType::Miss, index, 1);
None
}
Some(v) => {
self.0.metrics.add(MetricType::Hit, index, 1);
Some(v)
}
}
}
pub async fn get_mut<Q>(&self, key: &Q) -> Option<ValueRefMut<'_, V>>
where
K: core::borrow::Borrow<Q>,
Q: core::hash::Hash + Eq + ?Sized,
{
let (index, conflict) = self.0.key_to_hash.build_key(key);
self.0.get_buf.push(index);
match self.0.store.get_mut(&index, conflict) {
None => {
self.0.metrics.add(MetricType::Miss, index, 1);
None
}
Some(v) => {
self.0.metrics.add(MetricType::Hit, index, 1);
Some(v)
}
}
}
pub fn get_ttl<Q>(&self, key: &Q) -> Option<Duration>
where
K: core::borrow::Borrow<Q>,
Q: core::hash::Hash + Eq + ?Sized,
{
let (index, conflict) = self.0.key_to_hash.build_key(key);
self
.0
.store
.get(&index, conflict)
.and_then(|_| self.0.store.expiration(&index).map(|time| time.get_ttl()))
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn max_cost(&self) -> i64 {
self.0.policy.max_cost()
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn update_max_cost(&self, max_cost: i64) {
self.0.policy.update_max_cost(max_cost)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn len(&self) -> usize {
self.0.store.len()
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn is_empty(&self) -> bool {
self.0.store.len() == 0
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn try_update(
&self,
key: K,
val: V,
cost: i64,
ttl: Duration,
only_update: bool,
) -> Result<Option<(u64, Item<V>, Option<V>)>, CacheError> {
let expiration = if ttl.is_zero() {
Time::now()
} else {
Time::now_with_expiration(ttl)
};
let (index, conflict) = self.0.key_to_hash.build_key(&key);
let external_cost = if cost == 0 {
self.0.coster.cost(&val)
} else {
0
};
let captured_gen = self
.0
.clear_generation
.load(std::sync::atomic::Ordering::Acquire);
match self
.0
.store
.try_update(index, val, conflict, expiration, captured_gen)?
{
UpdateResult::NotExist(v) => {
if only_update {
Ok(None)
} else {
match self
.0
.store
.try_insert(index, v, conflict, expiration, captured_gen)?
{
Some(version) => Ok(Some((
index,
Item::new(
index,
conflict,
cost + external_cost,
expiration,
version,
captured_gen,
),
None,
))),
None => Ok(None),
}
}
}
UpdateResult::Reject(_) | UpdateResult::Conflict(_) => Ok(None),
UpdateResult::Stale(_) => Ok(None),
UpdateResult::Update(v, version) => {
Ok(Some((
index,
Item::update(
index,
conflict,
cost,
external_cost,
expiration,
version,
captured_gen,
),
Some(v),
)))
}
}
}
}
impl<K, V, R, KH, C, U, CB, S> AsRef<AsyncCache<K, V, R, KH, C, U, CB, S>>
for AsyncCache<K, V, R, KH, C, U, CB, S>
where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>,
C: Coster<Value = V>,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static,
{
fn as_ref(&self) -> &AsyncCache<K, V, R, KH, C, U, CB, S> {
self
}
}