#[cfg(not(target_arch = "wasm32"))]
#[allow(dead_code, unused_imports)]
mod imp {
pub use std::time::Instant;
pub use tokio::task::{JoinError, JoinHandle, JoinSet};
pub use tokio::time::sleep;
pub use tokio::time::timeout;
pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
where
F: std::future::Future + Send + 'static,
F::Output: Send + 'static,
{
tokio::spawn(fut)
}
}
#[cfg(target_arch = "wasm32")]
#[allow(dead_code, unused_imports)]
mod imp {
pub use web_time::Instant;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_util::future::LocalBoxFuture;
use futures_util::stream::{FuturesUnordered, StreamExt};
use tokio::sync::oneshot;
#[derive(Debug, thiserror::Error)]
#[error("task was aborted")]
pub struct JoinError;
pub struct JoinHandle<T> {
rx: oneshot::Receiver<Result<T, JoinError>>,
abort_tx: Option<oneshot::Sender<()>>,
}
impl<T> JoinHandle<T> {
pub fn abort(&mut self) {
if let Some(tx) = self.abort_tx.take() {
let _ = tx.send(());
}
}
}
impl<T> Future for JoinHandle<T> {
type Output = Result<T, JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.rx).poll(cx) {
Poll::Ready(Ok(res)) => Poll::Ready(res),
Poll::Ready(Err(_)) => Poll::Ready(Err(JoinError)),
Poll::Pending => Poll::Pending,
}
}
}
pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
let (out_tx, out_rx) = oneshot::channel();
let (abort_tx, abort_rx) = oneshot::channel();
wasm_bindgen_futures::spawn_local(async move {
tokio::select! {
v = fut => {
let _ = out_tx.send(Ok(v));
}
_ = abort_rx => {
let _ = out_tx.send(Err(JoinError));
}
}
});
JoinHandle {
rx: out_rx,
abort_tx: Some(abort_tx),
}
}
pub struct JoinSet<T> {
inner: FuturesUnordered<LocalBoxFuture<'static, Result<T, JoinError>>>,
}
impl<T: 'static> JoinSet<T> {
#[must_use]
pub fn new() -> Self {
Self {
inner: FuturesUnordered::new(),
}
}
pub fn spawn<F>(&mut self, fut: F)
where
F: Future<Output = T> + 'static,
{
let (tx, rx) = oneshot::channel::<T>();
wasm_bindgen_futures::spawn_local(async move {
let v = fut.await;
let _ = tx.send(v);
});
self.inner.push(Box::pin(async move {
match rx.await {
Ok(v) => Ok(v),
Err(_) => Err(JoinError),
}
}));
}
pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
self.inner.next().await
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn abort_all(&mut self) {
self.inner = FuturesUnordered::new();
}
}
impl<T: 'static> Default for JoinSet<T> {
fn default() -> Self {
Self::new()
}
}
pub async fn sleep(dur: std::time::Duration) {
use js_sys::Promise;
use wasm_bindgen::JsCast;
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let ms = dur.as_millis() as i32;
let promise = Promise::new(&mut |resolve, _reject| {
let global = js_sys::global();
let set_timeout = js_sys::Reflect::get(&global, &JsValue::from_str("setTimeout"))
.expect("setTimeout missing from global scope");
let set_timeout: js_sys::Function =
set_timeout.dyn_into().expect("setTimeout not a function");
let _ = set_timeout.call2(&JsValue::NULL, &resolve, &JsValue::from(ms));
});
let _ = JsFuture::from(promise).await;
}
#[derive(Debug)]
pub struct Elapsed;
impl core::fmt::Display for Elapsed {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("deadline has elapsed")
}
}
impl std::error::Error for Elapsed {}
pub async fn timeout<F>(dur: std::time::Duration, fut: F) -> Result<F::Output, Elapsed>
where
F: Future,
{
tokio::select! {
v = fut => Ok(v),
() = sleep(dur) => Err(Elapsed),
}
}
}
#[allow(unused_imports)]
pub use imp::*;