linera_core/
join_set_ext.rs1use futures::channel::oneshot;
12
13#[cfg(web)]
14mod implementation {
15 pub use futures::future::AbortHandle;
16 use futures::{future, stream, StreamExt as _};
17
18 use super::*;
19
20 #[derive(Default)]
21 pub struct JoinSet(Vec<oneshot::Receiver<()>>);
22
23 pub trait JoinSetExt: Sized {
25 fn spawn_task<F: Future + 'static>(&mut self, future: F) -> TaskHandle<F::Output>;
30
31 fn await_all_tasks(&mut self) -> impl Future<Output = ()>;
33
34 fn reap_finished_tasks(&mut self);
36 }
37
38 impl JoinSetExt for JoinSet {
39 fn spawn_task<F: Future + 'static>(&mut self, future: F) -> TaskHandle<F::Output> {
40 let (abort_handle, abort_registration) = AbortHandle::new_pair();
41 let (send_done, recv_done) = oneshot::channel();
42 let (send_output, recv_output) = oneshot::channel();
43 let future = async move {
44 send_output.send(future.await).ok();
46 send_done.send(()).ok();
47 };
48 self.0.push(recv_done);
49 wasm_bindgen_futures::spawn_local(
50 future::Abortable::new(future, abort_registration).map(drop),
51 );
52
53 TaskHandle {
54 output_receiver: recv_output,
55 abort_handle,
56 }
57 }
58
59 async fn await_all_tasks(&mut self) {
60 stream::iter(&mut self.0)
61 .then(|x| x)
62 .map(drop)
63 .collect()
64 .await
65 }
66
67 fn reap_finished_tasks(&mut self) {
68 self.0.retain_mut(|task| task.try_recv() == Ok(None));
69 }
70 }
71}
72
73#[cfg(not(web))]
74mod implementation {
75 pub use tokio::task::AbortHandle;
76
77 use super::*;
78
79 pub type JoinSet = tokio::task::JoinSet<()>;
80
81 #[trait_variant::make(Send)]
83 pub trait JoinSetExt: Sized {
84 fn spawn_task<F: Future<Output: Send> + Send + 'static>(
89 &mut self,
90 future: F,
91 ) -> TaskHandle<F::Output>;
92
93 async fn await_all_tasks(&mut self);
95
96 fn reap_finished_tasks(&mut self);
98 }
99
100 impl JoinSetExt for JoinSet {
101 fn spawn_task<F>(&mut self, future: F) -> TaskHandle<F::Output>
102 where
103 F: Future + Send + 'static,
104 F::Output: Send,
105 {
106 let (output_sender, output_receiver) = oneshot::channel();
107
108 let abort_handle = self.spawn(async move {
109 output_sender.send(future.await).ok();
111 });
112
113 TaskHandle {
114 output_receiver,
115 abort_handle,
116 }
117 }
118
119 async fn await_all_tasks(&mut self) {
120 while self.join_next().await.is_some() {}
121 }
122
123 fn reap_finished_tasks(&mut self) {
124 while self.try_join_next().is_some() {}
125 }
126 }
127}
128
129use std::{
130 future::Future,
131 pin::Pin,
132 task::{Context, Poll},
133};
134
135use futures::FutureExt as _;
136pub use implementation::*;
137
138pub struct TaskHandle<Output> {
142 output_receiver: oneshot::Receiver<Output>,
143 abort_handle: AbortHandle,
144}
145
146impl<Output> Future for TaskHandle<Output> {
147 type Output = Result<Output, oneshot::Canceled>;
148
149 fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
150 self.as_mut().output_receiver.poll_unpin(context)
151 }
152}
153
154impl<Output> TaskHandle<Output> {
155 pub fn abort(&self) {
157 self.abort_handle.abort();
158 }
159
160 pub fn is_running(&mut self) -> bool {
162 self.output_receiver.try_recv().is_err()
163 }
164}