use futures::channel::oneshot;
use futures::future::{AbortHandle, Abortable, Aborted};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum JoinError {
#[error("The task is cancelled")]
Canceled,
#[error("The task is aborted")]
Aborted,
}
pub struct JoinHandle<T> {
receiver: oneshot::Receiver<Result<T, Aborted>>,
abort_handle: AbortHandle,
}
impl<T> Future for JoinHandle<T> {
type Output = Result<T, JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.receiver).poll(cx) {
Poll::Ready(Ok(result)) => Poll::Ready(match result {
Ok(val) => Ok(val),
Err(_) => Err(JoinError::Aborted),
}),
Poll::Ready(Err(_)) => Poll::Ready(Err(JoinError::Canceled)),
Poll::Pending => Poll::Pending,
}
}
}
impl<T> JoinHandle<T> {
pub fn abort(&self) {
self.abort_handle.abort();
}
}
pub fn spawn_local<F>(future: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
let (sender, receiver) = oneshot::channel();
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let abortable_future = Abortable::new(future, abort_registration);
let wrapped_future = async move {
let res = abortable_future.await;
let _ = sender.send(res);
};
wasm_bindgen_futures::spawn_local(wrapped_future);
JoinHandle {
receiver,
abort_handle,
}
}
pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
spawn_local(future)
}
pub fn spawn_blocking<F, R>(blocking_func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + 'static,
R: 'static,
{
spawn_local(async move { blocking_func() })
}