async_local/
runtime.rs

1use std::{
2  fmt::{self, Debug},
3  io,
4  sync::{
5    Arc, Condvar, Mutex,
6    atomic::{AtomicUsize, Ordering},
7  },
8};
9
10use linkme::distributed_slice;
11
12use crate::{BarrierContext, CONTEXT};
13
14#[derive(Default)]
15struct ShutdownBarrier {
16  guard_count: AtomicUsize,
17  shutdown_finalized: Mutex<bool>,
18  cvar: Condvar,
19}
20
21#[derive(PartialEq, Eq)]
22pub(crate) enum Kind {
23  CurrentThread,
24  #[cfg(feature = "rt-multi-thread")]
25  MultiThread,
26}
27
28#[doc(hidden)]
29/// Builds Tokio runtime configured with a shutdown barrier
30pub struct Builder {
31  kind: Kind,
32  worker_threads: usize,
33  inner: tokio::runtime::Builder,
34}
35
36impl Builder {
37  /// Returns a new builder with the current thread scheduler selected.
38  pub fn new_current_thread() -> Builder {
39    Builder {
40      kind: Kind::CurrentThread,
41      worker_threads: 1,
42      inner: tokio::runtime::Builder::new_current_thread(),
43    }
44  }
45
46  /// Returns a new builder with the multi thread scheduler selected.
47  #[cfg(feature = "rt-multi-thread")]
48  pub fn new_multi_thread() -> Builder {
49    let worker_threads = std::env::var("TOKIO_WORKER_THEADS")
50      .ok()
51      .and_then(|worker_threads| worker_threads.parse().ok())
52      .unwrap_or_else(num_cpus::get);
53
54    Builder {
55      kind: Kind::MultiThread,
56      worker_threads,
57      inner: tokio::runtime::Builder::new_multi_thread(),
58    }
59  }
60
61  /// Enables both I/O and time drivers.
62  pub fn enable_all(&mut self) -> &mut Self {
63    self.inner.enable_all();
64    self
65  }
66
67  /// Sets the number of worker threads the [`Runtime`] will use.
68  ///
69  /// This can be any number above 0 though it is advised to keep this value
70  /// on the smaller side.
71  ///
72  /// This will override the value read from environment variable `TOKIO_WORKER_THREADS`.
73  ///
74  /// # Default
75  ///
76  /// The default value is the number of cores available to the system.
77  ///
78  /// When using the `current_thread` runtime this method has no effect.
79  ///
80  /// # Panics
81  ///
82  /// This will panic if `val` is not larger than `0`.
83  #[track_caller]
84  pub fn worker_threads(&mut self, val: usize) -> &mut Self {
85    assert!(val > 0, "Worker threads cannot be set to 0");
86    if self.kind.ne(&Kind::CurrentThread) {
87      self.worker_threads = val;
88      self.inner.worker_threads(val);
89    }
90    self
91  }
92
93  /// Creates a Tokio Runtime configured with a barrier that rendezvous worker threads during shutdown as to ensure tasks never outlive local data owned by worker threads
94  pub fn build(&mut self) -> io::Result<Runtime> {
95    let worker_threads = self.worker_threads;
96    let barrier = Arc::new(ShutdownBarrier::default());
97
98    let on_thread_start = {
99      let barrier = barrier.clone();
100      move || {
101        let thread_count = barrier.guard_count.fetch_add(1, Ordering::Release);
102
103        CONTEXT.with(|context| {
104          if thread_count.ge(&worker_threads) {
105            *context.borrow_mut() = Some(BarrierContext::PoolWorker)
106          } else {
107            *context.borrow_mut() = Some(BarrierContext::RuntimeWorker)
108          }
109        });
110      }
111    };
112
113    let on_thread_stop = move || {
114      let thread_count = barrier.guard_count.fetch_sub(1, Ordering::AcqRel);
115
116      CONTEXT.with(|context| {
117        if thread_count.eq(&1) {
118          *barrier.shutdown_finalized.lock().unwrap() = true;
119          barrier.cvar.notify_all();
120        } else if context.borrow().eq(&Some(BarrierContext::RuntimeWorker)) {
121          let mut shutdown_finalized = barrier.shutdown_finalized.lock().unwrap();
122          while !*shutdown_finalized {
123            shutdown_finalized = barrier.cvar.wait(shutdown_finalized).unwrap();
124          }
125        }
126      });
127    };
128
129    self
130      .inner
131      .on_thread_start(on_thread_start)
132      .on_thread_stop(on_thread_stop)
133      .build()
134      .map(Runtime::new)
135  }
136}
137
138#[doc(hidden)]
139pub struct Runtime(tokio::runtime::Runtime);
140
141impl Debug for Runtime {
142  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143    self.0.fmt(f)
144  }
145}
146
147impl Runtime {
148  fn new(inner: tokio::runtime::Runtime) -> Self {
149    Runtime(inner)
150  }
151  /// Runs a future to completion on the Tokio runtime. This is the
152  /// runtime's entry point.
153  ///
154  /// This runs the given future on the current thread, blocking until it is
155  /// complete, and yielding its resolved result. Any tasks or timers
156  /// which the future spawns internally will be executed on the runtime.
157  ///
158  /// # Non-worker future
159  ///
160  /// Note that the future required by this function does not run as a
161  /// worker. The expectation is that other tasks are spawned by the future here.
162  /// Awaiting on other futures from the future provided here will not
163  /// perform as fast as those spawned as workers.
164  ///
165  /// # Panics
166  ///
167  /// This function panics if the provided future panics, or if called within an
168  /// asynchronous execution context.
169  ///
170  /// # Safety
171  /// This is internal to async_local and is meant to be used exclusively with #[async_local::main] and #[async_local::test].
172  #[track_caller]
173  pub unsafe fn block_on<F: Future>(self, future: F) -> F::Output {
174    unsafe { self.run(|handle| handle.block_on(future)) }
175  }
176
177  pub unsafe fn run<F, Output>(self, f: F) -> Output
178  where
179    F: for<'a> FnOnce(&'a tokio::runtime::Runtime) -> Output,
180  {
181    CONTEXT.with(|context| *context.borrow_mut() = Some(BarrierContext::Owner));
182
183    let output = f(&self.0);
184
185    drop(self);
186
187    CONTEXT.with(|context| *context.borrow_mut() = None::<BarrierContext>);
188
189    output
190  }
191}
192
193#[doc(hidden)]
194#[derive(Debug, PartialEq, Eq)]
195pub enum RuntimeContext {
196  Main,
197  Test,
198}
199
200#[doc(hidden)]
201#[distributed_slice]
202pub static RUNTIMES: [RuntimeContext];
203
204#[cfg(not(feature = "compat"))]
205#[ctor::ctor]
206fn assert_runtime_configured() {
207  if RUNTIMES.is_empty() {
208    panic!(
209      "The #[async_local::main] or #[async_local::test] macro must be used to configure the Tokio runtime for use with the `async-local` crate. For compatibilty with other async runtime configurations, the `compat` feature can be used to disable the optimizations this crate provides"
210    );
211  }
212
213  if RUNTIMES
214    .iter()
215    .fold(0, |acc, context| {
216      if context.eq(&RuntimeContext::Main) {
217        acc + 1
218      } else {
219        acc
220      }
221    })
222    .gt(&1)
223  {
224    panic!("The #[async_local::main] macro cannot be used more than once");
225  }
226}