use std::future::Future;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use futures::FutureExt;
use futures::future::{BoxFuture, Shared};
type SharedFut<V, E> = Shared<BoxFuture<'static, Result<V, Arc<E>>>>;
enum State<V, E> {
Empty,
Current(V, clock::Instant),
Refreshing {
previous: Option<(V, clock::Instant)>,
future: SharedFut<V, E>,
},
}
impl<V: Clone, E> State<V, E> {
fn fresh_value(&self, ttl: Duration, refresh_window: Duration) -> Option<V> {
let fresh_threshold = ttl - refresh_window;
match self {
Self::Current(value, cached_at) => {
if clock::now().duration_since(*cached_at) < fresh_threshold {
Some(value.clone())
} else {
None
}
}
Self::Refreshing {
previous: Some((value, cached_at)),
..
} => {
if clock::now().duration_since(*cached_at) < fresh_threshold {
Some(value.clone())
} else {
None
}
}
_ => None,
}
}
}
struct CacheInner<V, E> {
state: State<V, E>,
generation: u64,
}
enum Action<V, E> {
Return(V),
Wait(SharedFut<V, E>),
}
pub struct BackgroundCache<V, E> {
inner: Arc<Mutex<CacheInner<V, E>>>,
ttl: Duration,
refresh_window: Duration,
}
impl<V, E> std::fmt::Debug for BackgroundCache<V, E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BackgroundCache")
.field("ttl", &self.ttl)
.field("refresh_window", &self.refresh_window)
.finish_non_exhaustive()
}
}
impl<V, E> Clone for BackgroundCache<V, E> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
ttl: self.ttl,
refresh_window: self.refresh_window,
}
}
}
impl<V, E> BackgroundCache<V, E>
where
V: Clone + Send + Sync + 'static,
E: Send + Sync + 'static,
{
pub fn new(ttl: Duration, refresh_window: Duration) -> Self {
assert!(
refresh_window < ttl,
"refresh_window ({refresh_window:?}) must be less than ttl ({ttl:?})"
);
Self {
inner: Arc::new(Mutex::new(CacheInner {
state: State::Empty,
generation: 0,
})),
ttl,
refresh_window,
}
}
pub fn try_get(&self) -> Option<V> {
let cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
cache.state.fresh_value(self.ttl, self.refresh_window)
}
pub async fn get<F, Fut>(&self, fetch: F) -> Result<V, Arc<E>>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = Result<V, E>> + Send + 'static,
{
{
let cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
if let Some(value) = cache.state.fresh_value(self.ttl, self.refresh_window) {
return Ok(value);
}
}
let mut fetch = Some(fetch);
let action = {
let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
self.determine_action(&mut cache, &mut fetch)
};
match action {
Action::Return(value) => Ok(value),
Action::Wait(fut) => fut.await,
}
}
pub fn seed(&self, value: V) {
let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
cache.state = State::Current(value, clock::now());
}
pub fn invalidate(&self) {
let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
cache.state = State::Empty;
cache.generation += 1;
}
fn determine_action<F, Fut>(
&self,
cache: &mut CacheInner<V, E>,
fetch: &mut Option<F>,
) -> Action<V, E>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = Result<V, E>> + Send + 'static,
{
match &cache.state {
State::Empty => {
let f = fetch
.take()
.expect("fetch closure required for empty cache");
let shared = self.start_fetch(cache, f, None);
Action::Wait(shared)
}
State::Current(value, cached_at) => {
let elapsed = clock::now().duration_since(*cached_at);
if elapsed < self.ttl - self.refresh_window {
Action::Return(value.clone())
} else if elapsed < self.ttl {
let value = value.clone();
let previous = Some((value.clone(), *cached_at));
if let Some(f) = fetch.take() {
drop(self.start_fetch(cache, f, previous));
}
Action::Return(value)
} else {
let previous = Some((value.clone(), *cached_at));
let f = fetch
.take()
.expect("fetch closure required for expired cache");
let shared = self.start_fetch(cache, f, previous);
Action::Wait(shared)
}
}
State::Refreshing { previous, future } => {
if let Some(result) = future.peek() {
match result {
Ok(value) => {
cache.state = State::Current(value.clone(), clock::now());
}
Err(_) => {
cache.state = match previous.clone() {
Some((v, t)) => State::Current(v, t),
None => State::Empty,
};
}
}
return self.determine_action(cache, fetch);
}
if let Some((value, cached_at)) = previous {
if clock::now().duration_since(*cached_at) < self.ttl {
Action::Return(value.clone())
} else {
Action::Wait(future.clone())
}
} else {
Action::Wait(future.clone())
}
}
}
}
fn start_fetch<F, Fut>(
&self,
cache: &mut CacheInner<V, E>,
fetch: F,
previous: Option<(V, clock::Instant)>,
) -> SharedFut<V, E>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = Result<V, E>> + Send + 'static,
{
let generation = cache.generation;
let shared = async move { (fetch)().await.map_err(Arc::new) }
.boxed()
.shared();
let inner = self.inner.clone();
let fut_for_spawn = shared.clone();
tokio::spawn(async move {
let result = fut_for_spawn.await;
let mut cache = inner.lock().unwrap_or_else(|e| e.into_inner());
if cache.generation != generation {
return;
}
match result {
Ok(value) => {
cache.state = State::Current(value, clock::now());
}
Err(_) => {
let prev = match &cache.state {
State::Refreshing { previous, .. } => previous.clone(),
_ => None,
};
cache.state = match prev {
Some((v, t)) => State::Current(v, t),
None => State::Empty,
};
}
}
});
cache.state = State::Refreshing {
previous,
future: shared.clone(),
};
shared
}
}
#[cfg(test)]
pub mod clock {
use std::cell::Cell;
use std::time::Duration;
pub use std::time::Instant;
thread_local! {
static MOCK_NOW: Cell<Option<Instant>> = const { Cell::new(None) };
}
pub fn now() -> Instant {
MOCK_NOW.with(|mock| mock.get().unwrap_or_else(Instant::now))
}
pub fn advance_by(duration: Duration) {
MOCK_NOW.with(|mock| {
let current = mock.get().unwrap_or_else(Instant::now);
mock.set(Some(current + duration));
});
}
#[allow(dead_code)]
pub fn clear_mock() {
MOCK_NOW.with(|mock| mock.set(None));
}
}
#[cfg(not(test))]
mod clock {
pub use std::time::Instant;
pub fn now() -> Instant {
Instant::now()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
struct TestError(String);
impl std::fmt::Display for TestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
const TEST_TTL: Duration = Duration::from_secs(30);
const TEST_REFRESH_WINDOW: Duration = Duration::from_secs(5);
fn new_cache() -> BackgroundCache<String, TestError> {
BackgroundCache::new(TEST_TTL, TEST_REFRESH_WINDOW)
}
fn ok_fetcher(
counter: Arc<AtomicUsize>,
value: &str,
) -> impl FnOnce() -> BoxFuture<'static, Result<String, TestError>> + Send + 'static {
let value = value.to_string();
move || {
counter.fetch_add(1, Ordering::SeqCst);
async move { Ok(value) }.boxed()
}
}
fn err_fetcher(
counter: Arc<AtomicUsize>,
msg: &str,
) -> impl FnOnce() -> BoxFuture<'static, Result<String, TestError>> + Send + 'static {
let msg = msg.to_string();
move || {
counter.fetch_add(1, Ordering::SeqCst);
async move { Err(TestError(msg)) }.boxed()
}
}
#[tokio::test]
async fn test_basic_caching() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
let v1 = cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
assert_eq!(v1, "hello");
assert_eq!(count.load(Ordering::SeqCst), 1);
let v2 = cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
assert_eq!(v2, "hello");
assert_eq!(count.load(Ordering::SeqCst), 1);
let v3 = cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
assert_eq!(v3, "hello");
assert_eq!(count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_try_get_returns_none_when_empty() {
let cache: BackgroundCache<String, TestError> = new_cache();
assert!(cache.try_get().is_none());
}
#[tokio::test]
async fn test_try_get_returns_value_when_fresh() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
assert_eq!(cache.try_get().unwrap(), "hello");
}
#[tokio::test]
async fn test_try_get_returns_none_in_refresh_window() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
clock::advance_by(Duration::from_secs(26));
assert!(cache.try_get().is_none());
}
#[tokio::test]
async fn test_ttl_expiration() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); assert_eq!(count.load(Ordering::SeqCst), 1);
clock::advance_by(Duration::from_secs(31));
let v = cache.get(ok_fetcher(count.clone(), "v2")).await.unwrap();
assert_eq!(v, "v2");
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_invalidate_forces_refetch() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); assert_eq!(count.load(Ordering::SeqCst), 1);
cache.invalidate();
let v = cache.get(ok_fetcher(count.clone(), "v2")).await.unwrap();
assert_eq!(v, "v2");
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_concurrent_get_single_fetch() {
let cache = Arc::new(new_cache());
let count = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for _ in 0..10 {
let cache = cache.clone();
let count = count.clone();
handles.push(tokio::spawn(async move {
cache.get(ok_fetcher(count, "hello")).await.unwrap()
}));
}
let results: Vec<String> = futures::future::try_join_all(handles).await.unwrap();
for r in &results {
assert_eq!(r, "hello");
}
assert_eq!(count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_background_refresh_in_window() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); assert_eq!(count.load(Ordering::SeqCst), 1);
clock::advance_by(Duration::from_secs(26));
let v = cache.get(ok_fetcher(count.clone(), "v2")).await.unwrap();
assert_eq!(v, "v1"); assert_eq!(count.load(Ordering::SeqCst), 1);
clock::advance_by(Duration::from_secs(30));
let v = cache.get(ok_fetcher(count.clone(), "v3")).await.unwrap();
assert_eq!(count.load(Ordering::SeqCst), 2);
assert_eq!(v, "v2"); }
#[tokio::test]
async fn test_no_duplicate_background_refreshes() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); assert_eq!(count.load(Ordering::SeqCst), 1);
clock::advance_by(Duration::from_secs(26));
for _ in 0..5 {
let v = cache.get(ok_fetcher(count.clone(), "v2")).await.unwrap();
assert_eq!(v, "v1");
}
clock::advance_by(Duration::from_secs(30));
cache.get(ok_fetcher(count.clone(), "v3")).await.unwrap();
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_background_refresh_error_preserves_cache() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); assert_eq!(count.load(Ordering::SeqCst), 1);
clock::advance_by(Duration::from_secs(26));
let v = cache.get(err_fetcher(count.clone(), "fail")).await.unwrap();
assert_eq!(v, "v1");
let v = cache.get(err_fetcher(count.clone(), "fail")).await.unwrap();
assert_eq!(v, "v1");
clock::advance_by(Duration::from_secs(30));
let result = cache.get(err_fetcher(count.clone(), "fail again")).await;
assert!(result.is_err());
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_invalidation_during_fetch_prevents_stale_update() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap();
clock::advance_by(Duration::from_secs(26));
cache.get(ok_fetcher(count.clone(), "stale")).await.unwrap();
cache.invalidate();
clock::advance_by(Duration::from_secs(30));
let v = cache.get(ok_fetcher(count.clone(), "fresh")).await.unwrap();
assert_eq!(v, "fresh");
}
fn poison_cache(cache: &BackgroundCache<String, TestError>) {
let inner = cache.inner.clone();
let handle = std::thread::spawn(move || {
let _guard = inner.lock().unwrap();
panic!("intentional panic to poison mutex");
});
let _ = handle.join();
assert!(cache.inner.lock().is_err(), "mutex should be poisoned");
}
#[tokio::test]
async fn test_try_get_recovers_from_poisoned_lock() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
poison_cache(&cache);
let result = cache.try_get();
let _ = result;
}
#[tokio::test]
async fn test_get_recovers_from_poisoned_lock() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
poison_cache(&cache);
let result = cache.get(ok_fetcher(count.clone(), "recovered")).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "recovered");
}
#[tokio::test]
async fn test_seed_recovers_from_poisoned_lock() {
let cache = new_cache();
poison_cache(&cache);
cache.seed("seeded".to_string());
}
#[tokio::test]
async fn test_invalidate_recovers_from_poisoned_lock() {
let cache = new_cache();
let count = Arc::new(AtomicUsize::new(0));
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
poison_cache(&cache);
cache.invalidate();
}
}