use once_cell::sync::OnceCell;
pub use tokio::{
runtime::{Handle as TokioHandle, Runtime as TokioRuntime},
sync::{
mpsc::{channel, Receiver, Sender},
Mutex, RwLock,
},
task::JoinHandle as TokioJoinHandle,
};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
static RUNTIME: OnceCell<GlobalRuntime> = OnceCell::new();
struct GlobalRuntime {
runtime: Option<Runtime>,
handle: RuntimeHandle,
}
impl GlobalRuntime {
fn handle(&self) -> RuntimeHandle {
if let Some(r) = &self.runtime {
r.handle()
} else {
self.handle.clone()
}
}
fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
if let Some(r) = &self.runtime {
r.spawn(task)
} else {
self.handle.spawn(task)
}
}
pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
if let Some(r) = &self.runtime {
r.spawn_blocking(func)
} else {
self.handle.spawn_blocking(func)
}
}
fn block_on<F: Future>(&self, task: F) -> F::Output {
if let Some(r) = &self.runtime {
r.block_on(task)
} else {
self.handle.block_on(task)
}
}
}
pub enum Runtime {
Tokio(TokioRuntime),
}
impl Runtime {
pub fn inner(&self) -> &TokioRuntime {
let Self::Tokio(r) = self;
r
}
pub fn handle(&self) -> RuntimeHandle {
match self {
Self::Tokio(r) => RuntimeHandle::Tokio(r.handle().clone()),
}
}
pub fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
match self {
Self::Tokio(r) => {
let _guard = r.enter();
JoinHandle::Tokio(tokio::spawn(task))
}
}
}
pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
match self {
Self::Tokio(r) => JoinHandle::Tokio(r.spawn_blocking(func)),
}
}
pub fn block_on<F: Future>(&self, task: F) -> F::Output {
match self {
Self::Tokio(r) => r.block_on(task),
}
}
}
#[derive(Debug)]
pub enum JoinHandle<T> {
Tokio(TokioJoinHandle<T>),
}
impl<T> JoinHandle<T> {
pub fn inner(&self) -> &TokioJoinHandle<T> {
let Self::Tokio(t) = self;
t
}
pub fn abort(&self) {
match self {
Self::Tokio(t) => t.abort(),
}
}
}
impl<T> Future for JoinHandle<T> {
type Output = crate::Result<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
Self::Tokio(t) => Pin::new(t).poll(cx).map_err(Into::into),
}
}
}
#[derive(Clone)]
pub enum RuntimeHandle {
Tokio(TokioHandle),
}
impl RuntimeHandle {
pub fn inner(&self) -> &TokioHandle {
let Self::Tokio(h) = self;
h
}
pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
match self {
Self::Tokio(h) => JoinHandle::Tokio(h.spawn_blocking(func)),
}
}
pub fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
match self {
Self::Tokio(h) => {
let _guard = h.enter();
JoinHandle::Tokio(tokio::spawn(task))
}
}
}
pub fn block_on<F: Future>(&self, task: F) -> F::Output {
match self {
Self::Tokio(h) => h.block_on(task),
}
}
}
fn default_runtime() -> GlobalRuntime {
let runtime = Runtime::Tokio(TokioRuntime::new().unwrap());
let handle = runtime.handle();
GlobalRuntime {
runtime: Some(runtime),
handle,
}
}
pub fn set(handle: TokioHandle) {
RUNTIME
.set(GlobalRuntime {
runtime: None,
handle: RuntimeHandle::Tokio(handle),
})
.unwrap_or_else(|_| panic!("runtime already initialized"))
}
pub fn handle() -> RuntimeHandle {
let runtime = RUNTIME.get_or_init(default_runtime);
runtime.handle()
}
pub fn block_on<F: Future>(task: F) -> F::Output {
let runtime = RUNTIME.get_or_init(default_runtime);
runtime.block_on(task)
}
pub fn spawn<F>(task: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let runtime = RUNTIME.get_or_init(default_runtime);
runtime.spawn(task)
}
pub fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let runtime = RUNTIME.get_or_init(default_runtime);
runtime.spawn_blocking(func)
}
#[allow(dead_code)]
pub(crate) fn safe_block_on<F>(task: F) -> F::Output
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
if let Ok(handle) = tokio::runtime::Handle::try_current() {
let (tx, rx) = std::sync::mpsc::sync_channel(1);
let handle_ = handle.clone();
handle.spawn_blocking(move || {
tx.send(handle_.block_on(task)).unwrap();
});
rx.recv().unwrap()
} else {
block_on(task)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn runtime_spawn() {
let join = spawn(async { 5 });
assert_eq!(join.await.unwrap(), 5);
}
#[test]
fn runtime_block_on() {
assert_eq!(block_on(async { 0 }), 0);
}
#[tokio::test]
async fn handle_spawn() {
let handle = handle();
let join = handle.spawn(async { 5 });
assert_eq!(join.await.unwrap(), 5);
}
#[test]
fn handle_block_on() {
let handle = handle();
assert_eq!(handle.block_on(async { 0 }), 0);
}
#[tokio::test]
async fn handle_abort() {
let handle = handle();
let join = handle.spawn(async {
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
5
});
join.abort();
if let crate::Error::JoinError(raw_error) = join.await.unwrap_err() {
assert!(raw_error.is_cancelled());
} else {
panic!("Abort did not result in the expected `JoinError`");
}
}
}