use crate::cache::builder::CacheBuilderCore;
use crate::policy::LFUPolicy;
use crate::ring::RingStripe;
use crate::store::ShardedMap;
use crate::sync::{
bounded, select, spawn, stop_channel, unbounded, Instant, JoinHandle, Receiver, Sender,
UnboundedReceiver, UnboundedSender, WaitGroup,
};
use crate::ttl::{ExpirationMap, Time};
use crate::{
metrics::MetricType, CacheCallback, CacheError, Coster, DefaultCacheCallback, DefaultCoster,
DefaultKeyBuilder, DefaultUpdateValidator, KeyBuilder, Metrics, UpdateValidator,
};
use crossbeam_channel::{tick, RecvError};
use std::collections::hash_map::RandomState;
use std::collections::HashMap;
use std::hash::{BuildHasher, Hash};
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
pub struct CacheBuilder<
K,
V,
KH = DefaultKeyBuilder<K>,
C = DefaultCoster<V>,
U = DefaultUpdateValidator<V>,
CB = DefaultCacheCallback<V>,
S = RandomState,
> {
inner: CacheBuilderCore<K, V, KH, C, U, CB, S>,
}
impl<K: Hash + Eq, V: Send + Sync + 'static> CacheBuilder<K, V> {
#[inline]
pub fn new(num_counters: usize, max_cost: i64) -> Self {
Self {
inner: CacheBuilderCore::new(num_counters, max_cost),
}
}
}
impl<K: Hash + Eq, V: Send + Sync + 'static, KH: KeyBuilder<Key = K>> CacheBuilder<K, V, KH> {
#[inline]
pub fn new_with_key_builder(num_counters: usize, max_cost: i64, kh: KH) -> Self {
Self {
inner: CacheBuilderCore::new_with_key_builder(num_counters, max_cost, kh),
}
}
}
impl<K, V, KH, C, U, CB, S> CacheBuilder<K, V, KH, C, U, CB, S>
where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>,
C: Coster<Value = V>,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Send + Clone + 'static + Sync,
{
#[inline]
pub fn finalize(self) -> Result<Cache<K, V, KH, C, U, CB, S>, CacheError> {
let num_counters = self.inner.num_counters;
if num_counters == 0 {
return Err(CacheError::InvalidNumCounters);
}
let max_cost = self.inner.max_cost;
if max_cost == 0 {
return Err(CacheError::InvalidMaxCost);
}
let insert_buffer_size = self.inner.insert_buffer_size;
if insert_buffer_size == 0 {
return Err(CacheError::InvalidBufferSize);
}
let (buf_tx, buf_rx) = bounded(insert_buffer_size);
let (stop_tx, stop_rx) = stop_channel();
let (clear_tx, clear_rx) = unbounded();
let hasher = self.inner.hasher.unwrap();
let expiration_map = ExpirationMap::with_hasher(hasher.clone());
let store = Arc::new(ShardedMap::with_validator_and_hasher(
expiration_map,
self.inner.update_validator.unwrap(),
hasher.clone(),
));
let buffer_items = self.inner.buffer_items;
let mut policy = LFUPolicy::with_hasher(num_counters, max_cost, hasher)?;
let coster = Arc::new(self.inner.coster.unwrap());
let callback = Arc::new(self.inner.callback.unwrap());
let metrics = if self.inner.metrics {
let m = Arc::new(Metrics::new_op());
policy.collect_metrics(m.clone());
m
} else {
Arc::new(Metrics::new())
};
let policy = Arc::new(policy);
CacheProcessor::new(
100000,
self.inner.ignore_internal_cost,
self.inner.cleanup_duration,
store.clone(),
policy.clone(),
buf_rx,
stop_rx,
clear_rx,
metrics.clone(),
callback.clone(),
)
.spawn();
let get_buf = RingStripe::new(policy.clone(), buffer_items);
let this = Cache {
store,
policy,
get_buf: Arc::new(get_buf),
insert_buf_tx: buf_tx,
callback,
key_to_hash: Arc::new(self.inner.key_to_hash),
stop_tx,
clear_tx,
is_closed: Arc::new(AtomicBool::new(false)),
coster,
metrics,
_marker: Default::default(),
};
Ok(this)
}
}
pub(crate) enum Item<V> {
New {
key: u64,
conflict: u64,
cost: i64,
value: V,
expiration: Time,
},
Update {
key: u64,
cost: i64,
external_cost: i64,
},
Delete {
key: u64,
conflict: u64,
},
Wait(WaitGroup),
}
impl<V> Item<V> {
#[inline]
fn new(key: u64, conflict: u64, cost: i64, val: V, exp: Time) -> Self {
Self::New {
key,
conflict,
cost,
value: val,
expiration: exp,
}
}
#[inline]
pub(crate) fn update(key: u64, cost: i64, external_cost: i64) -> Self {
Self::Update {
key,
cost,
external_cost,
}
}
#[inline]
fn delete(key: u64, conflict: u64) -> Self {
Self::Delete { key, conflict }
}
#[inline]
fn is_update(&self) -> bool {
matches!(self, Item::Update { .. })
}
}
pub(crate) struct CacheProcessor<V, U, CB, S> {
pub(crate) insert_buf_rx: Receiver<Item<V>>,
pub(crate) stop_rx: Receiver<()>,
pub(crate) clear_rx: UnboundedReceiver<()>,
pub(crate) metrics: Arc<Metrics>,
pub(crate) store: Arc<ShardedMap<V, U, S, S>>,
pub(crate) policy: Arc<LFUPolicy<S>>,
pub(crate) start_ts: HashMap<u64, Time, S>,
pub(crate) num_to_keep: usize,
pub(crate) callback: Arc<CB>,
pub(crate) ignore_internal_cost: bool,
pub(crate) item_size: usize,
pub(crate) cleanup_duration: Duration,
}
pub(crate) struct CacheCleaner<'a, V, U, CB, S> {
pub(crate) processor: &'a mut CacheProcessor<V, U, CB, S>,
}
pub struct Cache<
K,
V,
KH = DefaultKeyBuilder<K>,
C = DefaultCoster<V>,
U = DefaultUpdateValidator<V>,
CB = DefaultCacheCallback<V>,
S = RandomState,
> {
pub(crate) store: Arc<ShardedMap<V, U, S, S>>,
pub(crate) policy: Arc<LFUPolicy<S>>,
pub(crate) get_buf: Arc<RingStripe<S>>,
pub(crate) insert_buf_tx: Sender<Item<V>>,
pub(crate) stop_tx: Sender<()>,
pub(crate) clear_tx: UnboundedSender<()>,
pub(crate) callback: Arc<CB>,
pub(crate) key_to_hash: Arc<KH>,
pub(crate) is_closed: Arc<AtomicBool>,
pub(crate) coster: Arc<C>,
pub metrics: Arc<Metrics>,
pub(crate) _marker: PhantomData<fn(K)>,
}
impl<K: Hash + Eq, V: Send + Sync + 'static> Cache<K, V> {
#[inline]
pub fn new(num_counters: usize, max_cost: i64) -> Result<Self, CacheError> {
CacheBuilder::new(num_counters, max_cost).finalize()
}
#[inline]
pub fn builder(
num_counters: usize,
max_cost: i64,
) -> CacheBuilder<
K,
V,
DefaultKeyBuilder<K>,
DefaultCoster<V>,
DefaultUpdateValidator<V>,
DefaultCacheCallback<V>,
RandomState,
> {
CacheBuilder::new(num_counters, max_cost)
}
}
impl<K: Hash + Eq, V: Send + Sync + 'static, KH: KeyBuilder<Key = K>> Cache<K, V, KH> {
#[inline]
pub fn new_with_key_builder(
num_counters: usize,
max_cost: i64,
index: KH,
) -> Result<Self, CacheError> {
CacheBuilder::new_with_key_builder(num_counters, max_cost, index).finalize()
}
}
impl<K, V, KH, C, U, CB, S> Cache<K, V, KH, C, U, CB, S>
where
K: Hash + Eq,
V: Send + Sync + 'static,
KH: KeyBuilder<Key = K>,
C: Coster<Value = V>,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static + Send + Sync,
{
#[inline]
pub fn clear(&self) -> Result<(), CacheError> {
if self.is_closed.load(Ordering::SeqCst) {
return Ok(());
}
self.clear_tx.send(()).map_err(|e| {
CacheError::SendError(format!("fail to send clear signal to working thread {}", e))
})?;
self.policy.clear();
self.store.clear();
self.metrics.clear();
Ok(())
}
pub fn insert(&self, key: K, val: V, cost: i64) -> bool {
self.try_insert(key, val, cost).unwrap()
}
pub fn try_insert(&self, key: K, val: V, cost: i64) -> Result<bool, CacheError> {
self.try_insert_with_ttl(key, val, cost, Duration::ZERO)
}
pub fn insert_with_ttl(&self, key: K, val: V, cost: i64, ttl: Duration) -> bool {
self.try_insert_with_ttl(key, val, cost, ttl).unwrap()
}
pub fn try_insert_with_ttl(
&self,
key: K,
val: V,
cost: i64,
ttl: Duration,
) -> Result<bool, CacheError> {
self.try_insert_in(key, val, cost, ttl, false)
}
pub fn insert_if_present(&self, key: K, val: V, cost: i64) -> bool {
self.try_insert_if_present(key, val, cost).unwrap()
}
pub fn try_insert_if_present(&self, key: K, val: V, cost: i64) -> Result<bool, CacheError> {
self.try_insert_in(key, val, cost, Duration::ZERO, true)
}
pub fn wait(&self) -> Result<(), CacheError> {
if self.is_closed.load(Ordering::SeqCst) {
return Ok(());
}
let wg = WaitGroup::new();
let wait_item = Item::Wait(wg.add(1));
self.insert_buf_tx
.try_send(wait_item)
.map(|_| wg.wait())
.map_err(|e| CacheError::SendError(format!("cache set buf sender: {}", e)))
}
pub fn remove(&self, k: &K) {
self.try_remove(k).unwrap();
}
pub fn try_remove(&self, k: &K) -> Result<(), CacheError> {
if self.is_closed.load(Ordering::SeqCst) {
return Ok(());
}
let (index, conflict) = self.key_to_hash.build_key(k);
let prev = self.store.try_remove(&index, conflict)?;
if let Some(prev) = prev {
self.callback.on_exit(Some(prev.value.into_inner()));
}
self.insert_buf_tx
.try_send(Item::delete(index, conflict))
.map_err(|e| {
CacheError::ChannelError(format!(
"failed to send message to the insert buffer: {}",
&e
))
})?;
Ok(())
}
#[inline]
pub fn close(&self) -> Result<(), CacheError> {
if self.is_closed.load(Ordering::SeqCst) {
return Ok(());
}
self.clear()?;
self.stop_tx
.send(())
.map_err(|e| CacheError::SendError(format!("{}", e)))?;
self.policy.close()?;
self.is_closed.store(true, Ordering::SeqCst);
Ok(())
}
#[inline]
fn try_insert_in(
&self,
key: K,
val: V,
cost: i64,
ttl: Duration,
only_update: bool,
) -> Result<bool, CacheError> {
if self.is_closed.load(Ordering::SeqCst) {
return Ok(false);
}
self.try_update(key, val, cost, ttl, only_update)?
.map_or(Ok(false), |(index, item)| {
let is_update = item.is_update();
select! {
send(self.insert_buf_tx, item) -> res => {
res.map_or_else(|_| {
if is_update {
Ok(true)
} else {
self.metrics.add(MetricType::DropSets, index, 1);
Ok(false)
}
}, |_| Ok(true))
},
default => {
if item.is_update() {
Ok(true)
} else {
self.metrics.add(MetricType::DropSets, index, 1);
Ok(false)
}
}
}
})
}
}
impl<V, U, CB, S> CacheProcessor<V, U, CB, S>
where
V: Send + Sync + 'static,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static + Send + Sync,
{
pub(crate) fn new(
num_to_keep: usize,
ignore_internal_cost: bool,
cleanup_duration: Duration,
store: Arc<ShardedMap<V, U, S, S>>,
policy: Arc<LFUPolicy<S>>,
insert_buf_rx: Receiver<Item<V>>,
stop_rx: Receiver<()>,
clear_rx: UnboundedReceiver<()>,
metrics: Arc<Metrics>,
callback: Arc<CB>,
) -> Self {
let item_size = store.item_size();
let hasher = store.hasher();
Self {
insert_buf_rx,
stop_rx,
clear_rx,
metrics,
store,
policy,
start_ts: HashMap::with_hasher(hasher),
num_to_keep,
callback,
ignore_internal_cost,
item_size,
cleanup_duration,
}
}
#[inline]
pub(crate) fn spawn(mut self) -> JoinHandle<Result<(), CacheError>> {
let ticker = tick(self.cleanup_duration);
spawn(move || loop {
select! {
recv(self.insert_buf_rx) -> res => {
if let Err(e) = self.handle_insert_event(res) {
tracing::error!("fail to handle insert event: {}", e);
}
},
recv(self.clear_rx) -> _ => {
if let Err(e) = self.handle_clear_event() {
tracing::error!("fail to handle clear event: {}", e);
}
},
recv(ticker) -> msg => {
if let Err(e) = self.handle_cleanup_event(msg) {
tracing::error!("fail to handle cleanup event: {}", e);
}
},
recv(self.stop_rx) -> _ => return Ok(()),
}
})
}
#[inline]
pub(crate) fn handle_clear_event(&mut self) -> Result<(), CacheError> {
CacheCleaner::new(self).clean()
}
#[inline]
pub(crate) fn handle_insert_event(
&mut self,
msg: Result<Item<V>, RecvError>,
) -> Result<(), CacheError> {
msg.map_err(|e| {
CacheError::RecvError(format!("fail to receive msg from insert buffer: {}", e))
})
.and_then(|item| self.handle_item(item))
}
#[inline]
pub(crate) fn handle_cleanup_event(
&mut self,
res: Result<Instant, RecvError>,
) -> Result<(), CacheError> {
res.map_err(|e| CacheError::RecvError(format!("fail to receive msg from ticker: {}", e)))
.and_then(|_| {
self.store.try_cleanup(self.policy.clone()).map(|items| {
items.into_iter().for_each(|victim| {
self.prepare_evict(&victim);
self.callback.on_evict(victim);
});
})
})
}
}
impl<'a, V, U, CB, S> CacheCleaner<'a, V, U, CB, S>
where
V: Send + Sync + 'static,
U: UpdateValidator<Value = V>,
CB: CacheCallback<Value = V>,
S: BuildHasher + Clone + 'static + Send,
{
#[inline]
pub(crate) fn clean(mut self) -> Result<(), CacheError> {
loop {
select! {
recv(self.processor.insert_buf_rx) -> msg => {
msg.map(|item| self.handle_item(item)).map_err(|e| CacheError::RecvError(format!("fail to receive msg from insert buffer: {}", e)))?;
},
default => return Ok(()),
}
}
}
}
impl_builder!(CacheBuilder);
impl_cache!(Cache, CacheBuilder, Item);
impl_cache_processor!(CacheProcessor, Item);
impl_cache_cleaner!(CacheCleaner, CacheProcessor, Item);