1#![warn(clippy::pedantic)]
40#![warn(clippy::cargo)]
41#![warn(
42 missing_docs,
43 rustdoc::missing_crate_level_docs,
44 rustdoc::private_doc_tests
45)]
46#![deny(
47 rustdoc::broken_intra_doc_links,
48 rustdoc::private_intra_doc_links,
49 rustdoc::invalid_codeblock_attributes,
50 rustdoc::invalid_rust_codeblocks
51)]
52#![forbid(unsafe_code)]
53
54use std::fmt::Debug;
55use std::future::Future;
56use std::sync::{Arc, Weak};
57
58use futures::stream::{AbortHandle, Abortable, Aborted};
59use parking_lot::{Mutex, MutexGuard};
60use thiserror::Error;
61use tokio::sync::broadcast::error::RecvError;
62use tokio::sync::broadcast::{self, Receiver, Sender};
63
64#[derive(Debug, PartialEq, Error, Clone)]
70pub enum Error<E> {
71 #[error("The computation for get_or_compute panicked or the Future returned by get_or_compute was dropped: {0}")]
74 Broadcast(#[from] RecvError),
75 #[error("Inflight computation returned error value: {0}")]
77 Computation(E),
78 #[error("Inflight computation was aborted")]
80 Aborted(#[from] Aborted),
81}
82
83#[derive(Debug, Default)]
108pub struct Cached<T, E> {
109 inner: Arc<Mutex<CachedInner<T, E>>>,
110}
111
112impl<T, E> Clone for Cached<T, E> {
113 fn clone(&self) -> Self {
114 Self {
115 inner: Arc::clone(&self.inner),
116 }
117 }
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum CachedState<T> {
123 EmptyCache,
125 ValueCached(T),
127 Inflight,
129}
130
131impl<T> CachedState<T> {
132 #[must_use]
134 pub fn is_inflight(&self) -> bool {
135 matches!(self, CachedState::Inflight)
136 }
137
138 #[must_use]
140 pub fn get(&self) -> Option<&T> {
141 if let CachedState::ValueCached(val) = self {
142 Some(val)
143 } else {
144 None
145 }
146 }
147
148 #[must_use]
150 pub fn get_mut(&mut self) -> Option<&mut T> {
151 if let CachedState::ValueCached(val) = self {
152 Some(val)
153 } else {
154 None
155 }
156 }
157}
158
159type InflightComputation<T, E> = (AbortHandle, Sender<Result<T, Error<E>>>);
160
161#[derive(Clone, Debug)]
162enum CachedInner<T, E> {
163 CachedValue(T),
164 EmptyOrInflight(Weak<InflightComputation<T, E>>),
165}
166
167impl<T, E> Default for CachedInner<T, E> {
168 fn default() -> Self {
169 CachedInner::new()
170 }
171}
172
173impl<T, E> CachedInner<T, E> {
174 #[must_use]
175 fn new() -> Self {
176 CachedInner::EmptyOrInflight(Weak::new())
177 }
178
179 #[must_use]
180 fn new_with_value(value: T) -> Self {
181 CachedInner::CachedValue(value)
182 }
183
184 fn invalidate(&mut self) -> Option<T> {
185 if matches!(self, CachedInner::EmptyOrInflight(_)) {
186 None
187 } else if let CachedInner::CachedValue(value) = std::mem::take(self) {
188 Some(value)
189 } else {
190 unreachable!()
191 }
192 }
193
194 fn is_inflight(&self) -> bool {
195 self.inflight_weak()
196 .map_or(false, |weak| weak.strong_count() > 0)
197 }
198
199 fn inflight_waiting_count(&self) -> usize {
200 self.inflight_arc()
201 .map_or(0, |arc| arc.1.receiver_count() + 1)
202 }
203
204 fn abort(&mut self) -> bool {
205 if let Some(arc) = self.inflight_arc() {
206 arc.0.abort();
207
208 *self = CachedInner::new();
210
211 true
212 } else {
213 false
214 }
215 }
216
217 #[must_use]
218 fn is_value_cached(&self) -> bool {
219 matches!(self, CachedInner::CachedValue(_))
220 }
221
222 #[must_use]
223 fn inflight_weak(&self) -> Option<&Weak<InflightComputation<T, E>>> {
224 if let CachedInner::EmptyOrInflight(weak) = self {
225 Some(weak)
226 } else {
227 None
228 }
229 }
230
231 #[must_use]
232 fn inflight_arc(&self) -> Option<Arc<InflightComputation<T, E>>> {
233 self.inflight_weak().and_then(Weak::upgrade)
234 }
235
236 #[must_use]
237 fn get(&self) -> Option<&T> {
238 if let CachedInner::CachedValue(value) = self {
239 Some(value)
240 } else {
241 None
242 }
243 }
244
245 #[must_use]
246 fn get_receiver(&self) -> Option<Receiver<Result<T, Error<E>>>> {
247 self.inflight_arc().map(|arc| arc.1.subscribe())
248 }
249}
250
251impl<T, E> Cached<T, E> {
252 #[must_use]
254 pub fn new() -> Self {
255 Self {
256 inner: Arc::new(Mutex::new(CachedInner::new())),
257 }
258 }
259
260 #[must_use]
262 pub fn new_with_value(value: T) -> Self {
263 Cached {
264 inner: Arc::new(Mutex::new(CachedInner::new_with_value(value))),
265 }
266 }
267
268 #[allow(clippy::must_use_candidate)]
270 pub fn invalidate(&self) -> Option<T> {
271 self.inner.lock().invalidate()
272 }
273
274 #[must_use]
276 pub fn is_inflight(&self) -> bool {
277 self.inner.lock().is_inflight()
278 }
279
280 #[must_use]
282 pub fn inflight_waiting_count(&self) -> usize {
283 self.inner.lock().inflight_waiting_count()
284 }
285
286 #[allow(clippy::must_use_candidate)]
292 pub fn abort(&self) -> bool {
293 self.inner.lock().abort()
294 }
295
296 #[must_use]
298 pub fn is_value_cached(&self) -> bool {
299 self.inner.lock().is_value_cached()
300 }
301}
302
303impl<T: Clone, E> Cached<T, E> {
304 #[must_use]
306 pub fn get(&self) -> Option<T> {
307 self.inner.lock().get().cloned()
308 }
309}
310
311enum GetOrSubscribeResult<'a, T, E> {
312 Success(Result<T, Error<E>>),
313 FailureKeepLock(MutexGuard<'a, CachedInner<T, E>>),
314}
315
316impl<T, E> Cached<T, E>
317where
318 T: Clone,
319 E: Clone,
320{
321 #[allow(clippy::await_holding_lock)] pub async fn get_or_compute<Fut>(
347 &self,
348 computation: impl FnOnce() -> Fut,
349 ) -> Result<T, Error<E>>
350 where
351 Fut: Future<Output = Result<T, E>>,
352 {
353 let inner = match self.get_or_subscribe_keep_lock().await {
354 GetOrSubscribeResult::Success(res) => return res,
355 GetOrSubscribeResult::FailureKeepLock(lock) => lock,
356 };
357
358 self.compute_with_lock(computation, inner).await.unwrap()
360 }
361
362 pub async fn get_or_subscribe(&self) -> Option<Result<T, Error<E>>> {
378 if let GetOrSubscribeResult::Success(res) = self.get_or_subscribe_keep_lock().await {
379 Some(res)
380 } else {
381 None
382 }
383 }
384
385 #[allow(clippy::await_holding_lock)] pub async fn subscribe_or_recompute<Fut>(
410 &self,
411 computation: impl FnOnce() -> Fut,
412 ) -> (Option<T>, Result<T, Error<E>>)
413 where
414 Fut: Future<Output = Result<T, E>>,
415 {
416 let mut inner = self.inner.lock();
417
418 if let Some(mut receiver) = inner.get_receiver() {
419 drop(inner);
420
421 (
423 None,
424 match receiver.recv().await {
425 Err(why) => Err(Error::from(why)),
426 Ok(res) => res,
427 },
428 )
429 } else {
430 let prev = inner.invalidate();
431
432 let result = self.compute_with_lock(computation, inner).await.unwrap();
434
435 (prev, result)
436 }
437 }
438
439 #[allow(clippy::await_holding_lock)] pub async fn force_recompute<Fut>(
459 &self,
460 computation: Fut,
461 ) -> (CachedState<T>, Result<T, Error<E>>)
462 where
463 Fut: Future<Output = Result<T, E>>,
464 {
465 let mut inner = self.inner.lock();
466
467 let aborted = inner.abort();
468 let prev_cache = inner.invalidate();
469
470 let prev_state = match (aborted, prev_cache) {
471 (false, None) => CachedState::EmptyCache,
472 (false, Some(val)) => CachedState::ValueCached(val),
473 (true, None) => CachedState::Inflight,
474 (true, Some(_)) => unreachable!(),
475 };
476
477 let result = self.compute_with_lock(|| computation, inner).await.unwrap();
479
480 (prev_state, result)
481 }
482
483 #[allow(clippy::await_holding_lock)] async fn get_or_subscribe_keep_lock(&self) -> GetOrSubscribeResult<'_, T, E> {
487 let inner = self.inner.lock();
489
490 if let CachedInner::CachedValue(value) = &*inner {
492 return GetOrSubscribeResult::Success(Ok(value.clone()));
493 }
494
495 let Some(mut receiver) = inner.get_receiver() else {
496 return GetOrSubscribeResult::FailureKeepLock(inner);
497 };
498
499 drop(inner);
500
501 let result = receiver.recv().await;
502
503 GetOrSubscribeResult::Success(match result {
504 Err(why) => Err(Error::from(why)),
505 Ok(res) => res,
506 })
507 }
508
509 #[allow(clippy::await_holding_lock)] async fn compute_with_lock<'a, Fut>(
512 &'a self,
513 computation: impl FnOnce() -> Fut,
514 mut inner: MutexGuard<'a, CachedInner<T, E>>,
515 ) -> Option<Result<T, Error<E>>>
516 where
517 Fut: Future<Output = Result<T, E>>,
518 {
519 if inner.is_value_cached() || inner.is_inflight() {
521 return None;
522 }
523
524 let (tx, _) = broadcast::channel(1);
527
528 let (abort_handle, abort_registration) = AbortHandle::new_pair();
529
530 let arc = Arc::new((abort_handle, tx));
531
532 *inner = CachedInner::EmptyOrInflight(Arc::downgrade(&arc));
534
535 drop(inner);
537
538 let future = computation();
540
541 let res = match Abortable::new(future, abort_registration).await {
542 Ok(res) => res.map_err(Error::Computation),
543 Err(aborted) => Err(Error::from(aborted)),
544 };
545
546 {
547 let mut inner = self.inner.lock();
549
550 if !matches!(res, Err(Error::Aborted(_))) {
551 if let Ok(value) = &res {
554 *inner = CachedInner::CachedValue(value.clone());
555 } else {
556 *inner = CachedInner::new();
557 }
558 }
559 }
560
561 if arc.1.receiver_count() > 0 {
565 arc.1.send(res.clone()).ok();
567 }
568
569 Some(res)
570 }
571}
572
573#[cfg(test)]
574mod test {
575 use std::sync::Arc;
576 use std::time::Duration;
577 use tokio::sync::Notify;
578 use tokio::task::JoinHandle;
579
580 use crate::CachedState;
581
582 use super::{Cached, Error};
583
584 #[tokio::test]
585 async fn test_cached() {
586 let cached = Cached::<_, ()>::new_with_value(12);
587 assert_eq!(cached.get(), Some(12));
588 assert!(!cached.is_inflight());
589 assert!(cached.is_value_cached());
590 assert_eq!(cached.inflight_waiting_count(), 0);
591
592 let cached = Cached::new();
593 assert_eq!(cached.get(), None);
594 assert!(!cached.is_inflight());
595 assert!(!cached.is_value_cached());
596 assert_eq!(cached.inflight_waiting_count(), 0);
597
598 assert_eq!(cached.get_or_compute(|| async { Ok(12) }).await, Ok(12));
599 assert_eq!(cached.get(), Some(12));
600
601 assert_eq!(cached.invalidate(), Some(12));
602 assert_eq!(cached.get(), None);
603 assert_eq!(cached.invalidate(), None);
604
605 assert_eq!(
606 cached.get_or_compute(|| async { Err(42) }).await,
607 Err(Error::Computation(42)),
608 );
609 assert_eq!(cached.get(), None);
610
611 assert_eq!(cached.get_or_compute(|| async { Ok(1) }).await, Ok(1));
612 assert_eq!(cached.get(), Some(1));
613 assert_eq!(cached.get_or_compute(|| async { Ok(32) }).await, Ok(1));
614
615 assert_eq!(cached.invalidate(), Some(1));
616
617 let (tokio_notify, handle) = setup_inflight_request(Cached::clone(&cached), Ok(30)).await;
618
619 assert_eq!(cached.get(), None);
620
621 assert!(cached.is_inflight());
623 assert_eq!(cached.inflight_waiting_count(), 1);
624
625 let other_handle = {
626 let cached = Cached::clone(&cached);
627
628 tokio::spawn(async move { cached.get_or_compute(|| async move { Ok(24) }).await })
629 };
630
631 tokio_notify.notify_waiters();
632
633 assert_eq!(handle.await.unwrap(), Ok(30));
634 assert_eq!(other_handle.await.unwrap(), Ok(30));
635 assert_eq!(cached.get(), Some(30));
636 }
637
638 #[tokio::test]
639 async fn test_computation_panic() {
640 let cached = Cached::<_, ()>::new();
641
642 let is_panic = {
644 let cached = Cached::clone(&cached);
645
646 tokio::spawn(async move {
647 cached
648 .get_or_compute(|| {
649 panic!("Panic in computation");
650 #[allow(unreachable_code)]
651 async {
652 unreachable!()
653 }
654 })
655 .await
656 })
657 }
658 .await
659 .expect_err("Should panic")
660 .is_panic();
661
662 assert!(is_panic, "Should panic");
663
664 assert_eq!(cached.get(), None);
665 assert!(!cached.is_inflight());
666 assert_eq!(cached.inflight_waiting_count(), 0);
667
668 assert_eq!(
669 cached.get_or_compute(|| async move { Ok(21) }).await,
670 Ok(21),
671 );
672
673 assert_eq!(cached.invalidate(), Some(21));
675
676 let is_panic = {
677 let cached = Cached::clone(&cached);
678
679 tokio::spawn(async move {
680 cached
681 .get_or_compute(|| async { panic!("Panic in future") })
682 .await
683 })
684 }
685 .await
686 .expect_err("Should be panic")
687 .is_panic();
688
689 assert!(is_panic, "Should panic");
690
691 assert_eq!(cached.get(), None);
692 assert!(!cached.is_inflight());
693 assert_eq!(cached.inflight_waiting_count(), 0);
694
695 assert_eq!(
696 cached.get_or_compute(|| async move { Ok(17) }).await,
697 Ok(17),
698 );
699
700 assert_eq!(cached.invalidate(), Some(17));
702
703 let tokio_notify = Arc::new(Notify::new());
704 let registered = Arc::new(Notify::new());
705 let registered_fut = registered.notified();
706
707 let panicking_handle = {
708 let cached = Cached::clone(&cached);
709 let tokio_notify = Arc::clone(&tokio_notify);
710 let registered = Arc::clone(®istered);
711
712 tokio::spawn(async move {
713 cached
714 .get_or_compute(|| async move {
715 let notify_fut = tokio_notify.notified();
716 registered.notify_waiters();
717 notify_fut.await;
718 panic!("Panic in future")
719 })
720 .await
721 })
722 };
723
724 registered_fut.await;
726
727 let waiting_handle = {
728 let cached = Cached::clone(&cached);
729
730 tokio::spawn(async move {
731 cached
732 .get_or_compute(|| async {
733 panic!("Entered computation when another inflight computation should already be running")
734 })
735 .await
736 })
737 };
738
739 while cached.inflight_waiting_count() < 2 {
741 tokio::task::yield_now().await;
742 }
743
744 tokio_notify.notify_waiters();
746
747 assert!(panicking_handle.await.unwrap_err().is_panic());
748 assert!(matches!(waiting_handle.await, Ok(Err(Error::Broadcast(_)))));
749 assert_eq!(cached.get(), None);
750 }
751
752 #[tokio::test]
753 async fn test_computation_drop() {
754 let cached = Cached::<_, ()>::new();
755
756 let computing = Arc::new(Notify::new());
758 let computing_fut = computing.notified();
759
760 let dropping_handle = {
761 let cached = Cached::clone(&cached);
762 let computing = Arc::clone(&computing);
763
764 tokio::spawn(async move {
765 cached
766 .get_or_compute(|| async move {
767 computing.notify_waiters();
768 loop {
769 tokio::time::sleep(Duration::from_secs(1)).await;
770 }
771 })
772 .await
773 })
774 };
775
776 computing_fut.await;
778
779 let waiting_handle = {
780 let cached = Cached::clone(&cached);
781
782 tokio::spawn(async move {
783 cached
784 .get_or_compute(|| async {
785 panic!("Entered computation when another inflight computation should already be running");
786 })
787 .await
788 })
789 };
790
791 while cached.inflight_waiting_count() < 2 {
793 tokio::task::yield_now().await;
794 }
795
796 dropping_handle.abort();
798
799 assert!(dropping_handle.await.unwrap_err().is_cancelled());
800 assert!(matches!(waiting_handle.await, Ok(Err(Error::Broadcast(_)))));
801 assert_eq!(cached.get(), None);
802 assert_eq!(cached.get_or_compute(|| async { Ok(3) }).await, Ok(3));
804 assert_eq!(cached.get(), Some(3));
805 }
806
807 #[tokio::test]
808 async fn test_get_or_subscribe() {
809 let cached = Cached::<_, ()>::new();
810
811 assert_eq!(cached.get_or_subscribe().await, None);
813
814 assert_eq!(cached.get_or_compute(|| async { Ok(0) }).await, Ok(0));
816 assert_eq!(cached.get_or_subscribe().await, Some(Ok(0)));
817
818 cached.invalidate();
820
821 let (tokio_notify, handle) = setup_inflight_request(Cached::clone(&cached), Ok(30)).await;
822
823 assert!(cached.is_inflight());
825
826 let get_or_subscribe_handle = {
827 let cached = Cached::clone(&cached);
828
829 tokio::spawn(async move { cached.get_or_subscribe().await })
830 };
831
832 tokio_notify.notify_waiters();
834
835 assert_eq!(handle.await.unwrap(), Ok(30));
836 assert_eq!(get_or_subscribe_handle.await.unwrap(), Some(Ok(30)));
837 assert_eq!(cached.get(), Some(30));
838 }
839
840 #[tokio::test]
841 async fn test_subscribe_or_recompute() {
842 let cached = Cached::new();
843
844 assert_eq!(
846 cached.subscribe_or_recompute(|| async { Err(()) }).await,
847 (None, Err(Error::Computation(()))),
848 );
849 assert_eq!(cached.get(), None);
850
851 assert_eq!(
852 cached.subscribe_or_recompute(|| async { Ok(0) }).await,
853 (None, Ok(0)),
854 );
855 assert_eq!(cached.get(), Some(0));
856
857 assert_eq!(
859 cached.subscribe_or_recompute(|| async { Ok(30) }).await,
860 (Some(0), Ok(30)),
861 );
862 assert_eq!(cached.get(), Some(30));
863
864 assert_eq!(
866 cached.subscribe_or_recompute(|| async { Err(()) }).await,
867 (Some(30), Err(Error::Computation(()))),
868 );
869 assert_eq!(cached.get(), None);
870
871 let (notify, handle) = setup_inflight_request(Cached::clone(&cached), Ok(12)).await;
873
874 let second_handle = {
875 let cached = Cached::clone(&cached);
876
877 tokio::spawn(async move {
878 cached
879 .subscribe_or_recompute(|| async {
880 panic!("Shouldn't execute, already inflight")
881 })
882 .await
883 })
884 };
885
886 notify.notify_waiters();
887
888 assert_eq!(handle.await.unwrap(), Ok(12));
889 assert_eq!(second_handle.await.unwrap(), (None, Ok(12)));
890 assert_eq!(cached.get(), Some(12));
891 }
892
893 #[tokio::test]
894 async fn test_force_recompute() {
895 let cached = Cached::<_, ()>::new();
896
897 assert_eq!(
899 cached.force_recompute(async { Err(()) }).await,
900 (CachedState::EmptyCache, Err(Error::Computation(()))),
901 );
902 assert_eq!(cached.get(), None);
903 assert_eq!(
904 cached.force_recompute(async { Ok(0) }).await,
905 (CachedState::EmptyCache, Ok(0))
906 );
907 assert_eq!(cached.get(), Some(0));
908
909 assert_eq!(
911 cached.force_recompute(async { Ok(15) }).await,
912 (CachedState::ValueCached(0), Ok(15)),
913 );
914 assert_eq!(cached.get(), Some(15));
915 assert_eq!(
917 cached.force_recompute(async { Err(()) }).await,
918 (CachedState::ValueCached(15), Err(Error::Computation(()))),
919 );
920 assert_eq!(cached.get(), None);
921
922 let (_notify, handle) = setup_inflight_request(Cached::clone(&cached), Ok(0)).await;
924
925 assert_eq!(
926 cached.force_recompute(async { Ok(21) }).await,
927 (CachedState::Inflight, Ok(21))
928 );
929 assert!(matches!(handle.await.unwrap(), Err(Error::Aborted(_))));
930 assert_eq!(cached.get(), Some(21));
931 }
932
933 #[tokio::test]
934 async fn test_abort() {
935 let cached = Cached::<_, ()>::new();
936
937 assert!(!cached.abort());
939
940 assert_eq!(cached.get(), None);
942 let (_notify, handle) = setup_inflight_request(Cached::clone(&cached), Ok(0)).await;
943
944 assert!(cached.abort());
945 assert!(!cached.is_inflight());
946
947 assert!(matches!(handle.await.unwrap(), Err(Error::Aborted(_))));
948 assert_eq!(cached.get(), None);
949 assert_eq!(cached.inflight_waiting_count(), 0);
950 }
951
952 async fn setup_inflight_request<T, E>(
960 cached: Cached<T, E>,
961 result: Result<T, E>,
962 ) -> (Arc<Notify>, JoinHandle<Result<T, Error<E>>>)
963 where
964 T: Clone + Send + 'static,
965 E: Clone + Send + 'static,
966 {
967 assert!(!cached.is_inflight());
968 assert!(!cached.is_value_cached());
969
970 let tokio_notify = Arc::new(Notify::new());
971 let registered = Arc::new(Notify::new());
972 let registered_fut = registered.notified();
973
974 let handle = {
975 let tokio_notify = Arc::clone(&tokio_notify);
976 let registered = Arc::clone(®istered);
977 let cached = Cached::clone(&cached);
978
979 tokio::spawn(async move {
980 cached
981 .get_or_compute(|| async move {
982 let notified_fut = tokio_notify.notified();
983 registered.notify_waiters();
984 notified_fut.await;
985 result
986 })
987 .await
988 })
989 };
990
991 registered_fut.await;
993
994 (tokio_notify, handle)
995 }
996}