use std::borrow::Borrow;
use std::fmt::Debug;
use std::hash::Hash;
use cachet_tier::{CacheEntry, CacheTier, DynamicCache, SizeError};
use tick::Clock;
use uniflight::Merger;
use crate::Error;
use crate::builder::CacheBuilder;
use crate::telemetry::CacheTelemetry;
use crate::telemetry::cache::{WithRequestIdExt, next_request_id};
pub type CacheName = &'static str;
struct Mergers<K, V> {
get: Merger<K, Result<Option<CacheEntry<V>>, Error>>,
invalidate: Merger<K, Result<(), Error>>,
get_or_insert: Merger<K, Result<CacheEntry<V>, Error>>,
get_or_insert_with: Merger<K, Result<CacheEntry<V>, Error>>,
try_get_or_insert: Merger<K, Result<CacheEntry<V>, Error>>,
try_get_or_insert_with: Merger<K, Result<CacheEntry<V>, Error>>,
optionally_get_or_insert: Merger<K, Result<Option<CacheEntry<V>>, Error>>,
}
impl<K, V> Mergers<K, V>
where
K: Clone + Eq + Hash + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
fn new() -> Self {
Self {
get: Merger::new(),
invalidate: Merger::new(),
get_or_insert: Merger::new(),
get_or_insert_with: Merger::new(),
try_get_or_insert: Merger::new(),
try_get_or_insert_with: Merger::new(),
optionally_get_or_insert: Merger::new(),
}
}
}
impl<K, V> Debug for Mergers<K, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Mergers").finish_non_exhaustive()
}
}
#[derive(Debug)]
pub struct Cache<K, V> {
pub(crate) name: CacheName,
pub(crate) storage: DynamicCache<K, V>,
pub(crate) clock: Clock,
pub(crate) telemetry: CacheTelemetry,
mergers: Option<Mergers<K, V>>,
}
impl Cache<(), ()> {
#[must_use]
pub fn builder<K, V>(clock: Clock) -> CacheBuilder<K, V> {
CacheBuilder::new(clock)
}
}
impl<K, V> Cache<K, V>
where
K: Clone + Eq + Hash + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
pub(crate) fn new(
name: CacheName,
storage: DynamicCache<K, V>,
clock: Clock,
telemetry: CacheTelemetry,
stampede_protection: bool,
) -> Self {
Self {
name,
storage,
clock,
telemetry,
mergers: stampede_protection.then(Mergers::new),
}
}
}
impl<K, V> Cache<K, V>
where
K: Clone + Eq + Hash + Send + Sync,
V: Clone + Send + Sync,
{
#[must_use]
pub fn name(&self) -> CacheName {
self.name
}
#[must_use]
pub fn clock(&self) -> &Clock {
&self.clock
}
}
impl<K, V> Cache<K, V>
where
K: Clone + Eq + Hash + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
pub async fn get<Q>(&self, key: &Q) -> Result<Option<CacheEntry<V>>, Error>
where
K: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized + Send + Sync,
{
let request_id = next_request_id();
let watch = self.clock.stopwatch();
async {
let (result, coalesced) = if let Some(mergers) = &self.mergers {
let owned = key.to_owned();
let storage = &self.storage;
let result = mergers
.get
.execute(key, move || async move { storage.get(&owned).await })
.await
.unwrap_or_else(|panicked| Err(Error::from_source(panicked)));
(result, true)
} else {
let owned = key.to_owned();
(self.storage.get(&owned).await, false)
};
self.telemetry
.complete_operation(request_id, self.name, "cache.get", watch.elapsed(), coalesced);
result
}
.with_request_id(request_id)
.await
}
pub async fn insert(&self, key: K, entry: impl Into<CacheEntry<V>>) -> Result<(), Error> {
let request_id = next_request_id();
let watch = self.clock.stopwatch();
async {
let result = self.storage.insert(key, entry.into()).await;
self.telemetry
.complete_operation(request_id, self.name, "cache.insert", watch.elapsed(), false);
result
}
.with_request_id(request_id)
.await
}
pub async fn invalidate<Q>(&self, key: &Q) -> Result<(), Error>
where
K: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized + Send + Sync,
{
let request_id = next_request_id();
let watch = self.clock.stopwatch();
async {
let (result, coalesced) = if let Some(mergers) = &self.mergers {
let owned = key.to_owned();
let storage = &self.storage;
let result = mergers
.invalidate
.execute(key, move || async move { storage.invalidate(&owned).await })
.await
.unwrap_or_else(|panicked| Err(Error::from_source(panicked)));
(result, true)
} else {
let owned = key.to_owned();
(self.storage.invalidate(&owned).await, false)
};
self.telemetry
.complete_operation(request_id, self.name, "cache.invalidate", watch.elapsed(), coalesced);
result
}
.with_request_id(request_id)
.await
}
pub async fn contains<Q>(&self, key: &Q) -> Result<bool, Error>
where
K: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized + Send + Sync,
{
Ok(self.get(key).await?.is_some())
}
pub async fn clear(&self) -> Result<(), Error> {
let request_id = next_request_id();
let watch = self.clock.stopwatch();
async {
let result = self.storage.clear().await;
self.telemetry
.complete_operation(request_id, self.name, "cache.clear", watch.elapsed(), false);
result
}
.with_request_id(request_id)
.await
}
pub async fn len(&self) -> Result<u64, SizeError> {
self.storage.len().await
}
pub async fn is_empty(&self) -> Result<bool, SizeError> {
self.storage.is_empty().await
}
pub async fn get_or_insert<Q, Fut>(&self, key: &Q, f: impl FnOnce() -> Fut + Send) -> Result<CacheEntry<V>, Error>
where
K: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized + Send + Sync,
Fut: Future<Output = V> + Send,
{
let request_id = next_request_id();
let watch = self.clock.stopwatch();
async {
let owned = key.to_owned();
let (result, coalesced) = if let Some(mergers) = &self.mergers {
let result = mergers
.get_or_insert
.execute(key, move || async move { self.do_get_or_insert(&owned, f).await })
.await
.unwrap_or_else(|panicked| Err(Error::from_source(panicked)));
(result, true)
} else {
(self.do_get_or_insert(&owned, f).await, false)
};
self.telemetry
.complete_operation(request_id, self.name, "cache.get_or_insert", watch.elapsed(), coalesced);
result
}
.with_request_id(request_id)
.await
}
async fn do_get_or_insert<Fut>(&self, key: &K, f: impl FnOnce() -> Fut) -> Result<CacheEntry<V>, Error>
where
Fut: Future<Output = V>,
{
if let Some(entry) = self.storage.get(key).await? {
return Ok(entry);
}
let value = f().await;
let mut entry = CacheEntry::new(value);
entry.ensure_cached_at(self.clock.system_time());
self.storage.insert(key.clone(), entry.clone()).await?;
Ok(entry)
}
pub async fn get_or_insert_with<Q, Fut>(&self, key: &Q, f: impl FnOnce() -> Fut + Send) -> Result<CacheEntry<V>, Error>
where
K: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized + Send + Sync,
Fut: Future<Output = CacheEntry<V>> + Send,
{
let owned = key.to_owned();
if let Some(mergers) = &self.mergers {
mergers
.get_or_insert_with
.execute(key, move || async move { self.do_get_or_insert_with(&owned, f).await })
.await
.unwrap_or_else(|panicked| Err(Error::from_source(panicked)))
} else {
self.do_get_or_insert_with(&owned, f).await
}
}
async fn do_get_or_insert_with<Fut>(&self, key: &K, f: impl FnOnce() -> Fut) -> Result<CacheEntry<V>, Error>
where
Fut: Future<Output = CacheEntry<V>>,
{
if let Some(entry) = self.storage.get(key).await? {
return Ok(entry);
}
let mut entry = f().await;
entry.ensure_cached_at(self.clock.system_time());
self.storage.insert(key.clone(), entry.clone()).await?;
Ok(entry)
}
pub async fn try_get_or_insert_with<Q, E, Fut>(&self, key: &Q, f: impl FnOnce() -> Fut + Send) -> Result<CacheEntry<V>, Error>
where
K: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized + Send + Sync,
E: std::error::Error + Send + Sync + 'static,
Fut: Future<Output = Result<CacheEntry<V>, E>> + Send,
{
let owned = key.to_owned();
if let Some(mergers) = &self.mergers {
mergers
.try_get_or_insert_with
.execute(key, move || async move { self.do_try_get_or_insert_with(&owned, f).await })
.await
.unwrap_or_else(|panicked| Err(Error::from_source(panicked)))
} else {
self.do_try_get_or_insert_with(&owned, f).await
}
}
async fn do_try_get_or_insert_with<E, Fut>(&self, key: &K, f: impl FnOnce() -> Fut) -> Result<CacheEntry<V>, Error>
where
E: std::error::Error + Send + Sync + 'static,
Fut: Future<Output = Result<CacheEntry<V>, E>>,
{
if let Some(entry) = self.storage.get(key).await? {
return Ok(entry);
}
let mut entry = f().await.map_err(Error::from_source)?;
entry.ensure_cached_at(self.clock.system_time());
self.storage.insert(key.clone(), entry.clone()).await?;
Ok(entry)
}
pub async fn try_get_or_insert<Q, E, Fut>(&self, key: &Q, f: impl FnOnce() -> Fut + Send) -> Result<CacheEntry<V>, Error>
where
K: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized + Send + Sync,
E: std::error::Error + Send + Sync + 'static,
Fut: Future<Output = Result<V, E>> + Send,
{
let request_id = next_request_id();
let watch = self.clock.stopwatch();
async {
let owned = key.to_owned();
let (result, coalesced) = if let Some(mergers) = &self.mergers {
let result = mergers
.try_get_or_insert
.execute(key, move || async move { self.do_try_get_or_insert(&owned, f).await })
.await
.unwrap_or_else(|panicked| Err(Error::from_source(panicked)));
(result, true)
} else {
(self.do_try_get_or_insert(&owned, f).await, false)
};
self.telemetry
.complete_operation(request_id, self.name, "cache.try_get_or_insert", watch.elapsed(), coalesced);
result
}
.with_request_id(request_id)
.await
}
async fn do_try_get_or_insert<E, Fut>(&self, key: &K, f: impl FnOnce() -> Fut) -> Result<CacheEntry<V>, Error>
where
E: std::error::Error + Send + Sync + 'static,
Fut: Future<Output = Result<V, E>>,
{
if let Some(entry) = self.storage.get(key).await? {
return Ok(entry);
}
let value = f().await.map_err(Error::from_source)?;
let mut entry = CacheEntry::new(value);
entry.ensure_cached_at(self.clock.system_time());
self.storage.insert(key.clone(), entry.clone()).await?;
Ok(entry)
}
pub async fn optionally_get_or_insert<Q, Fut>(&self, key: &Q, f: impl FnOnce() -> Fut + Send) -> Result<Option<CacheEntry<V>>, Error>
where
K: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized + Send + Sync,
Fut: Future<Output = Option<V>> + Send,
{
let request_id = next_request_id();
let watch = self.clock.stopwatch();
async {
let owned = key.to_owned();
let (result, coalesced) = if let Some(mergers) = &self.mergers {
let result = mergers
.optionally_get_or_insert
.execute(key, move || async move { self.do_optionally_get_or_insert(&owned, f).await })
.await
.unwrap_or_else(|panicked| Err(Error::from_source(panicked)));
(result, true)
} else {
(self.do_optionally_get_or_insert(&owned, f).await, false)
};
self.telemetry
.complete_operation(request_id, self.name, "cache.optionally_get_or_insert", watch.elapsed(), coalesced);
result
}
.with_request_id(request_id)
.await
}
async fn do_optionally_get_or_insert<Fut>(&self, key: &K, f: impl FnOnce() -> Fut) -> Result<Option<CacheEntry<V>>, Error>
where
Fut: Future<Output = Option<V>>,
{
if let Some(entry) = self.storage.get(key).await? {
return Ok(Some(entry));
}
match f().await {
Some(value) => {
let mut entry = CacheEntry::new(value);
entry.ensure_cached_at(self.clock.system_time());
self.storage.insert(key.clone(), entry.clone()).await?;
Ok(Some(entry))
}
None => Ok(None),
}
}
}
#[cfg(feature = "service")]
impl<K, V> layered::Service<cachet_service::CacheOperation<K, V>> for Cache<K, V>
where
K: Clone + Eq + Hash + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
type Out = Result<cachet_service::CacheResponse<V>, Error>;
async fn execute(&self, input: cachet_service::CacheOperation<K, V>) -> Self::Out {
match input {
cachet_service::CacheOperation::Get(req) => {
let entry = self.get(&req.key).await?;
Ok(cachet_service::CacheResponse::Get(entry))
}
cachet_service::CacheOperation::Insert(req) => {
self.insert(req.key, req.entry).await?;
Ok(cachet_service::CacheResponse::Insert)
}
cachet_service::CacheOperation::Invalidate(req) => {
self.invalidate(&req.key).await?;
Ok(cachet_service::CacheResponse::Invalidate)
}
cachet_service::CacheOperation::Clear => {
self.clear().await?;
Ok(cachet_service::CacheResponse::Clear)
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};
use cachet_tier::MockCache;
use super::*;
use crate::telemetry::handler::RequestId;
use crate::{CacheEventHandler, CacheOperationEvent, CacheTierEvent};
fn block_on<F: std::future::Future>(f: F) -> F::Output {
futures::executor::block_on(f)
}
fn build_cache() -> Cache<String, i32> {
let clock = Clock::new_frozen();
Cache::builder::<String, i32>(clock).storage(MockCache::new()).build()
}
fn build_cache_with_stampede() -> Cache<String, i32> {
let clock = Clock::new_frozen();
Cache::builder::<String, i32>(clock)
.storage(MockCache::new())
.stampede_protection()
.build()
}
#[test]
fn mergers_new_and_debug() {
let m = Mergers::<String, i32>::new();
let debug = format!("{m:?}");
assert!(debug.contains("Mergers"));
}
#[test]
fn cache_builder_creates_cache() {
let clock = Clock::new_frozen();
let _ = Cache::builder::<String, i32>(clock);
}
#[test]
fn cache_new_and_accessors() {
let cache = build_cache();
assert!(!cache.name().is_empty());
let _ = cache.clock();
}
#[test]
fn cache_get_miss() {
block_on(async {
let cache = build_cache();
let result = cache.get("missing").await.unwrap();
assert!(result.is_none());
});
}
#[test]
fn cache_event_handler_receives_fallback_tier_events() {
type EventRecord = Vec<(RequestId, String, String, bool)>;
#[derive(Clone)]
struct RecordingHandler {
tier_events: Arc<Mutex<EventRecord>>,
operation_events: Arc<Mutex<EventRecord>>,
}
impl CacheEventHandler for RecordingHandler {
fn on_tier_event(&self, event: &CacheTierEvent<'_>) {
self.tier_events.lock().expect("test handler mutex should not be poisoned").push((
event.request_id,
event.tier_name.to_string(),
event.outcome.to_string(),
event.fallback,
));
}
fn on_operation_complete(&self, event: &CacheOperationEvent<'_>) {
self.operation_events
.lock()
.expect("test handler mutex should not be poisoned")
.push((
event.request_id,
event.cache_name.to_string(),
event.operation.to_string(),
event.coalesced,
));
}
}
let tier_events = Arc::new(Mutex::new(Vec::new()));
let operation_events = Arc::new(Mutex::new(Vec::new()));
block_on(async {
let clock = Clock::new_frozen();
let handler = RecordingHandler {
tier_events: Arc::clone(&tier_events),
operation_events: Arc::clone(&operation_events),
};
let l2 = Cache::builder::<String, i32>(clock.clone()).storage(MockCache::new()).name("l2");
let cache = Cache::builder::<String, i32>(clock)
.storage(MockCache::new())
.name("l1")
.event_handler(handler)
.fallback(l2)
.build();
let result = cache.get("missing").await.unwrap();
assert!(result.is_none());
});
let tier_events = tier_events.lock().expect("test handler mutex should not be poisoned").clone();
let operation_events = operation_events.lock().expect("test handler mutex should not be poisoned").clone();
let request_id = operation_events[0].0;
assert_eq!(
tier_events,
vec![
(
request_id,
"l1".to_string(),
crate::telemetry::attributes::EVENT_MISS.to_string(),
false
),
(
request_id,
"l2".to_string(),
crate::telemetry::attributes::EVENT_MISS.to_string(),
true
),
]
);
assert_eq!(
operation_events,
vec![(request_id, "l1".to_string(), "cache.get".to_string(), false)]
);
}
#[test]
fn cache_insert_and_get() {
block_on(async {
let cache = build_cache();
cache.insert("key".to_string(), CacheEntry::new(42)).await.unwrap();
let entry = cache.get("key").await.unwrap().expect("should exist");
assert_eq!(*entry.value(), 42);
});
}
#[test]
fn cache_invalidate_no_stampede() {
block_on(async {
let cache = build_cache();
cache.insert("key".to_string(), CacheEntry::new(1)).await.unwrap();
cache.invalidate("key").await.unwrap();
assert!(cache.get("key").await.unwrap().is_none());
});
}
#[test]
fn cache_invalidate_with_stampede() {
block_on(async {
let cache = build_cache_with_stampede();
cache.insert("key".to_string(), CacheEntry::new(1)).await.unwrap();
cache.invalidate("key").await.unwrap();
assert!(cache.get("key").await.unwrap().is_none());
});
}
#[test]
fn cache_contains() {
block_on(async {
let cache = build_cache();
assert!(!cache.contains("key").await.unwrap());
cache.insert("key".to_string(), CacheEntry::new(1)).await.unwrap();
assert!(cache.contains("key").await.unwrap());
});
}
#[test]
fn cache_clear() {
block_on(async {
let cache = build_cache();
cache.insert("a".to_string(), CacheEntry::new(1)).await.unwrap();
cache.clear().await.unwrap();
assert!(cache.get("a").await.unwrap().is_none());
});
}
#[test]
fn cache_len_and_is_empty() {
block_on(async {
let cache = build_cache();
assert_eq!(cache.len().await.expect("len should return Ok"), 0);
assert!(cache.is_empty().await.expect("is_empty should return Ok"));
cache.insert("key".to_string(), CacheEntry::new(1)).await.unwrap();
assert_eq!(cache.len().await.expect("len should return Ok"), 1);
assert!(!cache.is_empty().await.expect("is_empty should return Ok"));
});
}
#[test]
fn cache_get_with_stampede() {
block_on(async {
let cache = build_cache_with_stampede();
cache.insert("key".to_string(), CacheEntry::new(99)).await.unwrap();
let entry = cache.get("key").await.unwrap().expect("should exist");
assert_eq!(*entry.value(), 99);
});
}
#[test]
fn cache_get_miss_with_stampede() {
block_on(async {
let cache = build_cache_with_stampede();
assert!(cache.get("missing").await.unwrap().is_none());
});
}
#[test]
fn cache_get_or_insert_miss() {
block_on(async {
let cache = build_cache();
let entry = cache.get_or_insert("key", || async { 42 }).await.unwrap();
assert_eq!(*entry.value(), 42);
});
}
#[test]
fn cache_get_or_insert_hit() {
block_on(async {
let cache = build_cache();
cache.insert("key".to_string(), CacheEntry::new(1)).await.unwrap();
let entry = cache.get_or_insert("key", || async { 99 }).await.unwrap();
assert_eq!(*entry.value(), 1);
});
}
#[test]
fn cache_get_or_insert_with_stampede() {
block_on(async {
let cache = build_cache_with_stampede();
let entry = cache.get_or_insert("key", || async { 42 }).await.unwrap();
assert_eq!(*entry.value(), 42);
});
}
#[test]
fn cache_try_get_or_insert_ok() {
block_on(async {
let cache = build_cache();
let entry = cache
.try_get_or_insert("key", || async { Ok::<_, std::io::Error>(42) })
.await
.unwrap();
assert_eq!(*entry.value(), 42);
});
}
#[test]
fn cache_try_get_or_insert_hit() {
block_on(async {
let cache = build_cache();
cache.insert("key".to_string(), CacheEntry::new(1)).await.unwrap();
let entry = cache
.try_get_or_insert("key", || async { Ok::<_, std::io::Error>(99) })
.await
.unwrap();
assert_eq!(*entry.value(), 1);
});
}
#[test]
fn cache_try_get_or_insert_err() {
block_on(async {
let cache = build_cache();
let result = cache
.try_get_or_insert("key", || async { Err::<i32, _>(std::io::Error::other("fail")) })
.await;
result.unwrap_err();
});
}
#[test]
fn cache_try_get_or_insert_with_stampede() {
block_on(async {
let cache = build_cache_with_stampede();
let entry = cache
.try_get_or_insert("key", || async { Ok::<_, std::io::Error>(42) })
.await
.unwrap();
assert_eq!(*entry.value(), 42);
});
}
#[test]
fn cache_optionally_get_or_insert_some() {
block_on(async {
let cache = build_cache();
let entry = cache.optionally_get_or_insert("key", || async { Some(42) }).await.unwrap();
assert_eq!(*entry.unwrap().value(), 42);
});
}
#[test]
fn cache_optionally_get_or_insert_none() {
block_on(async {
let cache = build_cache();
let result = cache.optionally_get_or_insert("key", || async { None }).await.unwrap();
assert!(result.is_none());
});
}
#[test]
fn cache_optionally_get_or_insert_hit() {
block_on(async {
let cache = build_cache();
cache.insert("key".to_string(), CacheEntry::new(1)).await.unwrap();
let entry = cache.optionally_get_or_insert("key", || async { Some(99) }).await.unwrap();
assert_eq!(*entry.unwrap().value(), 1);
});
}
#[test]
fn cache_optionally_get_or_insert_with_stampede() {
block_on(async {
let cache = build_cache_with_stampede();
let entry = cache.optionally_get_or_insert("key", || async { Some(42) }).await.unwrap();
assert_eq!(*entry.unwrap().value(), 42);
});
}
}