n0_future/task.rs
1//! Async rust task spawning and utilities that work natively (using tokio) and in browsers
2//! (using wasm-bindgen-futures).
3
4#[cfg(not(wasm_browser))]
5pub use tokio::spawn;
6#[cfg(not(wasm_browser))]
7pub use tokio::task::{AbortHandle, Id, JoinError, JoinHandle, JoinSet};
8#[cfg(not(wasm_browser))]
9pub use tokio_util::task::AbortOnDropHandle;
10#[cfg(wasm_browser)]
11pub use wasm::*;
12
13#[cfg(wasm_browser)]
14mod wasm {
15 use std::{
16 cell::RefCell,
17 fmt::{self, Debug},
18 future::{Future, IntoFuture},
19 pin::Pin,
20 rc::Rc,
21 sync::Mutex,
22 task::{Context, Poll, Waker},
23 };
24
25 use futures_lite::{stream::StreamExt, FutureExt};
26 use send_wrapper::SendWrapper;
27
28 static TASK_ID_COUNTER: Mutex<u64> = Mutex::new(0);
29
30 fn next_task_id() -> u64 {
31 let mut counter = TASK_ID_COUNTER.lock().unwrap();
32 *counter += 1;
33 *counter
34 }
35
36 /// An opaque ID that uniquely identifies a task relative to all other currently running tasks.
37 #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, derive_more::Display)]
38 pub struct Id(u64);
39
40 /// Wasm shim for tokio's `JoinSet`.
41 ///
42 /// Uses a [`futures_buffered::FuturesUnordered`] queue of
43 /// [`JoinHandle`]s inside.
44 pub struct JoinSet<T> {
45 handles: futures_buffered::FuturesUnordered<JoinHandleWithId<T>>,
46 // We need to keep a second list of JoinHandles so we can access them for cancellation
47 to_cancel: Vec<JoinHandle<T>>,
48 }
49
50 impl<T> Debug for JoinSet<T> {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 f.debug_struct("JoinSet").field("len", &self.len()).finish()
53 }
54 }
55
56 impl<T> Default for JoinSet<T> {
57 fn default() -> Self {
58 Self::new()
59 }
60 }
61
62 impl<T> JoinSet<T> {
63 /// Creates a new, empty `JoinSet`
64 pub fn new() -> Self {
65 Self {
66 handles: futures_buffered::FuturesUnordered::new(),
67 to_cancel: Vec::new(),
68 }
69 }
70
71 /// Spawns a task into this `JoinSet`.
72 pub fn spawn(&mut self, fut: impl IntoFuture<Output = T> + 'static) -> AbortHandle
73 where
74 T: 'static,
75 {
76 let handle = JoinHandle::new();
77 let state = handle.task.state.clone();
78 let handle_for_spawn = JoinHandle {
79 task: handle.task.clone(),
80 };
81 let handle_for_cancel = JoinHandle {
82 task: handle.task.clone(),
83 };
84
85 wasm_bindgen_futures::spawn_local(SpawnFuture {
86 handle: handle_for_spawn,
87 fut: fut.into_future(),
88 });
89
90 self.handles.push(JoinHandleWithId(handle));
91 self.to_cancel.push(handle_for_cancel);
92 AbortHandle { state }
93 }
94
95 /// Alias to [`Self::spawn`].
96 ///
97 /// Mirrors [`tokio::JoinSet::spawn_local](https://docs.rs/tokio/latest/tokio/task/struct.JoinSet.html#method.spawn_local).
98 /// Because all tasks in WebAssembly are local, this is a simple alias to [`Self::spawn`].
99 pub fn spawn_local(&mut self, fut: impl IntoFuture<Output = T> + 'static) -> AbortHandle
100 where
101 T: 'static,
102 {
103 self.spawn(fut)
104 }
105
106 /// Aborts all tasks inside this `JoinSet`
107 pub fn abort_all(&self) {
108 self.to_cancel.iter().for_each(JoinHandle::abort);
109 }
110
111 /// Awaits the next `JoinSet`'s completion.
112 ///
113 /// If you `.spawn` a new task onto this `JoinSet` while the future
114 /// returned from this is currently pending, then this future will
115 /// continue to be pending, even if the newly spawned future is already
116 /// finished.
117 ///
118 /// TODO(matheus23): Fix this limitation.
119 ///
120 /// Current work around is to recreate the `join_next` future when
121 /// you newly spawned a task onto it. This seems to be the usual way
122 /// the `JoinSet` is used *most of the time* in the iroh codebase anyways.
123 pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
124 self.join_next_with_id()
125 .await
126 .map(|ret| ret.map(|(_id, out)| out))
127 }
128
129 /// Waits until one of the tasks in the set completes and returns its
130 /// output, along with the [task ID] of the completed task.
131 ///
132 /// Returns `None` if the set is empty.
133 ///
134 /// When this method returns an error, then the id of the task that failed can be accessed
135 /// using the [`JoinError::id`] method.
136 ///
137 /// [task ID]: crate::task::Id
138 /// [`JoinError::id`]: fn@crate::task::JoinError::id
139 pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
140 futures_lite::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await
141 }
142
143 /// Polls for one of the tasks in the set to complete.
144 ///
145 /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the set.
146 ///
147 /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
148 /// to receive a wakeup when a task in the `JoinSet` completes. Note that on multiple calls to
149 /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
150 /// scheduled to receive a wakeup.
151 ///
152 /// # Returns
153 ///
154 /// This function returns:
155 ///
156 /// * `Poll::Pending` if the `JoinSet` is not empty but there is no task whose output is
157 /// available right now.
158 /// * `Poll::Ready(Some(Ok(value)))` if one of the tasks in this `JoinSet` has completed.
159 /// The `value` is the return value of one of the tasks that completed.
160 /// * `Poll::Ready(Some(Err(err)))` if one of the tasks in this `JoinSet` has panicked or been
161 /// aborted. The `err` is the `JoinError` from the panicked/aborted task.
162 /// * `Poll::Ready(None)` if the `JoinSet` is empty.
163 pub fn poll_join_next(
164 &mut self,
165 cx: &mut Context<'_>,
166 ) -> Poll<Option<Result<T, JoinError>>> {
167 match self.poll_join_next_with_id(cx) {
168 Poll::Ready(Some(Ok((_, ret)))) => Poll::Ready(Some(Ok(ret))),
169 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
170 Poll::Ready(None) => Poll::Ready(None),
171 Poll::Pending => Poll::Pending,
172 }
173 }
174
175 /// Polls for one of the tasks in the set to complete.
176 ///
177 /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the set.
178 ///
179 /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
180 /// to receive a wakeup when a task in the `JoinSet` completes. Note that on multiple calls to
181 /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
182 /// scheduled to receive a wakeup.
183 ///
184 /// # Returns
185 ///
186 /// This function returns:
187 ///
188 /// * `Poll::Pending` if the `JoinSet` is not empty but there is no task whose output is
189 /// available right now.
190 /// * `Poll::Ready(Some(Ok((id, value))))` if one of the tasks in this `JoinSet` has completed.
191 /// The `value` is the return value of one of the tasks that completed, and
192 /// `id` is the [task ID] of that task.
193 /// * `Poll::Ready(Some(Err(err)))` if one of the tasks in this `JoinSet` has panicked or been
194 /// aborted. The `err` is the `JoinError` from the panicked/aborted task.
195 /// * `Poll::Ready(None)` if the `JoinSet` is empty.
196 ///
197 /// [task ID]: crate::task::Id
198 pub fn poll_join_next_with_id(
199 &mut self,
200 cx: &mut Context<'_>,
201 ) -> Poll<Option<Result<(Id, T), JoinError>>> {
202 let ret = self.handles.poll_next(cx);
203 // clean up handles that are either cancelled or have finished
204 self.to_cancel.retain(JoinHandle::is_running);
205 ret
206 }
207
208 /// Returns whether there's any tasks that are either still running or
209 /// have pending results in this `JoinSet`.
210 pub fn is_empty(&self) -> bool {
211 self.handles.is_empty()
212 }
213
214 /// Returns the amount of tasks that are either still running or have
215 /// pending results in this `JoinSet`.
216 pub fn len(&self) -> usize {
217 self.handles.len()
218 }
219
220 /// Waits for all tasks to finish. If any of them returns a JoinError,
221 /// this will panic.
222 pub async fn join_all(mut self) -> Vec<T> {
223 let mut output = Vec::new();
224 while let Some(res) = self.join_next().await {
225 match res {
226 Ok(t) => output.push(t),
227 Err(err) => panic!("{err}"),
228 }
229 }
230 output
231 }
232
233 /// Aborts all tasks and then waits for them to finish, ignoring panics.
234 pub async fn shutdown(&mut self) {
235 self.abort_all();
236 while let Some(_res) = self.join_next().await {}
237 }
238 }
239
240 impl<T> Drop for JoinSet<T> {
241 fn drop(&mut self) {
242 self.abort_all()
243 }
244 }
245
246 /// A handle to a spawned task.
247 pub struct JoinHandle<T> {
248 task: Task<T>,
249 }
250
251 struct Task<T> {
252 // Using SendWrapper here is safe as long as you keep all of your
253 // work on the main UI worker in the browser.
254 // The only exception to that being the case would be if our user
255 // would use multiple Wasm instances with a single SharedArrayBuffer,
256 // put the instances on different Web Workers and finally shared
257 // the JoinHandle across the Web Worker boundary.
258 // In that case, using the JoinHandle would panic.
259 state: SendWrapper<Rc<RefCell<State>>>,
260 result: SendWrapper<Rc<RefCell<Option<T>>>>,
261 }
262
263 impl<T> Clone for Task<T> {
264 fn clone(&self) -> Self {
265 Self {
266 state: self.state.clone(),
267 result: self.result.clone(),
268 }
269 }
270 }
271
272 #[derive(Debug)]
273 struct State {
274 id: Id,
275 cancelled: bool,
276 completed: bool,
277 waker_handler: Option<Waker>,
278 waker_spawn_fn: Option<Waker>,
279 }
280
281 impl State {
282 fn cancel(&mut self) {
283 if !self.cancelled {
284 self.cancelled = true;
285 self.wake();
286 }
287 }
288
289 fn complete(&mut self) {
290 self.completed = true;
291 self.wake();
292 }
293
294 fn is_complete(&self) -> bool {
295 self.completed || self.cancelled
296 }
297
298 fn wake(&mut self) {
299 if let Some(waker) = self.waker_handler.take() {
300 waker.wake();
301 }
302 if let Some(waker) = self.waker_spawn_fn.take() {
303 waker.wake();
304 }
305 }
306
307 fn register_handler(&mut self, cx: &mut Context<'_>) {
308 match self.waker_handler {
309 // clone_from can be marginally faster in some cases
310 Some(ref mut waker) => waker.clone_from(cx.waker()),
311 None => self.waker_handler = Some(cx.waker().clone()),
312 }
313 }
314
315 fn register_spawn_fn(&mut self, cx: &mut Context<'_>) {
316 match self.waker_spawn_fn {
317 // clone_from can be marginally faster in some cases
318 Some(ref mut waker) => waker.clone_from(cx.waker()),
319 None => self.waker_spawn_fn = Some(cx.waker().clone()),
320 }
321 }
322 }
323
324 impl<T> Debug for JoinHandle<T> {
325 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326 if self.task.state.valid() {
327 let state = self.task.state.borrow();
328 f.debug_struct("JoinHandle")
329 .field("id", &state.id)
330 .field("cancelled", &state.cancelled)
331 .field("completed", &state.completed)
332 .finish()
333 } else {
334 f.debug_tuple("JoinHandle")
335 .field(&format_args!("<other thread>"))
336 .finish()
337 }
338 }
339 }
340
341 impl<T> JoinHandle<T> {
342 fn new() -> Self {
343 Self {
344 task: Task {
345 state: SendWrapper::new(Rc::new(RefCell::new(State {
346 cancelled: false,
347 completed: false,
348 waker_handler: None,
349 waker_spawn_fn: None,
350 id: Id(next_task_id()),
351 }))),
352 result: SendWrapper::new(Rc::new(RefCell::new(None))),
353 },
354 }
355 }
356
357 /// Aborts this task.
358 pub fn abort(&self) {
359 self.task.state.borrow_mut().cancel();
360 }
361
362 /// Returns a new [`AbortHandle`] that can be used to remotely abort this task.
363 ///
364 /// Awaiting a task cancelled by the [`AbortHandle`] might complete as usual if the task was
365 /// already completed at the time it was cancelled, but most likely it
366 /// will fail with a [cancelled] `JoinError`.
367 ///
368 /// [cancelled]: JoinError::is_cancelled
369 pub fn abort_handle(&self) -> AbortHandle {
370 AbortHandle {
371 state: self.task.state.clone(),
372 }
373 }
374
375 /// Returns a [task ID] that uniquely identifies this task relative to other
376 /// currently spawned tasks.
377 ///
378 /// [task ID]: crate::task::Id
379 pub fn id(&self) -> Id {
380 let state = self.task.state.borrow();
381 state.id
382 }
383
384 /// Checks if the task associated with this `JoinHandle` has finished.
385 pub fn is_finished(&self) -> bool {
386 let state = self.task.state.borrow();
387 state.is_complete()
388 }
389
390 fn is_running(&self) -> bool {
391 !self.is_finished()
392 }
393 }
394
395 /// An error that can occur when waiting for the completion of a task.
396 #[derive(derive_more::Display, Debug, Clone, Copy)]
397 #[display("{cause}")]
398 pub struct JoinError {
399 cause: JoinErrorCause,
400 id: Id,
401 }
402
403 #[derive(derive_more::Display, Debug, Clone, Copy)]
404 enum JoinErrorCause {
405 /// The error that's returned when the task that's being waited on
406 /// has been cancelled.
407 #[display("task was cancelled")]
408 Cancelled,
409 }
410
411 impl std::error::Error for JoinError {}
412
413 impl JoinError {
414 /// Returns whether this join error is due to cancellation.
415 ///
416 /// Always true in this Wasm implementation, because we don't
417 /// unwind panics in tasks.
418 /// All panics just happen on the main thread anyways.
419 pub fn is_cancelled(&self) -> bool {
420 matches!(self.cause, JoinErrorCause::Cancelled)
421 }
422
423 /// Returns whether this is a panic. Always `false` in Wasm,
424 /// because when a task panics, it's not unwound, instead it
425 /// panics directly to the main thread.
426 pub fn is_panic(&self) -> bool {
427 false
428 }
429
430 /// Returns the panic, if the task has panicked.
431 ///
432 /// Always returns `Err(self)`, in Wasm, because when a task panics, it's not unwound,
433 /// instead it panics directly to the main thread.
434 pub fn try_into_panic(self) -> Result<Box<dyn std::any::Any + Send + 'static>, JoinError> {
435 Err(self)
436 }
437
438 /// Returns a task ID that identifies the task which errored relative to other currently spawned tasks.
439 pub fn id(&self) -> Id {
440 self.id
441 }
442 }
443
444 impl<T> Future for JoinHandle<T> {
445 type Output = Result<T, JoinError>;
446
447 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
448 let mut state = self.task.state.borrow_mut();
449 if state.cancelled {
450 return Poll::Ready(Err(JoinError {
451 cause: JoinErrorCause::Cancelled,
452 id: state.id,
453 }));
454 }
455
456 let mut result = self.task.result.borrow_mut();
457 if let Some(result) = result.take() {
458 return Poll::Ready(Ok(result));
459 }
460
461 state.register_handler(cx);
462 Poll::Pending
463 }
464 }
465
466 struct JoinHandleWithId<T>(JoinHandle<T>);
467
468 impl<T> Future for JoinHandleWithId<T> {
469 type Output = Result<(Id, T), JoinError>;
470
471 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
472 match self.0.poll(cx) {
473 Poll::Ready(out) => Poll::Ready(out.map(|out| (self.0.id(), out))),
474 Poll::Pending => Poll::Pending,
475 }
476 }
477 }
478
479 #[pin_project::pin_project]
480 struct SpawnFuture<Fut: Future<Output = T>, T> {
481 handle: JoinHandle<T>,
482 #[pin]
483 fut: Fut,
484 }
485
486 impl<Fut: Future<Output = T>, T> Future for SpawnFuture<Fut, T> {
487 type Output = ();
488
489 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
490 let this = self.project();
491 let mut state = this.handle.task.state.borrow_mut();
492
493 if state.cancelled {
494 return Poll::Ready(());
495 }
496
497 match this.fut.poll(cx) {
498 Poll::Ready(value) => {
499 let _ = this.handle.task.result.borrow_mut().insert(value);
500 state.complete();
501 Poll::Ready(())
502 }
503 Poll::Pending => {
504 state.register_spawn_fn(cx);
505 Poll::Pending
506 }
507 }
508 }
509 }
510
511 /// An owned permission to abort a spawned task, without awaiting its completion.
512 #[derive(Clone)]
513 pub struct AbortHandle {
514 state: SendWrapper<Rc<RefCell<State>>>,
515 }
516
517 impl Debug for AbortHandle {
518 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
519 if self.state.valid() {
520 let state = self.state.borrow();
521 f.debug_struct("AbortHandle")
522 .field("id", &state.id)
523 .field("cancelled", &state.cancelled)
524 .field("completed", &state.completed)
525 .finish()
526 } else {
527 f.debug_tuple("AbortHandle")
528 .field(&format_args!("<other thread>"))
529 .finish()
530 }
531 }
532 }
533
534 impl AbortHandle {
535 /// Abort the task associated with the handle.
536 pub fn abort(&self) {
537 self.state.borrow_mut().cancel();
538 }
539
540 /// Returns a [task ID] that uniquely identifies this task relative to other
541 /// currently spawned tasks.
542 ///
543 /// [task ID]: crate::task::Id
544 pub fn id(&self) -> Id {
545 self.state.borrow().id
546 }
547
548 /// Checks if the task associated with this `AbortHandle` has finished.
549 pub fn is_finished(&self) -> bool {
550 let state = self.state.borrow();
551 state.cancelled && state.completed
552 }
553 }
554
555 /// Similar to a `JoinHandle`, except it automatically aborts
556 /// the task when it's dropped.
557 #[pin_project::pin_project(PinnedDrop)]
558 #[derive(derive_more::Debug, derive_more::Deref)]
559 #[debug("AbortOnDropHandle")]
560 #[must_use = "Dropping the handle aborts the task immediately"]
561 pub struct AbortOnDropHandle<T>(#[pin] JoinHandle<T>);
562
563 #[pin_project::pinned_drop]
564 impl<T> PinnedDrop for AbortOnDropHandle<T> {
565 fn drop(self: Pin<&mut Self>) {
566 self.0.abort();
567 }
568 }
569
570 impl<T> Future for AbortOnDropHandle<T> {
571 type Output = <JoinHandle<T> as Future>::Output;
572
573 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
574 self.project().0.poll(cx)
575 }
576 }
577
578 impl<T> AbortOnDropHandle<T> {
579 /// Converts a `JoinHandle` into one that aborts on drop.
580 pub fn new(task: JoinHandle<T>) -> Self {
581 Self(task)
582 }
583
584 /// Returns a new [`AbortHandle`] that can be used to remotely abort this task,
585 /// equivalent to [`JoinHandle::abort_handle`].
586 pub fn abort_handle(&self) -> AbortHandle {
587 self.0.abort_handle()
588 }
589
590 /// Abort the task associated with this handle,
591 /// equivalent to [`JoinHandle::abort`].
592 pub fn abort(&self) {
593 self.0.abort()
594 }
595
596 /// Checks if the task associated with this handle is finished,
597 /// equivalent to [`JoinHandle::is_finished`].
598 pub fn is_finished(&self) -> bool {
599 self.0.is_finished()
600 }
601 }
602
603 /// Spawns a future as a task in the browser runtime.
604 ///
605 /// This is powered by `wasm_bidngen_futures`.
606 pub fn spawn<T: 'static>(fut: impl IntoFuture<Output = T> + 'static) -> JoinHandle<T> {
607 let handle = JoinHandle::new();
608
609 wasm_bindgen_futures::spawn_local(SpawnFuture {
610 handle: JoinHandle {
611 task: handle.task.clone(),
612 },
613 fut: fut.into_future(),
614 });
615
616 handle
617 }
618}
619
620#[cfg(test)]
621mod test {
622 use std::time::Duration;
623
624 #[cfg(not(wasm_browser))]
625 use tokio::test;
626 #[cfg(wasm_browser)]
627 use wasm_bindgen_test::wasm_bindgen_test as test;
628
629 use crate::task;
630
631 #[test]
632 async fn task_abort() {
633 let h1 = task::spawn(async {
634 crate::time::sleep(Duration::from_millis(10)).await;
635 });
636 let h2 = task::spawn(async {
637 crate::time::sleep(Duration::from_millis(10)).await;
638 });
639 assert!(h1.id() != h2.id());
640
641 h1.abort();
642 assert!(h1.await.err().unwrap().is_cancelled());
643 assert!(h2.await.is_ok());
644 }
645
646 #[test]
647 async fn join_set_abort() {
648 let fut = || async { 22 };
649 let mut set = task::JoinSet::new();
650 let h1 = set.spawn(fut());
651 let h2 = set.spawn(fut());
652 assert!(h1.id() != h2.id());
653 h2.abort();
654
655 let mut has_err = false;
656 let mut has_ok = false;
657 while let Some(ret) = set.join_next_with_id().await {
658 match ret {
659 Err(err) => {
660 if !has_err {
661 assert!(err.is_cancelled());
662 has_err = true;
663 } else {
664 panic!()
665 }
666 }
667 Ok((id, out)) => {
668 if !has_ok {
669 assert_eq!(id, h1.id());
670 assert_eq!(out, 22);
671 has_ok = true;
672 } else {
673 panic!()
674 }
675 }
676 }
677 }
678 assert!(has_err);
679 assert!(has_ok);
680 }
681}