use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
#[cfg(not(target_family = "wasm"))]
mod sys {
pub use tokio::{
runtime::{Handle, Runtime},
task::{AbortHandle, JoinError, JoinHandle, spawn},
};
}
#[cfg(target_family = "wasm")]
mod sys {
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
pub use futures_util::future::AbortHandle;
use futures_util::{
FutureExt,
future::{Abortable, RemoteHandle},
};
#[derive(Debug)]
pub enum JoinError {
Cancelled,
Panic,
}
impl JoinError {
pub fn is_cancelled(&self) -> bool {
matches!(self, JoinError::Cancelled)
}
}
impl std::fmt::Display for JoinError {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self {
JoinError::Cancelled => write!(fmt, "task was cancelled"),
JoinError::Panic => write!(fmt, "task panicked"),
}
}
}
#[derive(Debug)]
pub struct JoinHandle<T> {
remote_handle: Option<RemoteHandle<T>>,
abort_handle: AbortHandle,
}
impl<T> JoinHandle<T> {
pub fn abort(&self) {
self.abort_handle.abort();
}
pub fn abort_handle(&self) -> AbortHandle {
self.abort_handle.clone()
}
pub fn is_finished(&self) -> bool {
self.abort_handle.is_aborted()
}
}
impl<T> Drop for JoinHandle<T> {
fn drop(&mut self) {
if let Some(h) = self.remote_handle.take() {
h.forget();
}
}
}
impl<T: 'static> Future for JoinHandle<T> {
type Output = Result<T, JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.abort_handle.is_aborted() {
Poll::Ready(Err(JoinError::Cancelled))
} else if let Some(handle) = self.remote_handle.as_mut() {
Pin::new(handle).poll(cx).map(Ok)
} else {
Poll::Ready(Err(JoinError::Panic))
}
}
}
pub fn spawn<F, T>(future: F) -> JoinHandle<T>
where
F: Future<Output = T> + 'static,
{
let (future, remote_handle) = future.remote_handle();
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let future = Abortable::new(future, abort_registration);
wasm_bindgen_futures::spawn_local(async {
let _ = future.await;
});
JoinHandle { remote_handle: Some(remote_handle), abort_handle }
}
}
pub use sys::*;
#[derive(Debug)]
pub struct AbortOnDrop<T>(JoinHandle<T>);
impl<T> AbortOnDrop<T> {
pub fn new(join_handle: JoinHandle<T>) -> Self {
Self(join_handle)
}
}
impl<T> Drop for AbortOnDrop<T> {
fn drop(&mut self) {
self.0.abort();
}
}
impl<T: 'static> Future for AbortOnDrop<T> {
type Output = Result<T, JoinError>;
fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(context)
}
}
pub trait JoinHandleExt<T> {
fn abort_on_drop(self) -> AbortOnDrop<T>;
}
impl<T> JoinHandleExt<T> for JoinHandle<T> {
fn abort_on_drop(self) -> AbortOnDrop<T> {
AbortOnDrop::new(self)
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use matrix_sdk_test_macros::async_test;
use super::spawn;
#[async_test]
async fn test_spawn() {
let future = async { 42 };
let join_handle = spawn(future);
assert_matches!(join_handle.await, Ok(42));
}
#[async_test]
async fn test_abort() {
let future = async { 42 };
let join_handle = spawn(future);
join_handle.abort();
assert!(join_handle.await.is_err());
}
}