async_local/
runtime.rs

1use std::{
2  cell::RefCell,
3  io,
4  sync::{
5    Arc, Condvar, Mutex,
6    atomic::{AtomicUsize, Ordering},
7  },
8};
9
10use tokio::runtime::Runtime;
11
12#[derive(Default)]
13struct ShutdownBarrier {
14  guard_count: AtomicUsize,
15  shutdown_finalized: Mutex<bool>,
16  cvar: Condvar,
17}
18
19#[derive(PartialEq, Eq, Debug)]
20enum BarrierContext {
21  /// Tokio Runtime Worker
22  RuntimeWorker,
23  /// Tokio Pool Worker
24  PoolWorker,
25}
26
27thread_local! {
28  static CONTEXT: RefCell<Option<BarrierContext>> = const { RefCell::new(None) };
29}
30
31#[derive(PartialEq, Eq)]
32pub(crate) enum Kind {
33  CurrentThread,
34  MultiThread,
35}
36
37/// Builds Tokio runtime configured with a shutdown barrier
38pub struct Builder {
39  kind: Kind,
40  worker_threads: usize,
41  inner: tokio::runtime::Builder,
42}
43
44impl Builder {
45  /// Returns a new builder with the current thread scheduler selected.
46  pub fn new_current_thread() -> Builder {
47    Builder {
48      kind: Kind::CurrentThread,
49      worker_threads: 1,
50      inner: tokio::runtime::Builder::new_current_thread(),
51    }
52  }
53
54  /// Returns a new builder with the multi thread scheduler selected.
55  pub fn new_multi_thread() -> Builder {
56    let worker_threads = std::env::var("TOKIO_WORKER_THEADS")
57      .ok()
58      .and_then(|worker_threads| worker_threads.parse().ok())
59      .unwrap_or_else(num_cpus::get);
60
61    Builder {
62      kind: Kind::MultiThread,
63      worker_threads,
64      inner: tokio::runtime::Builder::new_multi_thread(),
65    }
66  }
67
68  /// Enables both I/O and time drivers.
69  pub fn enable_all(&mut self) -> &mut Self {
70    self.inner.enable_all();
71    self
72  }
73
74  /// Sets the number of worker threads the [`Runtime`] will use.
75  ///
76  /// This can be any number above 0 though it is advised to keep this value
77  /// on the smaller side.
78  ///
79  /// This will override the value read from environment variable `TOKIO_WORKER_THREADS`.
80  ///
81  /// # Default
82  ///
83  /// The default value is the number of cores available to the system.
84  ///
85  /// When using the `current_thread` runtime this method has no effect.
86  ///
87  /// # Examples
88  ///
89  /// ## Multi threaded runtime with 4 threads
90  ///
91  /// ```
92  /// use async_local::runtime;
93  ///
94  /// // This will spawn a work-stealing runtime with 4 worker threads.
95  /// let rt = runtime::Builder::new_multi_thread()
96  ///   .worker_threads(4)
97  ///   .build()
98  ///   .unwrap();
99  ///
100  /// rt.spawn(async move {});
101  /// ```
102  ///
103  /// ## Current thread runtime (will only run on the current thread via [`Runtime::block_on`])
104  ///
105  /// ```
106  /// use async_local::runtime;
107  ///
108  /// // Create a runtime that _must_ be driven from a call to [`Runtime::block_on`].
109  /// let rt = runtime::Builder::new_current_thread().build().unwrap();
110  ///
111  /// // This will run the runtime and future on the current thread
112  /// rt.block_on(async move {});
113  /// ```
114  ///
115  /// # Panics
116  ///
117  /// This will panic if `val` is not larger than `0`.
118  #[track_caller]
119  pub fn worker_threads(&mut self, val: usize) -> &mut Self {
120    assert!(val > 0, "Worker threads cannot be set to 0");
121    if self.kind.ne(&Kind::CurrentThread) {
122      self.worker_threads = val;
123      self.inner.worker_threads(val);
124    }
125    self
126  }
127
128  /// Creates a Tokio Runtime configured with a barrier that rendezvous worker threads during shutdown as to ensure tasks never outlive local data owned by worker threads
129  pub fn build(&mut self) -> io::Result<Runtime> {
130    let worker_threads = self.worker_threads;
131    let barrier = Arc::new(ShutdownBarrier::default());
132
133    let on_thread_start = {
134      let barrier = barrier.clone();
135      move || {
136        let thread_count = barrier.guard_count.fetch_add(1, Ordering::Release);
137
138        CONTEXT.with(|context| {
139          if thread_count.ge(&worker_threads) {
140            *context.borrow_mut() = Some(BarrierContext::PoolWorker)
141          } else {
142            *context.borrow_mut() = Some(BarrierContext::RuntimeWorker)
143          }
144        });
145      }
146    };
147
148    let on_thread_stop = move || {
149      let thread_count = barrier.guard_count.fetch_sub(1, Ordering::AcqRel);
150
151      CONTEXT.with(|context| {
152        if thread_count.eq(&1) {
153          *barrier.shutdown_finalized.lock().unwrap() = true;
154          barrier.cvar.notify_all();
155        } else if context.borrow().eq(&Some(BarrierContext::RuntimeWorker)) {
156          let mut shutdown_finalized = barrier.shutdown_finalized.lock().unwrap();
157          while !*shutdown_finalized {
158            shutdown_finalized = barrier.cvar.wait(shutdown_finalized).unwrap();
159          }
160        }
161      });
162    };
163
164    self
165      .inner
166      .on_thread_start(on_thread_start)
167      .on_thread_stop(on_thread_stop)
168      .build()
169  }
170}