1use std::{
2 fmt::{self, Debug},
3 io,
4 sync::{
5 Arc, Condvar, Mutex,
6 atomic::{AtomicUsize, Ordering},
7 },
8};
9
10use linkme::distributed_slice;
11
12use crate::{BarrierContext, CONTEXT};
13
14#[derive(Default)]
15struct ShutdownBarrier {
16 guard_count: AtomicUsize,
17 shutdown_finalized: Mutex<bool>,
18 cvar: Condvar,
19}
20
21#[derive(PartialEq, Eq)]
22pub(crate) enum Kind {
23 CurrentThread,
24 #[cfg(feature = "rt-multi-thread")]
25 MultiThread,
26}
27
28#[doc(hidden)]
29pub struct Builder {
31 kind: Kind,
32 worker_threads: usize,
33 inner: tokio::runtime::Builder,
34}
35
36impl Builder {
37 pub fn new_current_thread() -> Builder {
39 Builder {
40 kind: Kind::CurrentThread,
41 worker_threads: 1,
42 inner: tokio::runtime::Builder::new_current_thread(),
43 }
44 }
45
46 #[cfg(feature = "rt-multi-thread")]
48 pub fn new_multi_thread() -> Builder {
49 let worker_threads = std::env::var("TOKIO_WORKER_THEADS")
50 .ok()
51 .and_then(|worker_threads| worker_threads.parse().ok())
52 .unwrap_or_else(num_cpus::get);
53
54 Builder {
55 kind: Kind::MultiThread,
56 worker_threads,
57 inner: tokio::runtime::Builder::new_multi_thread(),
58 }
59 }
60
61 pub fn enable_all(&mut self) -> &mut Self {
63 self.inner.enable_all();
64 self
65 }
66
67 #[track_caller]
84 pub fn worker_threads(&mut self, val: usize) -> &mut Self {
85 assert!(val > 0, "Worker threads cannot be set to 0");
86 if self.kind.ne(&Kind::CurrentThread) {
87 self.worker_threads = val;
88 self.inner.worker_threads(val);
89 }
90 self
91 }
92
93 pub fn build(&mut self) -> io::Result<Runtime> {
95 let worker_threads = self.worker_threads;
96 let barrier = Arc::new(ShutdownBarrier::default());
97
98 let on_thread_start = {
99 let barrier = barrier.clone();
100 move || {
101 let thread_count = barrier.guard_count.fetch_add(1, Ordering::Release);
102
103 CONTEXT.with(|context| {
104 if thread_count.ge(&worker_threads) {
105 *context.borrow_mut() = Some(BarrierContext::PoolWorker)
106 } else {
107 *context.borrow_mut() = Some(BarrierContext::RuntimeWorker)
108 }
109 });
110 }
111 };
112
113 let on_thread_stop = move || {
114 let thread_count = barrier.guard_count.fetch_sub(1, Ordering::AcqRel);
115
116 CONTEXT.with(|context| {
117 if thread_count.eq(&1) {
118 *barrier.shutdown_finalized.lock().unwrap() = true;
119 barrier.cvar.notify_all();
120 } else if context.borrow().eq(&Some(BarrierContext::RuntimeWorker)) {
121 let mut shutdown_finalized = barrier.shutdown_finalized.lock().unwrap();
122 while !*shutdown_finalized {
123 shutdown_finalized = barrier.cvar.wait(shutdown_finalized).unwrap();
124 }
125 }
126 });
127 };
128
129 self
130 .inner
131 .on_thread_start(on_thread_start)
132 .on_thread_stop(on_thread_stop)
133 .build()
134 .map(Runtime::new)
135 }
136}
137
138#[doc(hidden)]
139pub struct Runtime(tokio::runtime::Runtime);
140
141impl Debug for Runtime {
142 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143 self.0.fmt(f)
144 }
145}
146
147impl Runtime {
148 fn new(inner: tokio::runtime::Runtime) -> Self {
149 Runtime(inner)
150 }
151 #[track_caller]
173 pub unsafe fn block_on<F: Future>(self, future: F) -> F::Output {
174 unsafe { self.run(|handle| handle.block_on(future)) }
175 }
176
177 pub unsafe fn run<F, Output>(self, f: F) -> Output
178 where
179 F: for<'a> FnOnce(&'a tokio::runtime::Runtime) -> Output,
180 {
181 CONTEXT.with(|context| *context.borrow_mut() = Some(BarrierContext::Owner));
182
183 let output = f(&self.0);
184
185 drop(self);
186
187 CONTEXT.with(|context| *context.borrow_mut() = None::<BarrierContext>);
188
189 output
190 }
191}
192
193#[doc(hidden)]
194#[derive(Debug, PartialEq, Eq)]
195pub enum RuntimeContext {
196 Main,
197 Test,
198}
199
200#[doc(hidden)]
201#[distributed_slice]
202pub static RUNTIMES: [RuntimeContext];
203
204#[cfg(not(feature = "compat"))]
205#[ctor::ctor]
206fn assert_runtime_configured() {
207 if RUNTIMES.is_empty() {
208 panic!(
209 "The #[async_local::main] or #[async_local::test] macro must be used to configure the Tokio runtime for use with the `async-local` crate. For compatibilty with other async runtime configurations, the `compat` feature can be used to disable the optimizations this crate provides"
210 );
211 }
212
213 if RUNTIMES
214 .iter()
215 .fold(0, |acc, context| {
216 if context.eq(&RuntimeContext::Main) {
217 acc + 1
218 } else {
219 acc
220 }
221 })
222 .gt(&1)
223 {
224 panic!("The #[async_local::main] macro cannot be used more than once");
225 }
226}