Skip to main content

linera_core/
join_set_ext.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4//! An extension trait to allow determining at compile time how tasks are spawned on the Tokio
5//! runtime.
6//!
7//! In most cases the [`Future`] task to be spawned should implement [`Send`], but that's
8//! not possible when compiling for the Web. In that case, the task is spawned on the
9//! browser event loop.
10
11use futures::channel::oneshot;
12
13#[cfg(web)]
14mod implementation {
15    pub use futures::future::AbortHandle;
16    use futures::{future, stream, StreamExt as _};
17
18    use super::*;
19
20    #[derive(Default)]
21    pub struct JoinSet(Vec<oneshot::Receiver<()>>);
22
23    /// An extension trait for the [`JoinSet`] type.
24    pub trait JoinSetExt: Sized {
25        /// Spawns a `future` task on this [`JoinSet`] using [`JoinSet::spawn_local`].
26        ///
27        /// Returns a [`oneshot::Receiver`] to receive the `future`'s output, and an
28        /// [`AbortHandle`] to cancel execution of the task.
29        fn spawn_task<F: Future + 'static>(&mut self, future: F) -> TaskHandle<F::Output>;
30
31        /// Awaits all tasks spawned in this [`JoinSet`].
32        fn await_all_tasks(&mut self) -> impl Future<Output = ()>;
33
34        /// Reaps tasks that have finished.
35        fn reap_finished_tasks(&mut self);
36    }
37
38    impl JoinSetExt for JoinSet {
39        fn spawn_task<F: Future + 'static>(&mut self, future: F) -> TaskHandle<F::Output> {
40            let (abort_handle, abort_registration) = AbortHandle::new_pair();
41            let (send_done, recv_done) = oneshot::channel();
42            let (send_output, recv_output) = oneshot::channel();
43            let future = async move {
44                // Receiver may have been dropped if the task was aborted.
45                send_output.send(future.await).ok();
46                send_done.send(()).ok();
47            };
48            self.0.push(recv_done);
49            wasm_bindgen_futures::spawn_local(
50                future::Abortable::new(future, abort_registration).map(drop),
51            );
52
53            TaskHandle {
54                output_receiver: recv_output,
55                abort_handle,
56            }
57        }
58
59        async fn await_all_tasks(&mut self) {
60            stream::iter(&mut self.0)
61                .then(|x| x)
62                .map(drop)
63                .collect()
64                .await
65        }
66
67        fn reap_finished_tasks(&mut self) {
68            self.0.retain_mut(|task| task.try_recv() == Ok(None));
69        }
70    }
71}
72
73#[cfg(not(web))]
74mod implementation {
75    pub use tokio::task::AbortHandle;
76
77    use super::*;
78
79    pub type JoinSet = tokio::task::JoinSet<()>;
80
81    /// An extension trait for the [`JoinSet`] type.
82    #[trait_variant::make(Send)]
83    pub trait JoinSetExt: Sized {
84        /// Spawns a `future` task on this [`JoinSet`] using [`JoinSet::spawn`].
85        ///
86        /// Returns a [`oneshot::Receiver`] to receive the `future`'s output, and an
87        /// [`AbortHandle`] to cancel execution of the task.
88        fn spawn_task<F: Future<Output: Send> + Send + 'static>(
89            &mut self,
90            future: F,
91        ) -> TaskHandle<F::Output>;
92
93        /// Awaits all tasks spawned in this [`JoinSet`].
94        async fn await_all_tasks(&mut self);
95
96        /// Reaps tasks that have finished.
97        fn reap_finished_tasks(&mut self);
98    }
99
100    impl JoinSetExt for JoinSet {
101        fn spawn_task<F>(&mut self, future: F) -> TaskHandle<F::Output>
102        where
103            F: Future + Send + 'static,
104            F::Output: Send,
105        {
106            let (output_sender, output_receiver) = oneshot::channel();
107
108            let abort_handle = self.spawn(async move {
109                // Receiver may have been dropped if the task was aborted.
110                output_sender.send(future.await).ok();
111            });
112
113            TaskHandle {
114                output_receiver,
115                abort_handle,
116            }
117        }
118
119        async fn await_all_tasks(&mut self) {
120            while self.join_next().await.is_some() {}
121        }
122
123        fn reap_finished_tasks(&mut self) {
124            while self.try_join_next().is_some() {}
125        }
126    }
127}
128
129use std::{
130    future::Future,
131    pin::Pin,
132    task::{Context, Poll},
133};
134
135use futures::FutureExt as _;
136pub use implementation::*;
137
138/// A handle to a task spawned with [`JoinSetExt`].
139///
140/// Dropping a handle detaches its respective task.
141pub struct TaskHandle<Output> {
142    output_receiver: oneshot::Receiver<Output>,
143    abort_handle: AbortHandle,
144}
145
146impl<Output> Future for TaskHandle<Output> {
147    type Output = Result<Output, oneshot::Canceled>;
148
149    fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
150        self.as_mut().output_receiver.poll_unpin(context)
151    }
152}
153
154impl<Output> TaskHandle<Output> {
155    /// Aborts the task.
156    pub fn abort(&self) {
157        self.abort_handle.abort();
158    }
159
160    /// Returns [`true`] if the task is still running.
161    pub fn is_running(&mut self) -> bool {
162        self.output_receiver.try_recv().is_err()
163    }
164}