n0_future/
task.rs

1//! Async rust task spawning and utilities that work natively (using tokio) and in browsers
2//! (using wasm-bindgen-futures).
3
4#[cfg(not(wasm_browser))]
5pub use tokio::spawn;
6#[cfg(not(wasm_browser))]
7pub use tokio::task::{AbortHandle, Id, JoinError, JoinHandle, JoinSet};
8#[cfg(not(wasm_browser))]
9pub use tokio_util::task::AbortOnDropHandle;
10#[cfg(wasm_browser)]
11pub use wasm::*;
12
13#[cfg(wasm_browser)]
14mod wasm {
15    use std::{
16        cell::RefCell,
17        fmt::{self, Debug},
18        future::{Future, IntoFuture},
19        pin::Pin,
20        rc::Rc,
21        sync::Mutex,
22        task::{Context, Poll, Waker},
23    };
24
25    use futures_lite::{stream::StreamExt, FutureExt};
26    use send_wrapper::SendWrapper;
27
28    static TASK_ID_COUNTER: Mutex<u64> = Mutex::new(0);
29
30    fn next_task_id() -> u64 {
31        let mut counter = TASK_ID_COUNTER.lock().unwrap();
32        *counter += 1;
33        *counter
34    }
35
36    /// An opaque ID that uniquely identifies a task relative to all other currently running tasks.
37    #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, derive_more::Display)]
38    pub struct Id(u64);
39
40    /// Wasm shim for tokio's `JoinSet`.
41    ///
42    /// Uses a [`futures_buffered::FuturesUnordered`] queue of
43    /// [`JoinHandle`]s inside.
44    pub struct JoinSet<T> {
45        handles: futures_buffered::FuturesUnordered<JoinHandleWithId<T>>,
46        // We need to keep a second list of JoinHandles so we can access them for cancellation
47        to_cancel: Vec<JoinHandle<T>>,
48    }
49
50    impl<T> Debug for JoinSet<T> {
51        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52            f.debug_struct("JoinSet").field("len", &self.len()).finish()
53        }
54    }
55
56    impl<T> Default for JoinSet<T> {
57        fn default() -> Self {
58            Self::new()
59        }
60    }
61
62    impl<T> JoinSet<T> {
63        /// Creates a new, empty `JoinSet`
64        pub fn new() -> Self {
65            Self {
66                handles: futures_buffered::FuturesUnordered::new(),
67                to_cancel: Vec::new(),
68            }
69        }
70
71        /// Spawns a task into this `JoinSet`.
72        pub fn spawn(&mut self, fut: impl IntoFuture<Output = T> + 'static) -> AbortHandle
73        where
74            T: 'static,
75        {
76            let handle = JoinHandle::new();
77            let state = handle.task.state.clone();
78            let handle_for_spawn = JoinHandle {
79                task: handle.task.clone(),
80            };
81            let handle_for_cancel = JoinHandle {
82                task: handle.task.clone(),
83            };
84
85            wasm_bindgen_futures::spawn_local(SpawnFuture {
86                handle: handle_for_spawn,
87                fut: fut.into_future(),
88            });
89
90            self.handles.push(JoinHandleWithId(handle));
91            self.to_cancel.push(handle_for_cancel);
92            AbortHandle { state }
93        }
94
95        /// Alias to [`Self::spawn`].
96        ///
97        /// Mirrors [`tokio::JoinSet::spawn_local](https://docs.rs/tokio/latest/tokio/task/struct.JoinSet.html#method.spawn_local).
98        /// Because all tasks in WebAssembly are local, this is a simple alias to [`Self::spawn`].
99        pub fn spawn_local(&mut self, fut: impl IntoFuture<Output = T> + 'static) -> AbortHandle
100        where
101            T: 'static,
102        {
103            self.spawn(fut)
104        }
105
106        /// Aborts all tasks inside this `JoinSet`
107        pub fn abort_all(&self) {
108            self.to_cancel.iter().for_each(JoinHandle::abort);
109        }
110
111        /// Awaits the next `JoinSet`'s completion.
112        ///
113        /// If you `.spawn` a new task onto this `JoinSet` while the future
114        /// returned from this is currently pending, then this future will
115        /// continue to be pending, even if the newly spawned future is already
116        /// finished.
117        ///
118        /// TODO(matheus23): Fix this limitation.
119        ///
120        /// Current work around is to recreate the `join_next` future when
121        /// you newly spawned a task onto it. This seems to be the usual way
122        /// the `JoinSet` is used *most of the time* in the iroh codebase anyways.
123        pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
124            self.join_next_with_id()
125                .await
126                .map(|ret| ret.map(|(_id, out)| out))
127        }
128
129        /// Waits until one of the tasks in the set completes and returns its
130        /// output, along with the [task ID] of the completed task.
131        ///
132        /// Returns `None` if the set is empty.
133        ///
134        /// When this method returns an error, then the id of the task that failed can be accessed
135        /// using the [`JoinError::id`] method.
136        ///
137        /// [task ID]: crate::task::Id
138        /// [`JoinError::id`]: fn@crate::task::JoinError::id
139        pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
140            futures_lite::future::poll_fn(|cx| {
141                let ret = self.handles.poll_next(cx);
142                // clean up handles that are either cancelled or have finished
143                self.to_cancel.retain(JoinHandle::is_running);
144                ret
145            })
146            .await
147        }
148
149        /// Returns whether there's any tasks that are either still running or
150        /// have pending results in this `JoinSet`.
151        pub fn is_empty(&self) -> bool {
152            self.handles.is_empty()
153        }
154
155        /// Returns the amount of tasks that are either still running or have
156        /// pending results in this `JoinSet`.
157        pub fn len(&self) -> usize {
158            self.handles.len()
159        }
160
161        /// Waits for all tasks to finish. If any of them returns a JoinError,
162        /// this will panic.
163        pub async fn join_all(mut self) -> Vec<T> {
164            let mut output = Vec::new();
165            while let Some(res) = self.join_next().await {
166                match res {
167                    Ok(t) => output.push(t),
168                    Err(err) => panic!("{err}"),
169                }
170            }
171            output
172        }
173
174        /// Aborts all tasks and then waits for them to finish, ignoring panics.
175        pub async fn shutdown(&mut self) {
176            self.abort_all();
177            while let Some(_res) = self.join_next().await {}
178        }
179    }
180
181    impl<T> Drop for JoinSet<T> {
182        fn drop(&mut self) {
183            self.abort_all()
184        }
185    }
186
187    /// A handle to a spawned task.
188    pub struct JoinHandle<T> {
189        task: Task<T>,
190    }
191
192    struct Task<T> {
193        // Using SendWrapper here is safe as long as you keep all of your
194        // work on the main UI worker in the browser.
195        // The only exception to that being the case would be if our user
196        // would use multiple Wasm instances with a single SharedArrayBuffer,
197        // put the instances on different Web Workers and finally shared
198        // the JoinHandle across the Web Worker boundary.
199        // In that case, using the JoinHandle would panic.
200        state: SendWrapper<Rc<RefCell<State>>>,
201        result: SendWrapper<Rc<RefCell<Option<T>>>>,
202    }
203
204    impl<T> Clone for Task<T> {
205        fn clone(&self) -> Self {
206            Self {
207                state: self.state.clone(),
208                result: self.result.clone(),
209            }
210        }
211    }
212
213    #[derive(Debug)]
214    struct State {
215        id: Id,
216        cancelled: bool,
217        completed: bool,
218        waker_handler: Option<Waker>,
219        waker_spawn_fn: Option<Waker>,
220    }
221
222    impl State {
223        fn cancel(&mut self) {
224            if !self.cancelled {
225                self.cancelled = true;
226                self.wake();
227            }
228        }
229
230        fn complete(&mut self) {
231            self.completed = true;
232            self.wake();
233        }
234
235        fn is_complete(&self) -> bool {
236            self.completed || self.cancelled
237        }
238
239        fn wake(&mut self) {
240            if let Some(waker) = self.waker_handler.take() {
241                waker.wake();
242            }
243            if let Some(waker) = self.waker_spawn_fn.take() {
244                waker.wake();
245            }
246        }
247
248        fn register_handler(&mut self, cx: &mut Context<'_>) {
249            match self.waker_handler {
250                // clone_from can be marginally faster in some cases
251                Some(ref mut waker) => waker.clone_from(cx.waker()),
252                None => self.waker_handler = Some(cx.waker().clone()),
253            }
254        }
255
256        fn register_spawn_fn(&mut self, cx: &mut Context<'_>) {
257            match self.waker_spawn_fn {
258                // clone_from can be marginally faster in some cases
259                Some(ref mut waker) => waker.clone_from(cx.waker()),
260                None => self.waker_spawn_fn = Some(cx.waker().clone()),
261            }
262        }
263    }
264
265    impl<T> Debug for JoinHandle<T> {
266        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
267            if self.task.state.valid() {
268                let state = self.task.state.borrow();
269                f.debug_struct("JoinHandle")
270                    .field("id", &state.id)
271                    .field("cancelled", &state.cancelled)
272                    .field("completed", &state.completed)
273                    .finish()
274            } else {
275                f.debug_tuple("JoinHandle")
276                    .field(&format_args!("<other thread>"))
277                    .finish()
278            }
279        }
280    }
281
282    impl<T> JoinHandle<T> {
283        fn new() -> Self {
284            Self {
285                task: Task {
286                    state: SendWrapper::new(Rc::new(RefCell::new(State {
287                        cancelled: false,
288                        completed: false,
289                        waker_handler: None,
290                        waker_spawn_fn: None,
291                        id: Id(next_task_id()),
292                    }))),
293                    result: SendWrapper::new(Rc::new(RefCell::new(None))),
294                },
295            }
296        }
297
298        /// Aborts this task.
299        pub fn abort(&self) {
300            self.task.state.borrow_mut().cancel();
301        }
302
303        /// Returns a new [`AbortHandle`] that can be used to remotely abort this task.
304        ///
305        /// Awaiting a task cancelled by the [`AbortHandle`] might complete as usual if the task was
306        /// already completed at the time it was cancelled, but most likely it
307        /// will fail with a [cancelled] `JoinError`.
308        ///
309        /// [cancelled]: JoinError::is_cancelled
310        pub fn abort_handle(&self) -> AbortHandle {
311            AbortHandle {
312                state: self.task.state.clone(),
313            }
314        }
315
316        /// Returns a [task ID] that uniquely identifies this task relative to other
317        /// currently spawned tasks.
318        ///
319        /// [task ID]: crate::task::Id
320        pub fn id(&self) -> Id {
321            let state = self.task.state.borrow();
322            state.id
323        }
324
325        /// Checks if the task associated with this `JoinHandle` has finished.
326        pub fn is_finished(&self) -> bool {
327            let state = self.task.state.borrow();
328            state.is_complete()
329        }
330
331        fn is_running(&self) -> bool {
332            !self.is_finished()
333        }
334    }
335
336    /// An error that can occur when waiting for the completion of a task.
337    #[derive(derive_more::Display, Debug, Clone, Copy)]
338    #[display("{cause}")]
339    pub struct JoinError {
340        cause: JoinErrorCause,
341        id: Id,
342    }
343
344    #[derive(derive_more::Display, Debug, Clone, Copy)]
345    enum JoinErrorCause {
346        /// The error that's returned when the task that's being waited on
347        /// has been cancelled.
348        #[display("task was cancelled")]
349        Cancelled,
350    }
351
352    impl std::error::Error for JoinError {}
353
354    impl JoinError {
355        /// Returns whether this join error is due to cancellation.
356        ///
357        /// Always true in this Wasm implementation, because we don't
358        /// unwind panics in tasks.
359        /// All panics just happen on the main thread anyways.
360        pub fn is_cancelled(&self) -> bool {
361            matches!(self.cause, JoinErrorCause::Cancelled)
362        }
363
364        /// Returns whether this is a panic. Always `false` in Wasm,
365        /// because when a task panics, it's not unwound, instead it
366        /// panics directly to the main thread.
367        pub fn is_panic(&self) -> bool {
368            false
369        }
370
371        /// Returns a task ID that identifies the task which errored relative to other currently spawned tasks.
372        pub fn id(&self) -> Id {
373            self.id
374        }
375    }
376
377    impl<T> Future for JoinHandle<T> {
378        type Output = Result<T, JoinError>;
379
380        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
381            let mut state = self.task.state.borrow_mut();
382            if state.cancelled {
383                return Poll::Ready(Err(JoinError {
384                    cause: JoinErrorCause::Cancelled,
385                    id: state.id,
386                }));
387            }
388
389            let mut result = self.task.result.borrow_mut();
390            if let Some(result) = result.take() {
391                return Poll::Ready(Ok(result));
392            }
393
394            state.register_handler(cx);
395            Poll::Pending
396        }
397    }
398
399    struct JoinHandleWithId<T>(JoinHandle<T>);
400
401    impl<T> Future for JoinHandleWithId<T> {
402        type Output = Result<(Id, T), JoinError>;
403
404        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
405            match self.0.poll(cx) {
406                Poll::Ready(out) => Poll::Ready(out.map(|out| (self.0.id(), out))),
407                Poll::Pending => Poll::Pending,
408            }
409        }
410    }
411
412    #[pin_project::pin_project]
413    struct SpawnFuture<Fut: Future<Output = T>, T> {
414        handle: JoinHandle<T>,
415        #[pin]
416        fut: Fut,
417    }
418
419    impl<Fut: Future<Output = T>, T> Future for SpawnFuture<Fut, T> {
420        type Output = ();
421
422        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
423            let this = self.project();
424            let mut state = this.handle.task.state.borrow_mut();
425
426            if state.cancelled {
427                return Poll::Ready(());
428            }
429
430            match this.fut.poll(cx) {
431                Poll::Ready(value) => {
432                    let _ = this.handle.task.result.borrow_mut().insert(value);
433                    state.complete();
434                    Poll::Ready(())
435                }
436                Poll::Pending => {
437                    state.register_spawn_fn(cx);
438                    Poll::Pending
439                }
440            }
441        }
442    }
443
444    /// An owned permission to abort a spawned task, without awaiting its completion.
445    #[derive(Clone)]
446    pub struct AbortHandle {
447        state: SendWrapper<Rc<RefCell<State>>>,
448    }
449
450    impl Debug for AbortHandle {
451        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
452            if self.state.valid() {
453                let state = self.state.borrow();
454                f.debug_struct("AbortHandle")
455                    .field("id", &state.id)
456                    .field("cancelled", &state.cancelled)
457                    .field("completed", &state.completed)
458                    .finish()
459            } else {
460                f.debug_tuple("AbortHandle")
461                    .field(&format_args!("<other thread>"))
462                    .finish()
463            }
464        }
465    }
466
467    impl AbortHandle {
468        /// Abort the task associated with the handle.
469        pub fn abort(&self) {
470            self.state.borrow_mut().cancel();
471        }
472
473        /// Returns a [task ID] that uniquely identifies this task relative to other
474        /// currently spawned tasks.
475        ///
476        /// [task ID]: crate::task::Id
477        pub fn id(&self) -> Id {
478            self.state.borrow().id
479        }
480
481        /// Checks if the task associated with this `AbortHandle` has finished.
482        pub fn is_finished(&self) -> bool {
483            let state = self.state.borrow();
484            state.cancelled && state.completed
485        }
486    }
487
488    /// Similar to a `JoinHandle`, except it automatically aborts
489    /// the task when it's dropped.
490    #[pin_project::pin_project(PinnedDrop)]
491    #[derive(derive_more::Debug, derive_more::Deref)]
492    #[debug("AbortOnDropHandle")]
493    #[must_use = "Dropping the handle aborts the task immediately"]
494    pub struct AbortOnDropHandle<T>(#[pin] JoinHandle<T>);
495
496    #[pin_project::pinned_drop]
497    impl<T> PinnedDrop for AbortOnDropHandle<T> {
498        fn drop(self: Pin<&mut Self>) {
499            self.0.abort();
500        }
501    }
502
503    impl<T> Future for AbortOnDropHandle<T> {
504        type Output = <JoinHandle<T> as Future>::Output;
505
506        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
507            self.project().0.poll(cx)
508        }
509    }
510
511    impl<T> AbortOnDropHandle<T> {
512        /// Converts a `JoinHandle` into one that aborts on drop.
513        pub fn new(task: JoinHandle<T>) -> Self {
514            Self(task)
515        }
516
517        /// Returns a new [`AbortHandle`] that can be used to remotely abort this task,
518        /// equivalent to [`JoinHandle::abort_handle`].
519        pub fn abort_handle(&self) -> AbortHandle {
520            self.0.abort_handle()
521        }
522
523        /// Abort the task associated with this handle,
524        /// equivalent to [`JoinHandle::abort`].
525        pub fn abort(&self) {
526            self.0.abort()
527        }
528
529        /// Checks if the task associated with this handle is finished,
530        /// equivalent to [`JoinHandle::is_finished`].
531        pub fn is_finished(&self) -> bool {
532            self.0.is_finished()
533        }
534    }
535
536    /// Spawns a future as a task in the browser runtime.
537    ///
538    /// This is powered by `wasm_bidngen_futures`.
539    pub fn spawn<T: 'static>(fut: impl IntoFuture<Output = T> + 'static) -> JoinHandle<T> {
540        let handle = JoinHandle::new();
541
542        wasm_bindgen_futures::spawn_local(SpawnFuture {
543            handle: JoinHandle {
544                task: handle.task.clone(),
545            },
546            fut: fut.into_future(),
547        });
548
549        handle
550    }
551}
552
553#[cfg(test)]
554mod test {
555    use std::time::Duration;
556
557    #[cfg(not(wasm_browser))]
558    use tokio::test;
559    #[cfg(wasm_browser)]
560    use wasm_bindgen_test::wasm_bindgen_test as test;
561
562    use crate::task;
563
564    #[test]
565    async fn task_abort() {
566        let h1 = task::spawn(async {
567            crate::time::sleep(Duration::from_millis(10)).await;
568        });
569        let h2 = task::spawn(async {
570            crate::time::sleep(Duration::from_millis(10)).await;
571        });
572        assert!(h1.id() != h2.id());
573
574        h1.abort();
575        assert!(h1.await.err().unwrap().is_cancelled());
576        assert!(h2.await.is_ok());
577    }
578
579    #[test]
580    async fn join_set_abort() {
581        let fut = || async { 22 };
582        let mut set = task::JoinSet::new();
583        let h1 = set.spawn(fut());
584        let h2 = set.spawn(fut());
585        assert!(h1.id() != h2.id());
586        h2.abort();
587
588        let mut has_err = false;
589        let mut has_ok = false;
590        while let Some(ret) = set.join_next_with_id().await {
591            match ret {
592                Err(err) => {
593                    if !has_err {
594                        assert!(err.is_cancelled());
595                        has_err = true;
596                    } else {
597                        panic!()
598                    }
599                }
600                Ok((id, out)) => {
601                    if !has_ok {
602                        assert_eq!(id, h1.id());
603                        assert_eq!(out, 22);
604                        has_ok = true;
605                    } else {
606                        panic!()
607                    }
608                }
609            }
610        }
611        assert!(has_err);
612        assert!(has_ok);
613    }
614}