use std::future::Future;
use tokio::runtime::{Handle, RuntimeFlavor};
pub(crate) fn run_future(future: impl Future<Output = ()>) {
if std::thread::panicking() {
return;
}
let Ok(handle) = Handle::try_current() else {
panic!(
"TerminateOnDrop requires a tokio runtime to be active when the handle drops. \
No runtime was found in the current thread context. Drop the handle from inside a \
#[tokio::main] or #[tokio::test(flavor = \"multi_thread\")] context, or use \
`must_not_be_terminated()` to opt out of automatic termination."
)
};
if matches!(handle.runtime_flavor(), RuntimeFlavor::CurrentThread) {
panic!(
"TerminateOnDrop requires a multi-threaded tokio runtime to function correctly. \
The current runtime is single-threaded, which cannot drive `block_in_place`. \
Switch to `#[tokio::test(flavor = \"multi_thread\")]` or build the runtime with \
`tokio::runtime::Builder::new_multi_thread()`."
);
}
tokio::task::block_in_place(|| handle.block_on(future));
}
#[cfg(test)]
mod tests {
use super::*;
fn run_silently<R>(f: impl FnOnce() -> R + std::panic::UnwindSafe) -> std::thread::Result<R> {
let prev_hook = std::panic::take_hook();
std::panic::set_hook(Box::new(|_info| {}));
let result = std::panic::catch_unwind(f);
std::panic::set_hook(prev_hook);
result
}
#[test]
fn missing_runtime_panic_names_terminate_on_drop() {
let payload = run_silently(|| run_future(async {})).expect_err("should panic");
let message = payload
.downcast_ref::<&'static str>()
.copied()
.or_else(|| payload.downcast_ref::<String>().map(String::as_str))
.unwrap_or("");
assert!(
message.contains("TerminateOnDrop"),
"panic message should name TerminateOnDrop, got: {message}"
);
}
#[tokio::test(flavor = "current_thread")]
async fn single_threaded_runtime_panic_names_terminate_on_drop() {
let payload = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
run_future(async {});
}));
let payload = payload.expect_err("should panic on single-threaded runtime");
let message = payload
.downcast_ref::<&'static str>()
.copied()
.or_else(|| payload.downcast_ref::<String>().map(String::as_str))
.unwrap_or("");
assert!(
message.contains("TerminateOnDrop"),
"panic message should name TerminateOnDrop, got: {message}"
);
}
}