use core::time::Duration;
use std::fmt::Debug;
use std::future::pending;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::task::Poll;
#[cfg(not(target_arch = "wasm32"))]
use async_compat::CompatExt;
use bevy_tasks::ConditionalSend;
use bevy_tasks::ConditionalSendFuture;
use futures::channel::oneshot;
use futures::task::AtomicWaker;
use crate::AsyncReceiver;
use crate::error::TimeoutError;
use crate::sleep;
use crate::util::timeout;
pub struct TimedAsyncTask<T: ConditionalSend> {
fut: Pin<Box<dyn ConditionalSendFuture<Output = T> + 'static>>,
timeout: Duration,
}
impl<T> Debug for TimedAsyncTask<T>
where
T: Debug + Send,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TimedAsyncTask")
.field("fut", &"<future>")
.field("timeout", &self.timeout)
.finish()
}
}
pub struct AsyncTask<T: ConditionalSend> {
fut: Pin<Box<dyn ConditionalSendFuture<Output = T> + 'static>>,
}
impl<T> Debug for AsyncTask<T>
where
T: Debug + Send,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncTask")
.field("fut", &"<future>")
.finish()
}
}
impl AsyncTask<()> {
pub async fn sleep(duration: Duration) -> Self {
Self::new(sleep(duration))
}
}
impl<T> AsyncTask<T>
where
T: ConditionalSend + 'static,
{
pub fn new<F>(fut: F) -> Self
where
F: ConditionalSendFuture<Output = T> + 'static,
F::Output: ConditionalSend + 'static,
{
Self { fut: Box::pin(fut) }
}
pub fn pending() -> Self {
Self::new(pending())
}
#[must_use]
pub fn split(self) -> (impl ConditionalSendFuture<Output = ()>, AsyncReceiver<T>) {
let (tx, rx) = oneshot::channel();
let waker = Arc::new(AtomicWaker::new());
let received = Arc::new(AtomicBool::new(false));
let fut = {
let waker = waker.clone();
let received = received.clone();
async move {
#[cfg(target_arch = "wasm32")]
let result = self.fut.await;
#[cfg(not(target_arch = "wasm32"))]
let result = self.fut.compat().await;
if let Ok(()) = tx.send(result) {
futures::future::poll_fn(|cx| {
waker.register(cx.waker());
if received.load(Ordering::Relaxed) {
Poll::Ready(())
} else {
Poll::Pending::<()>
}
})
.await;
}
}
};
let fut = Box::pin(fut);
let receiver = AsyncReceiver {
received,
waker,
receiver: rx,
};
(fut, receiver)
}
#[must_use]
pub fn with_timeout(self, dur: Duration) -> TimedAsyncTask<T> {
let millis = i32::try_from(dur.as_millis()).unwrap_or_else(|_e| {
panic!("failed to cast the duration into a i32 with Duration::as_millis.")
});
let timeout = core::time::Duration::from_millis(millis as u64);
TimedAsyncTask {
fut: self.fut,
timeout,
}
}
}
impl<T, Fnc> From<Fnc> for AsyncTask<T>
where
Fnc: ConditionalSendFuture<Output = T> + 'static,
Fnc::Output: ConditionalSend + 'static,
{
fn from(value: Fnc) -> Self {
Self::new(value)
}
}
impl<T> TimedAsyncTask<T>
where
T: ConditionalSend + 'static,
{
pub fn new<F>(dur: Duration, fut: F) -> Self
where
F: ConditionalSendFuture<Output = T> + 'static,
F::Output: ConditionalSend + 'static,
{
let millis = i32::try_from(dur.as_millis()).unwrap_or_else(|_e| {
panic!("failed to cast the duration into a i32 with Duration::as_millis.")
});
Self {
fut: Box::pin(fut),
timeout: core::time::Duration::from_millis(millis as u64),
}
}
pub fn pending() -> Self {
Self::new(crate::MAX_TIMEOUT, pending())
}
#[must_use]
pub fn split(
self,
) -> (
impl ConditionalSendFuture<Output = ()>,
AsyncReceiver<Result<T, TimeoutError>>,
) {
let (tx, rx) = oneshot::channel();
let waker = Arc::new(AtomicWaker::new());
let received = Arc::new(AtomicBool::new(false));
let fut = {
let waker = waker.clone();
let received = received.clone();
async move {
#[cfg(target_arch = "wasm32")]
let result = timeout(self.timeout, self.fut).await;
#[cfg(not(target_arch = "wasm32"))]
let result = timeout(self.timeout, self.fut.compat()).await;
if let Ok(()) = tx.send(result) {
futures::future::poll_fn(|cx| {
waker.register(cx.waker());
if received.load(Ordering::Relaxed) {
Poll::Ready(())
} else {
Poll::Pending::<()>
}
})
.await;
}
}
};
let fut = Box::pin(fut);
let receiver = AsyncReceiver {
received,
waker,
receiver: rx,
};
(fut, receiver)
}
#[must_use]
pub fn with_timeout(mut self, dur: Duration) -> Self {
let millis = i32::try_from(dur.as_millis()).unwrap_or_else(|_e| {
panic!("failed to cast the duration into a i32 with Duration::as_millis.")
});
self.timeout = core::time::Duration::from_millis(millis as u64);
self
}
#[must_use]
pub fn without_timeout(self) -> AsyncTask<T> {
AsyncTask { fut: self.fut }
}
}
impl<T, Fnc> From<Fnc> for TimedAsyncTask<T>
where
Fnc: ConditionalSendFuture<Output = T> + 'static,
Fnc::Output: ConditionalSend + 'static,
{
fn from(value: Fnc) -> Self {
Self::new(crate::DEFAULT_TIMEOUT, value)
}
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(test)]
mod test {
use core::time::Duration;
use futures::FutureExt;
use futures::pin_mut;
use futures_timer::Delay;
use tokio::select;
use super::*;
#[tokio::test]
async fn test_oneshot() {
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
if tx.send(3).is_err() {
panic!("the receiver dropped");
}
});
match rx.await {
Ok(v) => assert_eq!(3, v),
Err(e) => panic!("the sender dropped ({e})"),
}
}
#[tokio::test]
async fn test_try_recv() {
let task = AsyncTask::new(async move { 5 });
let (fut, mut rx) = task.split();
assert_eq!(None, rx.try_recv());
tokio::spawn(fut);
let fetch = Delay::new(Duration::from_millis(1));
let timeout = Delay::new(Duration::from_millis(100)).fuse();
pin_mut!(timeout, fetch);
'result: loop {
select! {
_ = (&mut fetch).fuse() => {
if let Some(v) = rx.try_recv() {
assert_eq!(5, v);
break 'result;
} else {
fetch.reset(Duration::from_millis(1));
}
}
_ = &mut timeout => panic!("timeout")
};
}
}
#[tokio::test]
async fn test_timeout() {
let task = TimedAsyncTask::new(Duration::from_millis(5), pending::<()>());
let (fut, mut rx) = task.split();
assert_eq!(None, rx.try_recv());
tokio::spawn(fut);
let fetch = Delay::new(Duration::from_millis(1));
let timeout = Delay::new(Duration::from_millis(100)).fuse();
pin_mut!(timeout, fetch);
'result: loop {
select! {
_ = (&mut fetch).fuse() => {
if let Some(v) = rx.try_recv() {
if matches!(v, Err(TimeoutError)) {
break 'result;
} else {
panic!("timeout should have triggered!");
}
} else {
fetch.reset(Duration::from_millis(1));
}
}
_ = &mut timeout => panic!("timeout")
};
}
}
#[tokio::test]
async fn test_with_timeout() {
let task = TimedAsyncTask::new(Duration::from_millis(5), pending::<()>());
let (fut, mut rx) = task.split();
assert_eq!(None, rx.try_recv());
tokio::spawn(fut);
let fetch = Delay::new(Duration::from_millis(1));
let timeout = Delay::new(Duration::from_millis(100)).fuse();
pin_mut!(timeout, fetch);
'result: loop {
select! {
_ = (&mut fetch).fuse() => {
if let Some(v) = rx.try_recv() {
if matches!(v, Err(TimeoutError)) {
break 'result;
} else {
panic!("timeout should have triggered!");
}
} else {
fetch.reset(Duration::from_millis(1));
}
}
_ = &mut timeout => panic!("timeout")
};
}
}
}
#[cfg(target_arch = "wasm32")]
#[cfg(test)]
mod test {
use wasm_bindgen::JsValue;
use wasm_bindgen_futures::JsFuture;
use wasm_bindgen_test::wasm_bindgen_test;
use super::*;
#[wasm_bindgen_test]
async fn test_oneshot() {
let (tx, rx) = oneshot::channel();
JsFuture::from(wasm_bindgen_futures::future_to_promise(async move {
if tx.send(3).is_err() {
panic!("the receiver dropped");
}
match rx.await {
Ok(v) => assert_eq!(3, v),
Err(e) => panic!("the sender dropped ({e})"),
}
Ok(JsValue::NULL)
}))
.await
.unwrap_or_else(|e| {
panic!("awaiting promise failed: {e:?}");
});
}
#[wasm_bindgen_test]
async fn test_try_recv() {
let task = AsyncTask::new(async move { 5 });
let (fut, mut rx) = task.split();
assert_eq!(None, rx.try_recv());
JsFuture::from(wasm_bindgen_futures::future_to_promise(async move {
fut.await;
Ok(JsValue::NULL)
}))
.await
.unwrap_or_else(|e| {
panic!("awaiting promise failed: {e:?}");
});
assert_eq!(Some(5), rx.try_recv());
}
#[wasm_bindgen_test]
async fn test_timeout() {
let task = TimedAsyncTask::<()>::pending().with_timeout(Duration::from_millis(5));
let (fut, mut rx) = task.split();
assert_eq!(None, rx.try_recv());
JsFuture::from(wasm_bindgen_futures::future_to_promise(async move {
fut.await;
Ok(JsValue::NULL)
}))
.await
.unwrap_or_else(|e| {
panic!("awaiting promise failed: {e:?}");
});
let v = rx.try_recv().unwrap_or_else(|| {
panic!("expected result after await");
});
assert!(v.is_err(), "timeout should have triggered!");
}
#[wasm_bindgen_test]
async fn test_with_timeout() {
let task = TimedAsyncTask::<()>::pending().with_timeout(Duration::from_millis(5));
let (fut, mut rx) = task.split();
assert_eq!(None, rx.try_recv());
JsFuture::from(wasm_bindgen_futures::future_to_promise(async move {
fut.await;
Ok(JsValue::NULL)
}))
.await
.unwrap_or_else(|e| {
panic!("awaiting promise failed: {e:?}");
});
let v = rx.try_recv().unwrap_or_else(|| {
panic!("expected result after await");
});
assert!(matches!(v, Err(TimeoutError)), "");
assert!(v.is_err(), "timeout should have triggered!");
}
}