n0_future/
task.rs

1//! Async rust task spawning and utilities that work natively (using tokio) and in browsers
2//! (using wasm-bindgen-futures).
3
4#[cfg(not(wasm_browser))]
5pub use tokio::spawn;
6#[cfg(not(wasm_browser))]
7pub use tokio::task::{JoinError, JoinHandle, JoinSet};
8#[cfg(not(wasm_browser))]
9pub use tokio_util::task::AbortOnDropHandle;
10#[cfg(wasm_browser)]
11pub use wasm::*;
12
13#[cfg(wasm_browser)]
14mod wasm {
15    use std::{
16        cell::RefCell,
17        fmt::Debug,
18        future::{Future, IntoFuture},
19        pin::Pin,
20        rc::Rc,
21        task::{Context, Poll, Waker},
22    };
23
24    use futures_lite::stream::StreamExt;
25    use send_wrapper::SendWrapper;
26
27    /// Wasm shim for tokio's `JoinSet`.
28    ///
29    /// Uses a `futures_buffered::FuturesUnordered` queue of
30    /// `JoinHandle`s inside.
31    pub struct JoinSet<T> {
32        handles: futures_buffered::FuturesUnordered<JoinHandle<T>>,
33        // We need to keep a second list of JoinHandles so we can access them for cancellation
34        to_cancel: Vec<JoinHandle<T>>,
35    }
36
37    impl<T> std::fmt::Debug for JoinSet<T> {
38        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39            f.debug_struct("JoinSet").field("len", &self.len()).finish()
40        }
41    }
42
43    impl<T> Default for JoinSet<T> {
44        fn default() -> Self {
45            Self::new()
46        }
47    }
48
49    impl<T> JoinSet<T> {
50        /// Creates a new, empty `JoinSet`
51        pub fn new() -> Self {
52            Self {
53                handles: futures_buffered::FuturesUnordered::new(),
54                to_cancel: Vec::new(),
55            }
56        }
57
58        /// Spawns a task into this `JoinSet`.
59        ///
60        /// (Doesn't return an `AbortHandle` unlike the original `tokio::task::JoinSet` yet.)
61        pub fn spawn(&mut self, fut: impl IntoFuture<Output = T> + 'static)
62        where
63            T: 'static,
64        {
65            let handle = JoinHandle::new();
66            let handle_for_spawn = JoinHandle {
67                task: handle.task.clone(),
68            };
69            let handle_for_cancel = JoinHandle {
70                task: handle.task.clone(),
71            };
72
73            wasm_bindgen_futures::spawn_local(SpawnFuture {
74                handle: handle_for_spawn,
75                fut: fut.into_future(),
76            });
77
78            self.handles.push(handle);
79            self.to_cancel.push(handle_for_cancel);
80        }
81
82        /// Aborts all tasks inside this `JoinSet`
83        pub fn abort_all(&self) {
84            self.to_cancel.iter().for_each(JoinHandle::abort);
85        }
86
87        /// Awaits the next `JoinSet`'s completion.
88        ///
89        /// If you `.spawn` a new task onto this `JoinSet` while the future
90        /// returned from this is currently pending, then this future will
91        /// continue to be pending, even if the newly spawned future is already
92        /// finished.
93        ///
94        /// TODO(matheus23): Fix this limitation.
95        ///
96        /// Current work around is to recreate the `join_next` future when
97        /// you newly spawned a task onto it. This seems to be the usual way
98        /// the `JoinSet` is used *most of the time* in the iroh codebase anyways.
99        pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
100            futures_lite::future::poll_fn(|cx| {
101                let ret = self.handles.poll_next(cx);
102                // clean up handles that are either cancelled or have finished
103                self.to_cancel.retain(JoinHandle::is_running);
104                ret
105            })
106            .await
107        }
108
109        /// Returns whether there's any tasks that are either still running or
110        /// have pending results in this `JoinSet`.
111        pub fn is_empty(&self) -> bool {
112            self.handles.is_empty()
113        }
114
115        /// Returns the amount of tasks that are either still running or have
116        /// pending results in this `JoinSet`.
117        pub fn len(&self) -> usize {
118            self.handles.len()
119        }
120
121        /// Waits for all tasks to finish. If any of them returns a JoinError,
122        /// this will panic.
123        pub async fn join_all(mut self) -> Vec<T> {
124            let mut output = Vec::new();
125            while let Some(res) = self.join_next().await {
126                match res {
127                    Ok(t) => output.push(t),
128                    Err(err) => panic!("{err}"),
129                }
130            }
131            output
132        }
133
134        /// Aborts all tasks and then waits for them to finish, ignoring panics.
135        pub async fn shutdown(&mut self) {
136            self.abort_all();
137            while let Some(_res) = self.join_next().await {}
138        }
139    }
140
141    impl<T> Drop for JoinSet<T> {
142        fn drop(&mut self) {
143            self.abort_all()
144        }
145    }
146
147    /// A handle to a spawned task.
148    pub struct JoinHandle<T> {
149        // Using SendWrapper here is safe as long as you keep all of your
150        // work on the main UI worker in the browser.
151        // The only exception to that being the case would be if our user
152        // would use multiple Wasm instances with a single SharedArrayBuffer,
153        // put the instances on different Web Workers and finally shared
154        // the JoinHandle across the Web Worker boundary.
155        // In that case, using the JoinHandle would panic.
156        task: SendWrapper<Rc<RefCell<Task<T>>>>,
157    }
158
159    struct Task<T> {
160        cancelled: bool,
161        completed: bool,
162        waker_handler: Option<Waker>,
163        waker_spawn_fn: Option<Waker>,
164        result: Option<T>,
165    }
166
167    impl<T> Task<T> {
168        fn cancel(&mut self) {
169            if !self.cancelled {
170                self.cancelled = true;
171                self.wake();
172            }
173        }
174
175        fn complete(&mut self, value: T) {
176            self.result = Some(value);
177            self.completed = true;
178            self.wake();
179        }
180
181        fn wake(&mut self) {
182            if let Some(waker) = self.waker_handler.take() {
183                waker.wake();
184            }
185            if let Some(waker) = self.waker_spawn_fn.take() {
186                waker.wake();
187            }
188        }
189
190        fn register_handler(&mut self, cx: &mut Context<'_>) {
191            match self.waker_handler {
192                // clone_from can be marginally faster in some cases
193                Some(ref mut waker) => waker.clone_from(cx.waker()),
194                None => self.waker_handler = Some(cx.waker().clone()),
195            }
196        }
197
198        fn register_spawn_fn(&mut self, cx: &mut Context<'_>) {
199            match self.waker_spawn_fn {
200                // clone_from can be marginally faster in some cases
201                Some(ref mut waker) => waker.clone_from(cx.waker()),
202                None => self.waker_spawn_fn = Some(cx.waker().clone()),
203            }
204        }
205    }
206
207    impl<T> Debug for JoinHandle<T> {
208        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209            if self.task.valid() {
210                let task = self.task.borrow();
211                let cancelled = task.cancelled;
212                let completed = task.completed;
213                f.debug_struct("JoinHandle")
214                    .field("cancelled", &cancelled)
215                    .field("completed", &completed)
216                    .finish()
217            } else {
218                f.debug_tuple("JoinHandle")
219                    .field(&format_args!("<other thread>"))
220                    .finish()
221            }
222        }
223    }
224
225    impl<T> JoinHandle<T> {
226        fn new() -> Self {
227            Self {
228                task: SendWrapper::new(Rc::new(RefCell::new(Task {
229                    cancelled: false,
230                    completed: false,
231                    waker_handler: None,
232                    waker_spawn_fn: None,
233                    result: None,
234                }))),
235            }
236        }
237
238        /// Aborts this task.
239        pub fn abort(&self) {
240            self.task.borrow_mut().cancel();
241        }
242
243        fn is_running(&self) -> bool {
244            let task = self.task.borrow();
245            !task.cancelled && !task.completed
246        }
247    }
248
249    /// An error that can occur when waiting for the completion of a task.
250    #[derive(derive_more::Display, Debug, Clone, Copy)]
251    pub enum JoinError {
252        /// The error that's returned when the task that's being waited on
253        /// has been cancelled.
254        #[display("task was cancelled")]
255        Cancelled,
256    }
257
258    impl std::error::Error for JoinError {}
259
260    impl JoinError {
261        /// Returns whether this join error is due to cancellation.
262        ///
263        /// Always true in this Wasm implementation, because we don't
264        /// unwind panics in tasks.
265        /// All panics just happen on the main thread anyways.
266        pub fn is_cancelled(&self) -> bool {
267            matches!(self, Self::Cancelled)
268        }
269
270        /// Returns whether this is a panic. Always `false` in Wasm,
271        /// because when a task panics, it's not unwound, instead it
272        /// panics directly to the main thread.
273        pub fn is_panic(&self) -> bool {
274            false
275        }
276    }
277
278    impl<T> Future for JoinHandle<T> {
279        type Output = Result<T, JoinError>;
280
281        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
282            let mut task = self.task.borrow_mut();
283            if task.cancelled {
284                return Poll::Ready(Err(JoinError::Cancelled));
285            }
286
287            if let Some(result) = task.result.take() {
288                return Poll::Ready(Ok(result));
289            }
290
291            task.register_handler(cx);
292            Poll::Pending
293        }
294    }
295
296    #[pin_project::pin_project]
297    struct SpawnFuture<Fut: Future<Output = T>, T> {
298        handle: JoinHandle<T>,
299        #[pin]
300        fut: Fut,
301    }
302
303    impl<Fut: Future<Output = T>, T> Future for SpawnFuture<Fut, T> {
304        type Output = ();
305
306        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
307            let this = self.project();
308            let mut task = this.handle.task.borrow_mut();
309
310            if task.cancelled {
311                return Poll::Ready(());
312            }
313
314            match this.fut.poll(cx) {
315                Poll::Ready(value) => {
316                    task.complete(value);
317                    Poll::Ready(())
318                }
319                Poll::Pending => {
320                    task.register_spawn_fn(cx);
321                    Poll::Pending
322                }
323            }
324        }
325    }
326
327    /// Similar to a `JoinHandle`, except it automatically aborts
328    /// the task when it's dropped.
329    #[pin_project::pin_project(PinnedDrop)]
330    #[derive(derive_more::Debug)]
331    #[debug("AbortOnDropHandle")]
332    pub struct AbortOnDropHandle<T>(#[pin] JoinHandle<T>);
333
334    #[pin_project::pinned_drop]
335    impl<T> PinnedDrop for AbortOnDropHandle<T> {
336        fn drop(self: Pin<&mut Self>) {
337            self.0.abort();
338        }
339    }
340
341    impl<T> Future for AbortOnDropHandle<T> {
342        type Output = <JoinHandle<T> as Future>::Output;
343
344        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
345            self.project().0.poll(cx)
346        }
347    }
348
349    impl<T> AbortOnDropHandle<T> {
350        /// Converts a `JoinHandle` into one that aborts on drop.
351        pub fn new(task: JoinHandle<T>) -> Self {
352            Self(task)
353        }
354    }
355
356    /// Spawns a future as a task in the browser runtime.
357    ///
358    /// This is powered by `wasm_bidngen_futures`.
359    pub fn spawn<T: 'static>(fut: impl IntoFuture<Output = T> + 'static) -> JoinHandle<T> {
360        let handle = JoinHandle::new();
361
362        wasm_bindgen_futures::spawn_local(SpawnFuture {
363            handle: JoinHandle {
364                task: handle.task.clone(),
365            },
366            fut: fut.into_future(),
367        });
368
369        handle
370    }
371}
372
373#[cfg(test)]
374mod test {
375    // TODO(matheus23): Test wasm shims using wasm-bindgen-test
376}