1#[cfg(not(wasm_browser))]
5pub use tokio::spawn;
6#[cfg(not(wasm_browser))]
7pub use tokio::task::{AbortHandle, Id, JoinError, JoinHandle, JoinSet};
8#[cfg(not(wasm_browser))]
9pub use tokio_util::task::AbortOnDropHandle;
10#[cfg(wasm_browser)]
11pub use wasm::*;
12
13#[cfg(wasm_browser)]
14mod wasm {
15 use std::{
16 cell::RefCell,
17 fmt::{self, Debug},
18 future::{Future, IntoFuture},
19 pin::Pin,
20 rc::Rc,
21 sync::Mutex,
22 task::{Context, Poll, Waker},
23 };
24
25 use futures_lite::{stream::StreamExt, FutureExt};
26 use send_wrapper::SendWrapper;
27
28 static TASK_ID_COUNTER: Mutex<u64> = Mutex::new(0);
29
30 fn next_task_id() -> u64 {
31 let mut counter = TASK_ID_COUNTER.lock().unwrap();
32 *counter += 1;
33 *counter
34 }
35
36 #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, derive_more::Display)]
38 pub struct Id(u64);
39
40 pub struct JoinSet<T> {
45 handles: futures_buffered::FuturesUnordered<JoinHandleWithId<T>>,
46 to_cancel: Vec<JoinHandle<T>>,
48 }
49
50 impl<T> Debug for JoinSet<T> {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 f.debug_struct("JoinSet").field("len", &self.len()).finish()
53 }
54 }
55
56 impl<T> Default for JoinSet<T> {
57 fn default() -> Self {
58 Self::new()
59 }
60 }
61
62 impl<T> JoinSet<T> {
63 pub fn new() -> Self {
65 Self {
66 handles: futures_buffered::FuturesUnordered::new(),
67 to_cancel: Vec::new(),
68 }
69 }
70
71 pub fn spawn(&mut self, fut: impl IntoFuture<Output = T> + 'static) -> AbortHandle
73 where
74 T: 'static,
75 {
76 let handle = JoinHandle::new();
77 let state = handle.task.state.clone();
78 let handle_for_spawn = JoinHandle {
79 task: handle.task.clone(),
80 };
81 let handle_for_cancel = JoinHandle {
82 task: handle.task.clone(),
83 };
84
85 wasm_bindgen_futures::spawn_local(SpawnFuture {
86 handle: handle_for_spawn,
87 fut: fut.into_future(),
88 });
89
90 self.handles.push(JoinHandleWithId(handle));
91 self.to_cancel.push(handle_for_cancel);
92 AbortHandle { state }
93 }
94
95 pub fn abort_all(&self) {
97 self.to_cancel.iter().for_each(JoinHandle::abort);
98 }
99
100 pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
113 self.join_next_with_id()
114 .await
115 .map(|ret| ret.map(|(_id, out)| out))
116 }
117
118 pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
129 futures_lite::future::poll_fn(|cx| {
130 let ret = self.handles.poll_next(cx);
131 self.to_cancel.retain(JoinHandle::is_running);
133 ret
134 })
135 .await
136 }
137
138 pub fn is_empty(&self) -> bool {
141 self.handles.is_empty()
142 }
143
144 pub fn len(&self) -> usize {
147 self.handles.len()
148 }
149
150 pub async fn join_all(mut self) -> Vec<T> {
153 let mut output = Vec::new();
154 while let Some(res) = self.join_next().await {
155 match res {
156 Ok(t) => output.push(t),
157 Err(err) => panic!("{err}"),
158 }
159 }
160 output
161 }
162
163 pub async fn shutdown(&mut self) {
165 self.abort_all();
166 while let Some(_res) = self.join_next().await {}
167 }
168 }
169
170 impl<T> Drop for JoinSet<T> {
171 fn drop(&mut self) {
172 self.abort_all()
173 }
174 }
175
176 pub struct JoinHandle<T> {
178 task: Task<T>,
179 }
180
181 struct Task<T> {
182 state: SendWrapper<Rc<RefCell<State>>>,
190 result: SendWrapper<Rc<RefCell<Option<T>>>>,
191 }
192
193 impl<T> Clone for Task<T> {
194 fn clone(&self) -> Self {
195 Self {
196 state: self.state.clone(),
197 result: self.result.clone(),
198 }
199 }
200 }
201
202 #[derive(Debug)]
203 struct State {
204 id: Id,
205 cancelled: bool,
206 completed: bool,
207 waker_handler: Option<Waker>,
208 waker_spawn_fn: Option<Waker>,
209 }
210
211 impl State {
212 fn cancel(&mut self) {
213 if !self.cancelled {
214 self.cancelled = true;
215 self.wake();
216 }
217 }
218
219 fn complete(&mut self) {
220 self.completed = true;
221 self.wake();
222 }
223
224 fn is_complete(&self) -> bool {
225 self.completed || self.cancelled
226 }
227
228 fn wake(&mut self) {
229 if let Some(waker) = self.waker_handler.take() {
230 waker.wake();
231 }
232 if let Some(waker) = self.waker_spawn_fn.take() {
233 waker.wake();
234 }
235 }
236
237 fn register_handler(&mut self, cx: &mut Context<'_>) {
238 match self.waker_handler {
239 Some(ref mut waker) => waker.clone_from(cx.waker()),
241 None => self.waker_handler = Some(cx.waker().clone()),
242 }
243 }
244
245 fn register_spawn_fn(&mut self, cx: &mut Context<'_>) {
246 match self.waker_spawn_fn {
247 Some(ref mut waker) => waker.clone_from(cx.waker()),
249 None => self.waker_spawn_fn = Some(cx.waker().clone()),
250 }
251 }
252 }
253
254 impl<T> Debug for JoinHandle<T> {
255 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256 if self.task.state.valid() {
257 let state = self.task.state.borrow();
258 f.debug_struct("JoinHandle")
259 .field("id", &state.id)
260 .field("cancelled", &state.cancelled)
261 .field("completed", &state.completed)
262 .finish()
263 } else {
264 f.debug_tuple("JoinHandle")
265 .field(&format_args!("<other thread>"))
266 .finish()
267 }
268 }
269 }
270
271 impl<T> JoinHandle<T> {
272 fn new() -> Self {
273 Self {
274 task: Task {
275 state: SendWrapper::new(Rc::new(RefCell::new(State {
276 cancelled: false,
277 completed: false,
278 waker_handler: None,
279 waker_spawn_fn: None,
280 id: Id(next_task_id()),
281 }))),
282 result: SendWrapper::new(Rc::new(RefCell::new(None))),
283 },
284 }
285 }
286
287 pub fn abort(&self) {
289 self.task.state.borrow_mut().cancel();
290 }
291
292 pub fn abort_handle(&self) -> AbortHandle {
300 AbortHandle {
301 state: self.task.state.clone(),
302 }
303 }
304
305 pub fn id(&self) -> Id {
310 let state = self.task.state.borrow();
311 state.id
312 }
313
314 pub fn is_finished(&self) -> bool {
316 let state = self.task.state.borrow();
317 state.is_complete()
318 }
319
320 fn is_running(&self) -> bool {
321 !self.is_finished()
322 }
323 }
324
325 #[derive(derive_more::Display, Debug, Clone, Copy)]
327 #[display("{cause}")]
328 pub struct JoinError {
329 cause: JoinErrorCause,
330 id: Id,
331 }
332
333 #[derive(derive_more::Display, Debug, Clone, Copy)]
334 enum JoinErrorCause {
335 #[display("task was cancelled")]
338 Cancelled,
339 }
340
341 impl std::error::Error for JoinError {}
342
343 impl JoinError {
344 pub fn is_cancelled(&self) -> bool {
350 matches!(self.cause, JoinErrorCause::Cancelled)
351 }
352
353 pub fn is_panic(&self) -> bool {
357 false
358 }
359
360 pub fn id(&self) -> Id {
362 self.id
363 }
364 }
365
366 impl<T> Future for JoinHandle<T> {
367 type Output = Result<T, JoinError>;
368
369 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
370 let mut state = self.task.state.borrow_mut();
371 if state.cancelled {
372 return Poll::Ready(Err(JoinError {
373 cause: JoinErrorCause::Cancelled,
374 id: state.id,
375 }));
376 }
377
378 let mut result = self.task.result.borrow_mut();
379 if let Some(result) = result.take() {
380 return Poll::Ready(Ok(result));
381 }
382
383 state.register_handler(cx);
384 Poll::Pending
385 }
386 }
387
388 struct JoinHandleWithId<T>(JoinHandle<T>);
389
390 impl<T> Future for JoinHandleWithId<T> {
391 type Output = Result<(Id, T), JoinError>;
392
393 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
394 match self.0.poll(cx) {
395 Poll::Ready(out) => Poll::Ready(out.map(|out| (self.0.id(), out))),
396 Poll::Pending => Poll::Pending,
397 }
398 }
399 }
400
401 #[pin_project::pin_project]
402 struct SpawnFuture<Fut: Future<Output = T>, T> {
403 handle: JoinHandle<T>,
404 #[pin]
405 fut: Fut,
406 }
407
408 impl<Fut: Future<Output = T>, T> Future for SpawnFuture<Fut, T> {
409 type Output = ();
410
411 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
412 let this = self.project();
413 let mut state = this.handle.task.state.borrow_mut();
414
415 if state.cancelled {
416 return Poll::Ready(());
417 }
418
419 match this.fut.poll(cx) {
420 Poll::Ready(value) => {
421 let _ = this.handle.task.result.borrow_mut().insert(value);
422 state.complete();
423 Poll::Ready(())
424 }
425 Poll::Pending => {
426 state.register_spawn_fn(cx);
427 Poll::Pending
428 }
429 }
430 }
431 }
432
433 #[derive(Clone)]
435 pub struct AbortHandle {
436 state: SendWrapper<Rc<RefCell<State>>>,
437 }
438
439 impl Debug for AbortHandle {
440 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
441 if self.state.valid() {
442 let state = self.state.borrow();
443 f.debug_struct("AbortHandle")
444 .field("id", &state.id)
445 .field("cancelled", &state.cancelled)
446 .field("completed", &state.completed)
447 .finish()
448 } else {
449 f.debug_tuple("AbortHandle")
450 .field(&format_args!("<other thread>"))
451 .finish()
452 }
453 }
454 }
455
456 impl AbortHandle {
457 pub fn abort(&self) {
459 self.state.borrow_mut().cancel();
460 }
461
462 pub fn id(&self) -> Id {
467 self.state.borrow().id
468 }
469
470 pub fn is_finished(&self) -> bool {
472 let state = self.state.borrow();
473 state.cancelled && state.completed
474 }
475 }
476
477 #[pin_project::pin_project(PinnedDrop)]
480 #[derive(derive_more::Debug, derive_more::Deref)]
481 #[debug("AbortOnDropHandle")]
482 #[must_use = "Dropping the handle aborts the task immediately"]
483 pub struct AbortOnDropHandle<T>(#[pin] JoinHandle<T>);
484
485 #[pin_project::pinned_drop]
486 impl<T> PinnedDrop for AbortOnDropHandle<T> {
487 fn drop(self: Pin<&mut Self>) {
488 self.0.abort();
489 }
490 }
491
492 impl<T> Future for AbortOnDropHandle<T> {
493 type Output = <JoinHandle<T> as Future>::Output;
494
495 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
496 self.project().0.poll(cx)
497 }
498 }
499
500 impl<T> AbortOnDropHandle<T> {
501 pub fn new(task: JoinHandle<T>) -> Self {
503 Self(task)
504 }
505
506 pub fn abort_handle(&self) -> AbortHandle {
509 self.0.abort_handle()
510 }
511
512 pub fn abort(&self) {
515 self.0.abort()
516 }
517
518 pub fn is_finished(&self) -> bool {
521 self.0.is_finished()
522 }
523 }
524
525 pub fn spawn<T: 'static>(fut: impl IntoFuture<Output = T> + 'static) -> JoinHandle<T> {
529 let handle = JoinHandle::new();
530
531 wasm_bindgen_futures::spawn_local(SpawnFuture {
532 handle: JoinHandle {
533 task: handle.task.clone(),
534 },
535 fut: fut.into_future(),
536 });
537
538 handle
539 }
540}
541
542#[cfg(test)]
543mod test {
544 use std::time::Duration;
545
546 #[cfg(not(wasm_browser))]
547 use tokio::test;
548 #[cfg(wasm_browser)]
549 use wasm_bindgen_test::wasm_bindgen_test as test;
550
551 use crate::task;
552
553 #[test]
554 async fn task_abort() {
555 let h1 = task::spawn(async {
556 crate::time::sleep(Duration::from_millis(10)).await;
557 });
558 let h2 = task::spawn(async {
559 crate::time::sleep(Duration::from_millis(10)).await;
560 });
561 assert!(h1.id() != h2.id());
562
563 h1.abort();
564 assert!(h1.await.err().unwrap().is_cancelled());
565 assert!(h2.await.is_ok());
566 }
567
568 #[test]
569 async fn join_set_abort() {
570 let fut = || async { 22 };
571 let mut set = task::JoinSet::new();
572 let h1 = set.spawn(fut());
573 let h2 = set.spawn(fut());
574 assert!(h1.id() != h2.id());
575 h2.abort();
576
577 let mut has_err = false;
578 let mut has_ok = false;
579 while let Some(ret) = set.join_next_with_id().await {
580 match ret {
581 Err(err) => {
582 if !has_err {
583 assert!(err.is_cancelled());
584 has_err = true;
585 } else {
586 panic!()
587 }
588 }
589 Ok((id, out)) => {
590 if !has_ok {
591 assert_eq!(id, h1.id());
592 assert_eq!(out, 22);
593 has_ok = true;
594 } else {
595 panic!()
596 }
597 }
598 }
599 }
600 assert!(has_err);
601 assert!(has_ok);
602 }
603}