use std::collections::HashSet;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use anyspawn::Spawner;
use cachet_tier::{CacheEntry, CacheTier};
use parking_lot::Mutex;
use crate::fallback::{FallbackCache, FallbackCacheInner};
use crate::telemetry::ext::ClockExt;
use crate::telemetry::{CacheActivity, CacheOperation};
pub struct TimeToRefresh<K> {
pub duration: Duration,
pub(crate) spawner: Spawner,
in_flight: Mutex<HashSet<K>>,
}
impl<K> Debug for TimeToRefresh<K> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TimeToRefresh")
.field("duration", &self.duration)
.finish_non_exhaustive()
}
}
impl<K> TimeToRefresh<K>
where
K: Clone + Eq + Hash + Send + 'static,
{
#[must_use]
pub fn new(duration: Duration, spawner: Spawner) -> Self {
Self {
duration,
spawner,
in_flight: Mutex::new(HashSet::new()),
}
}
pub(crate) fn should_refresh(&self, cached_at: SystemTime, now: SystemTime) -> bool {
match now.duration_since(cached_at) {
Ok(elapsed) => elapsed >= self.duration,
Err(_) => true, }
}
pub(crate) fn try_start_refresh(&self, key: &K) -> bool {
self.in_flight.lock().insert(key.clone())
}
pub(crate) fn finish_refresh(&self, key: &K) {
self.in_flight.lock().remove(key);
}
}
struct DropGuard<F: FnMut()>(F);
impl<F: FnMut()> Drop for DropGuard<F> {
fn drop(&mut self) {
(self.0)();
}
}
impl<K, V, P, F> FallbackCache<K, V, P, F>
where
K: Clone + Eq + Hash + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
P: CacheTier<K, V> + Send + Sync + 'static,
F: CacheTier<K, V> + Send + Sync + 'static,
{
pub fn do_refresh(&self, key: &K) {
if let Some(refresh) = &self.inner.refresh {
if !refresh.try_start_refresh(key) {
return;
}
let inner = Arc::clone(&self.inner);
let key = key.clone();
drop(refresh.spawner.spawn(async move {
let _guard = DropGuard({
let inner = Arc::clone(&inner);
let key = key.clone();
move || {
if let Some(refresh) = &inner.refresh {
refresh.finish_refresh(&key);
}
}
});
inner.fetch_and_promote(key).await;
}));
}
}
}
impl<K, V, P, F> FallbackCacheInner<K, V, P, F>
where
K: Clone + Eq + Hash + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
P: CacheTier<K, V> + Send + Sync + 'static,
F: CacheTier<K, V> + Send + Sync + 'static,
{
pub(crate) async fn fetch_and_promote(&self, key: K) {
let timed = self.clock.timed_async(self.fallback.get(&key)).await;
match timed.result {
Ok(Some(value)) => self.handle_fallback_hit(key, value, timed.duration).await,
Ok(None) | Err(_) => self.handle_fallback_miss(timed.duration),
}
}
async fn handle_fallback_hit(&self, key: K, value: CacheEntry<V>, fetch_duration: Duration) {
self.telemetry
.record(self.name, CacheOperation::Get, CacheActivity::RefreshHit, fetch_duration);
self.promote_to_primary(key, value).await;
}
async fn promote_to_primary(&self, key: K, value: CacheEntry<V>) {
let _ = self.primary.insert(key, value).await;
}
fn handle_fallback_miss(&self, duration: Duration) {
self.telemetry
.record(self.name, CacheOperation::Get, CacheActivity::RefreshMiss, duration);
}
}
#[cfg(test)]
mod tests {
use tick::Clock;
use super::*;
fn create_refresh() -> TimeToRefresh<String> {
TimeToRefresh::new(Duration::from_secs(60), Spawner::new_tokio())
}
#[test]
fn time_to_refresh_new() {
let refresh = create_refresh();
assert_eq!(refresh.duration, Duration::from_secs(60));
}
#[test]
fn time_to_refresh_should_refresh_false_when_recent() {
let refresh = create_refresh();
let clock = Clock::new_frozen();
let now = clock.system_time();
assert!(!refresh.should_refresh(now, now));
}
#[test]
fn time_to_refresh_try_start_refresh() {
let refresh = create_refresh();
let key = "key1".to_string();
assert!(refresh.try_start_refresh(&key));
assert!(!refresh.try_start_refresh(&key));
let key2 = "key2".to_string();
assert!(refresh.try_start_refresh(&key2));
}
#[test]
fn time_to_refresh_finish_refresh() {
let refresh = create_refresh();
let key = "key1".to_string();
assert!(refresh.try_start_refresh(&key));
assert!(!refresh.try_start_refresh(&key));
refresh.finish_refresh(&key);
assert!(refresh.try_start_refresh(&key));
}
#[test]
fn time_to_refresh_finish_refresh_nonexistent_key() {
let refresh = create_refresh();
refresh.finish_refresh(&"nonexistent".to_string());
assert!(refresh.try_start_refresh(&"other_key".to_string()));
}
#[test]
fn time_to_refresh_should_refresh_true_when_clock_goes_backward() {
let refresh: TimeToRefresh<String> = TimeToRefresh::new(Duration::from_secs(300), Spawner::new_tokio());
let clock = Clock::new_frozen();
let now = clock.system_time();
let cached_at = now + Duration::from_secs(3600);
assert!(
refresh.should_refresh(cached_at, now),
"should return true when system time goes backward"
);
}
#[test]
fn time_to_refresh_should_refresh_true_after_duration() {
let refresh: TimeToRefresh<String> = TimeToRefresh::new(Duration::from_secs(60), Spawner::new_tokio());
let clock = Clock::new_frozen();
let now = clock.system_time();
let cached_at = now - Duration::from_secs(61);
assert!(refresh.should_refresh(cached_at, now));
}
#[test]
fn time_to_refresh_duration_access() {
let refresh: TimeToRefresh<String> = TimeToRefresh::new(Duration::from_secs(300), Spawner::new_tokio());
assert_eq!(refresh.duration, Duration::from_secs(300));
}
#[test]
fn time_to_refresh_concurrent_keys() {
let refresh = create_refresh();
let key1 = "key1".to_string();
let key2 = "key2".to_string();
let key3 = "key3".to_string();
assert!(refresh.try_start_refresh(&key1));
assert!(refresh.try_start_refresh(&key2));
assert!(refresh.try_start_refresh(&key3));
assert!(!refresh.try_start_refresh(&key1));
assert!(!refresh.try_start_refresh(&key2));
assert!(!refresh.try_start_refresh(&key3));
refresh.finish_refresh(&key2);
assert!(!refresh.try_start_refresh(&key1));
assert!(refresh.try_start_refresh(&key2));
assert!(!refresh.try_start_refresh(&key3));
}
#[test]
fn time_to_refresh_debug() {
let refresh = create_refresh();
let debug_str = format!("{refresh:?}");
assert!(debug_str.contains("TimeToRefresh"), "got: {debug_str}");
assert!(debug_str.contains("duration"), "got: {debug_str}");
}
}
#[cfg(test)]
mod fetch_and_promote_tests {
use cachet_tier::MockCache;
use testing_aids::MetricTester;
use tick::Clock;
use super::*;
use crate::InsertPolicy;
use crate::telemetry::{TelemetryConfig, attributes};
use crate::wrapper::CacheWrapper;
fn block_on<F: std::future::Future>(f: F) -> F::Output {
futures::executor::block_on(f)
}
fn build_fallback_cache<P, F: CacheTier<String, i32> + 'static>(primary: P, fallback: F) -> FallbackCache<String, i32, P, F> {
let clock = Clock::new_frozen();
let telemetry = TelemetryConfig::new().build();
FallbackCache::new("test", primary, fallback, clock, None, telemetry)
}
#[cfg_attr(miri, ignore)] #[test]
fn fallback_miss_records_refresh_miss_telemetry() {
block_on(async {
let tester = MetricTester::new();
let clock = Clock::new_frozen();
let telemetry = TelemetryConfig::new().with_metrics(tester.meter_provider()).build();
let primary = MockCache::<String, i32>::new();
let fallback = MockCache::<String, i32>::new();
let fc = FallbackCache::new("test", primary, fallback, clock, None, telemetry);
fc.inner.fetch_and_promote("missing".to_string()).await;
tester.assert_attributes_contain(&[opentelemetry::KeyValue::new(attributes::CACHE_ACTIVITY_NAME, "cache.refresh_miss")]);
});
}
#[test]
fn fallback_error() {
block_on(async {
let primary = MockCache::<String, i32>::new();
let fallback = MockCache::<String, i32>::new();
fallback.fail_when(|_| true);
let fc = build_fallback_cache(primary, fallback);
fc.inner.fetch_and_promote("key".to_string()).await;
});
}
#[test]
fn hit_no_promote() {
block_on(async {
let primary_mock = MockCache::<String, i32>::new();
let primary_check = primary_mock.clone();
let fallback = MockCache::<String, i32>::new();
fallback.insert("key".to_string(), CacheEntry::new(42)).await.unwrap();
let clock = Clock::new_frozen();
let telemetry = TelemetryConfig::new().build();
let primary = CacheWrapper::new(
"primary",
primary_mock,
clock.clone(),
None,
telemetry.clone(),
InsertPolicy::never(),
);
let fc = FallbackCache::new("test", primary, fallback, clock, None, telemetry);
fc.inner.fetch_and_promote("key".to_string()).await;
assert!(primary_check.get(&"key".to_string()).await.unwrap().is_none());
});
}
#[test]
fn hit_with_promote() {
block_on(async {
let primary = MockCache::<String, i32>::new();
let fallback = MockCache::<String, i32>::new();
fallback.insert("key".to_string(), CacheEntry::new(42)).await.unwrap();
let fc = build_fallback_cache(primary.clone(), fallback);
fc.inner.fetch_and_promote("key".to_string()).await;
let result = primary.get(&"key".to_string()).await.unwrap();
assert!(result.is_some());
assert_eq!(*result.unwrap().value(), 42);
});
}
#[test]
fn promote_error() {
block_on(async {
let primary = MockCache::<String, i32>::new();
primary.fail_when(|_| true);
let fallback = MockCache::<String, i32>::new();
fallback.insert("key".to_string(), CacheEntry::new(42)).await.unwrap();
let fc = build_fallback_cache(primary, fallback);
fc.inner.fetch_and_promote("key".to_string()).await;
});
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn panic_in_refresh_does_not_leave_key_stuck_in_flight() {
let primary = MockCache::<String, i32>::new();
let fallback = MockCache::<String, i32>::new();
fallback.fail_when(|_| panic!("simulated panic in fallback get"));
let clock = Clock::new_frozen();
let telemetry = TelemetryConfig::new().build();
let refresh = TimeToRefresh::new(Duration::from_secs(60), Spawner::new_tokio());
let fc = FallbackCache::new("test", primary, fallback, clock, Some(refresh), telemetry);
let key = "panic_key".to_string();
fc.do_refresh(&key);
tokio::time::sleep(Duration::from_millis(100)).await;
let can_refresh_again = fc
.inner
.refresh
.as_ref()
.expect("refresh should be configured")
.try_start_refresh(&key);
assert!(
can_refresh_again,
"key should not be stuck in in_flight after a panic in fetch_and_promote"
);
}
type MockWrapper = CacheWrapper<String, i32, MockCache<String, i32>>;
fn make_wrapper(mock: MockCache<String, i32>) -> MockWrapper {
let clock = Clock::new_frozen();
let telemetry = TelemetryConfig::new().build();
CacheWrapper::new("test_primary", mock, clock, None, telemetry, InsertPolicy::default())
}
fn build_mock_fallback_cache(
primary: MockWrapper,
fallback: MockCache<String, i32>,
) -> FallbackCache<String, i32, MockWrapper, MockCache<String, i32>> {
let clock = Clock::new_frozen();
let telemetry = TelemetryConfig::new().build();
FallbackCache::new("test", primary, fallback, clock, None, telemetry)
}
#[test]
fn do_refresh_no_refresh_configured() {
let primary = make_wrapper(MockCache::new());
let fallback = MockCache::<String, i32>::new();
let fc = build_mock_fallback_cache(primary, fallback);
fc.do_refresh(&"key".to_string());
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn do_refresh_already_in_flight_returns_early() {
let primary = MockCache::<String, i32>::new();
let fallback = MockCache::<String, i32>::new();
let clock = Clock::new_frozen();
let telemetry = TelemetryConfig::new().build();
let refresh = TimeToRefresh::new(Duration::from_secs(60), Spawner::new_tokio());
let primary_wrapper = CacheWrapper::new("primary", primary, clock.clone(), None, telemetry.clone(), InsertPolicy::default());
let fc = FallbackCache::new("test", primary_wrapper, fallback, clock, Some(refresh), telemetry);
let key = "key".to_string();
fc.do_refresh(&key);
fc.do_refresh(&key);
tokio::time::sleep(Duration::from_millis(50)).await;
}
#[test]
fn drop_guard_runs_on_drop() {
use std::sync::atomic::{AtomicBool, Ordering};
let ran = Arc::new(AtomicBool::new(false));
let ran_clone = Arc::clone(&ran);
{
let _guard = DropGuard(move || {
ran_clone.store(true, Ordering::SeqCst);
});
}
assert!(ran.load(Ordering::SeqCst));
}
}