n0_future/
task.rs

1//! Async rust task spawning and utilities that work natively (using tokio) and in browsers
2//! (using wasm-bindgen-futures).
3
4#[cfg(not(wasm_browser))]
5pub use tokio::spawn;
6#[cfg(not(wasm_browser))]
7pub use tokio::task::{JoinError, JoinHandle, JoinSet};
8#[cfg(not(wasm_browser))]
9pub use tokio_util::task::AbortOnDropHandle;
10
11#[cfg(wasm_browser)]
12pub use wasm::*;
13
14#[cfg(wasm_browser)]
15mod wasm {
16    use std::{
17        cell::RefCell,
18        fmt::Debug,
19        future::{Future, IntoFuture},
20        pin::Pin,
21        rc::Rc,
22        task::{Context, Poll, Waker},
23    };
24
25    use futures_lite::stream::StreamExt;
26    use send_wrapper::SendWrapper;
27
28    /// Wasm shim for tokio's `JoinSet`.
29    ///
30    /// Uses a `futures_buffered::FuturesUnordered` queue of
31    /// `JoinHandle`s inside.
32    pub struct JoinSet<T> {
33        handles: futures_buffered::FuturesUnordered<JoinHandle<T>>,
34        // We need to keep a second list of JoinHandles so we can access them for cancellation
35        to_cancel: Vec<JoinHandle<T>>,
36    }
37
38    impl<T> std::fmt::Debug for JoinSet<T> {
39        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40            f.debug_struct("JoinSet").field("len", &self.len()).finish()
41        }
42    }
43
44    impl<T> Default for JoinSet<T> {
45        fn default() -> Self {
46            Self::new()
47        }
48    }
49
50    impl<T> JoinSet<T> {
51        /// Creates a new, empty `JoinSet`
52        pub fn new() -> Self {
53            Self {
54                handles: futures_buffered::FuturesUnordered::new(),
55                to_cancel: Vec::new(),
56            }
57        }
58
59        /// Spawns a task into this `JoinSet`.
60        ///
61        /// (Doesn't return an `AbortHandle` unlike the original `tokio::task::JoinSet` yet.)
62        pub fn spawn(&mut self, fut: impl IntoFuture<Output = T> + 'static)
63        where
64            T: 'static,
65        {
66            let handle = JoinHandle::new();
67            let handle_for_spawn = JoinHandle {
68                task: handle.task.clone(),
69            };
70            let handle_for_cancel = JoinHandle {
71                task: handle.task.clone(),
72            };
73
74            wasm_bindgen_futures::spawn_local(SpawnFuture {
75                handle: handle_for_spawn,
76                fut: fut.into_future(),
77            });
78
79            self.handles.push(handle);
80            self.to_cancel.push(handle_for_cancel);
81        }
82
83        /// Aborts all tasks inside this `JoinSet`
84        pub fn abort_all(&self) {
85            self.to_cancel.iter().for_each(JoinHandle::abort);
86        }
87
88        /// Awaits the next `JoinSet`'s completion.
89        ///
90        /// If you `.spawn` a new task onto this `JoinSet` while the future
91        /// returned from this is currently pending, then this future will
92        /// continue to be pending, even if the newly spawned future is already
93        /// finished.
94        ///
95        /// TODO(matheus23): Fix this limitation.
96        ///
97        /// Current work around is to recreate the `join_next` future when
98        /// you newly spawned a task onto it. This seems to be the usual way
99        /// the `JoinSet` is used *most of the time* in the iroh codebase anyways.
100        pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
101            futures_lite::future::poll_fn(|cx| {
102                let ret = self.handles.poll_next(cx);
103                // clean up handles that are either cancelled or have finished
104                self.to_cancel.retain(JoinHandle::is_running);
105                ret
106            })
107            .await
108        }
109
110        /// Returns whether there's any tasks that are either still running or
111        /// have pending results in this `JoinSet`.
112        pub fn is_empty(&self) -> bool {
113            self.handles.is_empty()
114        }
115
116        /// Returns the amount of tasks that are either still running or have
117        /// pending results in this `JoinSet`.
118        pub fn len(&self) -> usize {
119            self.handles.len()
120        }
121
122        /// Waits for all tasks to finish. If any of them returns a JoinError,
123        /// this will panic.
124        pub async fn join_all(mut self) -> Vec<T> {
125            let mut output = Vec::new();
126            while let Some(res) = self.join_next().await {
127                match res {
128                    Ok(t) => output.push(t),
129                    Err(err) => panic!("{err}"),
130                }
131            }
132            output
133        }
134
135        /// Aborts all tasks and then waits for them to finish, ignoring panics.
136        pub async fn shutdown(&mut self) {
137            self.abort_all();
138            while let Some(_res) = self.join_next().await {}
139        }
140    }
141
142    impl<T> Drop for JoinSet<T> {
143        fn drop(&mut self) {
144            self.abort_all()
145        }
146    }
147
148    /// A handle to a spawned task.
149    pub struct JoinHandle<T> {
150        // Using SendWrapper here is safe as long as you keep all of your
151        // work on the main UI worker in the browser.
152        // The only exception to that being the case would be if our user
153        // would use multiple Wasm instances with a single SharedArrayBuffer,
154        // put the instances on different Web Workers and finally shared
155        // the JoinHandle across the Web Worker boundary.
156        // In that case, using the JoinHandle would panic.
157        task: SendWrapper<Rc<RefCell<Task<T>>>>,
158    }
159
160    struct Task<T> {
161        cancelled: bool,
162        completed: bool,
163        waker_handler: Option<Waker>,
164        waker_spawn_fn: Option<Waker>,
165        result: Option<T>,
166    }
167
168    impl<T> Task<T> {
169        fn cancel(&mut self) {
170            if !self.cancelled {
171                self.cancelled = true;
172                self.wake();
173            }
174        }
175
176        fn complete(&mut self, value: T) {
177            self.result = Some(value);
178            self.completed = true;
179            self.wake();
180        }
181
182        fn wake(&mut self) {
183            if let Some(waker) = self.waker_handler.take() {
184                waker.wake();
185            }
186            if let Some(waker) = self.waker_spawn_fn.take() {
187                waker.wake();
188            }
189        }
190
191        fn register_handler(&mut self, cx: &mut Context<'_>) {
192            match self.waker_handler {
193                // clone_from can be marginally faster in some cases
194                Some(ref mut waker) => waker.clone_from(cx.waker()),
195                None => self.waker_handler = Some(cx.waker().clone()),
196            }
197        }
198
199        fn register_spawn_fn(&mut self, cx: &mut Context<'_>) {
200            match self.waker_spawn_fn {
201                // clone_from can be marginally faster in some cases
202                Some(ref mut waker) => waker.clone_from(cx.waker()),
203                None => self.waker_spawn_fn = Some(cx.waker().clone()),
204            }
205        }
206    }
207
208    impl<T> Debug for JoinHandle<T> {
209        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210            if self.task.valid() {
211                let task = self.task.borrow();
212                let cancelled = task.cancelled;
213                let completed = task.completed;
214                f.debug_struct("JoinHandle")
215                    .field("cancelled", &cancelled)
216                    .field("completed", &completed)
217                    .finish()
218            } else {
219                f.debug_tuple("JoinHandle")
220                    .field(&format_args!("<other thread>"))
221                    .finish()
222            }
223        }
224    }
225
226    impl<T> JoinHandle<T> {
227        fn new() -> Self {
228            Self {
229                task: SendWrapper::new(Rc::new(RefCell::new(Task {
230                    cancelled: false,
231                    completed: false,
232                    waker_handler: None,
233                    waker_spawn_fn: None,
234                    result: None,
235                }))),
236            }
237        }
238
239        /// Aborts this task.
240        pub fn abort(&self) {
241            self.task.borrow_mut().cancel();
242        }
243
244        fn is_running(&self) -> bool {
245            let task = self.task.borrow();
246            !task.cancelled && !task.completed
247        }
248    }
249
250    /// An error that can occur when waiting for the completion of a task.
251    #[derive(derive_more::Display, Debug, Clone, Copy)]
252    pub enum JoinError {
253        /// The error that's returned when the task that's being waited on
254        /// has been cancelled.
255        #[display("task was cancelled")]
256        Cancelled,
257    }
258
259    impl std::error::Error for JoinError {}
260
261    impl JoinError {
262        /// Returns whether this join error is due to cancellation.
263        ///
264        /// Always true in this Wasm implementation, because we don't
265        /// unwind panics in tasks.
266        /// All panics just happen on the main thread anyways.
267        pub fn is_cancelled(&self) -> bool {
268            matches!(self, Self::Cancelled)
269        }
270
271        /// Returns whether this is a panic. Always `false` in Wasm,
272        /// because when a task panics, it's not unwound, instead it
273        /// panics directly to the main thread.
274        pub fn is_panic(&self) -> bool {
275            false
276        }
277    }
278
279    impl<T> Future for JoinHandle<T> {
280        type Output = Result<T, JoinError>;
281
282        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
283            let mut task = self.task.borrow_mut();
284            if task.cancelled {
285                return Poll::Ready(Err(JoinError::Cancelled));
286            }
287
288            if let Some(result) = task.result.take() {
289                return Poll::Ready(Ok(result));
290            }
291
292            task.register_handler(cx);
293            Poll::Pending
294        }
295    }
296
297    #[pin_project::pin_project]
298    struct SpawnFuture<Fut: Future<Output = T>, T> {
299        handle: JoinHandle<T>,
300        #[pin]
301        fut: Fut,
302    }
303
304    impl<Fut: Future<Output = T>, T> Future for SpawnFuture<Fut, T> {
305        type Output = ();
306
307        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
308            let this = self.project();
309            let mut task = this.handle.task.borrow_mut();
310
311            if task.cancelled {
312                return Poll::Ready(());
313            }
314
315            match this.fut.poll(cx) {
316                Poll::Ready(value) => {
317                    task.complete(value);
318                    Poll::Ready(())
319                }
320                Poll::Pending => {
321                    task.register_spawn_fn(cx);
322                    Poll::Pending
323                }
324            }
325        }
326    }
327
328    /// Similar to a `JoinHandle`, except it automatically aborts
329    /// the task when it's dropped.
330    #[pin_project::pin_project(PinnedDrop)]
331    #[derive(derive_more::Debug)]
332    #[debug("AbortOnDropHandle")]
333    pub struct AbortOnDropHandle<T>(#[pin] JoinHandle<T>);
334
335    #[pin_project::pinned_drop]
336    impl<T> PinnedDrop for AbortOnDropHandle<T> {
337        fn drop(self: Pin<&mut Self>) {
338            self.0.abort();
339        }
340    }
341
342    impl<T> Future for AbortOnDropHandle<T> {
343        type Output = <JoinHandle<T> as Future>::Output;
344
345        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
346            self.project().0.poll(cx)
347        }
348    }
349
350    impl<T> AbortOnDropHandle<T> {
351        /// Converts a `JoinHandle` into one that aborts on drop.
352        pub fn new(task: JoinHandle<T>) -> Self {
353            Self(task)
354        }
355    }
356
357    /// Spawns a future as a task in the browser runtime.
358    ///
359    /// This is powered by `wasm_bidngen_futures`.
360    pub fn spawn<T: 'static>(fut: impl IntoFuture<Output = T> + 'static) -> JoinHandle<T> {
361        let handle = JoinHandle::new();
362
363        wasm_bindgen_futures::spawn_local(SpawnFuture {
364            handle: JoinHandle {
365                task: handle.task.clone(),
366            },
367            fut: fut.into_future(),
368        });
369
370        handle
371    }
372}
373
374#[cfg(test)]
375mod test {
376    // TODO(matheus23): Test wasm shims using wasm-bindgen-test
377}