use std::hash::Hash;
use std::marker::PhantomData;
use std::sync::Arc;
use cachet_tier::{CacheEntry, CacheTier, SizeError};
use futures::join;
use tick::Clock;
use crate::Error;
use crate::cache::CacheName;
use crate::refresh::TimeToRefresh;
use crate::telemetry::ext::ClockExt;
use crate::telemetry::{CacheActivity, CacheOperation, CacheTelemetry};
type PromotionPredicate<V> = Arc<dyn Fn(&CacheEntry<V>) -> bool + Send + Sync>;
#[derive(Debug, Default)]
pub struct FallbackPromotionPolicy<V>(PolicyType<V>);
#[derive(Default)]
enum PolicyType<V> {
#[default]
Always,
Never,
When(PromotionPredicate<V>),
}
impl<V> std::fmt::Debug for PolicyType<V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Always => write!(f, "Always"),
Self::Never => write!(f, "Never"),
Self::When(_) => write!(f, "WhenBoxed(<closure>)"),
}
}
}
impl<V> FallbackPromotionPolicy<V> {
#[must_use]
pub fn always() -> Self {
Self(PolicyType::Always)
}
#[must_use]
pub fn never() -> Self {
Self(PolicyType::Never)
}
pub fn when<F>(predicate: F) -> Self
where
F: Fn(&CacheEntry<V>) -> bool + Send + Sync + 'static,
{
Self(PolicyType::When(Arc::new(predicate)))
}
#[inline]
pub(crate) fn should_promote(&self, response: &CacheEntry<V>) -> bool {
match &self.0 {
PolicyType::Always => true,
PolicyType::Never => false,
PolicyType::When(pred) => pred(response),
}
}
}
pub(crate) struct FallbackCacheInner<K, V, P, F> {
pub(crate) name: CacheName,
pub(crate) primary: P,
pub(crate) fallback: F,
pub(crate) policy: FallbackPromotionPolicy<V>,
pub(crate) clock: Clock,
pub(crate) refresh: Option<TimeToRefresh<K>>,
pub(crate) telemetry: CacheTelemetry,
_phantom: PhantomData<K>,
}
impl<K, V, P, F> std::fmt::Debug for FallbackCacheInner<K, V, P, F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FallbackCacheInner")
.field("name", &self.name)
.finish_non_exhaustive()
}
}
#[derive(Debug)]
pub struct FallbackCache<K, V, P, F> {
pub(crate) inner: Arc<FallbackCacheInner<K, V, P, F>>,
}
impl<K, V, P, F> FallbackCache<K, V, P, F> {
pub(crate) fn new(
name: CacheName,
primary: P,
fallback: F,
policy: FallbackPromotionPolicy<V>,
clock: Clock,
refresh: Option<TimeToRefresh<K>>,
telemetry: CacheTelemetry,
) -> Self {
Self {
inner: Arc::new(FallbackCacheInner {
name,
primary,
fallback,
policy,
clock,
refresh,
telemetry,
_phantom: PhantomData,
}),
}
}
}
impl<K, V, P, F> FallbackCache<K, V, P, F>
where
K: Clone + Eq + Hash + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
P: CacheTier<K, V> + Send + Sync + 'static,
F: CacheTier<K, V> + Send + Sync + 'static,
{
async fn get_from_fallback(&self, key: &K) -> Result<Option<CacheEntry<V>>, Error> {
let timed = self.inner.clock.timed_async(self.inner.fallback.get(key)).await;
self.inner
.telemetry
.record(self.inner.name, CacheOperation::Get, CacheActivity::Fallback, timed.duration);
let fallback_value = timed.result?;
if let Some(ref v) = fallback_value
&& self.inner.policy.should_promote(v)
{
let timed_insert = self
.inner
.clock
.timed_async(self.inner.primary.insert(key.clone(), v.clone()))
.await;
if timed_insert.result.is_ok() {
self.inner.telemetry.record(
self.inner.name,
CacheOperation::Insert,
CacheActivity::FallbackPromotion,
timed_insert.duration,
);
}
}
Ok(fallback_value)
}
}
impl<K, V, P, F> CacheTier<K, V> for FallbackCache<K, V, P, F>
where
K: Clone + Eq + Hash + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
P: CacheTier<K, V> + Send + Sync + 'static,
F: CacheTier<K, V> + Send + Sync + 'static,
{
async fn get(&self, key: &K) -> Result<Option<CacheEntry<V>>, Error> {
if let Ok(Some(value)) = self.inner.primary.get(key).await {
if let Some(refresh) = &self.inner.refresh
&& let Some(cached_at) = value.cached_at()
&& refresh.should_refresh(cached_at, self.inner.clock.system_time())
{
self.do_refresh(key);
}
return Ok(Some(value));
}
self.get_from_fallback(key).await
}
async fn insert(&self, key: K, entry: CacheEntry<V>) -> Result<(), Error> {
let (primary_result, fallback_result) = join!(
self.inner.primary.insert(key.clone(), entry.clone()),
self.inner.fallback.insert(key.clone(), entry)
);
primary_result?;
fallback_result
}
async fn invalidate(&self, key: &K) -> Result<(), Error> {
let (primary_result, fallback_result) = join!(self.inner.primary.invalidate(key), self.inner.fallback.invalidate(key));
primary_result?;
fallback_result
}
async fn clear(&self) -> Result<(), Error> {
let (primary_result, fallback_result) = join!(self.inner.primary.clear(), self.inner.fallback.clear());
primary_result?;
fallback_result
}
async fn len(&self) -> Result<u64, SizeError> {
self.inner.primary.len().await
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use cachet_tier::MockCache;
use super::*;
use crate::Cache;
use crate::telemetry::TelemetryConfig;
use crate::wrapper::CacheWrapper;
type TestPrimary = CacheWrapper<String, i32, MockCache<String, i32>>;
type TestFallbackCache = FallbackCache<String, i32, TestPrimary, MockCache<String, i32>>;
fn make_primary() -> TestPrimary {
let clock = Clock::new_frozen();
let telemetry = TelemetryConfig::new().build();
CacheWrapper::new("primary", MockCache::new(), clock, None, telemetry)
}
fn make_fallback_cache(policy: FallbackPromotionPolicy<i32>) -> TestFallbackCache {
let clock = Clock::new_frozen();
let primary = make_primary();
let fallback_mock = MockCache::<String, i32>::new();
let telemetry = TelemetryConfig::new().build();
FallbackCache::new("fallback", primary, fallback_mock, policy, clock, None, telemetry)
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn fallback_cachet_promotes_from_fallback_to_primary() {
let clock = Clock::new_frozen();
let primary_storage = MockCache::<String, i32>::new();
let primary_check = primary_storage.clone();
let fallback_storage = MockCache::<String, i32>::new();
fallback_storage
.insert("key".to_string(), CacheEntry::new(42))
.await
.expect("insert failed");
let fallback = Cache::builder::<String, i32>(clock.clone()).storage(fallback_storage);
let cache = Cache::builder::<String, i32>(clock)
.storage(primary_storage)
.fallback(fallback)
.promotion_policy(FallbackPromotionPolicy::always())
.build();
let primary_result = primary_check.get(&"key".to_string()).await.expect("get failed");
assert!(primary_result.is_none());
let result = cache.get(&"key".to_string()).await.expect("get failed");
assert!(result.is_some());
assert_eq!(*result.unwrap().value(), 42);
let primary_result = primary_check.get(&"key".to_string()).await.expect("get failed");
assert!(primary_result.is_some());
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn fallback_cachet_never_policy_does_not_promote() {
let clock = Clock::new_frozen();
let primary_storage = MockCache::<String, i32>::new();
let primary_check = primary_storage.clone();
let fallback_storage = MockCache::<String, i32>::new();
fallback_storage
.insert("key".to_string(), CacheEntry::new(42))
.await
.expect("insert failed");
let fallback = Cache::builder::<String, i32>(clock.clone()).storage(fallback_storage);
let cache = Cache::builder::<String, i32>(clock)
.storage(primary_storage)
.fallback(fallback)
.promotion_policy(FallbackPromotionPolicy::never())
.build();
let result = cache.get(&"key".to_string()).await.expect("get failed");
assert!(result.is_some());
assert_eq!(*result.unwrap().value(), 42);
let primary_result = primary_check.get(&"key".to_string()).await.expect("get failed");
assert!(primary_result.is_none());
}
#[test]
fn fallback_cachet_inner_debug() {
let cache = make_fallback_cache(FallbackPromotionPolicy::always());
let debug_str = format!("{cache:?}");
assert_eq!(debug_str, "FallbackCache { inner: FallbackCacheInner { name: \"fallback\", .. } }");
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn fallback_cachet_when_policy_conditional_promotion() {
fn is_positive(entry: &CacheEntry<i32>) -> bool {
*entry.value() > 0
}
let clock = Clock::new_frozen();
let primary_storage = MockCache::<String, i32>::new();
let primary_check = primary_storage.clone();
let fallback_storage = MockCache::<String, i32>::new();
fallback_storage
.insert("positive".to_string(), CacheEntry::new(42))
.await
.expect("insert failed");
fallback_storage
.insert("negative".to_string(), CacheEntry::new(-10))
.await
.expect("insert failed");
let fallback = Cache::builder::<String, i32>(clock.clone()).storage(fallback_storage);
let cache = Cache::builder::<String, i32>(clock)
.storage(primary_storage)
.fallback(fallback)
.promotion_policy(FallbackPromotionPolicy::when(is_positive))
.build();
let result = cache.get(&"positive".to_string()).await.expect("get failed");
assert!(result.is_some());
assert_eq!(*result.unwrap().value(), 42);
let result = cache.get(&"negative".to_string()).await.expect("get failed");
assert!(result.is_some());
assert_eq!(*result.unwrap().value(), -10);
let positive = primary_check.get(&"positive".to_string()).await.expect("get failed");
assert!(positive.is_some());
let negative = primary_check.get(&"negative".to_string()).await.expect("get failed");
assert!(negative.is_none());
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn policy_type_debug_formatting() {
let always = FallbackPromotionPolicy::<i32>::always();
let never = FallbackPromotionPolicy::<i32>::never();
let when = FallbackPromotionPolicy::<i32>::when(|_| true);
let always_str = format!("{always:?}");
let never_str = format!("{never:?}");
let when_str = format!("{when:?}");
assert!(always_str.contains("Always"), "got: {always_str}");
assert!(never_str.contains("Never"), "got: {never_str}");
assert!(when_str.contains("WhenBoxed"), "got: {when_str}");
}
#[test]
fn promotion_policy_always() {
let policy = FallbackPromotionPolicy::<i32>::always();
let entry = CacheEntry::new(42);
assert!(policy.should_promote(&entry));
}
#[test]
fn promotion_policy_never() {
let policy = FallbackPromotionPolicy::<i32>::never();
let entry = CacheEntry::new(42);
assert!(!policy.should_promote(&entry));
}
#[test]
fn promotion_policy_when() {
let policy = FallbackPromotionPolicy::<i32>::when(|e| *e.value() > 10);
assert!(policy.should_promote(&CacheEntry::new(42)));
assert!(!policy.should_promote(&CacheEntry::new(5)));
}
#[test]
fn fallback_cache_new_constructs() {
let cache = make_fallback_cache(FallbackPromotionPolicy::always());
assert_eq!(cache.inner.name, "fallback");
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn fallback_get_miss_both() {
let cache = make_fallback_cache(FallbackPromotionPolicy::always());
let result = cache.get(&"key".to_string()).await.unwrap();
assert!(result.is_none());
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn fallback_insert_writes_both() {
let cache = make_fallback_cache(FallbackPromotionPolicy::always());
cache.insert("key".to_string(), CacheEntry::new(42)).await.unwrap();
let entry = cache.get(&"key".to_string()).await.unwrap().unwrap();
assert_eq!(*entry.value(), 42);
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn fallback_invalidate() {
let cache = make_fallback_cache(FallbackPromotionPolicy::always());
cache.insert("key".to_string(), CacheEntry::new(42)).await.unwrap();
cache.invalidate(&"key".to_string()).await.unwrap();
assert!(cache.get(&"key".to_string()).await.unwrap().is_none());
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn fallback_clear() {
let cache = make_fallback_cache(FallbackPromotionPolicy::always());
cache.insert("key".to_string(), CacheEntry::new(42)).await.unwrap();
cache.clear().await.unwrap();
assert!(cache.get(&"key".to_string()).await.unwrap().is_none());
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn fallback_len() {
let cache = make_fallback_cache(FallbackPromotionPolicy::always());
assert_eq!(cache.len().await.expect("len should return Ok"), 0);
cache.insert("key".to_string(), CacheEntry::new(42)).await.unwrap();
assert_eq!(cache.len().await.expect("len should return Ok"), 1);
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn fallback_get_triggers_background_refresh() {
let clock = Clock::new_frozen();
let primary_mock = MockCache::<String, i32>::new();
let old_time = clock.system_time() - Duration::from_secs(120);
let entry = CacheEntry::expires_at(42, Duration::from_secs(300), old_time);
primary_mock.insert("key".to_string(), entry).await.unwrap();
let fallback_mock = MockCache::<String, i32>::new();
let telemetry = TelemetryConfig::new().build();
let refresh = crate::refresh::TimeToRefresh::new(Duration::from_secs(30), anyspawn::Spawner::new_tokio());
let primary = CacheWrapper::new("primary", primary_mock, clock.clone(), None, telemetry.clone());
let fc = FallbackCache::new(
"test",
primary,
fallback_mock,
FallbackPromotionPolicy::always(),
clock,
Some(refresh),
telemetry,
);
let result = fc.get(&"key".to_string()).await.unwrap();
assert!(result.is_some());
assert_eq!(*result.unwrap().value(), 42);
tokio::time::sleep(Duration::from_millis(50)).await;
}
#[test]
fn do_refresh_without_time_to_refresh_is_noop() {
let cache = make_fallback_cache(FallbackPromotionPolicy::always());
cache.do_refresh(&"key".to_string());
}
}