use futures::channel::oneshot;
#[cfg(web)]
mod implementation {
pub use futures::future::AbortHandle;
use futures::{future, stream, StreamExt as _};
use super::*;
#[derive(Default)]
pub struct JoinSet(Vec<oneshot::Receiver<()>>);
pub trait JoinSetExt: Sized {
fn spawn_task<F: Future + 'static>(&mut self, future: F) -> TaskHandle<F::Output>;
fn await_all_tasks(&mut self) -> impl Future<Output = ()>;
fn reap_finished_tasks(&mut self);
}
impl JoinSetExt for JoinSet {
fn spawn_task<F: Future + 'static>(&mut self, future: F) -> TaskHandle<F::Output> {
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let (send_done, recv_done) = oneshot::channel();
let (send_output, recv_output) = oneshot::channel();
let future = async move {
send_output.send(future.await).ok();
send_done.send(()).ok();
};
self.0.push(recv_done);
wasm_bindgen_futures::spawn_local(
future::Abortable::new(future, abort_registration).map(drop),
);
TaskHandle {
output_receiver: recv_output,
abort_handle,
}
}
async fn await_all_tasks(&mut self) {
stream::iter(&mut self.0)
.then(|x| x)
.map(drop)
.collect()
.await
}
fn reap_finished_tasks(&mut self) {
self.0.retain_mut(|task| task.try_recv() == Ok(None));
}
}
}
#[cfg(not(web))]
mod implementation {
pub use tokio::task::AbortHandle;
use super::*;
pub type JoinSet = tokio::task::JoinSet<()>;
#[trait_variant::make(Send)]
pub trait JoinSetExt: Sized {
fn spawn_task<F: Future<Output: Send> + Send + 'static>(
&mut self,
future: F,
) -> TaskHandle<F::Output>;
async fn await_all_tasks(&mut self);
fn reap_finished_tasks(&mut self);
}
impl JoinSetExt for JoinSet {
fn spawn_task<F>(&mut self, future: F) -> TaskHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send,
{
let (output_sender, output_receiver) = oneshot::channel();
let abort_handle = self.spawn(async move {
output_sender.send(future.await).ok();
});
TaskHandle {
output_receiver,
abort_handle,
}
}
async fn await_all_tasks(&mut self) {
while self.join_next().await.is_some() {}
}
fn reap_finished_tasks(&mut self) {
while self.try_join_next().is_some() {}
}
}
}
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use futures::FutureExt as _;
pub use implementation::*;
pub struct TaskHandle<Output> {
output_receiver: oneshot::Receiver<Output>,
abort_handle: AbortHandle,
}
impl<Output> Future for TaskHandle<Output> {
type Output = Result<Output, oneshot::Canceled>;
fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
self.as_mut().output_receiver.poll_unpin(context)
}
}
impl<Output> TaskHandle<Output> {
pub fn abort(&self) {
self.abort_handle.abort();
}
pub fn is_running(&mut self) -> bool {
self.output_receiver.try_recv().is_err()
}
}