1use anyhow::Result;
2use std::future::Future;
3
4#[cfg(target_arch = "wasm32")]
5use std::pin::Pin;
6
7#[cfg(target_arch = "wasm32")]
8use tokio::sync::oneshot::channel;
9
10#[cfg(target_arch = "wasm32")]
11use futures::future::join_all;
12
13#[cfg(not(target_arch = "wasm32"))]
14use tokio::task::JoinSet;
15
16#[cfg(target_arch = "wasm32")]
17pub async fn spawn<F>(future: F) -> Result<F::Output>
20where
21 F: Future + 'static,
22 F::Output: Send + 'static,
23{
24 let (tx, rx) = channel();
25
26 wasm_bindgen_futures::spawn_local(async move {
27 if let Err(_) = tx.send(future.await) {
28 warn!("Receiver dropped before spawned task completed");
29 }
30 });
31
32 Ok(rx.await?)
33}
34
35#[cfg(not(target_arch = "wasm32"))]
36pub async fn spawn<F>(future: F) -> Result<F::Output>
39where
40 F: Future + Send + 'static,
41 F::Output: Send + 'static,
42{
43 Ok(tokio::spawn(future).await?)
44}
45
46#[derive(Default)]
71pub struct TaskQueue {
72 #[cfg(not(target_arch = "wasm32"))]
73 tasks: JoinSet<Result<()>>,
74
75 #[cfg(target_arch = "wasm32")]
76 tasks: Vec<Pin<Box<dyn Future<Output = ()>>>>,
77}
78
79impl TaskQueue {
80 #[cfg(not(target_arch = "wasm32"))]
81 pub fn spawn<F>(&mut self, future: F)
84 where
85 F: Future<Output = Result<()>> + Send + 'static,
86 {
87 self.tasks.spawn(future);
88 }
89
90 #[cfg(not(target_arch = "wasm32"))]
91 pub async fn join(&mut self) -> Result<()> {
93 while let Some(result) = self.tasks.join_next().await {
94 trace!("Task completed, {} remaining in queue...", self.tasks.len());
95 result??;
96 }
97 Ok(())
98 }
99
100 #[cfg(target_arch = "wasm32")]
101 pub fn spawn<F>(&mut self, future: F)
104 where
105 F: Future<Output = Result<()>> + 'static,
106 {
107 let task_count = self.tasks.len();
108
109 self.tasks.push(Box::pin(async move {
110 if let Err(error) = spawn(future).await {
111 error!("Queued task failed: {:?}", error);
112 }
113 trace!("Task {} completed...", task_count + 1);
114 }));
115 }
116
117 #[cfg(target_arch = "wasm32")]
118 pub async fn join(&mut self) -> Result<()> {
120 let tasks = std::mem::replace(&mut self.tasks, Vec::new());
121
122 debug!("Joining {} queued tasks...", tasks.len());
123
124 join_all(tasks).await;
125
126 Ok(())
127 }
128}