use crate::{Error, Task};
use core::{
future::Future,
panic::AssertUnwindSafe,
pin::Pin,
task::{Context, Poll},
};
use alloc::boxed::Box;
pub use async_task::{Runnable, Task as RawTask};
#[cfg(feature = "std")]
use crate::catch_unwind;
#[cfg(not(feature = "std"))]
fn catch_unwind<F, R>(f: F) -> Result<R, Error>
where
F: FnOnce() -> R,
{
Ok(f())
}
pub struct AsyncTask<T>(async_task::Task<T>);
impl<T> AsyncTask<T> {
pub async fn result(self) -> Result<T, Box<dyn core::any::Any + Send>> {
crate::Task::result(self).await
}
pub async fn cancel(self) {
self.0.cancel().await;
}
}
impl<T> core::fmt::Debug for AsyncTask<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("AsyncTask").finish_non_exhaustive()
}
}
impl<T> From<async_task::Task<T>> for AsyncTask<T> {
fn from(task: async_task::Task<T>) -> Self {
Self(task)
}
}
impl<T> Future for AsyncTask<T> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx)
}
}
impl<T> Task<T> for AsyncTask<T> {
fn poll_result(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<T, Error>> {
match catch_unwind(AssertUnwindSafe(|| Pin::new(&mut self.0).poll(cx))) {
Ok(Poll::Ready(value)) => Poll::Ready(Ok(value)),
Ok(Poll::Pending) => Poll::Pending,
Err(error) => Poll::Ready(Err(error)),
}
}
}
pub fn spawn<F, S>(future: F, scheduler: S) -> (Runnable, AsyncTask<F::Output>)
where
F: Future + Send + 'static,
F::Output: Send + 'static,
S: Fn(Runnable) + Send + Sync + 'static,
{
let (runnable, task) = async_task::spawn(future, scheduler);
(runnable, AsyncTask::from(task))
}
#[cfg(feature = "std")]
pub fn spawn_local<F, S>(future: F, scheduler: S) -> (Runnable, AsyncTask<F::Output>)
where
F: Future + 'static,
S: Fn(Runnable) + Send + Sync + 'static,
{
let (runnable, task) = async_task::spawn_local(future, scheduler);
(runnable, AsyncTask::from(task))
}