use futures::FutureExt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use once_cell::sync::OnceCell;
use tokio::task::JoinHandle;
use tokio::time;
#[cfg(feature = "tokio")]
fn tokio() -> &'static tokio::runtime::Runtime {
static INSTANCE: OnceCell<tokio::runtime::Runtime> = OnceCell::new();
INSTANCE.get_or_init(|| tokio::runtime::Runtime::new().unwrap())
}
#[derive(Debug)]
pub struct TaskHandle<T>(JoinHandle<T>);
impl<T> TaskHandle<T> {
pub async fn cancel(self) -> Option<T> {
self.0.abort();
self.await
}
pub async fn wait(self) -> Option<T> {
self.await
}
}
impl<T> Future for TaskHandle<T> {
type Output = Option<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.0.poll_unpin(cx) {
Poll::Ready(Ok(t)) => Poll::Ready(Some(t)),
Poll::Ready(Err(_e)) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
pub fn block_on<F, T>(future: F) -> T
where
F: Future<Output = T>,
{
use std::cell::Cell;
thread_local! {
static NUM_NESTED_BLOCKING: Cell<usize> = Cell::new(0);
}
NUM_NESTED_BLOCKING.with(|num_nested_blocking| {
let count = num_nested_blocking.get();
let should_run = count == 0;
num_nested_blocking.replace(count + 1);
let res = if should_run {
tokio().block_on(future)
} else {
tokio::task::block_in_place(|| futures::executor::block_on(future))
};
num_nested_blocking.replace(num_nested_blocking.get() - 1);
res
})
}
pub fn spawn<F, T>(future: F) -> TaskHandle<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let h = tokio().spawn(async move { future.await });
TaskHandle(h)
}
pub async fn sleep(dur: Duration) {
time::sleep(dur).await
}
pub async fn timeout<F, T>(dur: Duration, f: F) -> Result<T, ()>
where
F: Future<Output = T>,
{
time::timeout(dur, f).await.map_err(|_| ())
}