use std::{
any::Any,
future::Future,
panic::{AssertUnwindSafe, catch_unwind, resume_unwind},
sync::OnceLock,
};
use tokio::runtime::{Builder, Handle, Runtime};
static RUNTIME: OnceLock<Runtime> = OnceLock::new();
const WORKER_THREAD_PANIC_MESSAGE: &str = "rjango::runtime::block_on() cannot be called from a Tokio worker thread; wrap synchronous facade code in tokio::task::spawn_blocking before awaiting async Rjango work";
const TOKIO_REENTRANCY_MESSAGES: [&str; 3] = [
"Cannot start a runtime from within a runtime",
"Cannot block the current thread from within a runtime",
"thread is being used to drive asynchronous tasks",
];
fn build_runtime() -> Runtime {
Builder::new_multi_thread()
.enable_all()
.build()
.expect("Rjango global Tokio runtime should build")
}
fn global_runtime() -> &'static Runtime {
RUNTIME.get_or_init(build_runtime)
}
fn panic_message(payload: &(dyn Any + Send)) -> Option<&str> {
payload
.downcast_ref::<&str>()
.copied()
.or_else(|| payload.downcast_ref::<String>().map(String::as_str))
}
fn is_tokio_reentrancy_panic(payload: &(dyn Any + Send)) -> bool {
panic_message(payload).is_some_and(|message| {
TOKIO_REENTRANCY_MESSAGES
.iter()
.any(|candidate| message.contains(candidate))
})
}
pub fn init() {
let _ = global_runtime();
}
#[must_use]
pub fn handle() -> Handle {
Handle::try_current().unwrap_or_else(|_| global_runtime().handle().clone())
}
pub fn block_on<F: Future>(fut: F) -> F::Output {
match Handle::try_current() {
Ok(handle) => {
catch_unwind(AssertUnwindSafe(|| handle.block_on(fut))).unwrap_or_else(|payload| {
if is_tokio_reentrancy_panic(payload.as_ref()) {
panic!("{WORKER_THREAD_PANIC_MESSAGE}");
}
resume_unwind(payload);
})
}
Err(_) => global_runtime().block_on(fut),
}
}
pub fn spawn<F>(fut: F) -> tokio::task::JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
handle().spawn(fut)
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use tokio::sync::oneshot;
use tokio::time::sleep;
use super::{block_on, handle, init, spawn};
fn assert_send_sync<T: Send + Sync>() {}
#[test]
fn init_is_idempotent() {
init();
init();
}
#[test]
fn block_on_runs_future_from_sync() {
let result = block_on(async { 21 + 21 });
assert_eq!(result, 42);
}
#[test]
fn block_on_runs_async_sleep() {
let result = block_on(async {
sleep(Duration::from_millis(10)).await;
42
});
assert_eq!(result, 42);
}
#[test]
fn block_on_nested_from_spawn_blocking() {
let result = block_on(async {
tokio::task::spawn_blocking(|| block_on(async { 21 + 21 }))
.await
.expect("blocking task should join")
});
assert_eq!(result, 42);
}
#[test]
fn spawn_runs_task() {
let join = spawn(async { 21 + 21 });
let result = block_on(join).expect("spawned task should complete");
assert_eq!(result, 42);
}
#[test]
fn spawn_multiple_tasks() {
let joins: Vec<_> = (0..4)
.map(|value| spawn(async move { value * value }))
.collect();
let results = block_on(async {
let mut results = Vec::with_capacity(joins.len());
for join in joins {
results.push(join.await.expect("spawned task should complete"));
}
results
});
assert_eq!(results, vec![0, 1, 4, 9]);
}
#[test]
fn handle_returns_valid_handle() {
let runtime_handle = handle();
let join = runtime_handle.spawn(async {
sleep(Duration::from_millis(5)).await;
42
});
let result = block_on(join).expect("spawned task should complete");
assert_eq!(result, 42);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn handle_from_tokio_context() {
let runtime_handle = handle();
let join = runtime_handle.spawn(async {
sleep(Duration::from_millis(5)).await;
42
});
assert_eq!(join.await.expect("spawned task should complete"), 42);
}
#[test]
fn block_on_with_channel() {
let (tx, rx) = oneshot::channel();
let _join = spawn(async move {
tx.send(42).expect("receiver should still be waiting");
});
let result = block_on(async { rx.await.expect("sender should send a value") });
assert_eq!(result, 42);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[should_panic(
expected = "rjango::runtime::block_on() cannot be called from a Tokio worker thread"
)]
async fn block_on_panics_descriptive_on_worker_thread() {
let _ = block_on(async { 42 });
}
#[test]
fn runtime_is_send_sync() {
assert_send_sync::<tokio::runtime::Handle>();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn spawn_blocking_then_block_on() {
let result = tokio::task::spawn_blocking(|| {
let join = spawn(async {
sleep(Duration::from_millis(5)).await;
42
});
block_on(async { join.await.expect("spawned task should complete") })
})
.await
.expect("blocking task should join");
assert_eq!(result, 42);
}
}