#![allow(clippy::tabs_in_doc_comments)]
use std::{
future::Future,
pin::Pin,
task::{Context, Poll}
};
use futures_lite::future::FutureExt;
use once_cell::sync::OnceCell;
pub use tokio::{
runtime::{Handle as TokioHandle, Runtime as TokioRuntime},
sync::{
mpsc::{channel, Receiver, Sender},
Mutex, RwLock
},
task::JoinHandle as TokioJoinHandle
};
static RUNTIME: OnceCell<GlobalRuntime> = OnceCell::new();
struct GlobalRuntime {
runtime: Option<Runtime>,
handle: RuntimeHandle
}
impl GlobalRuntime {
fn handle(&self) -> RuntimeHandle {
if let Some(r) = &self.runtime { r.handle() } else { self.handle.clone() }
}
fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static
{
if let Some(r) = &self.runtime { r.spawn(task) } else { self.handle.spawn(task) }
}
pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static
{
if let Some(r) = &self.runtime {
r.spawn_blocking(func)
} else {
self.handle.spawn_blocking(func)
}
}
fn block_on<F: Future>(&self, task: F) -> F::Output {
if let Some(r) = &self.runtime { r.block_on(task) } else { self.handle.block_on(task) }
}
}
pub enum Runtime {
Tokio(TokioRuntime)
}
impl Runtime {
pub fn inner(&self) -> &TokioRuntime {
let Self::Tokio(r) = self;
r
}
pub fn handle(&self) -> RuntimeHandle {
match self {
Self::Tokio(r) => RuntimeHandle::Tokio(r.handle().clone())
}
}
pub fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static
{
match self {
Self::Tokio(r) => JoinHandle::Tokio(r.spawn(task))
}
}
pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static
{
match self {
Self::Tokio(r) => JoinHandle::Tokio(r.spawn_blocking(func))
}
}
pub fn block_on<F: Future>(&self, task: F) -> F::Output {
match self {
Self::Tokio(r) => r.block_on(task)
}
}
}
#[derive(Debug)]
pub enum JoinHandle<T> {
Tokio(TokioJoinHandle<T>)
}
impl<T> JoinHandle<T> {
pub fn inner(&self) -> &TokioJoinHandle<T> {
let Self::Tokio(t) = self;
t
}
pub fn abort(&self) {
match self {
Self::Tokio(t) => t.abort()
}
}
}
impl<T> Future for JoinHandle<T> {
type Output = crate::Result<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
Self::Tokio(t) => t.poll(cx).map_err(Into::into)
}
}
}
#[derive(Clone)]
pub enum RuntimeHandle {
Tokio(TokioHandle)
}
impl RuntimeHandle {
pub fn inner(&self) -> &TokioHandle {
let Self::Tokio(h) = self;
h
}
pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static
{
match self {
Self::Tokio(h) => JoinHandle::Tokio(h.spawn_blocking(func))
}
}
pub fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static
{
match self {
Self::Tokio(h) => JoinHandle::Tokio(h.spawn(task))
}
}
pub fn block_on<F: Future>(&self, task: F) -> F::Output {
match self {
Self::Tokio(h) => h.block_on(task)
}
}
}
fn default_runtime() -> GlobalRuntime {
let runtime = Runtime::Tokio(TokioRuntime::new().unwrap());
let handle = runtime.handle();
GlobalRuntime { runtime: Some(runtime), handle }
}
pub fn set(handle: TokioHandle) {
RUNTIME
.set(GlobalRuntime {
runtime: None,
handle: RuntimeHandle::Tokio(handle)
})
.unwrap_or_else(|_| panic!("runtime already initialized"))
}
pub fn handle() -> RuntimeHandle {
let runtime = RUNTIME.get_or_init(default_runtime);
runtime.handle()
}
pub fn block_on<F: Future>(task: F) -> F::Output {
let runtime = RUNTIME.get_or_init(default_runtime);
runtime.block_on(task)
}
pub fn spawn<F>(task: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static
{
let runtime = RUNTIME.get_or_init(default_runtime);
runtime.spawn(task)
}
pub fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static
{
let runtime = RUNTIME.get_or_init(default_runtime);
runtime.spawn_blocking(func)
}
#[allow(dead_code)]
pub(crate) fn safe_block_on<F>(task: F) -> F::Output
where
F: Future + Send + 'static,
F::Output: Send + 'static
{
if let Ok(handle) = tokio::runtime::Handle::try_current() {
let (tx, rx) = std::sync::mpsc::sync_channel(1);
let handle_ = handle.clone();
handle.spawn_blocking(move || {
tx.send(handle_.block_on(task)).unwrap();
});
rx.recv().unwrap()
} else {
block_on(task)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn runtime_spawn() {
let join = spawn(async { 5 });
assert_eq!(join.await.unwrap(), 5);
}
#[test]
fn runtime_block_on() {
assert_eq!(block_on(async { 0 }), 0);
}
#[tokio::test]
async fn handle_spawn() {
let handle = handle();
let join = handle.spawn(async { 5 });
assert_eq!(join.await.unwrap(), 5);
}
#[test]
fn handle_block_on() {
let handle = handle();
assert_eq!(handle.block_on(async { 0 }), 0);
}
#[tokio::test]
async fn handle_abort() {
let handle = handle();
let join = handle.spawn(async {
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
5
});
join.abort();
if let crate::Error::JoinError(raw_error) = join.await.unwrap_err() {
assert!(raw_error.is_cancelled());
} else {
panic!("Abort did not result in the expected `JoinError`");
}
}
}