use std::{cell::RefCell, future::Future};
use tokio::runtime::{Builder, Handle, Runtime};
const RUNTIME_NOT_INITIALIZED: &str = "rustrails_support::runtime::init_runtime() must be called on this thread before using runtime helpers";
thread_local! {
static RT_HANDLE: RefCell<Option<Handle>> = const { RefCell::new(None) };
}
fn with_handle<R>(f: impl FnOnce(&Handle) -> R) -> R {
RT_HANDLE.with(|cell| {
let borrow = cell.borrow();
let handle = borrow
.as_ref()
.unwrap_or_else(|| panic!("{RUNTIME_NOT_INITIALIZED}"));
f(handle)
})
}
pub fn init_runtime() -> Runtime {
let runtime = match Builder::new_multi_thread().enable_all().build() {
Ok(runtime) => runtime,
Err(error) => panic!("failed to build Tokio runtime: {error}"),
};
let handle = runtime.handle().clone();
RT_HANDLE.with(|cell| {
*cell.borrow_mut() = Some(handle);
});
runtime
}
pub fn block_on<F: Future>(future: F) -> F::Output {
with_handle(|handle| match Handle::try_current() {
Ok(current) if current.id() == handle.id() => {
tokio::task::block_in_place(|| handle.block_on(future))
}
_ => handle.block_on(future),
})
}
pub fn spawn<F>(future: F) -> tokio::task::JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
with_handle(|handle| handle.spawn(future))
}
pub fn is_initialized() -> bool {
RT_HANDLE.with(|cell| cell.borrow().is_some())
}
#[cfg(test)]
mod tests {
use std::{
any::Any,
sync::mpsc,
thread,
time::{Duration, Instant},
};
use super::{block_on, init_runtime, is_initialized, spawn};
fn run_isolated<R>(test: impl FnOnce() -> R + Send + 'static) -> R
where
R: Send + 'static,
{
match thread::spawn(test).join() {
Ok(result) => result,
Err(payload) => std::panic::resume_unwind(payload),
}
}
fn panic_message(payload: Box<dyn Any + Send>) -> String {
if let Some(message) = payload.downcast_ref::<String>() {
message.clone()
} else if let Some(message) = payload.downcast_ref::<&str>() {
(*message).to_owned()
} else {
"non-string panic payload".to_owned()
}
}
#[test]
fn init_runtime_sets_initialized_to_true() {
run_isolated(|| {
assert!(!is_initialized());
let _runtime = init_runtime();
assert!(is_initialized());
});
}
#[test]
fn is_initialized_is_false_before_init() {
run_isolated(|| {
assert!(!is_initialized());
});
}
#[test]
fn block_on_executes_simple_future() {
run_isolated(|| {
let _runtime = init_runtime();
assert_eq!(block_on(async { 42 }), 42);
});
}
#[test]
fn block_on_propagates_result_errors() {
run_isolated(|| {
let _runtime = init_runtime();
let result = block_on(async { Result::<(), &'static str>::Err("boom") });
assert_eq!(result, Err("boom"));
});
}
#[test]
fn block_on_panics_with_clear_message_before_init() {
let message = run_isolated(|| {
let panic = std::panic::catch_unwind(|| {
let _: i32 = block_on(async { 42 });
})
.expect_err("block_on should panic before init_runtime");
panic_message(panic)
});
assert!(message.contains("init_runtime() must be called on this thread"));
}
#[test]
fn spawn_runs_task_to_completion() {
run_isolated(|| {
let _runtime = init_runtime();
let join = spawn(async { 7_i32 * 6 });
let value = block_on(async { join.await.expect("task should complete") });
assert_eq!(value, 42);
});
}
#[test]
fn multiple_sequential_block_on_calls_work() {
run_isolated(|| {
let _runtime = init_runtime();
assert_eq!(block_on(async { 1 }), 1);
assert_eq!(block_on(async { 2 }), 2);
assert_eq!(block_on(async { 3 }), 3);
});
}
#[test]
fn block_on_supports_sleeping_futures() {
run_isolated(|| {
let _runtime = init_runtime();
let start = Instant::now();
block_on(async {
tokio::time::sleep(Duration::from_millis(10)).await;
});
assert!(start.elapsed() >= Duration::from_millis(10));
});
}
#[test]
fn block_on_reenters_the_same_runtime_inside_async_context() {
run_isolated(|| {
let runtime = init_runtime();
let value = runtime.block_on(async { block_on(async { 21 * 2 }) });
assert_eq!(value, 42);
});
}
#[test]
fn spawn_can_signal_back_to_sync_code() {
run_isolated(|| {
let _runtime = init_runtime();
let (sender, receiver) = mpsc::channel();
let join = spawn(async move {
sender.send("done").expect("channel send should succeed");
});
block_on(async { join.await.expect("task should complete") });
assert_eq!(
receiver.recv().expect("channel receive should succeed"),
"done"
);
});
}
}