1#[cfg(not(wasm_browser))]
5pub use tokio::spawn;
6#[cfg(not(wasm_browser))]
7pub use tokio::task::{AbortHandle, Id, JoinError, JoinHandle, JoinSet};
8#[cfg(not(wasm_browser))]
9pub use tokio_util::task::AbortOnDropHandle;
10#[cfg(wasm_browser)]
11pub use wasm::*;
12
13#[cfg(wasm_browser)]
14mod wasm {
15 use std::{
16 cell::RefCell,
17 fmt::{self, Debug},
18 future::{Future, IntoFuture},
19 pin::Pin,
20 rc::Rc,
21 sync::Mutex,
22 task::{Context, Poll, Waker},
23 };
24
25 use futures_lite::{stream::StreamExt, FutureExt};
26 use send_wrapper::SendWrapper;
27
28 static TASK_ID_COUNTER: Mutex<u64> = Mutex::new(0);
29
30 fn next_task_id() -> u64 {
31 let mut counter = TASK_ID_COUNTER.lock().unwrap();
32 *counter += 1;
33 *counter
34 }
35
36 #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, derive_more::Display)]
38 pub struct Id(u64);
39
40 pub struct JoinSet<T> {
45 handles: futures_buffered::FuturesUnordered<JoinHandleWithId<T>>,
46 to_cancel: Vec<JoinHandle<T>>,
48 }
49
50 impl<T> Debug for JoinSet<T> {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 f.debug_struct("JoinSet").field("len", &self.len()).finish()
53 }
54 }
55
56 impl<T> Default for JoinSet<T> {
57 fn default() -> Self {
58 Self::new()
59 }
60 }
61
62 impl<T> JoinSet<T> {
63 pub fn new() -> Self {
65 Self {
66 handles: futures_buffered::FuturesUnordered::new(),
67 to_cancel: Vec::new(),
68 }
69 }
70
71 pub fn spawn(&mut self, fut: impl IntoFuture<Output = T> + 'static) -> AbortHandle
73 where
74 T: 'static,
75 {
76 let handle = JoinHandle::new();
77 let state = handle.task.state.clone();
78 let handle_for_spawn = JoinHandle {
79 task: handle.task.clone(),
80 };
81 let handle_for_cancel = JoinHandle {
82 task: handle.task.clone(),
83 };
84
85 wasm_bindgen_futures::spawn_local(SpawnFuture {
86 handle: handle_for_spawn,
87 fut: fut.into_future(),
88 });
89
90 self.handles.push(JoinHandleWithId(handle));
91 self.to_cancel.push(handle_for_cancel);
92 AbortHandle { state }
93 }
94
95 pub fn spawn_local(&mut self, fut: impl IntoFuture<Output = T> + 'static) -> AbortHandle
100 where
101 T: 'static,
102 {
103 self.spawn(fut)
104 }
105
106 pub fn abort_all(&self) {
108 self.to_cancel.iter().for_each(JoinHandle::abort);
109 }
110
111 pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
124 self.join_next_with_id()
125 .await
126 .map(|ret| ret.map(|(_id, out)| out))
127 }
128
129 pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
140 futures_lite::future::poll_fn(|cx| {
141 let ret = self.handles.poll_next(cx);
142 self.to_cancel.retain(JoinHandle::is_running);
144 ret
145 })
146 .await
147 }
148
149 pub fn is_empty(&self) -> bool {
152 self.handles.is_empty()
153 }
154
155 pub fn len(&self) -> usize {
158 self.handles.len()
159 }
160
161 pub async fn join_all(mut self) -> Vec<T> {
164 let mut output = Vec::new();
165 while let Some(res) = self.join_next().await {
166 match res {
167 Ok(t) => output.push(t),
168 Err(err) => panic!("{err}"),
169 }
170 }
171 output
172 }
173
174 pub async fn shutdown(&mut self) {
176 self.abort_all();
177 while let Some(_res) = self.join_next().await {}
178 }
179 }
180
181 impl<T> Drop for JoinSet<T> {
182 fn drop(&mut self) {
183 self.abort_all()
184 }
185 }
186
187 pub struct JoinHandle<T> {
189 task: Task<T>,
190 }
191
192 struct Task<T> {
193 state: SendWrapper<Rc<RefCell<State>>>,
201 result: SendWrapper<Rc<RefCell<Option<T>>>>,
202 }
203
204 impl<T> Clone for Task<T> {
205 fn clone(&self) -> Self {
206 Self {
207 state: self.state.clone(),
208 result: self.result.clone(),
209 }
210 }
211 }
212
213 #[derive(Debug)]
214 struct State {
215 id: Id,
216 cancelled: bool,
217 completed: bool,
218 waker_handler: Option<Waker>,
219 waker_spawn_fn: Option<Waker>,
220 }
221
222 impl State {
223 fn cancel(&mut self) {
224 if !self.cancelled {
225 self.cancelled = true;
226 self.wake();
227 }
228 }
229
230 fn complete(&mut self) {
231 self.completed = true;
232 self.wake();
233 }
234
235 fn is_complete(&self) -> bool {
236 self.completed || self.cancelled
237 }
238
239 fn wake(&mut self) {
240 if let Some(waker) = self.waker_handler.take() {
241 waker.wake();
242 }
243 if let Some(waker) = self.waker_spawn_fn.take() {
244 waker.wake();
245 }
246 }
247
248 fn register_handler(&mut self, cx: &mut Context<'_>) {
249 match self.waker_handler {
250 Some(ref mut waker) => waker.clone_from(cx.waker()),
252 None => self.waker_handler = Some(cx.waker().clone()),
253 }
254 }
255
256 fn register_spawn_fn(&mut self, cx: &mut Context<'_>) {
257 match self.waker_spawn_fn {
258 Some(ref mut waker) => waker.clone_from(cx.waker()),
260 None => self.waker_spawn_fn = Some(cx.waker().clone()),
261 }
262 }
263 }
264
265 impl<T> Debug for JoinHandle<T> {
266 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
267 if self.task.state.valid() {
268 let state = self.task.state.borrow();
269 f.debug_struct("JoinHandle")
270 .field("id", &state.id)
271 .field("cancelled", &state.cancelled)
272 .field("completed", &state.completed)
273 .finish()
274 } else {
275 f.debug_tuple("JoinHandle")
276 .field(&format_args!("<other thread>"))
277 .finish()
278 }
279 }
280 }
281
282 impl<T> JoinHandle<T> {
283 fn new() -> Self {
284 Self {
285 task: Task {
286 state: SendWrapper::new(Rc::new(RefCell::new(State {
287 cancelled: false,
288 completed: false,
289 waker_handler: None,
290 waker_spawn_fn: None,
291 id: Id(next_task_id()),
292 }))),
293 result: SendWrapper::new(Rc::new(RefCell::new(None))),
294 },
295 }
296 }
297
298 pub fn abort(&self) {
300 self.task.state.borrow_mut().cancel();
301 }
302
303 pub fn abort_handle(&self) -> AbortHandle {
311 AbortHandle {
312 state: self.task.state.clone(),
313 }
314 }
315
316 pub fn id(&self) -> Id {
321 let state = self.task.state.borrow();
322 state.id
323 }
324
325 pub fn is_finished(&self) -> bool {
327 let state = self.task.state.borrow();
328 state.is_complete()
329 }
330
331 fn is_running(&self) -> bool {
332 !self.is_finished()
333 }
334 }
335
336 #[derive(derive_more::Display, Debug, Clone, Copy)]
338 #[display("{cause}")]
339 pub struct JoinError {
340 cause: JoinErrorCause,
341 id: Id,
342 }
343
344 #[derive(derive_more::Display, Debug, Clone, Copy)]
345 enum JoinErrorCause {
346 #[display("task was cancelled")]
349 Cancelled,
350 }
351
352 impl std::error::Error for JoinError {}
353
354 impl JoinError {
355 pub fn is_cancelled(&self) -> bool {
361 matches!(self.cause, JoinErrorCause::Cancelled)
362 }
363
364 pub fn is_panic(&self) -> bool {
368 false
369 }
370
371 pub fn id(&self) -> Id {
373 self.id
374 }
375 }
376
377 impl<T> Future for JoinHandle<T> {
378 type Output = Result<T, JoinError>;
379
380 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
381 let mut state = self.task.state.borrow_mut();
382 if state.cancelled {
383 return Poll::Ready(Err(JoinError {
384 cause: JoinErrorCause::Cancelled,
385 id: state.id,
386 }));
387 }
388
389 let mut result = self.task.result.borrow_mut();
390 if let Some(result) = result.take() {
391 return Poll::Ready(Ok(result));
392 }
393
394 state.register_handler(cx);
395 Poll::Pending
396 }
397 }
398
399 struct JoinHandleWithId<T>(JoinHandle<T>);
400
401 impl<T> Future for JoinHandleWithId<T> {
402 type Output = Result<(Id, T), JoinError>;
403
404 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
405 match self.0.poll(cx) {
406 Poll::Ready(out) => Poll::Ready(out.map(|out| (self.0.id(), out))),
407 Poll::Pending => Poll::Pending,
408 }
409 }
410 }
411
412 #[pin_project::pin_project]
413 struct SpawnFuture<Fut: Future<Output = T>, T> {
414 handle: JoinHandle<T>,
415 #[pin]
416 fut: Fut,
417 }
418
419 impl<Fut: Future<Output = T>, T> Future for SpawnFuture<Fut, T> {
420 type Output = ();
421
422 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
423 let this = self.project();
424 let mut state = this.handle.task.state.borrow_mut();
425
426 if state.cancelled {
427 return Poll::Ready(());
428 }
429
430 match this.fut.poll(cx) {
431 Poll::Ready(value) => {
432 let _ = this.handle.task.result.borrow_mut().insert(value);
433 state.complete();
434 Poll::Ready(())
435 }
436 Poll::Pending => {
437 state.register_spawn_fn(cx);
438 Poll::Pending
439 }
440 }
441 }
442 }
443
444 #[derive(Clone)]
446 pub struct AbortHandle {
447 state: SendWrapper<Rc<RefCell<State>>>,
448 }
449
450 impl Debug for AbortHandle {
451 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
452 if self.state.valid() {
453 let state = self.state.borrow();
454 f.debug_struct("AbortHandle")
455 .field("id", &state.id)
456 .field("cancelled", &state.cancelled)
457 .field("completed", &state.completed)
458 .finish()
459 } else {
460 f.debug_tuple("AbortHandle")
461 .field(&format_args!("<other thread>"))
462 .finish()
463 }
464 }
465 }
466
467 impl AbortHandle {
468 pub fn abort(&self) {
470 self.state.borrow_mut().cancel();
471 }
472
473 pub fn id(&self) -> Id {
478 self.state.borrow().id
479 }
480
481 pub fn is_finished(&self) -> bool {
483 let state = self.state.borrow();
484 state.cancelled && state.completed
485 }
486 }
487
488 #[pin_project::pin_project(PinnedDrop)]
491 #[derive(derive_more::Debug, derive_more::Deref)]
492 #[debug("AbortOnDropHandle")]
493 #[must_use = "Dropping the handle aborts the task immediately"]
494 pub struct AbortOnDropHandle<T>(#[pin] JoinHandle<T>);
495
496 #[pin_project::pinned_drop]
497 impl<T> PinnedDrop for AbortOnDropHandle<T> {
498 fn drop(self: Pin<&mut Self>) {
499 self.0.abort();
500 }
501 }
502
503 impl<T> Future for AbortOnDropHandle<T> {
504 type Output = <JoinHandle<T> as Future>::Output;
505
506 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
507 self.project().0.poll(cx)
508 }
509 }
510
511 impl<T> AbortOnDropHandle<T> {
512 pub fn new(task: JoinHandle<T>) -> Self {
514 Self(task)
515 }
516
517 pub fn abort_handle(&self) -> AbortHandle {
520 self.0.abort_handle()
521 }
522
523 pub fn abort(&self) {
526 self.0.abort()
527 }
528
529 pub fn is_finished(&self) -> bool {
532 self.0.is_finished()
533 }
534 }
535
536 pub fn spawn<T: 'static>(fut: impl IntoFuture<Output = T> + 'static) -> JoinHandle<T> {
540 let handle = JoinHandle::new();
541
542 wasm_bindgen_futures::spawn_local(SpawnFuture {
543 handle: JoinHandle {
544 task: handle.task.clone(),
545 },
546 fut: fut.into_future(),
547 });
548
549 handle
550 }
551}
552
553#[cfg(test)]
554mod test {
555 use std::time::Duration;
556
557 #[cfg(not(wasm_browser))]
558 use tokio::test;
559 #[cfg(wasm_browser)]
560 use wasm_bindgen_test::wasm_bindgen_test as test;
561
562 use crate::task;
563
564 #[test]
565 async fn task_abort() {
566 let h1 = task::spawn(async {
567 crate::time::sleep(Duration::from_millis(10)).await;
568 });
569 let h2 = task::spawn(async {
570 crate::time::sleep(Duration::from_millis(10)).await;
571 });
572 assert!(h1.id() != h2.id());
573
574 h1.abort();
575 assert!(h1.await.err().unwrap().is_cancelled());
576 assert!(h2.await.is_ok());
577 }
578
579 #[test]
580 async fn join_set_abort() {
581 let fut = || async { 22 };
582 let mut set = task::JoinSet::new();
583 let h1 = set.spawn(fut());
584 let h2 = set.spawn(fut());
585 assert!(h1.id() != h2.id());
586 h2.abort();
587
588 let mut has_err = false;
589 let mut has_ok = false;
590 while let Some(ret) = set.join_next_with_id().await {
591 match ret {
592 Err(err) => {
593 if !has_err {
594 assert!(err.is_cancelled());
595 has_err = true;
596 } else {
597 panic!()
598 }
599 }
600 Ok((id, out)) => {
601 if !has_ok {
602 assert_eq!(id, h1.id());
603 assert_eq!(out, 22);
604 has_ok = true;
605 } else {
606 panic!()
607 }
608 }
609 }
610 }
611 assert!(has_err);
612 assert!(has_ok);
613 }
614}