use std::hash::Hash;
use std::marker::PhantomData;
use std::time::Duration;
use cachet_tier::{CacheTier, SizeError};
use tick::Clock;
use crate::cache::CacheName;
use crate::telemetry::CacheTelemetry;
use crate::{CacheEntry, Error, InsertPolicy};
#[derive(Debug)]
pub(crate) struct CacheWrapper<K, V, CT> {
pub(crate) name: CacheName,
pub(crate) inner: CT,
pub(crate) clock: Clock,
pub(crate) ttl: Option<Duration>,
pub(crate) telemetry: CacheTelemetry,
pub(crate) policy: InsertPolicy<V>,
pub(crate) fallback: bool,
_phantom: PhantomData<(K, V)>,
}
impl<K, V, CT> CacheWrapper<K, V, CT> {
pub(crate) fn new(
name: CacheName,
inner: CT,
clock: Clock,
ttl: Option<Duration>,
telemetry: CacheTelemetry,
policy: InsertPolicy<V>,
fallback: bool,
) -> Self {
Self {
name,
inner,
clock,
ttl,
telemetry,
policy,
fallback,
_phantom: PhantomData,
}
}
#[cfg(test)]
#[must_use]
pub(crate) fn name(&self) -> CacheName {
self.name
}
}
impl<K, V, CT> CacheWrapper<K, V, CT>
where
K: Clone + Eq + Hash + Send + Sync,
V: Clone + Send + Sync,
CT: CacheTier<K, V> + Send + Sync,
{
fn is_expired(&self, entry: &CacheEntry<V>) -> bool {
let ttl = entry.ttl().or(self.ttl);
if let Some(ttl) = ttl {
match entry.cached_at() {
Some(cached_at) => match self.clock.system_time().duration_since(cached_at) {
Ok(elapsed) => elapsed > ttl,
Err(_) => true, },
None => true, }
} else {
false
}
}
fn handle_get_result(&self, value: Option<CacheEntry<V>>, duration: Duration) -> Option<CacheEntry<V>> {
if let Some(entry) = value {
if self.is_expired(&entry) {
self.telemetry.record_expired(self.name, duration, self.fallback);
None
} else {
self.telemetry.record_hit(self.name, duration, self.fallback);
Some(entry)
}
} else {
self.telemetry.record_miss(self.name, duration, self.fallback);
None
}
}
}
impl<K, V, CT> CacheTier<K, V> for CacheWrapper<K, V, CT>
where
K: Clone + Eq + Hash + Send + Sync,
V: Clone + Send + Sync,
CT: CacheTier<K, V> + Send + Sync,
{
async fn get(&self, key: &K) -> Result<Option<CacheEntry<V>>, Error> {
let watch = self.clock.stopwatch();
match self.inner.get(key).await {
Ok(value) => Ok(self.handle_get_result(value, watch.elapsed())),
Err(e) => {
self.telemetry.record_get_error(self.name, watch.elapsed(), self.fallback);
Err(e)
}
}
}
async fn insert(&self, key: K, mut entry: CacheEntry<V>) -> Result<(), Error> {
entry.ensure_cached_at(self.clock.system_time());
if !self.policy.should_insert(&entry) {
self.telemetry.record_insert_rejected(self.name, self.fallback);
return Ok(());
}
let watch = self.clock.stopwatch();
let result = self.inner.insert(key, entry).await;
match &result {
Ok(()) => self.telemetry.record_inserted(self.name, watch.elapsed(), self.fallback),
Err(_) => self.telemetry.record_insert_error(self.name, watch.elapsed(), self.fallback),
}
result
}
async fn invalidate(&self, key: &K) -> Result<(), Error> {
let watch = self.clock.stopwatch();
let result = self.inner.invalidate(key).await;
match &result {
Ok(()) => self.telemetry.record_invalidated(self.name, watch.elapsed(), self.fallback),
Err(_) => self.telemetry.record_invalidate_error(self.name, watch.elapsed(), self.fallback),
}
result
}
async fn clear(&self) -> Result<(), Error> {
let watch = self.clock.stopwatch();
let result = self.inner.clear().await;
match &result {
Ok(()) => self.telemetry.record_cleared(self.name, watch.elapsed(), self.fallback),
Err(_) => self.telemetry.record_clear_error(self.name, watch.elapsed(), self.fallback),
}
result
}
async fn len(&self) -> Result<u64, SizeError> {
self.inner.len().await
}
}
#[cfg(test)]
mod tests {
use cachet_tier::MockCache;
use super::*;
#[test]
fn wrapper_is_expired_with_no_ttl_returns_false() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
let telemetry = CacheTelemetry::new();
let wrapper: CacheWrapper<String, i32, _> =
CacheWrapper::new("test", inner, clock, None, telemetry, InsertPolicy::default(), false);
let entry = CacheEntry::new(42);
assert!(!wrapper.is_expired(&entry));
}
#[test]
fn wrapper_is_expired_with_ttl_without_cached_at_returns_true() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
let telemetry = CacheTelemetry::new();
let wrapper: CacheWrapper<String, i32, _> = CacheWrapper::new(
"test",
inner,
clock,
Some(Duration::from_mins(1)),
telemetry,
InsertPolicy::default(),
false,
);
let entry = CacheEntry::new(42);
assert!(wrapper.is_expired(&entry));
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn insert_preserves_per_entry_ttl_over_tier_ttl() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
let inner_check = inner.clone();
let telemetry = CacheTelemetry::new();
let tier_ttl = Duration::from_mins(1);
let entry_ttl = Duration::from_secs(30);
let wrapper: CacheWrapper<String, i32, _> = CacheWrapper::new(
"test",
inner,
clock.clone(),
Some(tier_ttl),
telemetry,
InsertPolicy::default(),
false,
);
let entry = CacheEntry::expires_at(42, entry_ttl, clock.system_time());
wrapper.insert("key".to_string(), entry).await.unwrap();
let stored = inner_check.get(&"key".to_string()).await.unwrap().unwrap();
assert_eq!(stored.ttl(), Some(entry_ttl));
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn insert_without_tier_ttl_leaves_entry_ttl_unset() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
let inner_check = inner.clone();
let telemetry = CacheTelemetry::new();
let wrapper: CacheWrapper<String, i32, _> =
CacheWrapper::new("test", inner, clock, None, telemetry, InsertPolicy::default(), false);
let entry = CacheEntry::new(42);
wrapper.insert("key".to_string(), entry).await.unwrap();
let stored = inner_check.get(&"key".to_string()).await.unwrap().unwrap();
assert!(stored.ttl().is_none());
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn insert_with_tier_ttl_leaves_entry_ttl_unset() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
let inner_check = inner.clone();
let telemetry = CacheTelemetry::new();
let tier_ttl = Duration::from_mins(1);
let wrapper: CacheWrapper<String, i32, _> =
CacheWrapper::new("test", inner, clock, Some(tier_ttl), telemetry, InsertPolicy::default(), false);
let entry = CacheEntry::new(42);
wrapper.insert("key".to_string(), entry).await.unwrap();
let stored = inner_check.get(&"key".to_string()).await.unwrap().unwrap();
assert!(stored.ttl().is_none());
}
#[test]
fn wrapper_is_expired_when_system_time_goes_backward() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
let telemetry = CacheTelemetry::new();
let wrapper: CacheWrapper<String, i32, _> = CacheWrapper::new(
"test",
inner,
clock.clone(),
Some(Duration::from_mins(1)),
telemetry,
InsertPolicy::default(),
false,
);
let entry = CacheEntry::expires_at(42, Duration::from_mins(1), clock.system_time() + Duration::from_hours(1));
assert!(wrapper.is_expired(&entry));
}
#[test]
fn wrapper_is_not_expired_when_elapsed_equals_ttl() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
let telemetry = CacheTelemetry::new();
let ttl = Duration::from_mins(1);
let wrapper: CacheWrapper<String, i32, _> =
CacheWrapper::new("test", inner, clock.clone(), Some(ttl), telemetry, InsertPolicy::default(), false);
let entry = CacheEntry::expires_at(42, ttl, clock.system_time() - ttl);
assert!(!wrapper.is_expired(&entry));
}
#[test]
fn mock_wrapper_new_and_accessors() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
let telemetry = CacheTelemetry::new();
let wrapper: CacheWrapper<String, i32, _> =
CacheWrapper::new("mock_test", inner, clock, None, telemetry, InsertPolicy::default(), false);
assert_eq!(wrapper.name(), "mock_test");
}
#[test]
fn mock_wrapper_handle_get_result_none() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
let telemetry = CacheTelemetry::new();
let wrapper: CacheWrapper<String, i32, _> =
CacheWrapper::new("test", inner, clock, None, telemetry, InsertPolicy::default(), false);
let result = wrapper.handle_get_result(None, Duration::from_secs(0));
assert!(result.is_none());
}
#[test]
fn mock_wrapper_handle_get_result_expired() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
let telemetry = CacheTelemetry::new();
let wrapper: CacheWrapper<String, i32, _> = CacheWrapper::new(
"test",
inner,
clock,
Some(Duration::from_mins(1)),
telemetry,
InsertPolicy::default(),
false,
);
let entry = CacheEntry::new(42);
let result = wrapper.handle_get_result(Some(entry), Duration::from_secs(0));
assert!(result.is_none());
}
#[test]
fn mock_wrapper_handle_get_result_valid() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
let telemetry = CacheTelemetry::new();
let wrapper: CacheWrapper<String, i32, _> =
CacheWrapper::new("test", inner, clock, None, telemetry, InsertPolicy::default(), false);
let entry = CacheEntry::new(42);
let result = wrapper.handle_get_result(Some(entry), Duration::from_secs(0));
assert!(result.is_some());
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn mock_wrapper_get_insert_invalidate_clear() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
let telemetry = CacheTelemetry::new();
let wrapper: CacheWrapper<String, i32, _> =
CacheWrapper::new("test", inner, clock, None, telemetry, InsertPolicy::default(), false);
assert!(wrapper.get(&"key".to_string()).await.unwrap().is_none());
wrapper.insert("key".to_string(), CacheEntry::new(42)).await.unwrap();
let entry = wrapper.get(&"key".to_string()).await.unwrap().unwrap();
assert_eq!(*entry.value(), 42);
wrapper.invalidate(&"key".to_string()).await.unwrap();
assert!(wrapper.get(&"key".to_string()).await.unwrap().is_none());
wrapper.insert("a".to_string(), CacheEntry::new(1)).await.unwrap();
wrapper.clear().await.unwrap();
assert!(wrapper.get(&"a".to_string()).await.unwrap().is_none());
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn mock_wrapper_len() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
let telemetry = CacheTelemetry::new();
let wrapper: CacheWrapper<String, i32, _> =
CacheWrapper::new("test", inner, clock, None, telemetry, InsertPolicy::default(), false);
assert_eq!(wrapper.len().await.expect("len should return Ok"), 0);
wrapper.insert("key".to_string(), CacheEntry::new(1)).await.unwrap();
assert_eq!(wrapper.len().await.expect("len should return Ok"), 1);
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn mock_wrapper_get_error() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
inner.fail_when(|op| matches!(op, cachet_tier::CacheOp::Get(_)));
let telemetry = CacheTelemetry::new();
let wrapper: CacheWrapper<String, i32, _> =
CacheWrapper::new("test", inner, clock, None, telemetry, InsertPolicy::default(), false);
let result = wrapper.get(&"key".to_string()).await;
result.unwrap_err();
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn mock_wrapper_insert_error() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
inner.fail_when(|op| matches!(op, cachet_tier::CacheOp::Insert { .. }));
let telemetry = CacheTelemetry::new();
let wrapper: CacheWrapper<String, i32, _> =
CacheWrapper::new("test", inner, clock, None, telemetry, InsertPolicy::default(), false);
let result = wrapper.insert("key".to_string(), CacheEntry::new(1)).await;
result.unwrap_err();
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn mock_wrapper_invalidate_error() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
inner.fail_when(|op| matches!(op, cachet_tier::CacheOp::Invalidate(_)));
let telemetry = CacheTelemetry::new();
let wrapper: CacheWrapper<String, i32, _> =
CacheWrapper::new("test", inner, clock, None, telemetry, InsertPolicy::default(), false);
let result = wrapper.invalidate(&"key".to_string()).await;
result.unwrap_err();
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn mock_wrapper_clear_error() {
let clock = Clock::new_frozen();
let inner = MockCache::<String, i32>::new();
inner.fail_when(|op| matches!(op, cachet_tier::CacheOp::Clear));
let telemetry = CacheTelemetry::new();
let wrapper: CacheWrapper<String, i32, _> =
CacheWrapper::new("test", inner, clock, None, telemetry, InsertPolicy::default(), false);
let result = wrapper.clear().await;
result.unwrap_err();
}
}