noosphere_common/
task.rs

1use anyhow::Result;
2use std::future::Future;
3
4#[cfg(target_arch = "wasm32")]
5use std::pin::Pin;
6
7#[cfg(target_arch = "wasm32")]
8use tokio::sync::oneshot::channel;
9
10#[cfg(target_arch = "wasm32")]
11use futures::future::join_all;
12
13#[cfg(not(target_arch = "wasm32"))]
14use tokio::task::JoinSet;
15
16#[cfg(target_arch = "wasm32")]
17/// Spawn a future by scheduling it with the local executor. The returned
18/// future will be pending until the spawned future completes.
19pub async fn spawn<F>(future: F) -> Result<F::Output>
20where
21    F: Future + 'static,
22    F::Output: Send + 'static,
23{
24    let (tx, rx) = channel();
25
26    wasm_bindgen_futures::spawn_local(async move {
27        if let Err(_) = tx.send(future.await) {
28            warn!("Receiver dropped before spawned task completed");
29        }
30    });
31
32    Ok(rx.await?)
33}
34
35#[cfg(not(target_arch = "wasm32"))]
36/// Spawn a future by scheduling it with the local executor. The returned
37/// future will be pending until the spawned future completes.
38pub async fn spawn<F>(future: F) -> Result<F::Output>
39where
40    F: Future + Send + 'static,
41    F::Output: Send + 'static,
42{
43    Ok(tokio::spawn(future).await?)
44}
45
46/// An aggregator of async work that can be used to observe the moment when all
47/// the aggregated work is completed. It is similar to tokio's [JoinSet], but is
48/// relatively constrained and also works on `wasm32-unknown-unknown`. Unlike
49/// [JoinSet], the results can not be observed individually.
50///
51/// ```rust
52/// # use anyhow::Result;
53/// # use noosphere_common::TaskQueue;
54/// #
55/// # #[tokio::main(flavor = "multi_thread")]
56/// # async fn main() -> Result<()> {
57/// #
58/// let mut task_queue = TaskQueue::default();
59/// for i in 0..10 {
60///     task_queue.spawn(async move {
61///         println!("{}", i);
62///         Ok(())
63///     });
64/// }
65/// task_queue.join().await?;
66/// #
67/// #   Ok(())
68/// # }
69/// ```
70#[derive(Default)]
71pub struct TaskQueue {
72    #[cfg(not(target_arch = "wasm32"))]
73    tasks: JoinSet<Result<()>>,
74
75    #[cfg(target_arch = "wasm32")]
76    tasks: Vec<Pin<Box<dyn Future<Output = ()>>>>,
77}
78
79impl TaskQueue {
80    #[cfg(not(target_arch = "wasm32"))]
81    /// Queue a future to be spawned in the local executor. All queued futures will be polled
82    /// to completion before the [TaskQueue] can be joined.
83    pub fn spawn<F>(&mut self, future: F)
84    where
85        F: Future<Output = Result<()>> + Send + 'static,
86    {
87        self.tasks.spawn(future);
88    }
89
90    #[cfg(not(target_arch = "wasm32"))]
91    /// Returns a future that finishes when all queued futures have finished.
92    pub async fn join(&mut self) -> Result<()> {
93        while let Some(result) = self.tasks.join_next().await {
94            trace!("Task completed, {} remaining in queue...", self.tasks.len());
95            result??;
96        }
97        Ok(())
98    }
99
100    #[cfg(target_arch = "wasm32")]
101    /// Queue a future to be spawned in the local executor. All queued futures will be polled
102    /// to completion before the [TaskQueue] can be joined.
103    pub fn spawn<F>(&mut self, future: F)
104    where
105        F: Future<Output = Result<()>> + 'static,
106    {
107        let task_count = self.tasks.len();
108
109        self.tasks.push(Box::pin(async move {
110            if let Err(error) = spawn(future).await {
111                error!("Queued task failed: {:?}", error);
112            }
113            trace!("Task {} completed...", task_count + 1);
114        }));
115    }
116
117    #[cfg(target_arch = "wasm32")]
118    /// Returns a future that finishes when all queued futures have finished.
119    pub async fn join(&mut self) -> Result<()> {
120        let tasks = std::mem::replace(&mut self.tasks, Vec::new());
121
122        debug!("Joining {} queued tasks...", tasks.len());
123
124        join_all(tasks).await;
125
126        Ok(())
127    }
128}