use std::{
cell::RefCell,
io,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Condvar, Mutex,
},
};
use tokio::runtime::Runtime;
#[derive(Default)]
struct ShutdownBarrier {
guard_count: AtomicUsize,
shutdown_finalized: Mutex<bool>,
cvar: Condvar,
}
#[derive(PartialEq, Eq, Debug)]
enum BarrierContext {
RuntimeWorker,
PoolWorker,
}
thread_local! {
static CONTEXT: RefCell<Option<BarrierContext>> = RefCell::new(None);
}
#[derive(PartialEq, Eq)]
pub(crate) enum Kind {
CurrentThread,
MultiThread,
}
pub struct Builder {
kind: Kind,
worker_threads: usize,
inner: tokio::runtime::Builder,
}
impl Builder {
pub fn new_current_thread() -> Builder {
Builder {
kind: Kind::CurrentThread,
worker_threads: 1,
inner: tokio::runtime::Builder::new_current_thread(),
}
}
pub fn new_multi_thread() -> Builder {
let worker_threads = std::env::var("TOKIO_WORKER_THEADS")
.ok()
.and_then(|worker_threads| worker_threads.parse().ok())
.unwrap_or_else(num_cpus::get);
Builder {
kind: Kind::MultiThread,
worker_threads,
inner: tokio::runtime::Builder::new_multi_thread(),
}
}
pub fn enable_all(&mut self) -> &mut Self {
self.inner.enable_all();
self
}
#[track_caller]
pub fn worker_threads(&mut self, val: usize) -> &mut Self {
assert!(val > 0, "Worker threads cannot be set to 0");
if self.kind.ne(&Kind::CurrentThread) {
self.worker_threads = val;
self.inner.worker_threads(val);
}
self
}
pub fn build(&mut self) -> io::Result<Runtime> {
let worker_threads = self.worker_threads;
let barrier = Arc::new(ShutdownBarrier::default());
let on_thread_start = {
let barrier = barrier.clone();
move || {
let thread_count = barrier.guard_count.fetch_add(1, Ordering::Release);
CONTEXT.with(|context| {
if thread_count.ge(&worker_threads) {
*context.borrow_mut() = Some(BarrierContext::PoolWorker)
} else {
*context.borrow_mut() = Some(BarrierContext::RuntimeWorker)
}
});
}
};
let on_thread_stop = move || {
let thread_count = barrier.guard_count.fetch_sub(1, Ordering::AcqRel);
CONTEXT.with(|context| {
if thread_count.eq(&1) {
*barrier.shutdown_finalized.lock().unwrap() = true;
barrier.cvar.notify_all();
} else if context.borrow().eq(&Some(BarrierContext::RuntimeWorker)) {
let mut shutdown_finalized = barrier.shutdown_finalized.lock().unwrap();
while !*shutdown_finalized {
shutdown_finalized = barrier.cvar.wait(shutdown_finalized).unwrap();
}
}
});
};
self
.inner
.on_thread_start(on_thread_start)
.on_thread_stop(on_thread_stop)
.build()
}
}