use crate::entry::CacheEntry;
use crate::entry_api_async::{AsyncEntry, AsyncOccupiedEntry, AsyncVacantEntry};
use crate::error::ComputeResult;
use crate::iter::{AsyncSnapshotIter, IterStream, DEFAULT_ITER_BATCH_SIZE};
use crate::loader::LoadFuture;
use crate::policy::AccessEvent;
use crate::shared::CacheShared;
use crate::task::janitor::COOPERATIVE_MAINTENANCE_DRAIN_LIMIT;
use crate::{time, Cache, EvictionReason, MetricsSnapshot};
use std::borrow::Borrow;
use std::cell::Cell;
use std::hash::{BuildHasher, Hash};
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;
use ahash::{HashMap, HashMapExt};
use equivalent::Equivalent;
use futures_util::future;
thread_local! {
static RNG: Cell<u32> = Cell::new(1);
}
#[derive(Debug)]
pub struct AsyncCache<K: Send, V: Send + Sync, H = ahash::RandomState> {
pub(crate) shared: Arc<CacheShared<K, V, H>>,
}
impl<K, V, H> Clone for AsyncCache<K, V, H>
where
K: Send,
V: Send + Sync,
H: BuildHasher,
{
fn clone(&self) -> Self {
Self {
shared: self.shared.clone(),
}
}
}
impl<K, V, H> AsyncCache<K, V, H>
where
K: Clone + Eq + Hash + Send,
V: Send + Sync,
H: BuildHasher + Clone,
{
pub fn to_sync(&self) -> Cache<K, V, H> {
Cache {
shared: self.shared.clone(),
}
}
pub fn metrics(&self) -> MetricsSnapshot {
self.shared.flush_for_introspection();
return self.shared.metrics.snapshot();
}
pub async fn get<Q, F, R>(&self, key: &Q, f: F) -> Option<R>
where
K: Borrow<Q>,
Q: Eq + Hash + Equivalent<K> + ?Sized,
F: FnOnce(&V) -> R,
{
let hash = crate::store::hash_key(&self.shared.store.hasher, key);
let shard_index = self.shared.get_shard_index_from_hash(hash);
let shard = &self.shared.store.shards[shard_index];
let mut result = None;
{
let guard = shard.map.read();
if let Some((found_key, entry_in_guard)) = guard.get_key_value(key) {
if !entry_in_guard.is_expired(self.shared.time_to_idle) {
result = Some(f(entry_in_guard.value().as_ref()));
self.on_hit(found_key, hash, entry_in_guard, shard_index);
}
}
}
if result.is_some() {
self.shared.metrics.hits.fetch_add(1, Ordering::Relaxed);
} else {
self.shared.metrics.misses.fetch_add(1, Ordering::Relaxed);
}
result
}
pub async fn fetch<Q>(&self, key: &Q) -> Option<Arc<V>>
where
K: Borrow<Q>,
Q: Eq + Hash + Equivalent<K> + ?Sized,
{
let hash = crate::store::hash_key(&self.shared.store.hasher, key);
let shard_index = self.shared.get_shard_index_from_hash(hash);
let shard = &self.shared.store.shards[shard_index];
let mut value = None;
{
let guard = shard.map.read();
if let Some((found_key, entry_in_guard)) = guard.get_key_value(key) {
if !entry_in_guard.is_expired(self.shared.time_to_idle) {
self.on_hit(found_key, hash, entry_in_guard, shard_index);
value = Some(entry_in_guard.value());
}
}
}
if value.is_some() {
self.shared.metrics.hits.fetch_add(1, Ordering::Relaxed);
} else {
self.shared.metrics.misses.fetch_add(1, Ordering::Relaxed);
}
value
}
pub async fn peek<Q>(&self, key: &Q) -> Option<Arc<V>>
where
K: Borrow<Q>,
Q: Eq + Hash + Equivalent<K> + ?Sized,
{
let shard = self.shared.store.get_shard(key);
let guard = shard.map.read();
if let Some(entry) = guard.get(key) {
if entry.is_expired(self.shared.time_to_idle) {
None
} else {
Some(entry.value())
}
} else {
None
}
}
pub async fn entry(&self, key: K) -> AsyncEntry<'_, K, V, H> {
let shard = self.shared.store.get_shard(&key);
let guard = shard.map.write_async().await;
if guard.contains_key(&key) {
AsyncEntry::Occupied(AsyncOccupiedEntry {
key,
shard_guard: guard,
})
} else {
AsyncEntry::Vacant(AsyncVacantEntry {
key,
shared: &self.shared,
shard_guard: guard,
})
}
}
pub async fn insert(&self, key: K, value: V, cost: u64)
where
K: Clone + Sync,
V: Sync,
H: Send + Sync,
{
let mut new_cache_entry = CacheEntry::new(
value,
cost,
self.shared.time_to_live,
self.shared.time_to_idle,
);
let shard = self.shared.store.get_shard(&key);
if let Some(wheel) = &shard.timer_wheel {
let key_hash = crate::store::hash_key(&self.shared.store.hasher, &key);
let ttl_handle = self
.shared
.time_to_live
.map(|ttl| wheel.schedule(key_hash, ttl));
let tti_handle = None;
new_cache_entry.set_timer_handles(ttl_handle, tti_handle);
}
let old_entry: Option<Arc<CacheEntry<V>>>;
{
let mut guard = shard.map.write_async().await;
old_entry = guard.insert(key.clone(), Arc::new(new_cache_entry));
}
if let Some(entry) = old_entry {
if let Some(wheel) = &shard.timer_wheel {
if let Some(handle) = &entry.ttl_timer_handle {
wheel.cancel(handle);
}
if let Some(handle) = &entry.tti_timer_handle {
wheel.cancel(handle);
}
}
let old_cost = entry.cost();
self
.shared
.metrics
.current_cost
.fetch_sub(old_cost, Ordering::Relaxed);
}
let _ = shard
.event_buffer_tx
.try_send(AccessEvent::Write(key.clone(), cost));
self.shared.metrics.inserts.fetch_add(1, Ordering::Relaxed);
self
.shared
.metrics
.keys_admitted
.fetch_add(1, Ordering::Relaxed);
self
.shared
.metrics
.current_cost
.fetch_add(cost, Ordering::Relaxed);
self
.shared
.metrics
.total_cost_added
.fetch_add(cost, Ordering::Relaxed);
self._run_opportunistic_maintenance(&key, &shard);
}
pub async fn insert_with_ttl(&self, key: K, value: V, cost: u64, ttl: Duration)
where
K: Clone + Sync,
V: Sync,
H: Send + Sync,
{
let expires_at = time::now_duration().as_nanos() as u64 + ttl.as_nanos() as u64;
let mut new_cache_entry =
CacheEntry::new_with_custom_expiry(value, cost, expires_at, self.shared.time_to_idle);
let shard = self.shared.store.get_shard(&key);
if let Some(wheel) = &shard.timer_wheel {
let key_hash = crate::store::hash_key(&self.shared.store.hasher, &key);
let ttl_handle = Some(wheel.schedule(key_hash, ttl));
let tti_handle = None;
new_cache_entry.set_timer_handles(ttl_handle, tti_handle);
}
let old_entry: Option<Arc<CacheEntry<V>>>;
{
let mut guard = shard.map.write_async().await;
old_entry = guard.insert(key.clone(), Arc::new(new_cache_entry));
}
if let Some(entry) = old_entry {
if let Some(wheel) = &shard.timer_wheel {
if let Some(handle) = &entry.ttl_timer_handle {
wheel.cancel(handle);
}
if let Some(handle) = &entry.tti_timer_handle {
wheel.cancel(handle);
}
}
let old_cost = entry.cost();
self
.shared
.metrics
.current_cost
.fetch_sub(old_cost, Ordering::Relaxed);
}
let _ = shard
.event_buffer_tx
.try_send(AccessEvent::Write(key.clone(), cost));
self.shared.metrics.inserts.fetch_add(1, Ordering::Relaxed);
self
.shared
.metrics
.keys_admitted
.fetch_add(1, Ordering::Relaxed);
self
.shared
.metrics
.current_cost
.fetch_add(cost, Ordering::Relaxed);
self
.shared
.metrics
.total_cost_added
.fetch_add(cost, Ordering::Relaxed);
self._run_opportunistic_maintenance(&key, &shard);
}
pub async fn compute<Q, F>(&self, key: &Q, mut f: F) -> bool
where
K: Borrow<Q>,
Q: Eq + Hash + Equivalent<K> + ?Sized,
F: FnMut(&mut V),
{
loop {
let opt = self.try_compute(key, &mut f).await;
match opt {
Some(true) => return true,
Some(false) => {}
None => return false,
}
tokio::task::yield_now().await;
}
}
pub async fn try_compute<Q, F>(&self, key: &Q, f: F) -> Option<bool>
where
K: Borrow<Q>,
Q: Eq + Hash + Equivalent<K> + ?Sized,
F: FnOnce(&mut V),
{
return match self.try_compute_val(key, f).await {
ComputeResult::Ok(_) => Some(true),
ComputeResult::Fail => Some(false),
ComputeResult::NotFound => None,
};
}
pub async fn try_compute_val<Q, F, R>(&self, key: &Q, f: F) -> ComputeResult<R>
where
K: Borrow<Q>,
Q: Eq + Hash + Equivalent<K> + ?Sized,
F: FnOnce(&mut V) -> R,
{
let shard = self.shared.store.get_shard(key);
let mut guard = shard.map.write_async().await;
if let Some(entry_arc) = guard.get_mut(key) {
if let Some(entry) = Arc::get_mut(entry_arc) {
if let Some(value) = Arc::get_mut(&mut entry.value) {
let user_value = f(value);
self.shared.metrics.updates.fetch_add(1, Ordering::Relaxed);
return ComputeResult::Ok(user_value);
}
}
return ComputeResult::Fail;
}
return ComputeResult::NotFound; }
pub async fn compute_val<Q, F, R>(&self, key: &Q, mut f: F) -> ComputeResult<R>
where
K: Borrow<Q>,
Q: Eq + Hash + Equivalent<K> + ?Sized,
F: FnMut(&mut V) -> R,
{
loop {
let result = self.try_compute_val(key, &mut f).await;
if !matches!(result, ComputeResult::Fail) {
return result;
}
tokio::task::yield_now().await;
}
}
pub async fn invalidate<Q>(&self, key: &Q) -> bool
where
K: Borrow<Q> + Clone,
Q: Eq + Hash + Equivalent<K> + ?Sized,
V: Sync,
{
let shard = self.shared.store.get_shard(key);
let removed_entry: Option<(K, Arc<CacheEntry<V>>)>;
{
let mut guard = shard.map.write_async().await;
removed_entry = guard.remove_entry(key);
}
if let Some((found_key, entry)) = removed_entry {
if let Some(wheel) = &shard.timer_wheel {
if let Some(handle) = &entry.ttl_timer_handle {
wheel.cancel(handle);
}
if let Some(handle) = &entry.tti_timer_handle {
wheel.cancel(handle);
}
}
self.shared.get_cache_policy(key).on_remove(&found_key);
self
.shared
.metrics
.invalidations
.fetch_add(1, Ordering::Relaxed);
self
.shared
.metrics
.current_cost
.fetch_sub(entry.cost(), Ordering::Relaxed);
if let Some(sender) = &self.shared.notification_sender {
let _ = sender.try_send((found_key, entry.value(), EvictionReason::Invalidated));
}
true
} else {
false
}
}
pub async fn clear(&self) {
let mut shard_guards =
future::join_all(self.shared.store.iter_shards().map(|s| s.map.write_async())).await;
for (i, guard) in shard_guards.iter_mut().enumerate() {
let policy = &self.shared.cache_policy[i];
for key in guard.keys() {
policy.on_remove(key);
}
guard.clear();
}
for policy in self.shared.cache_policy.iter() {
policy.clear();
}
self
.shared
.metrics
.current_cost
.store(0, std::sync::atomic::Ordering::Relaxed);
}
#[inline]
fn on_hit(&self, key: &K, hash: u64, entry: &Arc<CacheEntry<V>>, shard_index: usize)
where
K: Clone,
{
if self.shared.time_to_idle.is_some() {
entry.update_last_accessed();
}
let shard = &self.shared.store.shards[shard_index];
shard.read_access_batcher.record_access(key, hash, entry.cost());
}
fn _run_opportunistic_maintenance(&self, key: &K, shard: &crate::store::Shard<K, V, H>)
where
K: Eq + Hash + Clone + Send,
V: Send + Sync,
H: BuildHasher + Clone,
{
if !shard
.rng
.should_run(self.shared.maintenance_probability_denominator)
{
return;
}
if let Some(_guard) = shard.maintenance_lock.try_lock() {
let shard_index = self.shared.get_shard_index(key);
let janitor_context = crate::task::janitor::JanitorContext {
store: Arc::clone(&self.shared.store),
metrics: Arc::clone(&self.shared.metrics),
cache_policy: self.shared.cache_policy.clone(),
capacity: self.shared.capacity,
time_to_idle: self.shared.time_to_idle,
notification_sender: self
.shared
.notification_sender
.as_ref()
.map(|val| val.clone()),
};
crate::task::janitor::perform_shard_maintenance(
shard,
shard_index,
&janitor_context,
COOPERATIVE_MAINTENANCE_DRAIN_LIMIT,
);
}
}
}
impl<K: Send, V: Send, H> AsyncCache<K, V, H>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
V: Send + Sync + 'static,
H: BuildHasher + Clone + Send + 'static,
{
pub fn iter_stream(&self) -> IterStream<K, V, H> {
self.shared.flush_for_introspection();
IterStream::new(self, DEFAULT_ITER_BATCH_SIZE)
}
pub fn iter_stream_with_batch_size(&self, batch_size: usize) -> IterStream<K, V, H> {
self.shared.flush_for_introspection();
IterStream::new(self, batch_size.max(1))
}
pub fn iter_snapshot_async(&self) -> AsyncSnapshotIter<'_, K, V, H> {
self.shared.flush_for_introspection();
AsyncSnapshotIter::new(self)
}
}
impl<K: Send, V: Send, H> AsyncCache<K, V, H>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
V: Send + Sync,
H: BuildHasher + Clone + Send,
{
pub async fn multiget<I, Q>(&self, keys: I) -> HashMap<K, Arc<V>>
where
I: IntoIterator,
I::Item: std::borrow::Borrow<Q>,
Q: Eq + Hash + Equivalent<K> + Sync + ?Sized + ToOwned<Owned = K>,
K: std::borrow::Borrow<Q>,
{
let num_shards = self.shared.store.shards.len();
let mut keys_by_shard: Vec<Vec<K>> = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
keys_by_shard.push(Vec::new());
}
let mut total_reqs = 0;
for item in keys.into_iter() {
let q: &Q = item.borrow();
let hash = crate::store::hash_key(&self.shared.store.hasher, q);
let index = hash as usize % num_shards;
keys_by_shard[index].push(q.to_owned());
total_reqs += 1;
}
let get_futs = keys_by_shard
.into_iter()
.enumerate()
.filter(|(_, keys)| !keys.is_empty())
.map(|(i, shard_keys)| {
let shared = Arc::clone(&self.shared);
async move {
let shard = &shared.store.shards[i];
let guard = shard.map.read();
let mut found = HashMap::new();
for key in shard_keys {
if let Some(entry) = guard.get(key.borrow()) {
if !entry.is_expired(shared.time_to_idle) {
if shared.time_to_idle.is_some() {
entry.update_last_accessed();
}
shared.get_cache_policy(&key).on_access(&key, entry.cost());
found.insert(key.clone(), entry.value());
}
}
}
found
}
});
let results_by_shard: Vec<HashMap<K, Arc<V>>> = future::join_all(get_futs).await;
let mut final_map = HashMap::new();
for map in results_by_shard {
final_map.extend(map);
}
let hits = final_map.len() as u64;
self
.shared
.metrics
.hits
.fetch_add(hits, std::sync::atomic::Ordering::Relaxed);
self
.shared
.metrics
.misses
.fetch_add(total_reqs - hits, std::sync::atomic::Ordering::Relaxed);
final_map
}
#[cfg(feature = "bulk")]
pub async fn multi_insert<I>(&self, items: I)
where
I: IntoIterator<Item = (K, V, u64)>,
{
let num_shards = self.shared.store.iter_shards().count();
let mut items_by_shard: Vec<Vec<(K, V, u64)>> = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
items_by_shard.push(Vec::new());
}
for (key, value, cost) in items.into_iter() {
let hash = crate::store::hash_key(&self.shared.store.hasher, &key);
let index = hash as usize % items_by_shard.len();
items_by_shard[index].push((key, value, cost));
}
let insert_futs = items_by_shard
.into_iter()
.enumerate()
.filter(|(_, shard_items)| !shard_items.is_empty())
.map(|(i, shard_items)| {
let shared = Arc::clone(&self.shared);
async move {
let shard = &shared.store.shards[i];
let mut guard = shard.map.write_async().await;
for (key, value, cost) in shard_items {
let mut new_cache_entry =
CacheEntry::new(value, cost, shared.time_to_live, shared.time_to_idle);
if let Some(wheel) = &shard.timer_wheel {
let key_hash = crate::store::hash_key(&shared.store.hasher, &key);
let ttl_handle = shared.time_to_live.map(|ttl| wheel.schedule(key_hash, ttl));
let tti_handle = None; new_cache_entry.set_timer_handles(ttl_handle, tti_handle);
}
if let Some(old_entry) = guard.insert(key.clone(), Arc::new(new_cache_entry)) {
if let Some(wheel) = &shard.timer_wheel {
if let Some(handle) = &old_entry.ttl_timer_handle {
wheel.cancel(handle);
}
if let Some(handle) = &old_entry.tti_timer_handle {
wheel.cancel(handle);
}
}
shared
.metrics
.current_cost
.fetch_sub(old_entry.cost(), std::sync::atomic::Ordering::Relaxed);
}
shared
.metrics
.current_cost
.fetch_add(cost, std::sync::atomic::Ordering::Relaxed);
let _ = shard
.event_buffer_tx
.try_send(AccessEvent::Write(key, cost));
}
}
});
future::join_all(insert_futs).await;
}
#[cfg(feature = "bulk")]
pub async fn multi_invalidate<I, Q>(&self, keys: I)
where
I: IntoIterator<Item = Q>,
K: From<Q>,
{
let mut keys_by_shard: Vec<Vec<K>> = vec![Vec::new(); self.shared.store.iter_shards().count()];
for key in keys.into_iter().map(K::from) {
let hash = crate::store::hash_key(&self.shared.store.hasher, &key);
let index = hash as usize % keys_by_shard.len();
keys_by_shard[index].push(key);
}
let invalidate_futs = keys_by_shard
.into_iter()
.enumerate()
.filter(|(_, keys)| !keys.is_empty())
.map(|(i, shard_keys)| {
let shared = Arc::clone(&self.shared);
async move {
let shard = &shared.store.shards[i];
let mut guard = shard.map.write_async().await;
for key in shard_keys {
if let Some(entry) = guard.remove(&key) {
if let Some(wheel) = &shard.timer_wheel {
if let Some(handle) = &entry.ttl_timer_handle {
wheel.cancel(handle);
}
if let Some(handle) = &entry.tti_timer_handle {
wheel.cancel(handle);
}
}
self.shared.get_cache_policy(&key).on_remove(&key);
shared
.metrics
.invalidations
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
shared
.metrics
.current_cost
.fetch_sub(entry.cost(), std::sync::atomic::Ordering::Relaxed);
if let Some(sender) = &self.shared.notification_sender {
let _ = sender.try_send((key.clone(), entry.value(), EvictionReason::Invalidated));
}
}
}
}
});
future::join_all(invalidate_futs).await;
}
}
impl<K, V, H> AsyncCache<K, V, H>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
V: Send + Sync,
H: BuildHasher + Clone + Send + Sync + 'static,
{
pub async fn run_maintenance(&self) {
use crate::task::janitor::{
Janitor, JanitorContext, COOPERATIVE_MAINTENANCE_DRAIN_LIMIT, perform_shard_maintenance,
};
let janitor_context = JanitorContext {
store: Arc::clone(&self.shared.store),
metrics: Arc::clone(&self.shared.metrics),
cache_policy: self.shared.cache_policy.clone(),
capacity: self.shared.capacity,
time_to_idle: self.shared.time_to_idle,
notification_sender: self.shared.notification_sender.as_ref().map(|s| s.clone()),
};
for (i, shard) in self.shared.store.shards.iter().enumerate() {
let _guard = shard.maintenance_lock.lock_async().await;
perform_shard_maintenance(shard, i, &janitor_context, COOPERATIVE_MAINTENANCE_DRAIN_LIMIT);
Janitor::cleanup_ttl_for_shard(shard, i, &janitor_context);
Janitor::cleanup_tti_for_shard(shard, i, &janitor_context);
Janitor::cleanup_capacity_for_shard(shard, i, &janitor_context);
}
}
pub async fn fetch_with(&self, key: &K) -> Arc<V>
where
K: Clone + Send + Sync + 'static,
V: Send + Sync + 'static,
H: BuildHasher + Clone + Send + Sync,
{
let hash = crate::store::hash_key(&self.shared.store.hasher, key);
let shard_index = self.shared.get_shard_index_from_hash(hash);
let shard_ref = &self.shared.store.shards[shard_index];
let hit_value = {
let guard = shard_ref.map.read();
if let Some((found_key, entry_in_guard)) = guard.get_key_value(key) {
let expires_at_nanos = entry_in_guard.expires_at.load(Ordering::Relaxed);
if expires_at_nanos == 0 {
if !entry_in_guard.is_expired(self.shared.time_to_idle) {
self.on_hit(found_key, hash, entry_in_guard, shard_index);
self.shared.metrics.hits.fetch_add(1, Ordering::Relaxed);
Some(entry_in_guard.value())
} else {
None
}
} else {
let now_nanos = crate::time::now_duration().as_nanos() as u64;
if now_nanos < expires_at_nanos {
if !entry_in_guard.is_expired(self.shared.time_to_idle) {
self.on_hit(found_key, hash, entry_in_guard, shard_index);
self.shared.metrics.hits.fetch_add(1, Ordering::Relaxed);
Some(entry_in_guard.value())
} else {
None
}
} else if let Some(grace_period) = self.shared.stale_while_revalidate {
if now_nanos < expires_at_nanos + grace_period.as_nanos() as u64 {
self.trigger_background_load(found_key);
Some(entry_in_guard.value())
} else {
None
}
} else {
None
}
}
} else {
None
}
};
if let Some(val) = hit_value {
return val;
}
self.load_value_awaiting(key).await
}
fn trigger_background_load(&self, key: &K)
where
K: Clone + 'static,
V: 'static,
{
let hash = crate::store::hash_key(&self.shared.store.hasher, &key);
let index = hash as usize & (self.shared.pending_loads.len() - 1);
let pending_loads_lock = &self.shared.pending_loads[index];
if let Some(mut pending) = pending_loads_lock.try_lock() {
if pending.contains_key(key) {
return;
}
let future = Arc::new(LoadFuture::new());
pending.insert(key.clone(), future.clone());
CacheShared::spawn_loader_task(Arc::clone(&self.shared), key.clone(), future);
}
}
async fn load_value_awaiting(&self, key: &K) -> Arc<V>
where
K: Clone + Send + Sync + 'static,
V: Send + Sync + 'static,
H: BuildHasher + Clone + Send + Sync,
{
let mut am_leader = false;
let future = loop {
let hash = crate::store::hash_key(&self.shared.store.hasher, &key);
let index = hash as usize & (self.shared.pending_loads.len() - 1);
let pending_loads_lock = &self.shared.pending_loads[index];
let mut pending = pending_loads_lock.lock_async().await;
if let Some(existing_future) = pending.get(key) {
self.shared.metrics.hits.fetch_add(1, Ordering::Relaxed);
am_leader = false;
break existing_future.clone();
}
self.shared.metrics.misses.fetch_add(1, Ordering::Relaxed);
let new_future = Arc::new(LoadFuture::new());
pending.insert(key.clone(), new_future.clone());
am_leader = true;
break new_future;
};
if am_leader {
CacheShared::spawn_loader_task(Arc::clone(&self.shared), key.clone(), future.clone());
}
(&*future).await
}
}