cache_compute/
lib.rs

1//! This crate implements request/async computation coalescing.
2//!
3//! The starting point for this implementation was fasterthanlime's excellent [article on request coalescing in async rust](https://fasterthanli.me/articles/request-coalescing-in-async-rust).
4//!
5//! Caching of async computations can be a bit of a tough problem.
6//! If no cached value is available when we need it, we would want to compute it, often asynchronously.
7//! This crate helps ensure that this computation doesn't happen more than it needs to
8//! by avoiding starting new computations when one is already happening.
9//! Instead, we will subscribe to that computation and work with the result of it as well.
10//!
11//! # Example
12//!
13//! ```
14//! # fn answer_too_old() -> bool { true }
15//! # fn refresh_answer_timer() {}
16//! use cache_compute::Cached;
17//!
18//! pub async fn get_answer(cached_answer: Cached<u32, ()>) -> u32 {
19//!     if answer_too_old() {
20//!         cached_answer.invalidate();
21//!     }
22//!
23//!     cached_answer.get_or_compute(|| async {
24//!         // Really long async computation
25//!         // Phew the computer and network sure need a lot of time to work on this
26//!         // Good thing we cache it
27//!         // ...
28//!         // Ok done
29//!         // Other calls to get_answer will now also use that same value
30//!         // without having to compute it, until it's too old again
31//!         refresh_answer_timer();
32//!         Ok(42)
33//!     })
34//!     .await
35//!     .unwrap()
36//! }
37//! ```
38
39#![warn(clippy::pedantic)]
40#![warn(clippy::cargo)]
41#![warn(
42    missing_docs,
43    rustdoc::missing_crate_level_docs,
44    rustdoc::private_doc_tests
45)]
46#![deny(
47    rustdoc::broken_intra_doc_links,
48    rustdoc::private_intra_doc_links,
49    rustdoc::invalid_codeblock_attributes,
50    rustdoc::invalid_rust_codeblocks
51)]
52#![forbid(unsafe_code)]
53
54use std::fmt::Debug;
55use std::future::Future;
56use std::sync::{Arc, Weak};
57
58use futures::stream::{AbortHandle, Abortable, Aborted};
59use parking_lot::{Mutex, MutexGuard};
60use thiserror::Error;
61use tokio::sync::broadcast::error::RecvError;
62use tokio::sync::broadcast::{self, Receiver, Sender};
63
64// TODO: More sane struct/impl ordering
65
66/// The error type for [`Cached`].
67///
68/// `E` specifies the error the computation may return.
69#[derive(Debug, PartialEq, Error, Clone)]
70pub enum Error<E> {
71    /// Notifying the other waiters failed with a [`RecvError`].
72    /// Either the inflight computation panicked or the [`Future`] returned by `get_or_compute` was dropped/canceled.
73    #[error("The computation for get_or_compute panicked or the Future returned by get_or_compute was dropped: {0}")]
74    Broadcast(#[from] RecvError),
75    /// The inflight computation returned an error value.
76    #[error("Inflight computation returned error value: {0}")]
77    Computation(E),
78    /// The inflight computation was aborted
79    #[error("Inflight computation was aborted")]
80    Aborted(#[from] Aborted),
81}
82
83/// The main struct implementing the async computation coalescing.
84///
85/// `T` is the value type and `E` is the error type of the computation.
86///
87/// A [`Cached`] computation is in one of three states:
88/// - There is no cached value and no inflight computation is happening
89/// - There is a cached value and no inflight computation is happening
90/// - There is no cached value, but an inflight computation is currently computing one
91///
92/// The [`Cached`] instance can be shared via cloning as it uses an [`Arc`] internally.
93///
94/// [`Cached::get_or_compute`] will
95/// - Start a new inflight computation if there is no cached value and no inflight computation is happening
96/// - Return the cached value immediately if there is a cached value available
97/// - Subscribe to an inflight computation if there is one happening and return the result of that when it concludes
98///
99/// The cache can be invalidated using [`Cached::invalidate`]
100///
101/// The instances of `T` and `E` are cloned for every time a user requests a value or gets handed an error `E`.
102/// Thus, consider using an [`Arc`] for expensive to clone variants of `T` and `E`.
103///
104/// The cached value is stored on the stack, so you may want to consider using a [`Box`] for large `T`.
105///
106/// [`Box`]: std::boxed::Box
107#[derive(Debug, Default)]
108pub struct Cached<T, E> {
109    inner: Arc<Mutex<CachedInner<T, E>>>,
110}
111
112impl<T, E> Clone for Cached<T, E> {
113    fn clone(&self) -> Self {
114        Self {
115            inner: Arc::clone(&self.inner),
116        }
117    }
118}
119
120/// An enum representing the state of an instance of [`Cached`], returned by [`Cached::force_recompute`].
121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum CachedState<T> {
123    /// The cache is empty and there is no inflight computation happening.
124    EmptyCache,
125    /// A cached value is present.
126    ValueCached(T),
127    /// An inflight computation is currently happening.
128    Inflight,
129}
130
131impl<T> CachedState<T> {
132    /// Returns `true` iff there is an inflight computation happening.
133    #[must_use]
134    pub fn is_inflight(&self) -> bool {
135        matches!(self, CachedState::Inflight)
136    }
137
138    /// Returns the value in the cache immediately if present.
139    #[must_use]
140    pub fn get(&self) -> Option<&T> {
141        if let CachedState::ValueCached(val) = self {
142            Some(val)
143        } else {
144            None
145        }
146    }
147
148    /// Returns the value in the cache immediately if present.
149    #[must_use]
150    pub fn get_mut(&mut self) -> Option<&mut T> {
151        if let CachedState::ValueCached(val) = self {
152            Some(val)
153        } else {
154            None
155        }
156    }
157}
158
159type InflightComputation<T, E> = (AbortHandle, Sender<Result<T, Error<E>>>);
160
161#[derive(Clone, Debug)]
162enum CachedInner<T, E> {
163    CachedValue(T),
164    EmptyOrInflight(Weak<InflightComputation<T, E>>),
165}
166
167impl<T, E> Default for CachedInner<T, E> {
168    fn default() -> Self {
169        CachedInner::new()
170    }
171}
172
173impl<T, E> CachedInner<T, E> {
174    #[must_use]
175    fn new() -> Self {
176        CachedInner::EmptyOrInflight(Weak::new())
177    }
178
179    #[must_use]
180    fn new_with_value(value: T) -> Self {
181        CachedInner::CachedValue(value)
182    }
183
184    fn invalidate(&mut self) -> Option<T> {
185        if matches!(self, CachedInner::EmptyOrInflight(_)) {
186            None
187        } else if let CachedInner::CachedValue(value) = std::mem::take(self) {
188            Some(value)
189        } else {
190            unreachable!()
191        }
192    }
193
194    fn is_inflight(&self) -> bool {
195        self.inflight_weak()
196            .map_or(false, |weak| weak.strong_count() > 0)
197    }
198
199    fn inflight_waiting_count(&self) -> usize {
200        self.inflight_arc()
201            .map_or(0, |arc| arc.1.receiver_count() + 1)
202    }
203
204    fn abort(&mut self) -> bool {
205        if let Some(arc) = self.inflight_arc() {
206            arc.0.abort();
207
208            // Immediately enter no inflight state
209            *self = CachedInner::new();
210
211            true
212        } else {
213            false
214        }
215    }
216
217    #[must_use]
218    fn is_value_cached(&self) -> bool {
219        matches!(self, CachedInner::CachedValue(_))
220    }
221
222    #[must_use]
223    fn inflight_weak(&self) -> Option<&Weak<InflightComputation<T, E>>> {
224        if let CachedInner::EmptyOrInflight(weak) = self {
225            Some(weak)
226        } else {
227            None
228        }
229    }
230
231    #[must_use]
232    fn inflight_arc(&self) -> Option<Arc<InflightComputation<T, E>>> {
233        self.inflight_weak().and_then(Weak::upgrade)
234    }
235
236    #[must_use]
237    fn get(&self) -> Option<&T> {
238        if let CachedInner::CachedValue(value) = self {
239            Some(value)
240        } else {
241            None
242        }
243    }
244
245    #[must_use]
246    fn get_receiver(&self) -> Option<Receiver<Result<T, Error<E>>>> {
247        self.inflight_arc().map(|arc| arc.1.subscribe())
248    }
249}
250
251impl<T, E> Cached<T, E> {
252    /// Creates a new instance with no cached value present.
253    #[must_use]
254    pub fn new() -> Self {
255        Self {
256            inner: Arc::new(Mutex::new(CachedInner::new())),
257        }
258    }
259
260    /// Creates a new instance with the given value in the cache.
261    #[must_use]
262    pub fn new_with_value(value: T) -> Self {
263        Cached {
264            inner: Arc::new(Mutex::new(CachedInner::new_with_value(value))),
265        }
266    }
267
268    /// Invalidates the cache immediately, returning its value without cloning if present.
269    #[allow(clippy::must_use_candidate)]
270    pub fn invalidate(&self) -> Option<T> {
271        self.inner.lock().invalidate()
272    }
273
274    /// Returns `true` iff there is an inflight computation happening.
275    #[must_use]
276    pub fn is_inflight(&self) -> bool {
277        self.inner.lock().is_inflight()
278    }
279
280    /// Returns the amount of instances waiting on an inflight computation, including the instance that started the computation.
281    #[must_use]
282    pub fn inflight_waiting_count(&self) -> usize {
283        self.inner.lock().inflight_waiting_count()
284    }
285
286    /// Aborts the current inflight computation.
287    /// Returns `true` iff there was an inflight computation to abort.
288    ///
289    /// After this function returns, the instance will *immediately* act like there is no inflight computation happening.
290    /// However, it might still take some time until the actual inflight computation finishes aborting.
291    #[allow(clippy::must_use_candidate)]
292    pub fn abort(&self) -> bool {
293        self.inner.lock().abort()
294    }
295
296    /// Returns `true` iff a value is currently cached.
297    #[must_use]
298    pub fn is_value_cached(&self) -> bool {
299        self.inner.lock().is_value_cached()
300    }
301}
302
303impl<T: Clone, E> Cached<T, E> {
304    /// Returns the value of the cache immediately if present, cloning the value.
305    #[must_use]
306    pub fn get(&self) -> Option<T> {
307        self.inner.lock().get().cloned()
308    }
309}
310
311enum GetOrSubscribeResult<'a, T, E> {
312    Success(Result<T, Error<E>>),
313    FailureKeepLock(MutexGuard<'a, CachedInner<T, E>>),
314}
315
316impl<T, E> Cached<T, E>
317where
318    T: Clone,
319    E: Clone,
320{
321    /// This function will
322    /// - Execute `computation` and the [`Future`] it returns if there is no cached value and no inflight computation is happening,
323    /// starting a new inflight computation and returning the result of that
324    /// - Not do anything with `computation` and return the cached value immediately if there is a cached value available
325    /// - Not do anything with `computation` and subscribe to an inflight computation if there is one happening and return the result of that when it concludes
326    ///
327    /// Note that the [`Future`] `computation` returns will *not* be executed via [`tokio::spawn`] or similar, but rather will become part of the [`Future`]
328    /// this function returns.
329    /// This means it does not need to be [`Send`].
330    ///
331    /// # Errors
332    ///
333    /// If the inflight computation this function subscribed to or started returns an error,
334    /// that error is cloned and returned by this function in an [`Error::Computation`].
335    ///
336    /// If this function does not start a computation, but subscribes to a computation which panics or gets dropped/cancelled,
337    /// it will return an [`Error::Broadcast`].
338    ///
339    /// If this function starts a computation or subscribes to a computation that gets aborted with [`Cached::abort`],
340    /// it will return an [`Error::Aborted`].
341    ///
342    /// # Panics
343    ///
344    /// This function panics if `computation` gets executed and panics, or if the [`Future`] returned by `computation` panics.
345    #[allow(clippy::await_holding_lock)] // Clippy you're literally wrong we're moving it before the await
346    pub async fn get_or_compute<Fut>(
347        &self,
348        computation: impl FnOnce() -> Fut,
349    ) -> Result<T, Error<E>>
350    where
351        Fut: Future<Output = Result<T, E>>,
352    {
353        let inner = match self.get_or_subscribe_keep_lock().await {
354            GetOrSubscribeResult::Success(res) => return res,
355            GetOrSubscribeResult::FailureKeepLock(lock) => lock,
356        };
357
358        // Neither cached nor inflight so this is safe to unwrap
359        self.compute_with_lock(computation, inner).await.unwrap()
360    }
361
362    /// This function will
363    /// - Return immediately with the cached value if a cached value is present
364    /// - Return `None` immediately if no cached value is present and no inflight computation is happening
365    /// - Subscribe to an inflight computation if there is one happening and return the result of that when it concludes
366    ///
367    /// # Errors
368    ///
369    /// If the inflight computation this function subscribed to returns an error,
370    /// that error is cloned and returned by this function in an [`Error::Computation`].
371    ///
372    /// If this function subscribes to a computation which panics or gets dropped/cancelled,
373    /// it will return an [`Error::Broadcast`].
374    ///
375    /// If this function subscribes to a computation that gets aborted with [`Cached::abort`],
376    /// it will return an [`Error::Aborted`].
377    pub async fn get_or_subscribe(&self) -> Option<Result<T, Error<E>>> {
378        if let GetOrSubscribeResult::Success(res) = self.get_or_subscribe_keep_lock().await {
379            Some(res)
380        } else {
381            None
382        }
383    }
384
385    /// This function will
386    /// - Invalidate the cache and execute `computation` and the [`Future`] it returns if no inflight computation is happening,
387    /// starting a new inflight computation and returning the result of that
388    /// - Subscribe to an inflight computation if there is one happening and return the result of that when it concludes
389    ///
390    /// Note that after calling this function, the cache will *always* be empty, even if the computation results in an error.
391    ///
392    /// This function will return the previously cached value as well as the result of the computation it starts or subscribes to.
393    ///
394    /// # Errors
395    ///
396    /// If the inflight computation this function starts or subscribes to returns an error,
397    /// that error is cloned and returned by this function in an [`Error::Computation`].
398    ///
399    /// If this function subscribes to a computation which panics or gets dropped/cancelled,
400    /// it will return an [`Error::Broadcast`].
401    ///
402    /// If this function subscribes to or starts a computation that gets aborted with [`Cached::abort`],
403    /// it will return an [`Error::Aborted`].
404    ///
405    /// # Panics
406    ///
407    /// This function panics if `computation` gets executed and panics, or if the [`Future`] returned by `computation` panics.
408    #[allow(clippy::await_holding_lock)] // Clippy you're literally wrong we're dropping/moving it before the await
409    pub async fn subscribe_or_recompute<Fut>(
410        &self,
411        computation: impl FnOnce() -> Fut,
412    ) -> (Option<T>, Result<T, Error<E>>)
413    where
414        Fut: Future<Output = Result<T, E>>,
415    {
416        let mut inner = self.inner.lock();
417
418        if let Some(mut receiver) = inner.get_receiver() {
419            drop(inner);
420
421            // Lock is dropped so async is legal again :)
422            (
423                None,
424                match receiver.recv().await {
425                    Err(why) => Err(Error::from(why)),
426                    Ok(res) => res,
427                },
428            )
429        } else {
430            let prev = inner.invalidate();
431
432            // Neither cached nor inflight, so unwrap is fine
433            let result = self.compute_with_lock(computation, inner).await.unwrap();
434
435            (prev, result)
436        }
437    }
438
439    /// This function will invalidate the cache, potentially abort the inflight request if one is happening, and start a new inflight computation, returning the result of that.
440    ///
441    /// It will return the previous [`CachedState`] as well as the result of the computation it starts.
442    ///
443    /// # Errors
444    ///
445    /// If the inflight computation this function starts returns an error,
446    /// that error is cloned and returned by this function in an [`Error::Computation`].
447    ///
448    /// If this function starts a computation which panics or gets dropped/cancelled,
449    /// it will return an [`Error::Broadcast`].
450    ///
451    /// If this function starts a computation that gets aborted with [`Cached::abort`],
452    /// it will return an [`Error::Aborted`].
453    ///
454    /// # Panics
455    ///
456    /// This function panics if `computation` or the [`Future`] returned by `computation` panics.
457    #[allow(clippy::await_holding_lock)] // Clippy you're literally wrong we're moving it before the await
458    pub async fn force_recompute<Fut>(
459        &self,
460        computation: Fut,
461    ) -> (CachedState<T>, Result<T, Error<E>>)
462    where
463        Fut: Future<Output = Result<T, E>>,
464    {
465        let mut inner = self.inner.lock();
466
467        let aborted = inner.abort();
468        let prev_cache = inner.invalidate();
469
470        let prev_state = match (aborted, prev_cache) {
471            (false, None) => CachedState::EmptyCache,
472            (false, Some(val)) => CachedState::ValueCached(val),
473            (true, None) => CachedState::Inflight,
474            (true, Some(_)) => unreachable!(),
475        };
476
477        // Neither cached nor inflight at this point, so safe to unwrap here
478        let result = self.compute_with_lock(|| computation, inner).await.unwrap();
479
480        (prev_state, result)
481    }
482
483    /// Like [`Cached::get_or_subscribe`], but keeps and returns the lock the function used iff nothing is cached and no inflight computation is present.
484    /// This allows [`Cached::get_or_compute`] to re-use that same lock to set up the computation without creating a race condition.
485    #[allow(clippy::await_holding_lock)] // Clippy you're literally wrong we're dropping it before the await
486    async fn get_or_subscribe_keep_lock(&self) -> GetOrSubscribeResult<'_, T, E> {
487        // Only sync code in this block
488        let inner = self.inner.lock();
489
490        // Return cached if available
491        if let CachedInner::CachedValue(value) = &*inner {
492            return GetOrSubscribeResult::Success(Ok(value.clone()));
493        }
494
495        let Some(mut receiver) = inner.get_receiver() else {
496            return GetOrSubscribeResult::FailureKeepLock(inner);
497        };
498
499        drop(inner);
500
501        let result = receiver.recv().await;
502
503        GetOrSubscribeResult::Success(match result {
504            Err(why) => Err(Error::from(why)),
505            Ok(res) => res,
506        })
507    }
508
509    /// Doesn't execute `computation` and returns [`None`] if a cached value is present or an inflight computation is already happening.
510    #[allow(clippy::await_holding_lock)] // Clippy you're literally wrong we're dropping it before the await
511    async fn compute_with_lock<'a, Fut>(
512        &'a self,
513        computation: impl FnOnce() -> Fut,
514        mut inner: MutexGuard<'a, CachedInner<T, E>>,
515    ) -> Option<Result<T, Error<E>>>
516    where
517        Fut: Future<Output = Result<T, E>>,
518    {
519        // Check that no value is cached and no computation is happening
520        if inner.is_value_cached() || inner.is_inflight() {
521            return None;
522        }
523
524        // Neither cached nor inflight, so compute
525        // Underscore binding drops immediately, which is important for the receiver count
526        let (tx, _) = broadcast::channel(1);
527
528        let (abort_handle, abort_registration) = AbortHandle::new_pair();
529
530        let arc = Arc::new((abort_handle, tx));
531
532        // In case we panic or get aborted, have way for receivers to notice (via the Weak getting dropped)
533        *inner = CachedInner::EmptyOrInflight(Arc::downgrade(&arc));
534
535        // Release lock so we can do async computation
536        drop(inner);
537
538        // Run the computation
539        let future = computation();
540
541        let res = match Abortable::new(future, abort_registration).await {
542            Ok(res) => res.map_err(Error::Computation),
543            Err(aborted) => Err(Error::from(aborted)),
544        };
545
546        {
547            // Only sync code in this block
548            let mut inner = self.inner.lock();
549
550            if !matches!(res, Err(Error::Aborted(_))) {
551                // If we aborted, we have to leave inner as is
552                // Otherwise big races come up as the next inflight computation might already be underway at this point
553                if let Ok(value) = &res {
554                    *inner = CachedInner::CachedValue(value.clone());
555                } else {
556                    *inner = CachedInner::new();
557                }
558            }
559        }
560
561        // Only clone if we have receivers
562        // This is not a race condition because after inner gets assigned above (or if the request has been aborted),
563        // this Arc will be inaccessible from the struct and no new receivers can subscribe
564        if arc.1.receiver_count() > 0 {
565            // That being said, others might still *un*subscribe after the if, so we cannot unwrap here
566            arc.1.send(res.clone()).ok();
567        }
568
569        Some(res)
570    }
571}
572
573#[cfg(test)]
574mod test {
575    use std::sync::Arc;
576    use std::time::Duration;
577    use tokio::sync::Notify;
578    use tokio::task::JoinHandle;
579
580    use crate::CachedState;
581
582    use super::{Cached, Error};
583
584    #[tokio::test]
585    async fn test_cached() {
586        let cached = Cached::<_, ()>::new_with_value(12);
587        assert_eq!(cached.get(), Some(12));
588        assert!(!cached.is_inflight());
589        assert!(cached.is_value_cached());
590        assert_eq!(cached.inflight_waiting_count(), 0);
591
592        let cached = Cached::new();
593        assert_eq!(cached.get(), None);
594        assert!(!cached.is_inflight());
595        assert!(!cached.is_value_cached());
596        assert_eq!(cached.inflight_waiting_count(), 0);
597
598        assert_eq!(cached.get_or_compute(|| async { Ok(12) }).await, Ok(12));
599        assert_eq!(cached.get(), Some(12));
600
601        assert_eq!(cached.invalidate(), Some(12));
602        assert_eq!(cached.get(), None);
603        assert_eq!(cached.invalidate(), None);
604
605        assert_eq!(
606            cached.get_or_compute(|| async { Err(42) }).await,
607            Err(Error::Computation(42)),
608        );
609        assert_eq!(cached.get(), None);
610
611        assert_eq!(cached.get_or_compute(|| async { Ok(1) }).await, Ok(1));
612        assert_eq!(cached.get(), Some(1));
613        assert_eq!(cached.get_or_compute(|| async { Ok(32) }).await, Ok(1));
614
615        assert_eq!(cached.invalidate(), Some(1));
616
617        let (tokio_notify, handle) = setup_inflight_request(Cached::clone(&cached), Ok(30)).await;
618
619        assert_eq!(cached.get(), None);
620
621        // We also know we're inflight right now
622        assert!(cached.is_inflight());
623        assert_eq!(cached.inflight_waiting_count(), 1);
624
625        let other_handle = {
626            let cached = Cached::clone(&cached);
627
628            tokio::spawn(async move { cached.get_or_compute(|| async move { Ok(24) }).await })
629        };
630
631        tokio_notify.notify_waiters();
632
633        assert_eq!(handle.await.unwrap(), Ok(30));
634        assert_eq!(other_handle.await.unwrap(), Ok(30));
635        assert_eq!(cached.get(), Some(30));
636    }
637
638    #[tokio::test]
639    async fn test_computation_panic() {
640        let cached = Cached::<_, ()>::new();
641
642        // Panic during computation of Future
643        let is_panic = {
644            let cached = Cached::clone(&cached);
645
646            tokio::spawn(async move {
647                cached
648                    .get_or_compute(|| {
649                        panic!("Panic in computation");
650                        #[allow(unreachable_code)]
651                        async {
652                            unreachable!()
653                        }
654                    })
655                    .await
656            })
657        }
658        .await
659        .expect_err("Should panic")
660        .is_panic();
661
662        assert!(is_panic, "Should panic");
663
664        assert_eq!(cached.get(), None);
665        assert!(!cached.is_inflight());
666        assert_eq!(cached.inflight_waiting_count(), 0);
667
668        assert_eq!(
669            cached.get_or_compute(|| async move { Ok(21) }).await,
670            Ok(21),
671        );
672
673        // Panic in Future
674        assert_eq!(cached.invalidate(), Some(21));
675
676        let is_panic = {
677            let cached = Cached::clone(&cached);
678
679            tokio::spawn(async move {
680                cached
681                    .get_or_compute(|| async { panic!("Panic in future") })
682                    .await
683            })
684        }
685        .await
686        .expect_err("Should be panic")
687        .is_panic();
688
689        assert!(is_panic, "Should panic");
690
691        assert_eq!(cached.get(), None);
692        assert!(!cached.is_inflight());
693        assert_eq!(cached.inflight_waiting_count(), 0);
694
695        assert_eq!(
696            cached.get_or_compute(|| async move { Ok(17) }).await,
697            Ok(17),
698        );
699
700        // Panic in Future while others are waiting for inflight
701        assert_eq!(cached.invalidate(), Some(17));
702
703        let tokio_notify = Arc::new(Notify::new());
704        let registered = Arc::new(Notify::new());
705        let registered_fut = registered.notified();
706
707        let panicking_handle = {
708            let cached = Cached::clone(&cached);
709            let tokio_notify = Arc::clone(&tokio_notify);
710            let registered = Arc::clone(&registered);
711
712            tokio::spawn(async move {
713                cached
714                    .get_or_compute(|| async move {
715                        let notify_fut = tokio_notify.notified();
716                        registered.notify_waiters();
717                        notify_fut.await;
718                        panic!("Panic in future")
719                    })
720                    .await
721            })
722        };
723
724        // Make sure the notify is already registered and we're already computing
725        registered_fut.await;
726
727        let waiting_handle = {
728            let cached = Cached::clone(&cached);
729
730            tokio::spawn(async move {
731                cached
732                    .get_or_compute(|| async {
733                        panic!("Entered computation when another inflight computation should already be running")
734                    })
735                    .await
736            })
737        };
738
739        // Wait a bit for the waiting task to actually wait on rx
740        while cached.inflight_waiting_count() < 2 {
741            tokio::task::yield_now().await;
742        }
743
744        // Cause panic
745        tokio_notify.notify_waiters();
746
747        assert!(panicking_handle.await.unwrap_err().is_panic());
748        assert!(matches!(waiting_handle.await, Ok(Err(Error::Broadcast(_)))));
749        assert_eq!(cached.get(), None);
750    }
751
752    #[tokio::test]
753    async fn test_computation_drop() {
754        let cached = Cached::<_, ()>::new();
755
756        // Drop the Future while others are waiting for inflight
757        let computing = Arc::new(Notify::new());
758        let computing_fut = computing.notified();
759
760        let dropping_handle = {
761            let cached = Cached::clone(&cached);
762            let computing = Arc::clone(&computing);
763
764            tokio::spawn(async move {
765                cached
766                    .get_or_compute(|| async move {
767                        computing.notify_waiters();
768                        loop {
769                            tokio::time::sleep(Duration::from_secs(1)).await;
770                        }
771                    })
772                    .await
773            })
774        };
775
776        // Make sure we're already computing
777        computing_fut.await;
778
779        let waiting_handle = {
780            let cached = Cached::clone(&cached);
781
782            tokio::spawn(async move {
783                cached
784                    .get_or_compute(|| async {
785                        panic!("Entered computation when another inflight computation should already be running");
786                    })
787                    .await
788            })
789        };
790
791        // Wait a bit for the waiting task to actually wait on rx
792        while cached.inflight_waiting_count() < 2 {
793            tokio::task::yield_now().await;
794        }
795
796        // Drop future
797        dropping_handle.abort();
798
799        assert!(dropping_handle.await.unwrap_err().is_cancelled());
800        assert!(matches!(waiting_handle.await, Ok(Err(Error::Broadcast(_)))));
801        assert_eq!(cached.get(), None);
802        // Make sure cached still works as intended
803        assert_eq!(cached.get_or_compute(|| async { Ok(3) }).await, Ok(3));
804        assert_eq!(cached.get(), Some(3));
805    }
806
807    #[tokio::test]
808    async fn test_get_or_subscribe() {
809        let cached = Cached::<_, ()>::new();
810
811        // Test empty cache
812        assert_eq!(cached.get_or_subscribe().await, None);
813
814        // Test cached
815        assert_eq!(cached.get_or_compute(|| async { Ok(0) }).await, Ok(0));
816        assert_eq!(cached.get_or_subscribe().await, Some(Ok(0)));
817
818        // Test inflight
819        cached.invalidate();
820
821        let (tokio_notify, handle) = setup_inflight_request(Cached::clone(&cached), Ok(30)).await;
822
823        // We know we're inflight right now
824        assert!(cached.is_inflight());
825
826        let get_or_subscribe_handle = {
827            let cached = Cached::clone(&cached);
828
829            tokio::spawn(async move { cached.get_or_subscribe().await })
830        };
831
832        // Complete original future, placing 30 in cache
833        tokio_notify.notify_waiters();
834
835        assert_eq!(handle.await.unwrap(), Ok(30));
836        assert_eq!(get_or_subscribe_handle.await.unwrap(), Some(Ok(30)));
837        assert_eq!(cached.get(), Some(30));
838    }
839
840    #[tokio::test]
841    async fn test_subscribe_or_recompute() {
842        let cached = Cached::new();
843
844        // Test empty cache
845        assert_eq!(
846            cached.subscribe_or_recompute(|| async { Err(()) }).await,
847            (None, Err(Error::Computation(()))),
848        );
849        assert_eq!(cached.get(), None);
850
851        assert_eq!(
852            cached.subscribe_or_recompute(|| async { Ok(0) }).await,
853            (None, Ok(0)),
854        );
855        assert_eq!(cached.get(), Some(0));
856
857        // Test cached
858        assert_eq!(
859            cached.subscribe_or_recompute(|| async { Ok(30) }).await,
860            (Some(0), Ok(30)),
861        );
862        assert_eq!(cached.get(), Some(30));
863
864        // Error should still invalidate cache
865        assert_eq!(
866            cached.subscribe_or_recompute(|| async { Err(()) }).await,
867            (Some(30), Err(Error::Computation(()))),
868        );
869        assert_eq!(cached.get(), None);
870
871        // Test inflight
872        let (notify, handle) = setup_inflight_request(Cached::clone(&cached), Ok(12)).await;
873
874        let second_handle = {
875            let cached = Cached::clone(&cached);
876
877            tokio::spawn(async move {
878                cached
879                    .subscribe_or_recompute(|| async {
880                        panic!("Shouldn't execute, already inflight")
881                    })
882                    .await
883            })
884        };
885
886        notify.notify_waiters();
887
888        assert_eq!(handle.await.unwrap(), Ok(12));
889        assert_eq!(second_handle.await.unwrap(), (None, Ok(12)));
890        assert_eq!(cached.get(), Some(12));
891    }
892
893    #[tokio::test]
894    async fn test_force_recompute() {
895        let cached = Cached::<_, ()>::new();
896
897        // Test empty cache
898        assert_eq!(
899            cached.force_recompute(async { Err(()) }).await,
900            (CachedState::EmptyCache, Err(Error::Computation(()))),
901        );
902        assert_eq!(cached.get(), None);
903        assert_eq!(
904            cached.force_recompute(async { Ok(0) }).await,
905            (CachedState::EmptyCache, Ok(0))
906        );
907        assert_eq!(cached.get(), Some(0));
908
909        // Test cached
910        assert_eq!(
911            cached.force_recompute(async { Ok(15) }).await,
912            (CachedState::ValueCached(0), Ok(15)),
913        );
914        assert_eq!(cached.get(), Some(15));
915        // Error should still invalidate cache
916        assert_eq!(
917            cached.force_recompute(async { Err(()) }).await,
918            (CachedState::ValueCached(15), Err(Error::Computation(()))),
919        );
920        assert_eq!(cached.get(), None);
921
922        // Test inflight
923        let (_notify, handle) = setup_inflight_request(Cached::clone(&cached), Ok(0)).await;
924
925        assert_eq!(
926            cached.force_recompute(async { Ok(21) }).await,
927            (CachedState::Inflight, Ok(21))
928        );
929        assert!(matches!(handle.await.unwrap(), Err(Error::Aborted(_))));
930        assert_eq!(cached.get(), Some(21));
931    }
932
933    #[tokio::test]
934    async fn test_abort() {
935        let cached = Cached::<_, ()>::new();
936
937        // Test no inflight
938        assert!(!cached.abort());
939
940        // Test inflight
941        assert_eq!(cached.get(), None);
942        let (_notify, handle) = setup_inflight_request(Cached::clone(&cached), Ok(0)).await;
943
944        assert!(cached.abort());
945        assert!(!cached.is_inflight());
946
947        assert!(matches!(handle.await.unwrap(), Err(Error::Aborted(_))));
948        assert_eq!(cached.get(), None);
949        assert_eq!(cached.inflight_waiting_count(), 0);
950    }
951
952    /// After this function, `cached` will have an active inflight computation.
953    /// The computation will finish with `result` once the `notify_waiters` is called on the returned [`Notify`].
954    /// The computation can be joined with the returned `JoinHandle`.
955    ///
956    /// # Panics
957    ///
958    /// This function panics if `cached` is already in an inflight state or a cached value is available at the start. Please don't race that.
959    async fn setup_inflight_request<T, E>(
960        cached: Cached<T, E>,
961        result: Result<T, E>,
962    ) -> (Arc<Notify>, JoinHandle<Result<T, Error<E>>>)
963    where
964        T: Clone + Send + 'static,
965        E: Clone + Send + 'static,
966    {
967        assert!(!cached.is_inflight());
968        assert!(!cached.is_value_cached());
969
970        let tokio_notify = Arc::new(Notify::new());
971        let registered = Arc::new(Notify::new());
972        let registered_fut = registered.notified();
973
974        let handle = {
975            let tokio_notify = Arc::clone(&tokio_notify);
976            let registered = Arc::clone(&registered);
977            let cached = Cached::clone(&cached);
978
979            tokio::spawn(async move {
980                cached
981                    .get_or_compute(|| async move {
982                        let notified_fut = tokio_notify.notified();
983                        registered.notify_waiters();
984                        notified_fut.await;
985                        result
986                    })
987                    .await
988            })
989        };
990
991        // Wait until the tokio_notify is registered
992        registered_fut.await;
993
994        (tokio_notify, handle)
995    }
996}