use std::{future::Future, time::Duration};
pub trait Runtime: Send + Sync + 'static {
type JoinHandle<T>: Future<Output = Result<T, JoinError>> + Send
where
T: Send + 'static;
fn spawn<F, T>(&self, future: F) -> Self::JoinHandle<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static;
fn spawn_blocking<F, T>(&self, f: F) -> Self::JoinHandle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static;
fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send;
fn name(&self) -> &'static str;
fn is_available() -> bool;
}
#[derive(Debug)]
pub enum JoinError {
Cancelled,
Panic(Box<dyn std::any::Any + Send>),
Runtime(String),
}
impl std::fmt::Display for JoinError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Cancelled => write!(f, "Task was cancelled"),
Self::Panic(_) => write!(f, "Task panicked"),
Self::Runtime(msg) => write!(f, "Runtime error: {}", msg),
}
}
}
impl std::error::Error for JoinError {}
pub fn spawn<F, T>(future: F) -> impl Future<Output = Result<T, JoinError>> + Send
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
#[cfg(feature = "runtime-tokio")]
{
tokio_runtime::TokioRuntime.spawn(future)
}
#[cfg(all(feature = "runtime-async-std", not(feature = "runtime-tokio")))]
{
async_std_runtime::AsyncStdRuntime.spawn(future)
}
#[cfg(all(
feature = "runtime-smol",
not(feature = "runtime-tokio"),
not(feature = "runtime-async-std")
))]
{
smol_runtime::SmolRuntime.spawn(future)
}
}
pub fn spawn_blocking<F, T>(f: F) -> impl Future<Output = Result<T, JoinError>> + Send
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
#[cfg(feature = "runtime-tokio")]
{
tokio_runtime::TokioRuntime.spawn_blocking(f)
}
#[cfg(all(feature = "runtime-async-std", not(feature = "runtime-tokio")))]
{
async_std_runtime::AsyncStdRuntime.spawn_blocking(f)
}
#[cfg(all(
feature = "runtime-smol",
not(feature = "runtime-tokio"),
not(feature = "runtime-async-std")
))]
{
smol_runtime::SmolRuntime.spawn_blocking(f)
}
}
pub async fn sleep(duration: Duration) {
#[cfg(feature = "runtime-tokio")]
{
tokio_runtime::TokioRuntime.sleep(duration).await
}
#[cfg(all(feature = "runtime-async-std", not(feature = "runtime-tokio")))]
{
async_std_runtime::AsyncStdRuntime.sleep(duration).await
}
#[cfg(all(
feature = "runtime-smol",
not(feature = "runtime-tokio"),
not(feature = "runtime-async-std")
))]
{
smol_runtime::SmolRuntime.sleep(duration).await
}
}
pub fn runtime_name() -> &'static str {
#[cfg(feature = "runtime-tokio")]
{
"tokio"
}
#[cfg(all(feature = "runtime-async-std", not(feature = "runtime-tokio")))]
{
"async-std"
}
#[cfg(all(
feature = "runtime-smol",
not(feature = "runtime-tokio"),
not(feature = "runtime-async-std")
))]
{
"smol"
}
}
#[cfg(feature = "runtime-tokio")]
pub mod tokio_runtime {
use super::*;
pub struct TokioRuntime;
pub struct TokioJoinHandle<T>(tokio::task::JoinHandle<T>);
impl<T> Future for TokioJoinHandle<T>
where
T: Send + 'static,
{
type Output = Result<T, JoinError>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
use std::pin::Pin;
match Pin::new(&mut self.0).poll(cx) {
std::task::Poll::Ready(Ok(value)) => std::task::Poll::Ready(Ok(value)),
std::task::Poll::Ready(Err(e)) if e.is_cancelled() => {
std::task::Poll::Ready(Err(JoinError::Cancelled))
},
std::task::Poll::Ready(Err(e)) if e.is_panic() => {
std::task::Poll::Ready(Err(JoinError::Panic(e.into_panic())))
},
std::task::Poll::Ready(Err(e)) => {
std::task::Poll::Ready(Err(JoinError::Runtime(e.to_string())))
},
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
impl Runtime for TokioRuntime {
type JoinHandle<T>
= TokioJoinHandle<T>
where
T: Send + 'static;
fn spawn<F, T>(&self, future: F) -> Self::JoinHandle<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
TokioJoinHandle(tokio::task::spawn(future))
}
fn spawn_blocking<F, T>(&self, f: F) -> Self::JoinHandle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
TokioJoinHandle(tokio::task::spawn_blocking(f))
}
async fn sleep(&self, duration: Duration) {
tokio::time::sleep(duration).await
}
fn name(&self) -> &'static str {
"tokio"
}
fn is_available() -> bool {
tokio::runtime::Handle::try_current().is_ok()
}
}
}
#[cfg(feature = "runtime-async-std")]
pub mod async_std_runtime {
use super::*;
pub struct AsyncStdRuntime;
pub struct AsyncStdJoinHandle<T>(async_std::task::JoinHandle<T>);
impl<T> Future for AsyncStdJoinHandle<T>
where
T: Send + 'static,
{
type Output = Result<T, JoinError>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
use std::pin::Pin;
match Pin::new(&mut self.0).poll(cx) {
std::task::Poll::Ready(value) => std::task::Poll::Ready(Ok(value)),
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
impl Runtime for AsyncStdRuntime {
type JoinHandle<T>
= AsyncStdJoinHandle<T>
where
T: Send + 'static;
fn spawn<F, T>(&self, future: F) -> Self::JoinHandle<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
AsyncStdJoinHandle(async_std::task::spawn(future))
}
fn spawn_blocking<F, T>(&self, f: F) -> Self::JoinHandle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
AsyncStdJoinHandle(async_std::task::spawn_blocking(f))
}
async fn sleep(&self, duration: Duration) {
async_std::task::sleep(duration).await
}
fn name(&self) -> &'static str {
"async-std"
}
fn is_available() -> bool {
true
}
}
}
#[cfg(feature = "runtime-smol")]
pub mod smol_runtime {
use super::*;
pub struct SmolRuntime;
pub struct SmolJoinHandle<T>(smol::Task<T>);
impl<T> Future for SmolJoinHandle<T>
where
T: Send + 'static,
{
type Output = Result<T, JoinError>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
use std::pin::Pin;
match Pin::new(&mut self.0).poll(cx) {
std::task::Poll::Ready(value) => std::task::Poll::Ready(Ok(value)),
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
impl Runtime for SmolRuntime {
type JoinHandle<T>
= SmolJoinHandle<T>
where
T: Send + 'static;
fn spawn<F, T>(&self, future: F) -> Self::JoinHandle<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
SmolJoinHandle(smol::spawn(future))
}
fn spawn_blocking<F, T>(&self, f: F) -> Self::JoinHandle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
SmolJoinHandle(smol::unblock(f))
}
async fn sleep(&self, duration: Duration) {
smol::Timer::after(duration).await
}
fn name(&self) -> &'static str {
"smol"
}
fn is_available() -> bool {
true
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[cfg(feature = "runtime-tokio")]
async fn test_tokio_runtime() {
assert_eq!(runtime_name(), "tokio");
let handle = spawn(async { 42 });
let result = handle.await.unwrap();
assert_eq!(result, 42);
sleep(Duration::from_millis(10)).await;
}
#[async_std::test]
#[cfg(feature = "runtime-async-std")]
async fn test_async_std_runtime() {
assert_eq!(runtime_name(), "async-std");
let handle = spawn(async { 42 });
let result = handle.await.unwrap();
assert_eq!(result, 42);
sleep(Duration::from_millis(10)).await;
}
#[test]
#[cfg(feature = "runtime-smol")]
fn test_smol_runtime() {
smol::block_on(async {
assert_eq!(runtime_name(), "smol");
let handle = spawn(async { 42 });
let result = handle.await.unwrap();
assert_eq!(result, 42);
sleep(Duration::from_millis(10)).await;
});
}
#[tokio::test]
#[cfg(feature = "runtime-tokio")]
async fn test_spawn_blocking() {
let handle = spawn_blocking(|| {
std::thread::sleep(Duration::from_millis(10));
42
});
let result = handle.await.unwrap();
assert_eq!(result, 42);
}
}