use crate::{
CacheCallback, CacheError, Coster, DefaultCacheCallback, DefaultCoster, DefaultKeyBuilder,
DefaultUpdateValidator, Item as CrateItem, KeyBuilder, Metrics, UpdateValidator,
cache::builder::CacheBuilderCore,
metrics::MetricType,
policy::{AddOutcome, LFUPolicy},
ring::RingStripe,
store::ShardedMap,
sync::{Instant, JoinHandle, Receiver, Sender, Signal, WaitGroup, select, spawn, stop_channel},
ttl::{ExpirationMap, Time},
};
use crossbeam_channel::{RecvError, tick};
use std::{
cell::Cell,
collections::{HashMap, hash_map::RandomState},
hash::{BuildHasher, Hash},
marker::PhantomData,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::Duration,
};
thread_local! {
static ON_PROCESSOR_THREAD: Cell<bool> = const { Cell::new(false) };
}
#[inline]
fn on_processor_thread() -> bool {
ON_PROCESSOR_THREAD.with(|c| c.get())
}
pub struct CacheBuilder<
K,
V,
KH = DefaultKeyBuilder<K>,
C = DefaultCoster<V>,
U = DefaultUpdateValidator<V>,
CB = DefaultCacheCallback<V>,
S = RandomState,
> {
inner: CacheBuilderCore<K, V, KH, C, U, CB, S>,
}
impl<K: Hash + Eq, V: Send + Sync + 'static> CacheBuilder<K, V> {
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn new(num_counters: usize, max_cost: i64) -> Self {
Self {
inner: CacheBuilderCore::new(num_counters, max_cost),
}
}
}
impl<K: Hash + Eq, V: Send + Sync + 'static, KH: KeyBuilder<Key = K>> CacheBuilder<K, V, KH> {
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn new_with_key_builder(num_counters: usize, max_cost: i64, kh: KH) -> Self {
Self {
inner: CacheBuilderCore::new_with_key_builder(num_counters, max_cost, kh),
}
}
}
impl<K, V, KH, C, U, CB, S> CacheBuilder<K, V, KH, C, U, CB, S>
where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>,
C: Coster<Value = V>,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Send + Clone + 'static + Sync,
{
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn finalize(self) -> Result<Cache<K, V, KH, C, U, CB, S>, CacheError> {
let num_counters = self.inner.num_counters;
if num_counters == 0 {
return Err(CacheError::InvalidNumCounters);
}
let max_cost = self.inner.max_cost;
if max_cost == 0 {
return Err(CacheError::InvalidMaxCost);
}
let high_water = self.inner.insert_stripe_high_water;
if high_water == 0 {
return Err(CacheError::InvalidBufferSize);
}
let (insert_buf, buf_rx) =
crate::cache::insert_stripe::InsertStripeRing::<Item<V>>::new(high_water);
let insert_buf = Arc::new(insert_buf);
let (stop_tx, stop_rx) = stop_channel();
let hasher = self.inner.hasher.unwrap();
let expiration_map = ExpirationMap::with_hasher(hasher.clone());
let store = Arc::new(ShardedMap::with_validator_and_hasher(
expiration_map,
self.inner.update_validator.unwrap(),
hasher.clone(),
));
let buffer_items = self.inner.buffer_items;
let mut policy = LFUPolicy::with_hasher(num_counters, max_cost, hasher, stop_rx.clone())?;
let coster = Arc::new(self.inner.coster.unwrap());
let callback = Arc::new(self.inner.callback.unwrap());
let metrics = if self.inner.metrics {
let m = Arc::new(Metrics::new_op());
policy.collect_metrics(m.clone());
m
} else {
Arc::new(Metrics::new())
};
let policy = Arc::new(policy);
let clear_generation = Arc::new(AtomicU64::new(0));
let processor = CacheProcessor::new(
100000,
self.inner.ignore_internal_cost,
self.inner.drain_interval,
self.inner.cleanup_duration,
store.clone(),
policy.clone(),
buf_rx,
insert_buf.clone(),
stop_rx,
metrics.clone(),
callback.clone(),
clear_generation.clone(),
)
.spawn();
let get_buf = RingStripe::new(policy.clone(), buffer_items);
let inner = CacheInner {
store,
policy,
get_buf: Arc::new(get_buf),
insert_buf,
callback,
key_to_hash: Arc::new(self.inner.key_to_hash),
stop_tx: Some(stop_tx),
coster,
metrics,
clear_generation,
processor: Some(processor),
_marker: Default::default(),
};
Ok(Cache(Arc::new(inner)))
}
}
pub(crate) enum Item<V> {
New {
key: u64,
conflict: u64,
cost: i64,
expiration: Time,
version: u64,
generation: u64,
_marker: std::marker::PhantomData<fn() -> V>,
},
Update {
key: u64,
conflict: u64,
cost: i64,
external_cost: i64,
#[allow(dead_code)]
expiration: Time,
version: u64,
generation: u64,
},
Delete {
key: u64,
conflict: u64,
generation: u64,
version: u64,
},
Wait(Signal),
Clear(Signal),
}
impl<V> Item<V> {
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn new(
key: u64,
conflict: u64,
cost: i64,
exp: Time,
version: u64,
generation: u64,
) -> Self {
Self::New {
key,
conflict,
cost,
expiration: exp,
version,
generation,
_marker: std::marker::PhantomData,
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn update(
key: u64,
conflict: u64,
cost: i64,
external_cost: i64,
expiration: Time,
version: u64,
generation: u64,
) -> Self {
Self::Update {
key,
conflict,
cost,
external_cost,
expiration,
version,
generation,
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn delete(key: u64, conflict: u64, generation: u64, version: u64) -> Self {
Self::Delete {
key,
conflict,
generation,
version,
}
}
}
pub(crate) struct CacheProcessor<V, U, CB, S> {
pub(crate) insert_buf_rx: Receiver<Vec<Item<V>>>,
pub(crate) stop_rx: Receiver<()>,
pub(crate) metrics: Arc<Metrics>,
pub(crate) store: Arc<ShardedMap<V, U, S, S>>,
pub(crate) policy: Arc<LFUPolicy<S>>,
pub(crate) start_ts: HashMap<u64, Time, S>,
pub(crate) num_to_keep: usize,
pub(crate) callback: Arc<CB>,
pub(crate) ignore_internal_cost: bool,
pub(crate) item_size: usize,
pub(crate) drain_interval: Duration,
pub(crate) cleanup_duration: Duration,
pub(crate) insert_stripe: Arc<crate::cache::insert_stripe::InsertStripeRing<Item<V>>>,
pub(crate) clear_generation: Arc<AtomicU64>,
}
pub struct Cache<
K,
V,
KH = DefaultKeyBuilder<K>,
C = DefaultCoster<V>,
U = DefaultUpdateValidator<V>,
CB = DefaultCacheCallback<V>,
S = RandomState,
>(pub(crate) Arc<CacheInner<K, V, KH, C, U, CB, S>>);
impl<K, V, KH, C, U, CB, S> Clone for Cache<K, V, KH, C, U, CB, S> {
#[cfg_attr(not(tarpaulin), inline(always))]
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
pub(crate) struct CacheInner<
K,
V,
KH = DefaultKeyBuilder<K>,
C = DefaultCoster<V>,
U = DefaultUpdateValidator<V>,
CB = DefaultCacheCallback<V>,
S = RandomState,
> {
pub(crate) store: Arc<ShardedMap<V, U, S, S>>,
pub(crate) policy: Arc<LFUPolicy<S>>,
pub(crate) get_buf: Arc<RingStripe<S>>,
pub(crate) insert_buf: Arc<crate::cache::insert_stripe::InsertStripeRing<Item<V>>>,
pub(crate) stop_tx: Option<Sender<()>>,
pub(crate) callback: Arc<CB>,
pub(crate) key_to_hash: Arc<KH>,
pub(crate) coster: Arc<C>,
pub(crate) metrics: Arc<Metrics>,
pub(crate) clear_generation: Arc<AtomicU64>,
pub(crate) processor: Option<JoinHandle<Result<(), CacheError>>>,
pub(crate) _marker: PhantomData<fn(K)>,
}
impl<K, V, KH, C, U, CB, S> Drop for CacheInner<K, V, KH, C, U, CB, S> {
fn drop(&mut self) {
let _ = self.stop_tx.take();
if let Some(handle) = self.processor.take() {
if handle.thread().id() == std::thread::current().id() {
return;
}
let _ = handle.join();
}
}
}
impl<K: Hash + Eq, V: Send + Sync + 'static> Cache<K, V> {
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn new(num_counters: usize, max_cost: i64) -> Result<Self, CacheError> {
CacheBuilder::new(num_counters, max_cost).finalize()
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn builder(
num_counters: usize,
max_cost: i64,
) -> CacheBuilder<
K,
V,
DefaultKeyBuilder<K>,
DefaultCoster<V>,
DefaultUpdateValidator<V>,
DefaultCacheCallback<V>,
RandomState,
> {
CacheBuilder::new(num_counters, max_cost)
}
}
impl<K: Hash + Eq, V: Send + Sync + 'static, KH: KeyBuilder<Key = K>> Cache<K, V, KH> {
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn new_with_key_builder(
num_counters: usize,
max_cost: i64,
index: KH,
) -> Result<Self, CacheError> {
CacheBuilder::new_with_key_builder(num_counters, max_cost, index).finalize()
}
}
impl<K, V, KH, C, U, CB, S> Cache<K, V, KH, C, U, CB, S>
where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>,
C: Coster<Value = V>,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static + Send + Sync,
{
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn clear(&self) -> Result<(), CacheError> {
if on_processor_thread() {
return Err(CacheError::ChannelError(
"clear() cannot be called from a CacheCallback: \
it would wait on the processor thread that is currently running \
the callback"
.to_string(),
));
}
let wg = WaitGroup::new();
let wait = Signal::new(wg.add(1));
if self.0.insert_buf.drain_all_stripes_to_channel().is_err() {
return Err(CacheError::SendError(
"cache insert buffer: channel closed during clear prelude".to_string(),
));
}
match self.0.insert_buf.send_single(Item::Clear(wait)) {
Ok(()) => {
wg.wait();
Ok(())
}
Err(()) => Err(CacheError::SendError(
"fail to enqueue clear marker: channel closed".to_string(),
)),
}
}
pub fn insert(&self, key: K, val: V, cost: i64) -> bool {
self.try_insert(key, val, cost).unwrap()
}
pub fn try_insert(&self, key: K, val: V, cost: i64) -> Result<bool, CacheError> {
self.try_insert_with_ttl(key, val, cost, Duration::ZERO)
}
pub fn insert_with_ttl(&self, key: K, val: V, cost: i64, ttl: Duration) -> bool {
self.try_insert_with_ttl(key, val, cost, ttl).unwrap()
}
pub fn try_insert_with_ttl(
&self,
key: K,
val: V,
cost: i64,
ttl: Duration,
) -> Result<bool, CacheError> {
self.try_insert_in(key, val, cost, ttl, false)
}
pub fn insert_if_present(&self, key: K, val: V, cost: i64) -> bool {
self.try_insert_if_present(key, val, cost).unwrap()
}
pub fn try_insert_if_present(&self, key: K, val: V, cost: i64) -> Result<bool, CacheError> {
self.try_insert_in(key, val, cost, Duration::ZERO, true)
}
pub fn wait(&self) -> Result<(), CacheError> {
if on_processor_thread() {
return Err(CacheError::ChannelError(
"wait() cannot be called from a CacheCallback: \
it would wait on the processor thread that is currently running \
the callback"
.to_string(),
));
}
let wg = WaitGroup::new();
let wait_item = Item::Wait(Signal::new(wg.add(1)));
if self.0.insert_buf.drain_all_stripes_to_channel().is_err() {
return Err(CacheError::SendError(
"cache insert buffer: channel closed during wait prelude".to_string(),
));
}
match self.0.insert_buf.send_single(wait_item) {
Ok(()) => {
wg.wait();
Ok(())
}
Err(()) => Err(CacheError::SendError(
"cache insert buffer: channel closed".to_string(),
)),
}
}
pub fn remove(&self, k: &K) {
self.try_remove(k).unwrap();
}
pub fn try_remove(&self, k: &K) -> Result<(), CacheError> {
let (index, conflict) = self.0.key_to_hash.build_key(k);
if on_processor_thread() {
let captured_gen = self.0.clear_generation.load(Ordering::Acquire);
let prev = self
.0
.store
.try_remove_if_not_stale(&index, conflict, captured_gen)?;
if let Some(prev) = prev {
let current_gen = self.0.clear_generation.load(Ordering::Acquire);
if captured_gen == current_gen && !self.0.store.contains_key(&index, 0) {
self.0.policy.remove(&index);
}
self.0.callback.on_exit(Some(prev.value));
}
return Ok(());
}
let captured_gen = self.0.clear_generation.load(Ordering::Acquire);
let prev = self
.0
.store
.try_remove_if_not_stale(&index, conflict, captured_gen)?;
let Some(prev) = prev else {
return Ok(());
};
let prev_version = prev.version;
let send_result =
self
.0
.insert_buf
.send_single(Item::delete(index, conflict, captured_gen, prev_version));
self.0.callback.on_exit(Some(prev.value));
match send_result {
Ok(()) => Ok(()),
Err(()) => Err(CacheError::ChannelError(
"failed to send delete to insert buffer: channel closed".to_string(),
)),
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn try_insert_in(
&self,
key: K,
val: V,
cost: i64,
ttl: Duration,
only_update: bool,
) -> Result<bool, CacheError> {
let (_index, item, prev_val) = match self.try_update(key, val, cost, ttl, only_update)? {
Some(triple) => triple,
None => return Ok(false),
};
use crate::cache::insert_stripe::PushOutcome;
let is_update = matches!(item, Item::Update { .. });
let outcome = self.0.insert_buf.push(item);
match outcome {
PushOutcome::Buffered | PushOutcome::Sent => {
if let Some(v) = prev_val {
self.0.callback.on_exit(Some(v));
}
Ok(true)
}
PushOutcome::Dropped(batch) => {
for dropped in batch {
match dropped {
Item::New {
key,
conflict,
cost,
expiration,
version,
..
} => {
self.0.metrics.add(MetricType::DropSets, key, 1);
if let Ok(Some(sitem)) = self.0.store.try_remove_if_version(&key, conflict, version) {
if !self.0.store.contains_key(&key, 0) {
self.0.policy.remove(&key);
}
self.0.callback.on_reject(CrateItem {
val: Some(sitem.value),
index: key,
conflict,
cost,
exp: expiration,
});
}
}
Item::Update { key, .. } => {
self.0.metrics.add(MetricType::DropSets, key, 1);
}
Item::Delete { .. } | Item::Wait(_) | Item::Clear(_) => {}
}
}
if let Some(v) = prev_val {
self.0.callback.on_exit(Some(v));
}
Ok(is_update)
}
}
}
}
impl<V, U, CB, S> CacheProcessor<V, U, CB, S>
where
V: Send + Sync + 'static,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static + Send + Sync,
{
pub(crate) fn new(
num_to_keep: usize,
ignore_internal_cost: bool,
drain_interval: Duration,
cleanup_duration: Duration,
store: Arc<ShardedMap<V, U, S, S>>,
policy: Arc<LFUPolicy<S>>,
insert_buf_rx: Receiver<Vec<Item<V>>>,
insert_stripe: Arc<crate::cache::insert_stripe::InsertStripeRing<Item<V>>>,
stop_rx: Receiver<()>,
metrics: Arc<Metrics>,
callback: Arc<CB>,
clear_generation: Arc<AtomicU64>,
) -> Self {
let item_size = store.item_size();
let hasher = store.hasher();
Self {
insert_buf_rx,
stop_rx,
metrics,
store,
policy,
start_ts: HashMap::with_hasher(hasher),
num_to_keep,
callback,
ignore_internal_cost,
item_size,
drain_interval,
cleanup_duration,
insert_stripe,
clear_generation,
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn spawn(mut self) -> JoinHandle<Result<(), CacheError>> {
let drain_ticker = tick(self.drain_interval);
let cleanup_ticker = tick(self.cleanup_duration);
let stripe = self.insert_stripe.clone();
spawn(move || {
ON_PROCESSOR_THREAD.with(|c| c.set(true));
loop {
select! {
recv(self.insert_buf_rx) -> res => {
match res {
Ok(batch) => {
if let Err(e) = self.handle_insert_batch(batch) {
tracing::error!("fail to handle insert batch: {}", e);
}
}
Err(e) => {
stripe.drain_all_stripes_inline(|batch| {
for item in batch {
let _ = self.handle_item(item);
}
});
tracing::debug!("insert receiver disconnected: {}", e);
return Ok(());
}
}
},
recv(drain_ticker) -> _ => {
stripe.drain_all_stripes_inline(|batch| {
for item in batch {
let _ = self.handle_item(item);
}
});
},
recv(cleanup_ticker) -> msg => {
if let Err(e) = self.handle_cleanup_event(msg) {
tracing::error!("fail to handle cleanup event: {}", e);
}
},
recv(self.stop_rx) -> _ => {
stripe.drain_all_stripes_inline(|batch| {
for item in batch {
let _ = self.handle_item(item);
}
});
while let Ok(batch) = self.insert_buf_rx.try_recv() {
for item in batch {
match item {
Item::Wait(wg) | Item::Clear(wg) => wg.done(),
Item::New { key, conflict, version, .. }
| Item::Update { key, conflict, version, .. } => {
let _ = self.store.try_remove_if_version(&key, conflict, version);
}
Item::Delete { .. } => {}
}
}
}
return Ok(());
},
}
}
})
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn handle_cleanup_event(
&mut self,
res: Result<Instant, RecvError>,
) -> Result<(), CacheError> {
res
.map_err(|e| CacheError::RecvError(format!("fail to receive msg from ticker: {}", e)))
.and_then(|_| {
self.store.try_cleanup(self.policy.clone()).map(|items| {
items.into_iter().for_each(|victim| {
self.prepare_evict(&victim);
self.callback.on_evict(victim);
});
})
})
}
}
impl_builder!(CacheBuilder);
impl_cache_processor!(CacheProcessor, Item);
use crate::{ValueRef, ValueRefMut, store::UpdateResult};
impl<K, V, KH, C, U, CB, S> Cache<K, V, KH, C, U, CB, S>
where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>,
C: Coster<Value = V>,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static + Send,
{
pub fn get<Q>(&self, key: &Q) -> Option<ValueRef<'_, V>>
where
K: core::borrow::Borrow<Q>,
Q: core::hash::Hash + Eq + ?Sized,
{
let (index, conflict) = self.0.key_to_hash.build_key(key);
self.0.get_buf.push(index);
match self.0.store.get(&index, conflict) {
None => {
self.0.metrics.add(MetricType::Miss, index, 1);
None
}
Some(v) => {
self.0.metrics.add(MetricType::Hit, index, 1);
Some(v)
}
}
}
pub fn get_mut<Q>(&self, key: &Q) -> Option<ValueRefMut<'_, V>>
where
K: core::borrow::Borrow<Q>,
Q: core::hash::Hash + Eq + ?Sized,
{
let (index, conflict) = self.0.key_to_hash.build_key(key);
self.0.get_buf.push(index);
match self.0.store.get_mut(&index, conflict) {
None => {
self.0.metrics.add(MetricType::Miss, index, 1);
None
}
Some(v) => {
self.0.metrics.add(MetricType::Hit, index, 1);
Some(v)
}
}
}
pub fn get_ttl<Q>(&self, key: &Q) -> Option<Duration>
where
K: core::borrow::Borrow<Q>,
Q: core::hash::Hash + Eq + ?Sized,
{
let (index, conflict) = self.0.key_to_hash.build_key(key);
self
.0
.store
.get(&index, conflict)
.and_then(|_| self.0.store.expiration(&index).map(|time| time.get_ttl()))
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn max_cost(&self) -> i64 {
self.0.policy.max_cost()
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn update_max_cost(&self, max_cost: i64) {
self.0.policy.update_max_cost(max_cost)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn len(&self) -> usize {
self.0.store.len()
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn is_empty(&self) -> bool {
self.0.store.len() == 0
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn try_update(
&self,
key: K,
val: V,
cost: i64,
ttl: Duration,
only_update: bool,
) -> Result<Option<(u64, Item<V>, Option<V>)>, CacheError> {
let expiration = if ttl.is_zero() {
Time::now()
} else {
Time::now_with_expiration(ttl)
};
let (index, conflict) = self.0.key_to_hash.build_key(&key);
let external_cost = if cost == 0 {
self.0.coster.cost(&val)
} else {
0
};
let captured_gen = self
.0
.clear_generation
.load(std::sync::atomic::Ordering::Acquire);
match self
.0
.store
.try_update(index, val, conflict, expiration, captured_gen)?
{
UpdateResult::NotExist(v) => {
if only_update {
Ok(None)
} else {
match self
.0
.store
.try_insert(index, v, conflict, expiration, captured_gen)?
{
Some(version) => Ok(Some((
index,
Item::new(
index,
conflict,
cost + external_cost,
expiration,
version,
captured_gen,
),
None,
))),
None => Ok(None),
}
}
}
UpdateResult::Reject(_) | UpdateResult::Conflict(_) => Ok(None),
UpdateResult::Stale(_) => Ok(None),
UpdateResult::Update(v, version) => {
Ok(Some((
index,
Item::update(
index,
conflict,
cost,
external_cost,
expiration,
version,
captured_gen,
),
Some(v),
)))
}
}
}
}
impl<K, V, KH, C, U, CB, S> AsRef<Cache<K, V, KH, C, U, CB, S>> for Cache<K, V, KH, C, U, CB, S>
where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>,
C: Coster<Value = V>,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static,
{
fn as_ref(&self) -> &Cache<K, V, KH, C, U, CB, S> {
self
}
}