use std::{
fmt::{self, Debug},
io,
sync::{
Arc, Condvar, Mutex,
atomic::{AtomicUsize, Ordering},
},
};
use linkme::distributed_slice;
use crate::{BarrierContext, CONTEXT};
#[derive(Default)]
struct ShutdownBarrier {
guard_count: AtomicUsize,
shutdown_finalized: Mutex<bool>,
cvar: Condvar,
}
#[derive(PartialEq, Eq)]
pub(crate) enum Kind {
CurrentThread,
#[cfg(feature = "rt-multi-thread")]
MultiThread,
}
#[doc(hidden)]
pub struct Builder {
kind: Kind,
worker_threads: usize,
inner: tokio::runtime::Builder,
}
impl Builder {
pub fn new_current_thread() -> Builder {
Builder {
kind: Kind::CurrentThread,
worker_threads: 1,
inner: tokio::runtime::Builder::new_current_thread(),
}
}
#[cfg(feature = "rt-multi-thread")]
pub fn new_multi_thread() -> Builder {
let worker_threads = std::env::var("TOKIO_WORKER_THEADS")
.ok()
.and_then(|worker_threads| worker_threads.parse().ok())
.unwrap_or_else(num_cpus::get);
Builder {
kind: Kind::MultiThread,
worker_threads,
inner: tokio::runtime::Builder::new_multi_thread(),
}
}
pub fn enable_all(&mut self) -> &mut Self {
self.inner.enable_all();
self
}
#[track_caller]
pub fn worker_threads(&mut self, val: usize) -> &mut Self {
assert!(val > 0, "Worker threads cannot be set to 0");
if self.kind.ne(&Kind::CurrentThread) {
self.worker_threads = val;
self.inner.worker_threads(val);
}
self
}
pub fn build(&mut self) -> io::Result<Runtime> {
let worker_threads = self.worker_threads;
let barrier = Arc::new(ShutdownBarrier::default());
let on_thread_start = {
let barrier = barrier.clone();
move || {
let thread_count = barrier.guard_count.fetch_add(1, Ordering::Release);
CONTEXT.with(|context| {
if thread_count.ge(&worker_threads) {
*context.borrow_mut() = Some(BarrierContext::PoolWorker)
} else {
*context.borrow_mut() = Some(BarrierContext::RuntimeWorker)
}
});
}
};
let on_thread_stop = move || {
let thread_count = barrier.guard_count.fetch_sub(1, Ordering::AcqRel);
CONTEXT.with(|context| {
if thread_count.eq(&1) {
*barrier.shutdown_finalized.lock().unwrap() = true;
barrier.cvar.notify_all();
} else if context.borrow().eq(&Some(BarrierContext::RuntimeWorker)) {
let mut shutdown_finalized = barrier.shutdown_finalized.lock().unwrap();
while !*shutdown_finalized {
shutdown_finalized = barrier.cvar.wait(shutdown_finalized).unwrap();
}
}
});
};
self
.inner
.on_thread_start(on_thread_start)
.on_thread_stop(on_thread_stop)
.build()
.map(Runtime::new)
}
}
#[doc(hidden)]
pub struct Runtime(tokio::runtime::Runtime);
impl Debug for Runtime {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl Runtime {
fn new(inner: tokio::runtime::Runtime) -> Self {
Runtime(inner)
}
#[track_caller]
pub unsafe fn block_on<F: Future>(self, future: F) -> F::Output {
unsafe { self.run(|handle| handle.block_on(future)) }
}
pub unsafe fn run<F, Output>(self, f: F) -> Output
where
F: for<'a> FnOnce(&'a tokio::runtime::Runtime) -> Output,
{
CONTEXT.with(|context| *context.borrow_mut() = Some(BarrierContext::Owner));
let output = f(&self.0);
drop(self);
CONTEXT.with(|context| *context.borrow_mut() = None::<BarrierContext>);
output
}
}
#[doc(hidden)]
#[derive(Debug, PartialEq, Eq)]
pub enum RuntimeContext {
Main,
Test,
}
#[doc(hidden)]
#[distributed_slice]
pub static RUNTIMES: [RuntimeContext];
#[cfg(not(feature = "compat"))]
#[ctor::ctor]
fn assert_runtime_configured() {
if RUNTIMES.is_empty() {
panic!(
"The #[async_local::main] or #[async_local::test] macro must be used to configure the Tokio runtime for use with the `async-local` crate. For compatibilty with other async runtime configurations, the `compat` feature can be used to disable the optimizations this crate provides"
);
}
if RUNTIMES
.iter()
.fold(0, |acc, context| {
if context.eq(&RuntimeContext::Main) {
acc + 1
} else {
acc
}
})
.gt(&1)
{
panic!("The #[async_local::main] macro cannot be used more than once");
}
}