n0_future/
task.rs

1//! Async rust task spawning and utilities that work natively (using tokio) and in browsers
2//! (using wasm-bindgen-futures).
3
4#[cfg(not(wasm_browser))]
5pub use tokio::spawn;
6#[cfg(not(wasm_browser))]
7pub use tokio::task::{AbortHandle, Id, JoinError, JoinHandle, JoinSet};
8#[cfg(not(wasm_browser))]
9pub use tokio_util::task::AbortOnDropHandle;
10#[cfg(wasm_browser)]
11pub use wasm::*;
12
13#[cfg(wasm_browser)]
14mod wasm {
15    use std::{
16        cell::RefCell,
17        fmt::{self, Debug},
18        future::{Future, IntoFuture},
19        pin::Pin,
20        rc::Rc,
21        sync::Mutex,
22        task::{Context, Poll, Waker},
23    };
24
25    use futures_lite::{stream::StreamExt, FutureExt};
26    use send_wrapper::SendWrapper;
27
28    static TASK_ID_COUNTER: Mutex<u64> = Mutex::new(0);
29
30    fn next_task_id() -> u64 {
31        let mut counter = TASK_ID_COUNTER.lock().unwrap();
32        *counter += 1;
33        *counter
34    }
35
36    /// An opaque ID that uniquely identifies a task relative to all other currently running tasks.
37    #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, derive_more::Display)]
38    pub struct Id(u64);
39
40    /// Wasm shim for tokio's `JoinSet`.
41    ///
42    /// Uses a [`futures_buffered::FuturesUnordered`] queue of
43    /// [`JoinHandle`]s inside.
44    pub struct JoinSet<T> {
45        handles: futures_buffered::FuturesUnordered<JoinHandleWithId<T>>,
46        // We need to keep a second list of JoinHandles so we can access them for cancellation
47        to_cancel: Vec<JoinHandle<T>>,
48    }
49
50    impl<T> Debug for JoinSet<T> {
51        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52            f.debug_struct("JoinSet").field("len", &self.len()).finish()
53        }
54    }
55
56    impl<T> Default for JoinSet<T> {
57        fn default() -> Self {
58            Self::new()
59        }
60    }
61
62    impl<T> JoinSet<T> {
63        /// Creates a new, empty `JoinSet`
64        pub fn new() -> Self {
65            Self {
66                handles: futures_buffered::FuturesUnordered::new(),
67                to_cancel: Vec::new(),
68            }
69        }
70
71        /// Spawns a task into this `JoinSet`.
72        pub fn spawn(&mut self, fut: impl IntoFuture<Output = T> + 'static) -> AbortHandle
73        where
74            T: 'static,
75        {
76            let handle = JoinHandle::new();
77            let state = handle.task.state.clone();
78            let handle_for_spawn = JoinHandle {
79                task: handle.task.clone(),
80            };
81            let handle_for_cancel = JoinHandle {
82                task: handle.task.clone(),
83            };
84
85            wasm_bindgen_futures::spawn_local(SpawnFuture {
86                handle: handle_for_spawn,
87                fut: fut.into_future(),
88            });
89
90            self.handles.push(JoinHandleWithId(handle));
91            self.to_cancel.push(handle_for_cancel);
92            AbortHandle { state }
93        }
94
95        /// Aborts all tasks inside this `JoinSet`
96        pub fn abort_all(&self) {
97            self.to_cancel.iter().for_each(JoinHandle::abort);
98        }
99
100        /// Awaits the next `JoinSet`'s completion.
101        ///
102        /// If you `.spawn` a new task onto this `JoinSet` while the future
103        /// returned from this is currently pending, then this future will
104        /// continue to be pending, even if the newly spawned future is already
105        /// finished.
106        ///
107        /// TODO(matheus23): Fix this limitation.
108        ///
109        /// Current work around is to recreate the `join_next` future when
110        /// you newly spawned a task onto it. This seems to be the usual way
111        /// the `JoinSet` is used *most of the time* in the iroh codebase anyways.
112        pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
113            self.join_next_with_id()
114                .await
115                .map(|ret| ret.map(|(_id, out)| out))
116        }
117
118        /// Waits until one of the tasks in the set completes and returns its
119        /// output, along with the [task ID] of the completed task.
120        ///
121        /// Returns `None` if the set is empty.
122        ///
123        /// When this method returns an error, then the id of the task that failed can be accessed
124        /// using the [`JoinError::id`] method.
125        ///
126        /// [task ID]: crate::task::Id
127        /// [`JoinError::id`]: fn@crate::task::JoinError::id
128        pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
129            futures_lite::future::poll_fn(|cx| {
130                let ret = self.handles.poll_next(cx);
131                // clean up handles that are either cancelled or have finished
132                self.to_cancel.retain(JoinHandle::is_running);
133                ret
134            })
135            .await
136        }
137
138        /// Returns whether there's any tasks that are either still running or
139        /// have pending results in this `JoinSet`.
140        pub fn is_empty(&self) -> bool {
141            self.handles.is_empty()
142        }
143
144        /// Returns the amount of tasks that are either still running or have
145        /// pending results in this `JoinSet`.
146        pub fn len(&self) -> usize {
147            self.handles.len()
148        }
149
150        /// Waits for all tasks to finish. If any of them returns a JoinError,
151        /// this will panic.
152        pub async fn join_all(mut self) -> Vec<T> {
153            let mut output = Vec::new();
154            while let Some(res) = self.join_next().await {
155                match res {
156                    Ok(t) => output.push(t),
157                    Err(err) => panic!("{err}"),
158                }
159            }
160            output
161        }
162
163        /// Aborts all tasks and then waits for them to finish, ignoring panics.
164        pub async fn shutdown(&mut self) {
165            self.abort_all();
166            while let Some(_res) = self.join_next().await {}
167        }
168    }
169
170    impl<T> Drop for JoinSet<T> {
171        fn drop(&mut self) {
172            self.abort_all()
173        }
174    }
175
176    /// A handle to a spawned task.
177    pub struct JoinHandle<T> {
178        task: Task<T>,
179    }
180
181    struct Task<T> {
182        // Using SendWrapper here is safe as long as you keep all of your
183        // work on the main UI worker in the browser.
184        // The only exception to that being the case would be if our user
185        // would use multiple Wasm instances with a single SharedArrayBuffer,
186        // put the instances on different Web Workers and finally shared
187        // the JoinHandle across the Web Worker boundary.
188        // In that case, using the JoinHandle would panic.
189        state: SendWrapper<Rc<RefCell<State>>>,
190        result: SendWrapper<Rc<RefCell<Option<T>>>>,
191    }
192
193    impl<T> Clone for Task<T> {
194        fn clone(&self) -> Self {
195            Self {
196                state: self.state.clone(),
197                result: self.result.clone(),
198            }
199        }
200    }
201
202    #[derive(Debug)]
203    struct State {
204        id: Id,
205        cancelled: bool,
206        completed: bool,
207        waker_handler: Option<Waker>,
208        waker_spawn_fn: Option<Waker>,
209    }
210
211    impl State {
212        fn cancel(&mut self) {
213            if !self.cancelled {
214                self.cancelled = true;
215                self.wake();
216            }
217        }
218
219        fn complete(&mut self) {
220            self.completed = true;
221            self.wake();
222        }
223
224        fn is_complete(&self) -> bool {
225            self.completed || self.cancelled
226        }
227
228        fn wake(&mut self) {
229            if let Some(waker) = self.waker_handler.take() {
230                waker.wake();
231            }
232            if let Some(waker) = self.waker_spawn_fn.take() {
233                waker.wake();
234            }
235        }
236
237        fn register_handler(&mut self, cx: &mut Context<'_>) {
238            match self.waker_handler {
239                // clone_from can be marginally faster in some cases
240                Some(ref mut waker) => waker.clone_from(cx.waker()),
241                None => self.waker_handler = Some(cx.waker().clone()),
242            }
243        }
244
245        fn register_spawn_fn(&mut self, cx: &mut Context<'_>) {
246            match self.waker_spawn_fn {
247                // clone_from can be marginally faster in some cases
248                Some(ref mut waker) => waker.clone_from(cx.waker()),
249                None => self.waker_spawn_fn = Some(cx.waker().clone()),
250            }
251        }
252    }
253
254    impl<T> Debug for JoinHandle<T> {
255        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256            if self.task.state.valid() {
257                let state = self.task.state.borrow();
258                f.debug_struct("JoinHandle")
259                    .field("id", &state.id)
260                    .field("cancelled", &state.cancelled)
261                    .field("completed", &state.completed)
262                    .finish()
263            } else {
264                f.debug_tuple("JoinHandle")
265                    .field(&format_args!("<other thread>"))
266                    .finish()
267            }
268        }
269    }
270
271    impl<T> JoinHandle<T> {
272        fn new() -> Self {
273            Self {
274                task: Task {
275                    state: SendWrapper::new(Rc::new(RefCell::new(State {
276                        cancelled: false,
277                        completed: false,
278                        waker_handler: None,
279                        waker_spawn_fn: None,
280                        id: Id(next_task_id()),
281                    }))),
282                    result: SendWrapper::new(Rc::new(RefCell::new(None))),
283                },
284            }
285        }
286
287        /// Aborts this task.
288        pub fn abort(&self) {
289            self.task.state.borrow_mut().cancel();
290        }
291
292        /// Returns a new [`AbortHandle`] that can be used to remotely abort this task.
293        ///
294        /// Awaiting a task cancelled by the [`AbortHandle`] might complete as usual if the task was
295        /// already completed at the time it was cancelled, but most likely it
296        /// will fail with a [cancelled] `JoinError`.
297        ///
298        /// [cancelled]: JoinError::is_cancelled
299        pub fn abort_handle(&self) -> AbortHandle {
300            AbortHandle {
301                state: self.task.state.clone(),
302            }
303        }
304
305        /// Returns a [task ID] that uniquely identifies this task relative to other
306        /// currently spawned tasks.
307        ///
308        /// [task ID]: crate::task::Id
309        pub fn id(&self) -> Id {
310            let state = self.task.state.borrow();
311            state.id
312        }
313
314        /// Checks if the task associated with this `JoinHandle` has finished.
315        pub fn is_finished(&self) -> bool {
316            let state = self.task.state.borrow();
317            state.is_complete()
318        }
319
320        fn is_running(&self) -> bool {
321            !self.is_finished()
322        }
323    }
324
325    /// An error that can occur when waiting for the completion of a task.
326    #[derive(derive_more::Display, Debug, Clone, Copy)]
327    #[display("{cause}")]
328    pub struct JoinError {
329        cause: JoinErrorCause,
330        id: Id,
331    }
332
333    #[derive(derive_more::Display, Debug, Clone, Copy)]
334    enum JoinErrorCause {
335        /// The error that's returned when the task that's being waited on
336        /// has been cancelled.
337        #[display("task was cancelled")]
338        Cancelled,
339    }
340
341    impl std::error::Error for JoinError {}
342
343    impl JoinError {
344        /// Returns whether this join error is due to cancellation.
345        ///
346        /// Always true in this Wasm implementation, because we don't
347        /// unwind panics in tasks.
348        /// All panics just happen on the main thread anyways.
349        pub fn is_cancelled(&self) -> bool {
350            matches!(self.cause, JoinErrorCause::Cancelled)
351        }
352
353        /// Returns whether this is a panic. Always `false` in Wasm,
354        /// because when a task panics, it's not unwound, instead it
355        /// panics directly to the main thread.
356        pub fn is_panic(&self) -> bool {
357            false
358        }
359
360        /// Returns a task ID that identifies the task which errored relative to other currently spawned tasks.
361        pub fn id(&self) -> Id {
362            self.id
363        }
364    }
365
366    impl<T> Future for JoinHandle<T> {
367        type Output = Result<T, JoinError>;
368
369        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
370            let mut state = self.task.state.borrow_mut();
371            if state.cancelled {
372                return Poll::Ready(Err(JoinError {
373                    cause: JoinErrorCause::Cancelled,
374                    id: state.id,
375                }));
376            }
377
378            let mut result = self.task.result.borrow_mut();
379            if let Some(result) = result.take() {
380                return Poll::Ready(Ok(result));
381            }
382
383            state.register_handler(cx);
384            Poll::Pending
385        }
386    }
387
388    struct JoinHandleWithId<T>(JoinHandle<T>);
389
390    impl<T> Future for JoinHandleWithId<T> {
391        type Output = Result<(Id, T), JoinError>;
392
393        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
394            match self.0.poll(cx) {
395                Poll::Ready(out) => Poll::Ready(out.map(|out| (self.0.id(), out))),
396                Poll::Pending => Poll::Pending,
397            }
398        }
399    }
400
401    #[pin_project::pin_project]
402    struct SpawnFuture<Fut: Future<Output = T>, T> {
403        handle: JoinHandle<T>,
404        #[pin]
405        fut: Fut,
406    }
407
408    impl<Fut: Future<Output = T>, T> Future for SpawnFuture<Fut, T> {
409        type Output = ();
410
411        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
412            let this = self.project();
413            let mut state = this.handle.task.state.borrow_mut();
414
415            if state.cancelled {
416                return Poll::Ready(());
417            }
418
419            match this.fut.poll(cx) {
420                Poll::Ready(value) => {
421                    let _ = this.handle.task.result.borrow_mut().insert(value);
422                    state.complete();
423                    Poll::Ready(())
424                }
425                Poll::Pending => {
426                    state.register_spawn_fn(cx);
427                    Poll::Pending
428                }
429            }
430        }
431    }
432
433    /// An owned permission to abort a spawned task, without awaiting its completion.
434    #[derive(Clone)]
435    pub struct AbortHandle {
436        state: SendWrapper<Rc<RefCell<State>>>,
437    }
438
439    impl Debug for AbortHandle {
440        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
441            if self.state.valid() {
442                let state = self.state.borrow();
443                f.debug_struct("AbortHandle")
444                    .field("id", &state.id)
445                    .field("cancelled", &state.cancelled)
446                    .field("completed", &state.completed)
447                    .finish()
448            } else {
449                f.debug_tuple("AbortHandle")
450                    .field(&format_args!("<other thread>"))
451                    .finish()
452            }
453        }
454    }
455
456    impl AbortHandle {
457        /// Abort the task associated with the handle.
458        pub fn abort(&self) {
459            self.state.borrow_mut().cancel();
460        }
461
462        /// Returns a [task ID] that uniquely identifies this task relative to other
463        /// currently spawned tasks.
464        ///
465        /// [task ID]: crate::task::Id
466        pub fn id(&self) -> Id {
467            self.state.borrow().id
468        }
469
470        /// Checks if the task associated with this `AbortHandle` has finished.
471        pub fn is_finished(&self) -> bool {
472            let state = self.state.borrow();
473            state.cancelled && state.completed
474        }
475    }
476
477    /// Similar to a `JoinHandle`, except it automatically aborts
478    /// the task when it's dropped.
479    #[pin_project::pin_project(PinnedDrop)]
480    #[derive(derive_more::Debug, derive_more::Deref)]
481    #[debug("AbortOnDropHandle")]
482    #[must_use = "Dropping the handle aborts the task immediately"]
483    pub struct AbortOnDropHandle<T>(#[pin] JoinHandle<T>);
484
485    #[pin_project::pinned_drop]
486    impl<T> PinnedDrop for AbortOnDropHandle<T> {
487        fn drop(self: Pin<&mut Self>) {
488            self.0.abort();
489        }
490    }
491
492    impl<T> Future for AbortOnDropHandle<T> {
493        type Output = <JoinHandle<T> as Future>::Output;
494
495        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
496            self.project().0.poll(cx)
497        }
498    }
499
500    impl<T> AbortOnDropHandle<T> {
501        /// Converts a `JoinHandle` into one that aborts on drop.
502        pub fn new(task: JoinHandle<T>) -> Self {
503            Self(task)
504        }
505
506        /// Returns a new [`AbortHandle`] that can be used to remotely abort this task,
507        /// equivalent to [`JoinHandle::abort_handle`].
508        pub fn abort_handle(&self) -> AbortHandle {
509            self.0.abort_handle()
510        }
511
512        /// Abort the task associated with this handle,
513        /// equivalent to [`JoinHandle::abort`].
514        pub fn abort(&self) {
515            self.0.abort()
516        }
517
518        /// Checks if the task associated with this handle is finished,
519        /// equivalent to [`JoinHandle::is_finished`].
520        pub fn is_finished(&self) -> bool {
521            self.0.is_finished()
522        }
523    }
524
525    /// Spawns a future as a task in the browser runtime.
526    ///
527    /// This is powered by `wasm_bidngen_futures`.
528    pub fn spawn<T: 'static>(fut: impl IntoFuture<Output = T> + 'static) -> JoinHandle<T> {
529        let handle = JoinHandle::new();
530
531        wasm_bindgen_futures::spawn_local(SpawnFuture {
532            handle: JoinHandle {
533                task: handle.task.clone(),
534            },
535            fut: fut.into_future(),
536        });
537
538        handle
539    }
540}
541
542#[cfg(test)]
543mod test {
544    use std::time::Duration;
545
546    #[cfg(not(wasm_browser))]
547    use tokio::test;
548    #[cfg(wasm_browser)]
549    use wasm_bindgen_test::wasm_bindgen_test as test;
550
551    use crate::task;
552
553    #[test]
554    async fn task_abort() {
555        let h1 = task::spawn(async {
556            crate::time::sleep(Duration::from_millis(10)).await;
557        });
558        let h2 = task::spawn(async {
559            crate::time::sleep(Duration::from_millis(10)).await;
560        });
561        assert!(h1.id() != h2.id());
562
563        h1.abort();
564        assert!(h1.await.err().unwrap().is_cancelled());
565        assert!(h2.await.is_ok());
566    }
567
568    #[test]
569    async fn join_set_abort() {
570        let fut = || async { 22 };
571        let mut set = task::JoinSet::new();
572        let h1 = set.spawn(fut());
573        let h2 = set.spawn(fut());
574        assert!(h1.id() != h2.id());
575        h2.abort();
576
577        let mut has_err = false;
578        let mut has_ok = false;
579        while let Some(ret) = set.join_next_with_id().await {
580            match ret {
581                Err(err) => {
582                    if !has_err {
583                        assert!(err.is_cancelled());
584                        has_err = true;
585                    } else {
586                        panic!()
587                    }
588                }
589                Ok((id, out)) => {
590                    if !has_ok {
591                        assert_eq!(id, h1.id());
592                        assert_eq!(out, 22);
593                        has_ok = true;
594                    } else {
595                        panic!()
596                    }
597                }
598            }
599        }
600        assert!(has_err);
601        assert!(has_ok);
602    }
603}