#![warn(clippy::pedantic)]
#![warn(clippy::cargo)]
#![warn(
missing_docs,
rustdoc::missing_crate_level_docs,
rustdoc::private_doc_tests
)]
#![deny(
rustdoc::broken_intra_doc_links,
rustdoc::private_intra_doc_links,
rustdoc::invalid_codeblock_attributes,
rustdoc::invalid_rust_codeblocks
)]
#![forbid(unsafe_code)]
use std::fmt::Debug;
use std::future::Future;
use std::sync::{Arc, Weak};
use futures::stream::{AbortHandle, Abortable, Aborted};
use parking_lot::{Mutex, MutexGuard};
use thiserror::Error;
use tokio::sync::broadcast::error::RecvError;
use tokio::sync::broadcast::{self, Receiver, Sender};
#[derive(Debug, PartialEq, Error, Clone)]
pub enum Error<E> {
#[error("The computation for get_or_compute panicked or the Future returned by get_or_compute was dropped: {0}")]
Broadcast(#[from] RecvError),
#[error("Inflight computation returned error value: {0}")]
Computation(E),
#[error("Inflight computation was aborted")]
Aborted(#[from] Aborted),
}
#[derive(Debug, Default)]
pub struct Cached<T, E> {
inner: Arc<Mutex<CachedInner<T, E>>>,
}
impl<T, E> Clone for Cached<T, E> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CachedState<T> {
EmptyCache,
ValueCached(T),
Inflight,
}
impl<T> CachedState<T> {
#[must_use]
pub fn is_inflight(&self) -> bool {
matches!(self, CachedState::Inflight)
}
#[must_use]
pub fn get(&self) -> Option<&T> {
if let CachedState::ValueCached(val) = self {
Some(val)
} else {
None
}
}
#[must_use]
pub fn get_mut(&mut self) -> Option<&mut T> {
if let CachedState::ValueCached(val) = self {
Some(val)
} else {
None
}
}
}
type InflightComputation<T, E> = (AbortHandle, Sender<Result<T, Error<E>>>);
#[derive(Clone, Debug)]
enum CachedInner<T, E> {
CachedValue(T),
EmptyOrInflight(Weak<InflightComputation<T, E>>),
}
impl<T, E> Default for CachedInner<T, E> {
fn default() -> Self {
CachedInner::new()
}
}
impl<T, E> CachedInner<T, E> {
#[must_use]
fn new() -> Self {
CachedInner::EmptyOrInflight(Weak::new())
}
#[must_use]
fn new_with_value(value: T) -> Self {
CachedInner::CachedValue(value)
}
fn invalidate(&mut self) -> Option<T> {
if matches!(self, CachedInner::EmptyOrInflight(_)) {
None
} else if let CachedInner::CachedValue(value) = std::mem::take(self) {
Some(value)
} else {
unreachable!()
}
}
fn is_inflight(&self) -> bool {
self.inflight_weak()
.map_or(false, |weak| weak.strong_count() > 0)
}
fn inflight_waiting_count(&self) -> usize {
self.inflight_arc()
.map_or(0, |arc| arc.1.receiver_count() + 1)
}
fn abort(&mut self) -> bool {
if let Some(arc) = self.inflight_arc() {
arc.0.abort();
*self = CachedInner::new();
true
} else {
false
}
}
#[must_use]
fn is_value_cached(&self) -> bool {
matches!(self, CachedInner::CachedValue(_))
}
#[must_use]
fn inflight_weak(&self) -> Option<&Weak<InflightComputation<T, E>>> {
if let CachedInner::EmptyOrInflight(weak) = self {
Some(weak)
} else {
None
}
}
#[must_use]
fn inflight_arc(&self) -> Option<Arc<InflightComputation<T, E>>> {
self.inflight_weak().and_then(Weak::upgrade)
}
#[must_use]
fn get(&self) -> Option<&T> {
if let CachedInner::CachedValue(value) = self {
Some(value)
} else {
None
}
}
#[must_use]
fn get_receiver(&self) -> Option<Receiver<Result<T, Error<E>>>> {
self.inflight_arc().map(|arc| arc.1.subscribe())
}
}
impl<T, E> Cached<T, E> {
#[must_use]
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(CachedInner::new())),
}
}
#[must_use]
pub fn new_with_value(value: T) -> Self {
Cached {
inner: Arc::new(Mutex::new(CachedInner::new_with_value(value))),
}
}
#[allow(clippy::must_use_candidate)]
pub fn invalidate(&self) -> Option<T> {
self.inner.lock().invalidate()
}
#[must_use]
pub fn is_inflight(&self) -> bool {
self.inner.lock().is_inflight()
}
#[must_use]
pub fn inflight_waiting_count(&self) -> usize {
self.inner.lock().inflight_waiting_count()
}
#[allow(clippy::must_use_candidate)]
pub fn abort(&self) -> bool {
self.inner.lock().abort()
}
#[must_use]
pub fn is_value_cached(&self) -> bool {
self.inner.lock().is_value_cached()
}
}
impl<T: Clone, E> Cached<T, E> {
#[must_use]
pub fn get(&self) -> Option<T> {
self.inner.lock().get().cloned()
}
}
enum GetOrSubscribeResult<'a, T, E> {
Success(Result<T, Error<E>>),
FailureKeepLock(MutexGuard<'a, CachedInner<T, E>>),
}
impl<T, E> Cached<T, E>
where
T: Clone,
E: Clone,
{
#[allow(clippy::await_holding_lock)] pub async fn get_or_compute<Fut>(
&self,
computation: impl FnOnce() -> Fut,
) -> Result<T, Error<E>>
where
Fut: Future<Output = Result<T, E>>,
{
let inner = match self.get_or_subscribe_keep_lock().await {
GetOrSubscribeResult::Success(res) => return res,
GetOrSubscribeResult::FailureKeepLock(lock) => lock,
};
self.compute_with_lock(computation, inner).await.unwrap()
}
pub async fn get_or_subscribe(&self) -> Option<Result<T, Error<E>>> {
if let GetOrSubscribeResult::Success(res) = self.get_or_subscribe_keep_lock().await {
Some(res)
} else {
None
}
}
#[allow(clippy::await_holding_lock)] pub async fn subscribe_or_recompute<Fut>(
&self,
computation: impl FnOnce() -> Fut,
) -> (Option<T>, Result<T, Error<E>>)
where
Fut: Future<Output = Result<T, E>>,
{
let mut inner = self.inner.lock();
if let Some(mut receiver) = inner.get_receiver() {
drop(inner);
(
None,
match receiver.recv().await {
Err(why) => Err(Error::from(why)),
Ok(res) => res,
},
)
} else {
let prev = inner.invalidate();
let result = self.compute_with_lock(computation, inner).await.unwrap();
(prev, result)
}
}
#[allow(clippy::await_holding_lock)] pub async fn force_recompute<Fut>(
&self,
computation: Fut,
) -> (CachedState<T>, Result<T, Error<E>>)
where
Fut: Future<Output = Result<T, E>>,
{
let mut inner = self.inner.lock();
let aborted = inner.abort();
let prev_cache = inner.invalidate();
let prev_state = match (aborted, prev_cache) {
(false, None) => CachedState::EmptyCache,
(false, Some(val)) => CachedState::ValueCached(val),
(true, None) => CachedState::Inflight,
(true, Some(_)) => unreachable!(),
};
let result = self.compute_with_lock(|| computation, inner).await.unwrap();
(prev_state, result)
}
#[allow(clippy::await_holding_lock)] async fn get_or_subscribe_keep_lock(&self) -> GetOrSubscribeResult<'_, T, E> {
let inner = self.inner.lock();
if let CachedInner::CachedValue(value) = &*inner {
return GetOrSubscribeResult::Success(Ok(value.clone()));
}
let Some(mut receiver) = inner.get_receiver() else {
return GetOrSubscribeResult::FailureKeepLock(inner);
};
drop(inner);
let result = receiver.recv().await;
GetOrSubscribeResult::Success(match result {
Err(why) => Err(Error::from(why)),
Ok(res) => res,
})
}
#[allow(clippy::await_holding_lock)] async fn compute_with_lock<'a, Fut>(
&'a self,
computation: impl FnOnce() -> Fut,
mut inner: MutexGuard<'a, CachedInner<T, E>>,
) -> Option<Result<T, Error<E>>>
where
Fut: Future<Output = Result<T, E>>,
{
if inner.is_value_cached() || inner.is_inflight() {
return None;
}
let (tx, _) = broadcast::channel(1);
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let arc = Arc::new((abort_handle, tx));
*inner = CachedInner::EmptyOrInflight(Arc::downgrade(&arc));
drop(inner);
let future = computation();
let res = match Abortable::new(future, abort_registration).await {
Ok(res) => res.map_err(Error::Computation),
Err(aborted) => Err(Error::from(aborted)),
};
{
let mut inner = self.inner.lock();
if !matches!(res, Err(Error::Aborted(_))) {
if let Ok(value) = &res {
*inner = CachedInner::CachedValue(value.clone());
} else {
*inner = CachedInner::new();
}
}
}
if arc.1.receiver_count() > 0 {
arc.1.send(res.clone()).ok();
}
Some(res)
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Notify;
use tokio::task::JoinHandle;
use crate::CachedState;
use super::{Cached, Error};
#[tokio::test]
async fn test_cached() {
let cached = Cached::<_, ()>::new_with_value(12);
assert_eq!(cached.get(), Some(12));
assert!(!cached.is_inflight());
assert!(cached.is_value_cached());
assert_eq!(cached.inflight_waiting_count(), 0);
let cached = Cached::new();
assert_eq!(cached.get(), None);
assert!(!cached.is_inflight());
assert!(!cached.is_value_cached());
assert_eq!(cached.inflight_waiting_count(), 0);
assert_eq!(cached.get_or_compute(|| async { Ok(12) }).await, Ok(12));
assert_eq!(cached.get(), Some(12));
assert_eq!(cached.invalidate(), Some(12));
assert_eq!(cached.get(), None);
assert_eq!(cached.invalidate(), None);
assert_eq!(
cached.get_or_compute(|| async { Err(42) }).await,
Err(Error::Computation(42)),
);
assert_eq!(cached.get(), None);
assert_eq!(cached.get_or_compute(|| async { Ok(1) }).await, Ok(1));
assert_eq!(cached.get(), Some(1));
assert_eq!(cached.get_or_compute(|| async { Ok(32) }).await, Ok(1));
assert_eq!(cached.invalidate(), Some(1));
let (tokio_notify, handle) = setup_inflight_request(Cached::clone(&cached), Ok(30)).await;
assert_eq!(cached.get(), None);
assert!(cached.is_inflight());
assert_eq!(cached.inflight_waiting_count(), 1);
let other_handle = {
let cached = Cached::clone(&cached);
tokio::spawn(async move { cached.get_or_compute(|| async move { Ok(24) }).await })
};
tokio_notify.notify_waiters();
assert_eq!(handle.await.unwrap(), Ok(30));
assert_eq!(other_handle.await.unwrap(), Ok(30));
assert_eq!(cached.get(), Some(30));
}
#[tokio::test]
async fn test_computation_panic() {
let cached = Cached::<_, ()>::new();
let is_panic = {
let cached = Cached::clone(&cached);
tokio::spawn(async move {
cached
.get_or_compute(|| {
panic!("Panic in computation");
#[allow(unreachable_code)]
async {
unreachable!()
}
})
.await
})
}
.await
.expect_err("Should panic")
.is_panic();
assert!(is_panic, "Should panic");
assert_eq!(cached.get(), None);
assert!(!cached.is_inflight());
assert_eq!(cached.inflight_waiting_count(), 0);
assert_eq!(
cached.get_or_compute(|| async move { Ok(21) }).await,
Ok(21),
);
assert_eq!(cached.invalidate(), Some(21));
let is_panic = {
let cached = Cached::clone(&cached);
tokio::spawn(async move {
cached
.get_or_compute(|| async { panic!("Panic in future") })
.await
})
}
.await
.expect_err("Should be panic")
.is_panic();
assert!(is_panic, "Should panic");
assert_eq!(cached.get(), None);
assert!(!cached.is_inflight());
assert_eq!(cached.inflight_waiting_count(), 0);
assert_eq!(
cached.get_or_compute(|| async move { Ok(17) }).await,
Ok(17),
);
assert_eq!(cached.invalidate(), Some(17));
let tokio_notify = Arc::new(Notify::new());
let registered = Arc::new(Notify::new());
let registered_fut = registered.notified();
let panicking_handle = {
let cached = Cached::clone(&cached);
let tokio_notify = Arc::clone(&tokio_notify);
let registered = Arc::clone(®istered);
tokio::spawn(async move {
cached
.get_or_compute(|| async move {
let notify_fut = tokio_notify.notified();
registered.notify_waiters();
notify_fut.await;
panic!("Panic in future")
})
.await
})
};
registered_fut.await;
let waiting_handle = {
let cached = Cached::clone(&cached);
tokio::spawn(async move {
cached
.get_or_compute(|| async {
panic!("Entered computation when another inflight computation should already be running")
})
.await
})
};
while cached.inflight_waiting_count() < 2 {
tokio::task::yield_now().await;
}
tokio_notify.notify_waiters();
assert!(panicking_handle.await.unwrap_err().is_panic());
assert!(matches!(waiting_handle.await, Ok(Err(Error::Broadcast(_)))));
assert_eq!(cached.get(), None);
}
#[tokio::test]
async fn test_computation_drop() {
let cached = Cached::<_, ()>::new();
let computing = Arc::new(Notify::new());
let computing_fut = computing.notified();
let dropping_handle = {
let cached = Cached::clone(&cached);
let computing = Arc::clone(&computing);
tokio::spawn(async move {
cached
.get_or_compute(|| async move {
computing.notify_waiters();
loop {
tokio::time::sleep(Duration::from_secs(1)).await;
}
})
.await
})
};
computing_fut.await;
let waiting_handle = {
let cached = Cached::clone(&cached);
tokio::spawn(async move {
cached
.get_or_compute(|| async {
panic!("Entered computation when another inflight computation should already be running");
})
.await
})
};
while cached.inflight_waiting_count() < 2 {
tokio::task::yield_now().await;
}
dropping_handle.abort();
assert!(dropping_handle.await.unwrap_err().is_cancelled());
assert!(matches!(waiting_handle.await, Ok(Err(Error::Broadcast(_)))));
assert_eq!(cached.get(), None);
assert_eq!(cached.get_or_compute(|| async { Ok(3) }).await, Ok(3));
assert_eq!(cached.get(), Some(3));
}
#[tokio::test]
async fn test_get_or_subscribe() {
let cached = Cached::<_, ()>::new();
assert_eq!(cached.get_or_subscribe().await, None);
assert_eq!(cached.get_or_compute(|| async { Ok(0) }).await, Ok(0));
assert_eq!(cached.get_or_subscribe().await, Some(Ok(0)));
cached.invalidate();
let (tokio_notify, handle) = setup_inflight_request(Cached::clone(&cached), Ok(30)).await;
assert!(cached.is_inflight());
let get_or_subscribe_handle = {
let cached = Cached::clone(&cached);
tokio::spawn(async move { cached.get_or_subscribe().await })
};
tokio_notify.notify_waiters();
assert_eq!(handle.await.unwrap(), Ok(30));
assert_eq!(get_or_subscribe_handle.await.unwrap(), Some(Ok(30)));
assert_eq!(cached.get(), Some(30));
}
#[tokio::test]
async fn test_subscribe_or_recompute() {
let cached = Cached::new();
assert_eq!(
cached.subscribe_or_recompute(|| async { Err(()) }).await,
(None, Err(Error::Computation(()))),
);
assert_eq!(cached.get(), None);
assert_eq!(
cached.subscribe_or_recompute(|| async { Ok(0) }).await,
(None, Ok(0)),
);
assert_eq!(cached.get(), Some(0));
assert_eq!(
cached.subscribe_or_recompute(|| async { Ok(30) }).await,
(Some(0), Ok(30)),
);
assert_eq!(cached.get(), Some(30));
assert_eq!(
cached.subscribe_or_recompute(|| async { Err(()) }).await,
(Some(30), Err(Error::Computation(()))),
);
assert_eq!(cached.get(), None);
let (notify, handle) = setup_inflight_request(Cached::clone(&cached), Ok(12)).await;
let second_handle = {
let cached = Cached::clone(&cached);
tokio::spawn(async move {
cached
.subscribe_or_recompute(|| async {
panic!("Shouldn't execute, already inflight")
})
.await
})
};
notify.notify_waiters();
assert_eq!(handle.await.unwrap(), Ok(12));
assert_eq!(second_handle.await.unwrap(), (None, Ok(12)));
assert_eq!(cached.get(), Some(12));
}
#[tokio::test]
async fn test_force_recompute() {
let cached = Cached::<_, ()>::new();
assert_eq!(
cached.force_recompute(async { Err(()) }).await,
(CachedState::EmptyCache, Err(Error::Computation(()))),
);
assert_eq!(cached.get(), None);
assert_eq!(
cached.force_recompute(async { Ok(0) }).await,
(CachedState::EmptyCache, Ok(0))
);
assert_eq!(cached.get(), Some(0));
assert_eq!(
cached.force_recompute(async { Ok(15) }).await,
(CachedState::ValueCached(0), Ok(15)),
);
assert_eq!(cached.get(), Some(15));
assert_eq!(
cached.force_recompute(async { Err(()) }).await,
(CachedState::ValueCached(15), Err(Error::Computation(()))),
);
assert_eq!(cached.get(), None);
let (_notify, handle) = setup_inflight_request(Cached::clone(&cached), Ok(0)).await;
assert_eq!(
cached.force_recompute(async { Ok(21) }).await,
(CachedState::Inflight, Ok(21))
);
assert!(matches!(handle.await.unwrap(), Err(Error::Aborted(_))));
assert_eq!(cached.get(), Some(21));
}
#[tokio::test]
async fn test_abort() {
let cached = Cached::<_, ()>::new();
assert!(!cached.abort());
assert_eq!(cached.get(), None);
let (_notify, handle) = setup_inflight_request(Cached::clone(&cached), Ok(0)).await;
assert!(cached.abort());
assert!(!cached.is_inflight());
assert!(matches!(handle.await.unwrap(), Err(Error::Aborted(_))));
assert_eq!(cached.get(), None);
assert_eq!(cached.inflight_waiting_count(), 0);
}
async fn setup_inflight_request<T, E>(
cached: Cached<T, E>,
result: Result<T, E>,
) -> (Arc<Notify>, JoinHandle<Result<T, Error<E>>>)
where
T: Clone + Send + 'static,
E: Clone + Send + 'static,
{
assert!(!cached.is_inflight());
assert!(!cached.is_value_cached());
let tokio_notify = Arc::new(Notify::new());
let registered = Arc::new(Notify::new());
let registered_fut = registered.notified();
let handle = {
let tokio_notify = Arc::clone(&tokio_notify);
let registered = Arc::clone(®istered);
let cached = Cached::clone(&cached);
tokio::spawn(async move {
cached
.get_or_compute(|| async move {
let notified_fut = tokio_notify.notified();
registered.notify_waiters();
notified_fut.await;
result
})
.await
})
};
registered_fut.await;
(tokio_notify, handle)
}
}