n0_future/
task.rs

1//! Async rust task spawning and utilities that work natively (using tokio) and in browsers
2//! (using wasm-bindgen-futures).
3
4#[cfg(not(wasm_browser))]
5pub use tokio::spawn;
6#[cfg(not(wasm_browser))]
7pub use tokio::task::{AbortHandle, Id, JoinError, JoinHandle, JoinSet};
8#[cfg(not(wasm_browser))]
9pub use tokio_util::task::AbortOnDropHandle;
10#[cfg(wasm_browser)]
11pub use wasm::*;
12
13#[cfg(wasm_browser)]
14mod wasm {
15    use std::{
16        cell::RefCell,
17        fmt::{self, Debug},
18        future::{Future, IntoFuture},
19        pin::Pin,
20        rc::Rc,
21        sync::Mutex,
22        task::{Context, Poll, Waker},
23    };
24
25    use futures_lite::{stream::StreamExt, FutureExt};
26    use send_wrapper::SendWrapper;
27
28    static TASK_ID_COUNTER: Mutex<u64> = Mutex::new(0);
29
30    fn next_task_id() -> u64 {
31        let mut counter = TASK_ID_COUNTER.lock().unwrap();
32        *counter += 1;
33        *counter
34    }
35
36    /// An opaque ID that uniquely identifies a task relative to all other currently running tasks.
37    #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, derive_more::Display)]
38    pub struct Id(u64);
39
40    /// Wasm shim for tokio's `JoinSet`.
41    ///
42    /// Uses a [`futures_buffered::FuturesUnordered`] queue of
43    /// [`JoinHandle`]s inside.
44    pub struct JoinSet<T> {
45        handles: futures_buffered::FuturesUnordered<JoinHandleWithId<T>>,
46        // We need to keep a second list of JoinHandles so we can access them for cancellation
47        to_cancel: Vec<JoinHandle<T>>,
48    }
49
50    impl<T> Debug for JoinSet<T> {
51        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52            f.debug_struct("JoinSet").field("len", &self.len()).finish()
53        }
54    }
55
56    impl<T> Default for JoinSet<T> {
57        fn default() -> Self {
58            Self::new()
59        }
60    }
61
62    impl<T> JoinSet<T> {
63        /// Creates a new, empty `JoinSet`
64        pub fn new() -> Self {
65            Self {
66                handles: futures_buffered::FuturesUnordered::new(),
67                to_cancel: Vec::new(),
68            }
69        }
70
71        /// Spawns a task into this `JoinSet`.
72        pub fn spawn(&mut self, fut: impl IntoFuture<Output = T> + 'static) -> AbortHandle
73        where
74            T: 'static,
75        {
76            let handle = JoinHandle::new();
77            let state = handle.task.state.clone();
78            let handle_for_spawn = JoinHandle {
79                task: handle.task.clone(),
80            };
81            let handle_for_cancel = JoinHandle {
82                task: handle.task.clone(),
83            };
84
85            wasm_bindgen_futures::spawn_local(SpawnFuture {
86                handle: handle_for_spawn,
87                fut: fut.into_future(),
88            });
89
90            self.handles.push(JoinHandleWithId(handle));
91            self.to_cancel.push(handle_for_cancel);
92            AbortHandle { state }
93        }
94
95        /// Alias to [`Self::spawn`].
96        ///
97        /// Mirrors [`tokio::JoinSet::spawn_local](https://docs.rs/tokio/latest/tokio/task/struct.JoinSet.html#method.spawn_local).
98        /// Because all tasks in WebAssembly are local, this is a simple alias to [`Self::spawn`].
99        pub fn spawn_local(&mut self, fut: impl IntoFuture<Output = T> + 'static) -> AbortHandle
100        where
101            T: 'static,
102        {
103            self.spawn(fut)
104        }
105
106        /// Aborts all tasks inside this `JoinSet`
107        pub fn abort_all(&self) {
108            self.to_cancel.iter().for_each(JoinHandle::abort);
109        }
110
111        /// Awaits the next `JoinSet`'s completion.
112        ///
113        /// If you `.spawn` a new task onto this `JoinSet` while the future
114        /// returned from this is currently pending, then this future will
115        /// continue to be pending, even if the newly spawned future is already
116        /// finished.
117        ///
118        /// TODO(matheus23): Fix this limitation.
119        ///
120        /// Current work around is to recreate the `join_next` future when
121        /// you newly spawned a task onto it. This seems to be the usual way
122        /// the `JoinSet` is used *most of the time* in the iroh codebase anyways.
123        pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
124            self.join_next_with_id()
125                .await
126                .map(|ret| ret.map(|(_id, out)| out))
127        }
128
129        /// Waits until one of the tasks in the set completes and returns its
130        /// output, along with the [task ID] of the completed task.
131        ///
132        /// Returns `None` if the set is empty.
133        ///
134        /// When this method returns an error, then the id of the task that failed can be accessed
135        /// using the [`JoinError::id`] method.
136        ///
137        /// [task ID]: crate::task::Id
138        /// [`JoinError::id`]: fn@crate::task::JoinError::id
139        pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
140            futures_lite::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await
141        }
142
143        /// Polls for one of the tasks in the set to complete.
144        ///
145        /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the set.
146        ///
147        /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
148        /// to receive a wakeup when a task in the `JoinSet` completes. Note that on multiple calls to
149        /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
150        /// scheduled to receive a wakeup.
151        ///
152        /// # Returns
153        ///
154        /// This function returns:
155        ///
156        ///  * `Poll::Pending` if the `JoinSet` is not empty but there is no task whose output is
157        ///    available right now.
158        ///  * `Poll::Ready(Some(Ok(value)))` if one of the tasks in this `JoinSet` has completed.
159        ///    The `value` is the return value of one of the tasks that completed.
160        ///  * `Poll::Ready(Some(Err(err)))` if one of the tasks in this `JoinSet` has panicked or been
161        ///    aborted. The `err` is the `JoinError` from the panicked/aborted task.
162        ///  * `Poll::Ready(None)` if the `JoinSet` is empty.
163        pub fn poll_join_next(
164            &mut self,
165            cx: &mut Context<'_>,
166        ) -> Poll<Option<Result<T, JoinError>>> {
167            match self.poll_join_next_with_id(cx) {
168                Poll::Ready(Some(Ok((_, ret)))) => Poll::Ready(Some(Ok(ret))),
169                Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
170                Poll::Ready(None) => Poll::Ready(None),
171                Poll::Pending => Poll::Pending,
172            }
173        }
174
175        /// Polls for one of the tasks in the set to complete.
176        ///
177        /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the set.
178        ///
179        /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
180        /// to receive a wakeup when a task in the `JoinSet` completes. Note that on multiple calls to
181        /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
182        /// scheduled to receive a wakeup.
183        ///
184        /// # Returns
185        ///
186        /// This function returns:
187        ///
188        ///  * `Poll::Pending` if the `JoinSet` is not empty but there is no task whose output is
189        ///    available right now.
190        ///  * `Poll::Ready(Some(Ok((id, value))))` if one of the tasks in this `JoinSet` has completed.
191        ///    The `value` is the return value of one of the tasks that completed, and
192        ///    `id` is the [task ID] of that task.
193        ///  * `Poll::Ready(Some(Err(err)))` if one of the tasks in this `JoinSet` has panicked or been
194        ///    aborted. The `err` is the `JoinError` from the panicked/aborted task.
195        ///  * `Poll::Ready(None)` if the `JoinSet` is empty.
196        ///
197        /// [task ID]: crate::task::Id
198        pub fn poll_join_next_with_id(
199            &mut self,
200            cx: &mut Context<'_>,
201        ) -> Poll<Option<Result<(Id, T), JoinError>>> {
202            let ret = self.handles.poll_next(cx);
203            // clean up handles that are either cancelled or have finished
204            self.to_cancel.retain(JoinHandle::is_running);
205            ret
206        }
207
208        /// Returns whether there's any tasks that are either still running or
209        /// have pending results in this `JoinSet`.
210        pub fn is_empty(&self) -> bool {
211            self.handles.is_empty()
212        }
213
214        /// Returns the amount of tasks that are either still running or have
215        /// pending results in this `JoinSet`.
216        pub fn len(&self) -> usize {
217            self.handles.len()
218        }
219
220        /// Waits for all tasks to finish. If any of them returns a JoinError,
221        /// this will panic.
222        pub async fn join_all(mut self) -> Vec<T> {
223            let mut output = Vec::new();
224            while let Some(res) = self.join_next().await {
225                match res {
226                    Ok(t) => output.push(t),
227                    Err(err) => panic!("{err}"),
228                }
229            }
230            output
231        }
232
233        /// Aborts all tasks and then waits for them to finish, ignoring panics.
234        pub async fn shutdown(&mut self) {
235            self.abort_all();
236            while let Some(_res) = self.join_next().await {}
237        }
238    }
239
240    impl<T> Drop for JoinSet<T> {
241        fn drop(&mut self) {
242            self.abort_all()
243        }
244    }
245
246    /// A handle to a spawned task.
247    pub struct JoinHandle<T> {
248        task: Task<T>,
249    }
250
251    struct Task<T> {
252        // Using SendWrapper here is safe as long as you keep all of your
253        // work on the main UI worker in the browser.
254        // The only exception to that being the case would be if our user
255        // would use multiple Wasm instances with a single SharedArrayBuffer,
256        // put the instances on different Web Workers and finally shared
257        // the JoinHandle across the Web Worker boundary.
258        // In that case, using the JoinHandle would panic.
259        state: SendWrapper<Rc<RefCell<State>>>,
260        result: SendWrapper<Rc<RefCell<Option<T>>>>,
261    }
262
263    impl<T> Clone for Task<T> {
264        fn clone(&self) -> Self {
265            Self {
266                state: self.state.clone(),
267                result: self.result.clone(),
268            }
269        }
270    }
271
272    #[derive(Debug)]
273    struct State {
274        id: Id,
275        cancelled: bool,
276        completed: bool,
277        waker_handler: Option<Waker>,
278        waker_spawn_fn: Option<Waker>,
279    }
280
281    impl State {
282        fn cancel(&mut self) {
283            if !self.cancelled {
284                self.cancelled = true;
285                self.wake();
286            }
287        }
288
289        fn complete(&mut self) {
290            self.completed = true;
291            self.wake();
292        }
293
294        fn is_complete(&self) -> bool {
295            self.completed || self.cancelled
296        }
297
298        fn wake(&mut self) {
299            if let Some(waker) = self.waker_handler.take() {
300                waker.wake();
301            }
302            if let Some(waker) = self.waker_spawn_fn.take() {
303                waker.wake();
304            }
305        }
306
307        fn register_handler(&mut self, cx: &mut Context<'_>) {
308            match self.waker_handler {
309                // clone_from can be marginally faster in some cases
310                Some(ref mut waker) => waker.clone_from(cx.waker()),
311                None => self.waker_handler = Some(cx.waker().clone()),
312            }
313        }
314
315        fn register_spawn_fn(&mut self, cx: &mut Context<'_>) {
316            match self.waker_spawn_fn {
317                // clone_from can be marginally faster in some cases
318                Some(ref mut waker) => waker.clone_from(cx.waker()),
319                None => self.waker_spawn_fn = Some(cx.waker().clone()),
320            }
321        }
322    }
323
324    impl<T> Debug for JoinHandle<T> {
325        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326            if self.task.state.valid() {
327                let state = self.task.state.borrow();
328                f.debug_struct("JoinHandle")
329                    .field("id", &state.id)
330                    .field("cancelled", &state.cancelled)
331                    .field("completed", &state.completed)
332                    .finish()
333            } else {
334                f.debug_tuple("JoinHandle")
335                    .field(&format_args!("<other thread>"))
336                    .finish()
337            }
338        }
339    }
340
341    impl<T> JoinHandle<T> {
342        fn new() -> Self {
343            Self {
344                task: Task {
345                    state: SendWrapper::new(Rc::new(RefCell::new(State {
346                        cancelled: false,
347                        completed: false,
348                        waker_handler: None,
349                        waker_spawn_fn: None,
350                        id: Id(next_task_id()),
351                    }))),
352                    result: SendWrapper::new(Rc::new(RefCell::new(None))),
353                },
354            }
355        }
356
357        /// Aborts this task.
358        pub fn abort(&self) {
359            self.task.state.borrow_mut().cancel();
360        }
361
362        /// Returns a new [`AbortHandle`] that can be used to remotely abort this task.
363        ///
364        /// Awaiting a task cancelled by the [`AbortHandle`] might complete as usual if the task was
365        /// already completed at the time it was cancelled, but most likely it
366        /// will fail with a [cancelled] `JoinError`.
367        ///
368        /// [cancelled]: JoinError::is_cancelled
369        pub fn abort_handle(&self) -> AbortHandle {
370            AbortHandle {
371                state: self.task.state.clone(),
372            }
373        }
374
375        /// Returns a [task ID] that uniquely identifies this task relative to other
376        /// currently spawned tasks.
377        ///
378        /// [task ID]: crate::task::Id
379        pub fn id(&self) -> Id {
380            let state = self.task.state.borrow();
381            state.id
382        }
383
384        /// Checks if the task associated with this `JoinHandle` has finished.
385        pub fn is_finished(&self) -> bool {
386            let state = self.task.state.borrow();
387            state.is_complete()
388        }
389
390        fn is_running(&self) -> bool {
391            !self.is_finished()
392        }
393    }
394
395    /// An error that can occur when waiting for the completion of a task.
396    #[derive(derive_more::Display, Debug, Clone, Copy)]
397    #[display("{cause}")]
398    pub struct JoinError {
399        cause: JoinErrorCause,
400        id: Id,
401    }
402
403    #[derive(derive_more::Display, Debug, Clone, Copy)]
404    enum JoinErrorCause {
405        /// The error that's returned when the task that's being waited on
406        /// has been cancelled.
407        #[display("task was cancelled")]
408        Cancelled,
409    }
410
411    impl std::error::Error for JoinError {}
412
413    impl JoinError {
414        /// Returns whether this join error is due to cancellation.
415        ///
416        /// Always true in this Wasm implementation, because we don't
417        /// unwind panics in tasks.
418        /// All panics just happen on the main thread anyways.
419        pub fn is_cancelled(&self) -> bool {
420            matches!(self.cause, JoinErrorCause::Cancelled)
421        }
422
423        /// Returns whether this is a panic. Always `false` in Wasm,
424        /// because when a task panics, it's not unwound, instead it
425        /// panics directly to the main thread.
426        pub fn is_panic(&self) -> bool {
427            false
428        }
429
430        /// Returns the panic, if the task has panicked.
431        ///
432        /// Always returns `Err(self)`, in Wasm, because when a task panics, it's not unwound,
433        /// instead it panics directly to the main thread.
434        pub fn try_into_panic(self) -> Result<Box<dyn std::any::Any + Send + 'static>, JoinError> {
435            Err(self)
436        }
437
438        /// Returns a task ID that identifies the task which errored relative to other currently spawned tasks.
439        pub fn id(&self) -> Id {
440            self.id
441        }
442    }
443
444    impl<T> Future for JoinHandle<T> {
445        type Output = Result<T, JoinError>;
446
447        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
448            let mut state = self.task.state.borrow_mut();
449            if state.cancelled {
450                return Poll::Ready(Err(JoinError {
451                    cause: JoinErrorCause::Cancelled,
452                    id: state.id,
453                }));
454            }
455
456            let mut result = self.task.result.borrow_mut();
457            if let Some(result) = result.take() {
458                return Poll::Ready(Ok(result));
459            }
460
461            state.register_handler(cx);
462            Poll::Pending
463        }
464    }
465
466    struct JoinHandleWithId<T>(JoinHandle<T>);
467
468    impl<T> Future for JoinHandleWithId<T> {
469        type Output = Result<(Id, T), JoinError>;
470
471        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
472            match self.0.poll(cx) {
473                Poll::Ready(out) => Poll::Ready(out.map(|out| (self.0.id(), out))),
474                Poll::Pending => Poll::Pending,
475            }
476        }
477    }
478
479    #[pin_project::pin_project]
480    struct SpawnFuture<Fut: Future<Output = T>, T> {
481        handle: JoinHandle<T>,
482        #[pin]
483        fut: Fut,
484    }
485
486    impl<Fut: Future<Output = T>, T> Future for SpawnFuture<Fut, T> {
487        type Output = ();
488
489        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
490            let this = self.project();
491            let mut state = this.handle.task.state.borrow_mut();
492
493            if state.cancelled {
494                return Poll::Ready(());
495            }
496
497            match this.fut.poll(cx) {
498                Poll::Ready(value) => {
499                    let _ = this.handle.task.result.borrow_mut().insert(value);
500                    state.complete();
501                    Poll::Ready(())
502                }
503                Poll::Pending => {
504                    state.register_spawn_fn(cx);
505                    Poll::Pending
506                }
507            }
508        }
509    }
510
511    /// An owned permission to abort a spawned task, without awaiting its completion.
512    #[derive(Clone)]
513    pub struct AbortHandle {
514        state: SendWrapper<Rc<RefCell<State>>>,
515    }
516
517    impl Debug for AbortHandle {
518        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
519            if self.state.valid() {
520                let state = self.state.borrow();
521                f.debug_struct("AbortHandle")
522                    .field("id", &state.id)
523                    .field("cancelled", &state.cancelled)
524                    .field("completed", &state.completed)
525                    .finish()
526            } else {
527                f.debug_tuple("AbortHandle")
528                    .field(&format_args!("<other thread>"))
529                    .finish()
530            }
531        }
532    }
533
534    impl AbortHandle {
535        /// Abort the task associated with the handle.
536        pub fn abort(&self) {
537            self.state.borrow_mut().cancel();
538        }
539
540        /// Returns a [task ID] that uniquely identifies this task relative to other
541        /// currently spawned tasks.
542        ///
543        /// [task ID]: crate::task::Id
544        pub fn id(&self) -> Id {
545            self.state.borrow().id
546        }
547
548        /// Checks if the task associated with this `AbortHandle` has finished.
549        pub fn is_finished(&self) -> bool {
550            let state = self.state.borrow();
551            state.cancelled && state.completed
552        }
553    }
554
555    /// Similar to a `JoinHandle`, except it automatically aborts
556    /// the task when it's dropped.
557    #[pin_project::pin_project(PinnedDrop)]
558    #[derive(derive_more::Debug, derive_more::Deref)]
559    #[debug("AbortOnDropHandle")]
560    #[must_use = "Dropping the handle aborts the task immediately"]
561    pub struct AbortOnDropHandle<T>(#[pin] JoinHandle<T>);
562
563    #[pin_project::pinned_drop]
564    impl<T> PinnedDrop for AbortOnDropHandle<T> {
565        fn drop(self: Pin<&mut Self>) {
566            self.0.abort();
567        }
568    }
569
570    impl<T> Future for AbortOnDropHandle<T> {
571        type Output = <JoinHandle<T> as Future>::Output;
572
573        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
574            self.project().0.poll(cx)
575        }
576    }
577
578    impl<T> AbortOnDropHandle<T> {
579        /// Converts a `JoinHandle` into one that aborts on drop.
580        pub fn new(task: JoinHandle<T>) -> Self {
581            Self(task)
582        }
583
584        /// Returns a new [`AbortHandle`] that can be used to remotely abort this task,
585        /// equivalent to [`JoinHandle::abort_handle`].
586        pub fn abort_handle(&self) -> AbortHandle {
587            self.0.abort_handle()
588        }
589
590        /// Abort the task associated with this handle,
591        /// equivalent to [`JoinHandle::abort`].
592        pub fn abort(&self) {
593            self.0.abort()
594        }
595
596        /// Checks if the task associated with this handle is finished,
597        /// equivalent to [`JoinHandle::is_finished`].
598        pub fn is_finished(&self) -> bool {
599            self.0.is_finished()
600        }
601    }
602
603    /// Spawns a future as a task in the browser runtime.
604    ///
605    /// This is powered by `wasm_bidngen_futures`.
606    pub fn spawn<T: 'static>(fut: impl IntoFuture<Output = T> + 'static) -> JoinHandle<T> {
607        let handle = JoinHandle::new();
608
609        wasm_bindgen_futures::spawn_local(SpawnFuture {
610            handle: JoinHandle {
611                task: handle.task.clone(),
612            },
613            fut: fut.into_future(),
614        });
615
616        handle
617    }
618}
619
620#[cfg(test)]
621mod test {
622    use std::time::Duration;
623
624    #[cfg(not(wasm_browser))]
625    use tokio::test;
626    #[cfg(wasm_browser)]
627    use wasm_bindgen_test::wasm_bindgen_test as test;
628
629    use crate::task;
630
631    #[test]
632    async fn task_abort() {
633        let h1 = task::spawn(async {
634            crate::time::sleep(Duration::from_millis(10)).await;
635        });
636        let h2 = task::spawn(async {
637            crate::time::sleep(Duration::from_millis(10)).await;
638        });
639        assert!(h1.id() != h2.id());
640
641        h1.abort();
642        assert!(h1.await.err().unwrap().is_cancelled());
643        assert!(h2.await.is_ok());
644    }
645
646    #[test]
647    async fn join_set_abort() {
648        let fut = || async { 22 };
649        let mut set = task::JoinSet::new();
650        let h1 = set.spawn(fut());
651        let h2 = set.spawn(fut());
652        assert!(h1.id() != h2.id());
653        h2.abort();
654
655        let mut has_err = false;
656        let mut has_ok = false;
657        while let Some(ret) = set.join_next_with_id().await {
658            match ret {
659                Err(err) => {
660                    if !has_err {
661                        assert!(err.is_cancelled());
662                        has_err = true;
663                    } else {
664                        panic!()
665                    }
666                }
667                Ok((id, out)) => {
668                    if !has_ok {
669                        assert_eq!(id, h1.id());
670                        assert_eq!(out, 22);
671                        has_ok = true;
672                    } else {
673                        panic!()
674                    }
675                }
676            }
677        }
678        assert!(has_err);
679        assert!(has_ok);
680    }
681}