use core::fmt;
use futures_util::stream::{AbortHandle, Abortable};
use futures_util::Future;
#[cfg(not(target_arch = "wasm32"))]
use tokio::task::JoinHandle as TokioJoinHandle;
#[cfg(target_arch = "wasm32")]
mod wasm;
#[cfg(not(target_arch = "wasm32"))]
use crate::runtime;
#[derive(Debug)]
pub enum Error {
JoinError,
}
impl std::error::Error for Error {}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::JoinError => write!(f, "impossible to join thread"),
}
}
}
pub enum JoinHandle<T> {
#[cfg(not(target_arch = "wasm32"))]
Tokio(TokioJoinHandle<T>),
#[cfg(target_arch = "wasm32")]
Wasm(self::wasm::JoinHandle<T>),
}
impl<T> JoinHandle<T> {
pub async fn join(self) -> Result<T, Error> {
match self {
#[cfg(not(target_arch = "wasm32"))]
Self::Tokio(handle) => handle.await.map_err(|_| Error::JoinError),
#[cfg(target_arch = "wasm32")]
Self::Wasm(handle) => handle.join().await.map_err(|_| Error::JoinError),
}
}
}
#[inline]
#[cfg(not(target_arch = "wasm32"))]
pub fn spawn<T>(future: T) -> JoinHandle<T::Output>
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
JoinHandle::Tokio(runtime::handle().spawn(future))
}
#[cfg(target_arch = "wasm32")]
pub fn spawn<T>(future: T) -> JoinHandle<T::Output>
where
T: Future + 'static,
{
let handle = self::wasm::spawn(future);
JoinHandle::Wasm(handle)
}
#[cfg(not(target_arch = "wasm32"))]
pub fn abortable<T>(future: T) -> AbortHandle
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let _ = spawn(Abortable::new(future, abort_registration));
abort_handle
}
#[cfg(target_arch = "wasm32")]
pub fn abortable<T>(future: T) -> AbortHandle
where
T: Future + 'static,
{
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let _ = spawn(Abortable::new(future, abort_registration));
abort_handle
}
#[inline]
#[cfg(not(target_arch = "wasm32"))]
pub fn spawn_blocking<F, R>(f: F) -> TokioJoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
runtime::handle().spawn_blocking(f)
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
use crate::time;
#[tokio::test]
#[cfg(not(target_arch = "wasm32"))]
async fn test_is_tokio_context_macros() {
assert!(runtime::is_tokio_context());
}
#[async_std::test]
#[cfg(not(target_arch = "wasm32"))]
async fn test_is_tokio_context_in_async_std() {
let handle = runtime::handle();
let _guard = handle.enter();
assert!(runtime::is_tokio_context());
}
#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_is_tokio_context_once_lock() {
let handle = runtime::handle();
let _guard = handle.enter();
assert!(runtime::is_tokio_context());
}
#[tokio::test]
#[cfg(not(target_arch = "wasm32"))]
async fn test_spawn() {
let future = async {
time::sleep(Duration::from_secs(5)).await;
42
};
let handle = spawn(future);
let result = handle.join().await.unwrap();
assert_eq!(result, 42);
}
#[async_std::test]
#[cfg(not(target_arch = "wasm32"))]
async fn test_spawn_in_async_std() {
let future = async {
time::sleep(Duration::from_secs(5)).await;
42
};
let handle = spawn(future);
let result = handle.join().await.unwrap();
assert_eq!(result, 42);
}
#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_spawn_in_smol() {
smol::block_on(async {
let future = async {
time::sleep(Duration::from_secs(5)).await;
42
};
let handle = spawn(future);
let result = handle.join().await.unwrap();
assert_eq!(result, 42);
});
}
#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_spawn_outside_tokio_ctx() {
let future = async {
time::sleep(Duration::from_secs(5)).await;
42
};
let _handle = spawn(future);
}
#[tokio::test]
#[cfg(not(target_arch = "wasm32"))]
async fn test_spawn_blocking() {
let handle = spawn_blocking(|| 42);
let result = handle.await.unwrap();
assert_eq!(result, 42);
}
#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_spawn_blocking_outside_tokio_ctx() {
let _handle = spawn_blocking(|| 42);
}
#[tokio::test]
#[cfg(not(target_arch = "wasm32"))]
async fn test_abortable() {
let future = async {
time::sleep(Duration::from_secs(1)).await;
42
};
let abort_handle = abortable(future);
abort_handle.abort();
assert!(abort_handle.is_aborted());
}
}