1#[cfg(not(wasm_browser))]
5pub use tokio::spawn;
6#[cfg(not(wasm_browser))]
7pub use tokio::task::{JoinError, JoinHandle, JoinSet};
8#[cfg(not(wasm_browser))]
9pub use tokio_util::task::AbortOnDropHandle;
10#[cfg(wasm_browser)]
11pub use wasm::*;
12
13#[cfg(wasm_browser)]
14mod wasm {
15 use std::{
16 cell::RefCell,
17 fmt::Debug,
18 future::{Future, IntoFuture},
19 pin::Pin,
20 rc::Rc,
21 task::{Context, Poll, Waker},
22 };
23
24 use futures_lite::stream::StreamExt;
25 use send_wrapper::SendWrapper;
26
27 pub struct JoinSet<T> {
32 handles: futures_buffered::FuturesUnordered<JoinHandle<T>>,
33 to_cancel: Vec<JoinHandle<T>>,
35 }
36
37 impl<T> std::fmt::Debug for JoinSet<T> {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 f.debug_struct("JoinSet").field("len", &self.len()).finish()
40 }
41 }
42
43 impl<T> Default for JoinSet<T> {
44 fn default() -> Self {
45 Self::new()
46 }
47 }
48
49 impl<T> JoinSet<T> {
50 pub fn new() -> Self {
52 Self {
53 handles: futures_buffered::FuturesUnordered::new(),
54 to_cancel: Vec::new(),
55 }
56 }
57
58 pub fn spawn(&mut self, fut: impl IntoFuture<Output = T> + 'static)
62 where
63 T: 'static,
64 {
65 let handle = JoinHandle::new();
66 let handle_for_spawn = JoinHandle {
67 task: handle.task.clone(),
68 };
69 let handle_for_cancel = JoinHandle {
70 task: handle.task.clone(),
71 };
72
73 wasm_bindgen_futures::spawn_local(SpawnFuture {
74 handle: handle_for_spawn,
75 fut: fut.into_future(),
76 });
77
78 self.handles.push(handle);
79 self.to_cancel.push(handle_for_cancel);
80 }
81
82 pub fn abort_all(&self) {
84 self.to_cancel.iter().for_each(JoinHandle::abort);
85 }
86
87 pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
100 futures_lite::future::poll_fn(|cx| {
101 let ret = self.handles.poll_next(cx);
102 self.to_cancel.retain(JoinHandle::is_running);
104 ret
105 })
106 .await
107 }
108
109 pub fn is_empty(&self) -> bool {
112 self.handles.is_empty()
113 }
114
115 pub fn len(&self) -> usize {
118 self.handles.len()
119 }
120
121 pub async fn join_all(mut self) -> Vec<T> {
124 let mut output = Vec::new();
125 while let Some(res) = self.join_next().await {
126 match res {
127 Ok(t) => output.push(t),
128 Err(err) => panic!("{err}"),
129 }
130 }
131 output
132 }
133
134 pub async fn shutdown(&mut self) {
136 self.abort_all();
137 while let Some(_res) = self.join_next().await {}
138 }
139 }
140
141 impl<T> Drop for JoinSet<T> {
142 fn drop(&mut self) {
143 self.abort_all()
144 }
145 }
146
147 pub struct JoinHandle<T> {
149 task: SendWrapper<Rc<RefCell<Task<T>>>>,
157 }
158
159 struct Task<T> {
160 cancelled: bool,
161 completed: bool,
162 waker_handler: Option<Waker>,
163 waker_spawn_fn: Option<Waker>,
164 result: Option<T>,
165 }
166
167 impl<T> Task<T> {
168 fn cancel(&mut self) {
169 if !self.cancelled {
170 self.cancelled = true;
171 self.wake();
172 }
173 }
174
175 fn complete(&mut self, value: T) {
176 self.result = Some(value);
177 self.completed = true;
178 self.wake();
179 }
180
181 fn wake(&mut self) {
182 if let Some(waker) = self.waker_handler.take() {
183 waker.wake();
184 }
185 if let Some(waker) = self.waker_spawn_fn.take() {
186 waker.wake();
187 }
188 }
189
190 fn register_handler(&mut self, cx: &mut Context<'_>) {
191 match self.waker_handler {
192 Some(ref mut waker) => waker.clone_from(cx.waker()),
194 None => self.waker_handler = Some(cx.waker().clone()),
195 }
196 }
197
198 fn register_spawn_fn(&mut self, cx: &mut Context<'_>) {
199 match self.waker_spawn_fn {
200 Some(ref mut waker) => waker.clone_from(cx.waker()),
202 None => self.waker_spawn_fn = Some(cx.waker().clone()),
203 }
204 }
205 }
206
207 impl<T> Debug for JoinHandle<T> {
208 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 if self.task.valid() {
210 let task = self.task.borrow();
211 let cancelled = task.cancelled;
212 let completed = task.completed;
213 f.debug_struct("JoinHandle")
214 .field("cancelled", &cancelled)
215 .field("completed", &completed)
216 .finish()
217 } else {
218 f.debug_tuple("JoinHandle")
219 .field(&format_args!("<other thread>"))
220 .finish()
221 }
222 }
223 }
224
225 impl<T> JoinHandle<T> {
226 fn new() -> Self {
227 Self {
228 task: SendWrapper::new(Rc::new(RefCell::new(Task {
229 cancelled: false,
230 completed: false,
231 waker_handler: None,
232 waker_spawn_fn: None,
233 result: None,
234 }))),
235 }
236 }
237
238 pub fn abort(&self) {
240 self.task.borrow_mut().cancel();
241 }
242
243 fn is_running(&self) -> bool {
244 let task = self.task.borrow();
245 !task.cancelled && !task.completed
246 }
247 }
248
249 #[derive(derive_more::Display, Debug, Clone, Copy)]
251 pub enum JoinError {
252 #[display("task was cancelled")]
255 Cancelled,
256 }
257
258 impl std::error::Error for JoinError {}
259
260 impl JoinError {
261 pub fn is_cancelled(&self) -> bool {
267 matches!(self, Self::Cancelled)
268 }
269
270 pub fn is_panic(&self) -> bool {
274 false
275 }
276 }
277
278 impl<T> Future for JoinHandle<T> {
279 type Output = Result<T, JoinError>;
280
281 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
282 let mut task = self.task.borrow_mut();
283 if task.cancelled {
284 return Poll::Ready(Err(JoinError::Cancelled));
285 }
286
287 if let Some(result) = task.result.take() {
288 return Poll::Ready(Ok(result));
289 }
290
291 task.register_handler(cx);
292 Poll::Pending
293 }
294 }
295
296 #[pin_project::pin_project]
297 struct SpawnFuture<Fut: Future<Output = T>, T> {
298 handle: JoinHandle<T>,
299 #[pin]
300 fut: Fut,
301 }
302
303 impl<Fut: Future<Output = T>, T> Future for SpawnFuture<Fut, T> {
304 type Output = ();
305
306 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
307 let this = self.project();
308 let mut task = this.handle.task.borrow_mut();
309
310 if task.cancelled {
311 return Poll::Ready(());
312 }
313
314 match this.fut.poll(cx) {
315 Poll::Ready(value) => {
316 task.complete(value);
317 Poll::Ready(())
318 }
319 Poll::Pending => {
320 task.register_spawn_fn(cx);
321 Poll::Pending
322 }
323 }
324 }
325 }
326
327 #[pin_project::pin_project(PinnedDrop)]
330 #[derive(derive_more::Debug)]
331 #[debug("AbortOnDropHandle")]
332 pub struct AbortOnDropHandle<T>(#[pin] JoinHandle<T>);
333
334 #[pin_project::pinned_drop]
335 impl<T> PinnedDrop for AbortOnDropHandle<T> {
336 fn drop(self: Pin<&mut Self>) {
337 self.0.abort();
338 }
339 }
340
341 impl<T> Future for AbortOnDropHandle<T> {
342 type Output = <JoinHandle<T> as Future>::Output;
343
344 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
345 self.project().0.poll(cx)
346 }
347 }
348
349 impl<T> AbortOnDropHandle<T> {
350 pub fn new(task: JoinHandle<T>) -> Self {
352 Self(task)
353 }
354 }
355
356 pub fn spawn<T: 'static>(fut: impl IntoFuture<Output = T> + 'static) -> JoinHandle<T> {
360 let handle = JoinHandle::new();
361
362 wasm_bindgen_futures::spawn_local(SpawnFuture {
363 handle: JoinHandle {
364 task: handle.task.clone(),
365 },
366 fut: fut.into_future(),
367 });
368
369 handle
370 }
371}
372
373#[cfg(test)]
374mod test {
375 }