1use std::{
2 cell::RefCell,
3 io,
4 sync::{
5 Arc, Condvar, Mutex,
6 atomic::{AtomicUsize, Ordering},
7 },
8};
9
10use tokio::runtime::Runtime;
11
12#[derive(Default)]
13struct ShutdownBarrier {
14 guard_count: AtomicUsize,
15 shutdown_finalized: Mutex<bool>,
16 cvar: Condvar,
17}
18
19#[derive(PartialEq, Eq, Debug)]
20enum BarrierContext {
21 RuntimeWorker,
23 PoolWorker,
25}
26
27thread_local! {
28 static CONTEXT: RefCell<Option<BarrierContext>> = const { RefCell::new(None) };
29}
30
31#[derive(PartialEq, Eq)]
32pub(crate) enum Kind {
33 CurrentThread,
34 MultiThread,
35}
36
37pub struct Builder {
39 kind: Kind,
40 worker_threads: usize,
41 inner: tokio::runtime::Builder,
42}
43
44impl Builder {
45 pub fn new_current_thread() -> Builder {
47 Builder {
48 kind: Kind::CurrentThread,
49 worker_threads: 1,
50 inner: tokio::runtime::Builder::new_current_thread(),
51 }
52 }
53
54 pub fn new_multi_thread() -> Builder {
56 let worker_threads = std::env::var("TOKIO_WORKER_THEADS")
57 .ok()
58 .and_then(|worker_threads| worker_threads.parse().ok())
59 .unwrap_or_else(num_cpus::get);
60
61 Builder {
62 kind: Kind::MultiThread,
63 worker_threads,
64 inner: tokio::runtime::Builder::new_multi_thread(),
65 }
66 }
67
68 pub fn enable_all(&mut self) -> &mut Self {
70 self.inner.enable_all();
71 self
72 }
73
74 #[track_caller]
119 pub fn worker_threads(&mut self, val: usize) -> &mut Self {
120 assert!(val > 0, "Worker threads cannot be set to 0");
121 if self.kind.ne(&Kind::CurrentThread) {
122 self.worker_threads = val;
123 self.inner.worker_threads(val);
124 }
125 self
126 }
127
128 pub fn build(&mut self) -> io::Result<Runtime> {
130 let worker_threads = self.worker_threads;
131 let barrier = Arc::new(ShutdownBarrier::default());
132
133 let on_thread_start = {
134 let barrier = barrier.clone();
135 move || {
136 let thread_count = barrier.guard_count.fetch_add(1, Ordering::Release);
137
138 CONTEXT.with(|context| {
139 if thread_count.ge(&worker_threads) {
140 *context.borrow_mut() = Some(BarrierContext::PoolWorker)
141 } else {
142 *context.borrow_mut() = Some(BarrierContext::RuntimeWorker)
143 }
144 });
145 }
146 };
147
148 let on_thread_stop = move || {
149 let thread_count = barrier.guard_count.fetch_sub(1, Ordering::AcqRel);
150
151 CONTEXT.with(|context| {
152 if thread_count.eq(&1) {
153 *barrier.shutdown_finalized.lock().unwrap() = true;
154 barrier.cvar.notify_all();
155 } else if context.borrow().eq(&Some(BarrierContext::RuntimeWorker)) {
156 let mut shutdown_finalized = barrier.shutdown_finalized.lock().unwrap();
157 while !*shutdown_finalized {
158 shutdown_finalized = barrier.cvar.wait(shutdown_finalized).unwrap();
159 }
160 }
161 });
162 };
163
164 self
165 .inner
166 .on_thread_start(on_thread_start)
167 .on_thread_stop(on_thread_stop)
168 .build()
169 }
170}