reactive_graph/computed/async_derived/
arc_async_derived.rs

1use super::{
2    inner::{ArcAsyncDerivedInner, AsyncDerivedState},
3    AsyncDerivedReadyFuture, ScopedFuture,
4};
5#[cfg(feature = "sandboxed-arenas")]
6use crate::owner::Sandboxed;
7use crate::{
8    channel::channel,
9    computed::suspense::SuspenseContext,
10    diagnostics::SpecialNonReactiveFuture,
11    graph::{
12        AnySource, AnySubscriber, ReactiveNode, Source, SourceSet, Subscriber,
13        SubscriberSet, ToAnySource, ToAnySubscriber, WithObserver,
14    },
15    owner::{use_context, Owner},
16    signal::{
17        guards::{AsyncPlain, ReadGuard, WriteGuard},
18        ArcTrigger,
19    },
20    traits::{
21        DefinedAt, IsDisposed, Notify, ReadUntracked, Track, UntrackableGuard,
22        Write,
23    },
24    transition::AsyncTransition,
25};
26use any_spawner::Executor;
27use async_lock::RwLock as AsyncRwLock;
28use core::fmt::Debug;
29use futures::{channel::oneshot, FutureExt, StreamExt};
30use or_poisoned::OrPoisoned;
31use send_wrapper::SendWrapper;
32use std::{
33    future::Future,
34    mem,
35    ops::DerefMut,
36    panic::Location,
37    sync::{
38        atomic::{AtomicBool, Ordering},
39        Arc, RwLock, Weak,
40    },
41    task::Waker,
42};
43
44/// A reactive value that is derived by running an asynchronous computation in response to changes
45/// in its sources.
46///
47/// When one of its dependencies changes, this will re-run its async computation, then notify other
48/// values that depend on it that it has changed.
49///
50/// This is a reference-counted type, which is `Clone` but not `Copy`.
51/// For arena-allocated `Copy` memos, use [`AsyncDerived`](super::AsyncDerived).
52///
53/// ## Examples
54/// ```rust
55/// # use reactive_graph::computed::*;
56/// # use reactive_graph::signal::*; let owner = reactive_graph::owner::Owner::new(); owner.set();
57/// # use reactive_graph::prelude::*;
58/// # tokio_test::block_on(async move {
59/// # any_spawner::Executor::init_tokio(); let owner = reactive_graph::owner::Owner::new(); owner.set();
60/// # let _guard = reactive_graph::diagnostics::SpecialNonReactiveZone::enter();
61///
62/// let signal1 = RwSignal::new(0);
63/// let signal2 = RwSignal::new(0);
64/// let derived = ArcAsyncDerived::new(move || async move {
65///   // reactive values can be tracked anywhere in the `async` block
66///   let value1 = signal1.get();
67///   tokio::time::sleep(std::time::Duration::from_millis(25)).await;
68///   let value2 = signal2.get();
69///
70///   value1 + value2
71/// });
72///
73/// // the value can be accessed synchronously as `Option<T>`
74/// assert_eq!(derived.get(), None);
75/// // we can also .await the value, i.e., convert it into a Future
76/// assert_eq!(derived.clone().await, 0);
77/// assert_eq!(derived.get(), Some(0));
78///
79/// signal1.set(1);
80/// // while the new value is still pending, the signal holds the old value
81/// tokio::time::sleep(std::time::Duration::from_millis(5)).await;
82/// assert_eq!(derived.get(), Some(0));
83///
84/// // setting multiple dependencies will hold until the latest change is ready
85/// signal2.set(1);
86/// assert_eq!(derived.await, 2);
87/// # });
88/// ```
89///
90/// ## Core Trait Implementations
91/// - [`.get()`](crate::traits::Get) clones the current value as an `Option<T>`.
92///   If you call it within an effect, it will cause that effect to subscribe
93///   to the memo, and to re-run whenever the value of the memo changes.
94///   - [`.get_untracked()`](crate::traits::GetUntracked) clones the value of
95///     without reactively tracking it.
96/// - [`.read()`](crate::traits::Read) returns a guard that allows accessing the
97///   value by reference. If you call it within an effect, it will
98///   cause that effect to subscribe to the memo, and to re-run whenever the
99///   value changes.
100///   - [`.read_untracked()`](crate::traits::ReadUntracked) gives access to the
101///     current value without reactively tracking it.
102/// - [`.with()`](crate::traits::With) allows you to reactively access the
103///   value without cloning by applying a callback function.
104///   - [`.with_untracked()`](crate::traits::WithUntracked) allows you to access
105///     the value by applying a callback function without reactively
106///     tracking it.
107/// - [`IntoFuture`](std::future::Future) allows you to create a [`Future`] that resolves
108///   when this resource is done loading.
109pub struct ArcAsyncDerived<T> {
110    #[cfg(any(debug_assertions, leptos_debuginfo))]
111    pub(crate) defined_at: &'static Location<'static>,
112    // the current state of this signal
113    pub(crate) value: Arc<AsyncRwLock<Option<T>>>,
114    // holds wakers generated when you .await this
115    pub(crate) wakers: Arc<RwLock<Vec<Waker>>>,
116    pub(crate) inner: Arc<RwLock<ArcAsyncDerivedInner>>,
117    pub(crate) loading: Arc<AtomicBool>,
118}
119
120#[allow(dead_code)]
121pub(crate) trait BlockingLock<T> {
122    fn blocking_read_arc(self: &Arc<Self>)
123        -> async_lock::RwLockReadGuardArc<T>;
124
125    fn blocking_write_arc(
126        self: &Arc<Self>,
127    ) -> async_lock::RwLockWriteGuardArc<T>;
128
129    fn blocking_read(&self) -> async_lock::RwLockReadGuard<'_, T>;
130
131    fn blocking_write(&self) -> async_lock::RwLockWriteGuard<'_, T>;
132}
133
134impl<T> BlockingLock<T> for AsyncRwLock<T> {
135    fn blocking_read_arc(
136        self: &Arc<Self>,
137    ) -> async_lock::RwLockReadGuardArc<T> {
138        #[cfg(not(target_family = "wasm"))]
139        {
140            self.read_arc_blocking()
141        }
142        #[cfg(target_family = "wasm")]
143        {
144            self.read_arc().now_or_never().unwrap()
145        }
146    }
147
148    fn blocking_write_arc(
149        self: &Arc<Self>,
150    ) -> async_lock::RwLockWriteGuardArc<T> {
151        #[cfg(not(target_family = "wasm"))]
152        {
153            self.write_arc_blocking()
154        }
155        #[cfg(target_family = "wasm")]
156        {
157            self.write_arc().now_or_never().unwrap()
158        }
159    }
160
161    fn blocking_read(&self) -> async_lock::RwLockReadGuard<'_, T> {
162        #[cfg(not(target_family = "wasm"))]
163        {
164            self.read_blocking()
165        }
166        #[cfg(target_family = "wasm")]
167        {
168            self.read().now_or_never().unwrap()
169        }
170    }
171
172    fn blocking_write(&self) -> async_lock::RwLockWriteGuard<'_, T> {
173        #[cfg(not(target_family = "wasm"))]
174        {
175            self.write_blocking()
176        }
177        #[cfg(target_family = "wasm")]
178        {
179            self.write().now_or_never().unwrap()
180        }
181    }
182}
183
184impl<T> Clone for ArcAsyncDerived<T> {
185    fn clone(&self) -> Self {
186        Self {
187            #[cfg(any(debug_assertions, leptos_debuginfo))]
188            defined_at: self.defined_at,
189            value: Arc::clone(&self.value),
190            wakers: Arc::clone(&self.wakers),
191            inner: Arc::clone(&self.inner),
192            loading: Arc::clone(&self.loading),
193        }
194    }
195}
196
197impl<T> Debug for ArcAsyncDerived<T> {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        let mut f = f.debug_struct("ArcAsyncDerived");
200        #[cfg(any(debug_assertions, leptos_debuginfo))]
201        f.field("defined_at", &self.defined_at);
202        f.finish_non_exhaustive()
203    }
204}
205
206impl<T> DefinedAt for ArcAsyncDerived<T> {
207    #[inline(always)]
208    fn defined_at(&self) -> Option<&'static Location<'static>> {
209        #[cfg(any(debug_assertions, leptos_debuginfo))]
210        {
211            Some(self.defined_at)
212        }
213        #[cfg(not(any(debug_assertions, leptos_debuginfo)))]
214        {
215            None
216        }
217    }
218}
219
220// This helps create a derived async signal.
221// It needs to be implemented as a macro because it needs to be flexible over
222// whether `fun` returns a `Future` that is `Send`. Doing it as a function would,
223// as far as I can tell, require repeating most of the function body.
224macro_rules! spawn_derived {
225    ($spawner:expr, $initial:ident, $fun:ident, $should_spawn:literal, $force_spawn:literal, $should_track:literal, $source:expr) => {{
226        let (notifier, mut rx) = channel();
227
228        let is_ready = $initial.is_some() && !$force_spawn;
229
230        let owner = Owner::new();
231        let inner = Arc::new(RwLock::new(ArcAsyncDerivedInner {
232            owner: owner.clone(),
233            notifier,
234            sources: SourceSet::new(),
235            subscribers: SubscriberSet::new(),
236            state: AsyncDerivedState::Clean,
237            version: 0,
238            suspenses: Vec::new()
239        }));
240        let value = Arc::new(AsyncRwLock::new($initial));
241        let wakers = Arc::new(RwLock::new(Vec::new()));
242
243        let this = ArcAsyncDerived {
244            #[cfg(any(debug_assertions, leptos_debuginfo))]
245            defined_at: Location::caller(),
246            value: Arc::clone(&value),
247            wakers,
248            inner: Arc::clone(&inner),
249            loading: Arc::new(AtomicBool::new(!is_ready)),
250        };
251        let any_subscriber = this.to_any_subscriber();
252        let initial_fut = if $should_track {
253            owner.with_cleanup(|| {
254                any_subscriber
255                    .with_observer(|| ScopedFuture::new($fun()))
256            })
257        } else {
258            owner.with_cleanup(|| {
259                any_subscriber
260                    .with_observer_untracked(|| ScopedFuture::new($fun()))
261            })
262        };
263        #[cfg(feature = "sandboxed-arenas")]
264        let initial_fut = Sandboxed::new(initial_fut);
265        let mut initial_fut = Box::pin(initial_fut);
266
267        let (was_ready, mut initial_fut) = {
268            if is_ready {
269                (true, None)
270            } else {
271                // if we don't already know that it's ready, we need to poll once, initially
272                // so that the correct value is set synchronously
273                let initial = initial_fut.as_mut().now_or_never();
274                match initial {
275                    None => {
276                        inner.write().or_poisoned().notifier.notify();
277                        (false, Some(initial_fut))
278                    }
279                    Some(orig_value) => {
280                        let mut guard = this.inner.write().or_poisoned();
281
282                        guard.state = AsyncDerivedState::Clean;
283                        *value.blocking_write() = Some(orig_value);
284                        this.loading.store(false, Ordering::Relaxed);
285                        (true, None)
286                    }
287                }
288            }
289        };
290
291        let mut first_run = {
292            let (ready_tx, ready_rx) = oneshot::channel();
293            if !was_ready {
294                AsyncTransition::register(ready_rx);
295            }
296            Some(ready_tx)
297        };
298
299        if was_ready {
300            first_run.take();
301        }
302
303        if let Some(source) = $source {
304            any_subscriber.with_observer(|| source.track());
305        }
306
307        if $should_spawn {
308            $spawner({
309                let value = Arc::downgrade(&this.value);
310                let inner = Arc::downgrade(&this.inner);
311                let wakers = Arc::downgrade(&this.wakers);
312                let loading = Arc::downgrade(&this.loading);
313                let fut = async move {
314                    // if the AsyncDerived has *already* been marked dirty (i.e., one of its
315                    // sources has changed after creation), we should throw out the Future
316                    // we already created, because its values might be stale
317                    let already_dirty = inner.upgrade()
318                        .as_ref()
319                        .and_then(|inner| inner.read().ok())
320                        .map(|inner| inner.state == AsyncDerivedState::Dirty)
321                        .unwrap_or(false);
322                    if already_dirty {
323                        initial_fut.take();
324                    }
325
326                    while rx.next().await.is_some() {
327                        let update_if_necessary = !owner.paused() && if $should_track {
328                            any_subscriber
329                                .with_observer(|| any_subscriber.update_if_necessary())
330                        } else {
331                            any_subscriber
332                                .with_observer_untracked(|| any_subscriber.update_if_necessary())
333                        };
334                        if update_if_necessary || first_run.is_some() {
335                            match (value.upgrade(), inner.upgrade(), wakers.upgrade(), loading.upgrade()) {
336                                (Some(value), Some(inner), Some(wakers), Some(loading)) => {
337                                    // generate new Future
338                                    let owner = inner.read().or_poisoned().owner.clone();
339                                    let fut = initial_fut.take().unwrap_or_else(|| {
340                                        let fut = if $should_track {
341                                            owner.with_cleanup(|| {
342                                                any_subscriber
343                                                    .with_observer(|| ScopedFuture::new($fun()))
344                                            })
345                                        } else {
346                                            owner.with_cleanup(|| {
347                                                any_subscriber
348                                                    .with_observer_untracked(|| ScopedFuture::new($fun()))
349                                            })
350                                        };
351                                        #[cfg(feature = "sandboxed-arenas")]
352                                        let fut = Sandboxed::new(fut);
353                                        Box::pin(fut)
354                                    });
355
356                                    // register with global transition listener, if any
357                                    let ready_tx = first_run.take().unwrap_or_else(|| {
358                                        let (ready_tx, ready_rx) = oneshot::channel();
359                                        if !was_ready {
360                                            AsyncTransition::register(ready_rx);
361                                        }
362                                        ready_tx
363                                    });
364
365                                    // generate and assign new value
366                                    loading.store(true, Ordering::Relaxed);
367
368                                    let (this_version, suspense_ids) = {
369                                        let mut guard = inner.write().or_poisoned();
370                                        guard.version += 1;
371                                        let version = guard.version;
372                                        let suspense_ids = mem::take(&mut guard.suspenses)
373                                            .into_iter()
374                                            .map(|sc| sc.task_id())
375                                            .collect::<Vec<_>>();
376                                        (version, suspense_ids)
377                                    };
378
379                                    let new_value = fut.await;
380
381                                    drop(suspense_ids);
382
383                                    let latest_version = inner.read().or_poisoned().version;
384
385                                    if latest_version == this_version {
386                                        Self::set_inner_value(new_value, value, wakers, inner, loading, Some(ready_tx)).await;
387                                    }
388                                }
389                                _ => break,
390                            }
391                        }
392                    }
393                };
394
395                #[cfg(feature = "sandboxed-arenas")]
396                let fut = Sandboxed::new(fut);
397
398                fut
399            });
400        }
401
402        (this, is_ready)
403    }};
404}
405
406impl<T: 'static> ArcAsyncDerived<T> {
407    async fn set_inner_value(
408        new_value: T,
409        value: Arc<AsyncRwLock<Option<T>>>,
410        wakers: Arc<RwLock<Vec<Waker>>>,
411        inner: Arc<RwLock<ArcAsyncDerivedInner>>,
412        loading: Arc<AtomicBool>,
413        ready_tx: Option<oneshot::Sender<()>>,
414    ) {
415        *value.write().await = Some(new_value);
416        Self::notify_subs(&wakers, &inner, &loading, ready_tx);
417    }
418
419    fn notify_subs(
420        wakers: &Arc<RwLock<Vec<Waker>>>,
421        inner: &Arc<RwLock<ArcAsyncDerivedInner>>,
422        loading: &Arc<AtomicBool>,
423        ready_tx: Option<oneshot::Sender<()>>,
424    ) {
425        loading.store(false, Ordering::Relaxed);
426
427        let prev_state = mem::replace(
428            &mut inner.write().or_poisoned().state,
429            AsyncDerivedState::Notifying,
430        );
431
432        if let Some(ready_tx) = ready_tx {
433            // if it's an Err, that just means the Receiver was dropped
434            // we don't particularly care about that: the point is to notify if
435            // it still exists, but we don't need to know if Suspense is no
436            // longer listening
437            _ = ready_tx.send(());
438        }
439
440        // notify reactive subscribers that we're not loading any more
441        for sub in (&inner.read().or_poisoned().subscribers).into_iter() {
442            sub.mark_dirty();
443        }
444
445        // notify async .awaiters
446        for waker in mem::take(&mut *wakers.write().or_poisoned()) {
447            waker.wake();
448        }
449
450        // if this was marked dirty before notifications began, this means it
451        // had been notified while loading; marking it clean will cause it not to
452        // run on the next tick of the async loop, so here it should be left dirty
453        inner.write().or_poisoned().state = prev_state;
454    }
455}
456
457impl<T: 'static> ArcAsyncDerived<T> {
458    /// Creates a new async derived computation.
459    ///
460    /// This runs eagerly: i.e., calls `fun` once when created and immediately spawns the `Future`
461    /// as a new task.
462    #[track_caller]
463    pub fn new<Fut>(fun: impl Fn() -> Fut + Send + Sync + 'static) -> Self
464    where
465        T: Send + Sync + 'static,
466        Fut: Future<Output = T> + Send + 'static,
467    {
468        Self::new_with_initial(None, fun)
469    }
470
471    /// Creates a new async derived computation with an initial value as a fallback, and begins running the
472    /// `Future` eagerly to get the actual first value.
473    #[track_caller]
474    pub fn new_with_initial<Fut>(
475        initial_value: Option<T>,
476        fun: impl Fn() -> Fut + Send + Sync + 'static,
477    ) -> Self
478    where
479        T: Send + Sync + 'static,
480        Fut: Future<Output = T> + Send + 'static,
481    {
482        let (this, _) = spawn_derived!(
483            Executor::spawn,
484            initial_value,
485            fun,
486            true,
487            true,
488            true,
489            None::<ArcTrigger>
490        );
491        this
492    }
493
494    /// Creates a new async derived computation with an initial value, and does not spawn a task
495    /// initially.
496    ///
497    /// This is mostly used with manual dependency tracking, for primitives built on top of this
498    /// where you do not want to run the run the `Future` unnecessarily.
499    #[doc(hidden)]
500    #[track_caller]
501    pub fn new_with_manual_dependencies<Fut, S>(
502        initial_value: Option<T>,
503        fun: impl Fn() -> Fut + Send + Sync + 'static,
504        source: &S,
505    ) -> Self
506    where
507        T: Send + Sync + 'static,
508        Fut: Future<Output = T> + Send + 'static,
509        S: Track,
510    {
511        let (this, _) = spawn_derived!(
512            Executor::spawn,
513            initial_value,
514            fun,
515            true,
516            false,
517            false,
518            Some(source)
519        );
520        this
521    }
522
523    /// Creates a new async derived computation that will be guaranteed to run on the current
524    /// thread.
525    ///
526    /// This runs eagerly: i.e., calls `fun` once when created and immediately spawns the `Future`
527    /// as a new task.
528    #[track_caller]
529    pub fn new_unsync<Fut>(fun: impl Fn() -> Fut + 'static) -> Self
530    where
531        T: 'static,
532        Fut: Future<Output = T> + 'static,
533    {
534        Self::new_unsync_with_initial(None, fun)
535    }
536
537    /// Creates a new async derived computation with an initial value as a fallback, and begins running the
538    /// `Future` eagerly to get the actual first value.
539    #[track_caller]
540    pub fn new_unsync_with_initial<Fut>(
541        initial_value: Option<T>,
542        fun: impl Fn() -> Fut + 'static,
543    ) -> Self
544    where
545        T: 'static,
546        Fut: Future<Output = T> + 'static,
547    {
548        let (this, _) = spawn_derived!(
549            Executor::spawn_local,
550            initial_value,
551            fun,
552            true,
553            true,
554            true,
555            None::<ArcTrigger>
556        );
557        this
558    }
559
560    /// Returns a `Future` that is ready when this resource has next finished loading.
561    pub fn ready(&self) -> AsyncDerivedReadyFuture {
562        AsyncDerivedReadyFuture::new(
563            self.to_any_source(),
564            &self.loading,
565            &self.wakers,
566        )
567    }
568}
569
570impl<T: 'static> ArcAsyncDerived<SendWrapper<T>> {
571    #[doc(hidden)]
572    #[track_caller]
573    pub fn new_mock<Fut>(fun: impl Fn() -> Fut + 'static) -> Self
574    where
575        T: 'static,
576        Fut: Future<Output = T> + 'static,
577    {
578        let initial = None::<SendWrapper<T>>;
579        let fun = move || {
580            let fut = fun();
581            async move {
582                let value = fut.await;
583                SendWrapper::new(value)
584            }
585        };
586        let (this, _) = spawn_derived!(
587            Executor::spawn_local,
588            initial,
589            fun,
590            false,
591            false,
592            true,
593            None::<ArcTrigger>
594        );
595        this
596    }
597}
598
599impl<T: 'static> ReadUntracked for ArcAsyncDerived<T> {
600    type Value = ReadGuard<Option<T>, AsyncPlain<Option<T>>>;
601
602    fn try_read_untracked(&self) -> Option<Self::Value> {
603        if let Some(suspense_context) = use_context::<SuspenseContext>() {
604            let handle = suspense_context.task_id();
605            let ready = SpecialNonReactiveFuture::new(self.ready());
606            crate::spawn(async move {
607                ready.await;
608                drop(handle);
609            });
610            self.inner
611                .write()
612                .or_poisoned()
613                .suspenses
614                .push(suspense_context);
615        }
616        AsyncPlain::try_new(&self.value).map(ReadGuard::new)
617    }
618}
619
620impl<T: 'static> Notify for ArcAsyncDerived<T> {
621    fn notify(&self) {
622        Self::notify_subs(&self.wakers, &self.inner, &self.loading, None);
623    }
624}
625
626impl<T: 'static> Write for ArcAsyncDerived<T> {
627    type Value = Option<T>;
628
629    fn try_write(&self) -> Option<impl UntrackableGuard<Target = Self::Value>> {
630        Some(WriteGuard::new(self.clone(), self.value.blocking_write()))
631    }
632
633    fn try_write_untracked(
634        &self,
635    ) -> Option<impl DerefMut<Target = Self::Value>> {
636        Some(self.value.blocking_write())
637    }
638}
639
640impl<T: 'static> IsDisposed for ArcAsyncDerived<T> {
641    #[inline(always)]
642    fn is_disposed(&self) -> bool {
643        false
644    }
645}
646
647impl<T: 'static> ToAnySource for ArcAsyncDerived<T> {
648    fn to_any_source(&self) -> AnySource {
649        AnySource(
650            Arc::as_ptr(&self.inner) as usize,
651            Arc::downgrade(&self.inner) as Weak<dyn Source + Send + Sync>,
652            #[cfg(any(debug_assertions, leptos_debuginfo))]
653            self.defined_at,
654        )
655    }
656}
657
658impl<T: 'static> ToAnySubscriber for ArcAsyncDerived<T> {
659    fn to_any_subscriber(&self) -> AnySubscriber {
660        AnySubscriber(
661            Arc::as_ptr(&self.inner) as usize,
662            Arc::downgrade(&self.inner) as Weak<dyn Subscriber + Send + Sync>,
663        )
664    }
665}
666
667impl<T> Source for ArcAsyncDerived<T> {
668    fn add_subscriber(&self, subscriber: AnySubscriber) {
669        self.inner.add_subscriber(subscriber);
670    }
671
672    fn remove_subscriber(&self, subscriber: &AnySubscriber) {
673        self.inner.remove_subscriber(subscriber);
674    }
675
676    fn clear_subscribers(&self) {
677        self.inner.clear_subscribers();
678    }
679}
680
681impl<T> ReactiveNode for ArcAsyncDerived<T> {
682    fn mark_dirty(&self) {
683        self.inner.mark_dirty();
684    }
685
686    fn mark_check(&self) {
687        self.inner.mark_check();
688    }
689
690    fn mark_subscribers_check(&self) {
691        self.inner.mark_subscribers_check();
692    }
693
694    fn update_if_necessary(&self) -> bool {
695        self.inner.update_if_necessary()
696    }
697}
698
699impl<T> Subscriber for ArcAsyncDerived<T> {
700    fn add_source(&self, source: AnySource) {
701        self.inner.add_source(source);
702    }
703
704    fn clear_sources(&self, subscriber: &AnySubscriber) {
705        self.inner.clear_sources(subscriber);
706    }
707}