1#[cfg(not(wasm_browser))]
5pub use tokio::spawn;
6#[cfg(not(wasm_browser))]
7pub use tokio::task::{JoinError, JoinHandle, JoinSet};
8#[cfg(not(wasm_browser))]
9pub use tokio_util::task::AbortOnDropHandle;
10
11#[cfg(wasm_browser)]
12pub use wasm::*;
13
14#[cfg(wasm_browser)]
15mod wasm {
16 use std::{
17 cell::RefCell,
18 fmt::Debug,
19 future::{Future, IntoFuture},
20 pin::Pin,
21 rc::Rc,
22 task::{Context, Poll, Waker},
23 };
24
25 use futures_lite::stream::StreamExt;
26 use send_wrapper::SendWrapper;
27
28 pub struct JoinSet<T> {
33 handles: futures_buffered::FuturesUnordered<JoinHandle<T>>,
34 to_cancel: Vec<JoinHandle<T>>,
36 }
37
38 impl<T> std::fmt::Debug for JoinSet<T> {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("JoinSet").field("len", &self.len()).finish()
41 }
42 }
43
44 impl<T> Default for JoinSet<T> {
45 fn default() -> Self {
46 Self::new()
47 }
48 }
49
50 impl<T> JoinSet<T> {
51 pub fn new() -> Self {
53 Self {
54 handles: futures_buffered::FuturesUnordered::new(),
55 to_cancel: Vec::new(),
56 }
57 }
58
59 pub fn spawn(&mut self, fut: impl IntoFuture<Output = T> + 'static)
63 where
64 T: 'static,
65 {
66 let handle = JoinHandle::new();
67 let handle_for_spawn = JoinHandle {
68 task: handle.task.clone(),
69 };
70 let handle_for_cancel = JoinHandle {
71 task: handle.task.clone(),
72 };
73
74 wasm_bindgen_futures::spawn_local(SpawnFuture {
75 handle: handle_for_spawn,
76 fut: fut.into_future(),
77 });
78
79 self.handles.push(handle);
80 self.to_cancel.push(handle_for_cancel);
81 }
82
83 pub fn abort_all(&self) {
85 self.to_cancel.iter().for_each(JoinHandle::abort);
86 }
87
88 pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
101 futures_lite::future::poll_fn(|cx| {
102 let ret = self.handles.poll_next(cx);
103 self.to_cancel.retain(JoinHandle::is_running);
105 ret
106 })
107 .await
108 }
109
110 pub fn is_empty(&self) -> bool {
113 self.handles.is_empty()
114 }
115
116 pub fn len(&self) -> usize {
119 self.handles.len()
120 }
121
122 pub async fn join_all(mut self) -> Vec<T> {
125 let mut output = Vec::new();
126 while let Some(res) = self.join_next().await {
127 match res {
128 Ok(t) => output.push(t),
129 Err(err) => panic!("{err}"),
130 }
131 }
132 output
133 }
134
135 pub async fn shutdown(&mut self) {
137 self.abort_all();
138 while let Some(_res) = self.join_next().await {}
139 }
140 }
141
142 impl<T> Drop for JoinSet<T> {
143 fn drop(&mut self) {
144 self.abort_all()
145 }
146 }
147
148 pub struct JoinHandle<T> {
150 task: SendWrapper<Rc<RefCell<Task<T>>>>,
158 }
159
160 struct Task<T> {
161 cancelled: bool,
162 completed: bool,
163 waker_handler: Option<Waker>,
164 waker_spawn_fn: Option<Waker>,
165 result: Option<T>,
166 }
167
168 impl<T> Task<T> {
169 fn cancel(&mut self) {
170 if !self.cancelled {
171 self.cancelled = true;
172 self.wake();
173 }
174 }
175
176 fn complete(&mut self, value: T) {
177 self.result = Some(value);
178 self.completed = true;
179 self.wake();
180 }
181
182 fn wake(&mut self) {
183 if let Some(waker) = self.waker_handler.take() {
184 waker.wake();
185 }
186 if let Some(waker) = self.waker_spawn_fn.take() {
187 waker.wake();
188 }
189 }
190
191 fn register_handler(&mut self, cx: &mut Context<'_>) {
192 match self.waker_handler {
193 Some(ref mut waker) => waker.clone_from(cx.waker()),
195 None => self.waker_handler = Some(cx.waker().clone()),
196 }
197 }
198
199 fn register_spawn_fn(&mut self, cx: &mut Context<'_>) {
200 match self.waker_spawn_fn {
201 Some(ref mut waker) => waker.clone_from(cx.waker()),
203 None => self.waker_spawn_fn = Some(cx.waker().clone()),
204 }
205 }
206 }
207
208 impl<T> Debug for JoinHandle<T> {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 if self.task.valid() {
211 let task = self.task.borrow();
212 let cancelled = task.cancelled;
213 let completed = task.completed;
214 f.debug_struct("JoinHandle")
215 .field("cancelled", &cancelled)
216 .field("completed", &completed)
217 .finish()
218 } else {
219 f.debug_tuple("JoinHandle")
220 .field(&format_args!("<other thread>"))
221 .finish()
222 }
223 }
224 }
225
226 impl<T> JoinHandle<T> {
227 fn new() -> Self {
228 Self {
229 task: SendWrapper::new(Rc::new(RefCell::new(Task {
230 cancelled: false,
231 completed: false,
232 waker_handler: None,
233 waker_spawn_fn: None,
234 result: None,
235 }))),
236 }
237 }
238
239 pub fn abort(&self) {
241 self.task.borrow_mut().cancel();
242 }
243
244 fn is_running(&self) -> bool {
245 let task = self.task.borrow();
246 !task.cancelled && !task.completed
247 }
248 }
249
250 #[derive(derive_more::Display, Debug, Clone, Copy)]
252 pub enum JoinError {
253 #[display("task was cancelled")]
256 Cancelled,
257 }
258
259 impl std::error::Error for JoinError {}
260
261 impl JoinError {
262 pub fn is_cancelled(&self) -> bool {
268 matches!(self, Self::Cancelled)
269 }
270
271 pub fn is_panic(&self) -> bool {
275 false
276 }
277 }
278
279 impl<T> Future for JoinHandle<T> {
280 type Output = Result<T, JoinError>;
281
282 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
283 let mut task = self.task.borrow_mut();
284 if task.cancelled {
285 return Poll::Ready(Err(JoinError::Cancelled));
286 }
287
288 if let Some(result) = task.result.take() {
289 return Poll::Ready(Ok(result));
290 }
291
292 task.register_handler(cx);
293 Poll::Pending
294 }
295 }
296
297 #[pin_project::pin_project]
298 struct SpawnFuture<Fut: Future<Output = T>, T> {
299 handle: JoinHandle<T>,
300 #[pin]
301 fut: Fut,
302 }
303
304 impl<Fut: Future<Output = T>, T> Future for SpawnFuture<Fut, T> {
305 type Output = ();
306
307 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
308 let this = self.project();
309 let mut task = this.handle.task.borrow_mut();
310
311 if task.cancelled {
312 return Poll::Ready(());
313 }
314
315 match this.fut.poll(cx) {
316 Poll::Ready(value) => {
317 task.complete(value);
318 Poll::Ready(())
319 }
320 Poll::Pending => {
321 task.register_spawn_fn(cx);
322 Poll::Pending
323 }
324 }
325 }
326 }
327
328 #[pin_project::pin_project(PinnedDrop)]
331 #[derive(derive_more::Debug)]
332 #[debug("AbortOnDropHandle")]
333 pub struct AbortOnDropHandle<T>(#[pin] JoinHandle<T>);
334
335 #[pin_project::pinned_drop]
336 impl<T> PinnedDrop for AbortOnDropHandle<T> {
337 fn drop(self: Pin<&mut Self>) {
338 self.0.abort();
339 }
340 }
341
342 impl<T> Future for AbortOnDropHandle<T> {
343 type Output = <JoinHandle<T> as Future>::Output;
344
345 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
346 self.project().0.poll(cx)
347 }
348 }
349
350 impl<T> AbortOnDropHandle<T> {
351 pub fn new(task: JoinHandle<T>) -> Self {
353 Self(task)
354 }
355 }
356
357 pub fn spawn<T: 'static>(fut: impl IntoFuture<Output = T> + 'static) -> JoinHandle<T> {
361 let handle = JoinHandle::new();
362
363 wasm_bindgen_futures::spawn_local(SpawnFuture {
364 handle: JoinHandle {
365 task: handle.task.clone(),
366 },
367 fut: fut.into_future(),
368 });
369
370 handle
371 }
372}
373
374#[cfg(test)]
375mod test {
376 }