async_local/
lib.rs

1#![cfg_attr(test, feature(exit_status_error))]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4extern crate self as async_local;
5
6/// A Tokio Runtime builder that configures a barrier to rendezvous worker threads during shutdown to ensure tasks never outlive local data owned by worker threads
7#[doc(hidden)]
8#[cfg(all(not(loom), feature = "rt"))]
9#[path = "runtime.rs"]
10pub mod __runtime;
11
12#[cfg(not(feature = "compat"))]
13use std::ptr::addr_of;
14#[cfg(feature = "compat")]
15use std::sync::Arc;
16#[cfg(not(loom))]
17use std::thread::LocalKey;
18use std::{cell::RefCell, ops::Deref};
19
20pub use derive_async_local::{AsContext, main, test};
21use generativity::{Guard, Id, make_guard};
22#[doc(hidden)]
23pub use linkme;
24#[cfg(loom)]
25use loom::thread::LocalKey;
26#[doc(hidden)]
27#[cfg(all(not(loom), feature = "rt"))]
28pub use tokio::pin;
29#[cfg(all(not(loom), feature = "rt"))]
30use tokio::task::{JoinHandle, spawn_blocking};
31
32#[derive(PartialEq, Eq, Debug)]
33pub(crate) enum BarrierContext {
34  Owner,
35  /// Tokio Runtime Worker
36  RuntimeWorker,
37  /// Tokio Pool Worker
38  PoolWorker,
39}
40
41thread_local! {
42  pub(crate) static CONTEXT: RefCell<Option<BarrierContext>> = const { RefCell::new(None) };
43}
44
45/// A wrapper type used for creating pointers to thread-locals
46pub struct Context<T: Sync + 'static>(
47  #[cfg(not(feature = "compat"))] T,
48  #[cfg(feature = "compat")] Arc<T>,
49);
50
51impl<T> Context<T>
52where
53  T: Sync,
54{
55  /// Create a new thread-local context
56  ///
57  /// If the `compat` feature flag is enabled, [`Context`] will downgrade to internally using [`std::sync::Arc`] to ensure the validity of `T`
58  ///
59  /// # Usage
60  ///
61  /// Either wrap a type with [`Context`] and assign to a thread-local, or use as an unwrapped field in a struct that derives [`AsContext`]
62  ///
63  /// # Example
64  ///
65  /// ```rust
66  /// use std::sync::atomic::{AtomicUsize, Ordering};
67  ///
68  /// use async_local::{AsyncLocal, Context};
69  /// use generativity::make_guard;
70  ///
71  /// thread_local! {
72  ///   static COUNTER: Context<AtomicUsize> = Context::new(AtomicUsize::new(0));
73  /// }
74  ///
75  /// #[async_local::main(flavor = "current_thread")]
76  /// async fn main() {
77  ///   make_guard!(guard);
78  ///   let counter = COUNTER.local_ref(guard);
79  ///
80  ///   let _count = counter
81  ///     .with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed))
82  ///     .await
83  ///     .unwrap();
84  /// }
85  /// ```
86  pub fn new(inner: T) -> Context<T> {
87    #[cfg(not(feature = "compat"))]
88    {
89      Context(inner)
90    }
91    #[cfg(feature = "compat")]
92    {
93      Context(Arc::new(inner))
94    }
95  }
96
97  /// Construct [`LocalRef`] with an unbounded lifetime.
98  ///
99  /// # Safety
100  ///
101  /// This lifetime must be restricted to avoid unsoundness
102  pub unsafe fn local_ref<'a>(&self) -> LocalRef<'a, T> {
103    unsafe { LocalRef::new(self, Guard::new(Id::new())) }
104  }
105}
106
107impl<T> AsRef<Context<T>> for Context<T>
108where
109  T: Sync,
110{
111  fn as_ref(&self) -> &Context<T> {
112    self
113  }
114}
115
116impl<T> Deref for Context<T>
117where
118  T: Sync,
119{
120  type Target = T;
121  fn deref(&self) -> &Self::Target {
122    #[cfg(not(feature = "compat"))]
123    {
124      &self.0
125    }
126    #[cfg(feature = "compat")]
127    {
128      self.0.as_ref()
129    }
130  }
131}
132
133/// A marker trait promising that [AsRef](https://doc.rust-lang.org/std/convert/trait.AsRef.html)<[`Context<T>`]> is implemented in a way that can't be invalidated
134///
135/// # Safety
136///
137/// [`Context`] must not be invalidated as references may exist for the lifetime of the runtime.
138pub unsafe trait AsContext: AsRef<Context<Self::Target>> {
139  type Target: Sync + 'static;
140}
141
142unsafe impl<T> AsContext for Context<T>
143where
144  T: Sync,
145{
146  type Target = T;
147}
148
149/// A thread-safe pointer to a thread-local [`Context`] constrained by a "[generative](https://crates.io/crates/generativity)" lifetime brand that is [invariant](https://doc.rust-lang.org/nomicon/subtyping.html#variance) over the lifetime parameter and cannot be coerced into `'static`
150#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
151pub struct LocalRef<'id, T: Sync + 'static> {
152  #[cfg(not(feature = "compat"))]
153  inner: *const T,
154  #[cfg(feature = "compat")]
155  inner: Arc<T>,
156  /// Lifetime carrier
157  _brand: Id<'id>,
158}
159
160impl<'id, T> LocalRef<'id, T>
161where
162  T: Sync + 'static,
163{
164  unsafe fn new(context: &Context<T>, guard: Guard<'id>) -> Self {
165    LocalRef {
166      #[cfg(not(feature = "compat"))]
167      inner: addr_of!(context.0),
168      #[cfg(feature = "compat")]
169      inner: context.0.clone(),
170      _brand: guard.into(),
171    }
172  }
173
174  /// A wrapper around [`tokio::task::spawn_blocking`](https://docs.rs/tokio/latest/tokio/task/fn.spawn_blocking.html) that safely constrains the lifetime of [`LocalRef`]
175  #[cfg(all(not(loom), feature = "rt"))]
176  #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
177  pub fn with_blocking<F, R>(self, f: F) -> JoinHandle<R>
178  where
179    F: for<'a> FnOnce(LocalRef<'a, T>) -> R + Send + 'static,
180    R: Send + 'static,
181  {
182    use std::mem::transmute;
183
184    let local_ref = unsafe { transmute::<LocalRef<'_, T>, LocalRef<'_, T>>(self) };
185
186    spawn_blocking(move || f(local_ref))
187  }
188}
189
190impl<T> Deref for LocalRef<'_, T>
191where
192  T: Sync,
193{
194  type Target = T;
195  fn deref(&self) -> &Self::Target {
196    #[cfg(not(feature = "compat"))]
197    {
198      unsafe { &*self.inner }
199    }
200    #[cfg(feature = "compat")]
201    {
202      self.inner.deref()
203    }
204  }
205}
206
207impl<T> Clone for LocalRef<'_, T>
208where
209  T: Sync + 'static,
210{
211  fn clone(&self) -> Self {
212    LocalRef {
213      #[cfg(not(feature = "compat"))]
214      inner: self.inner,
215      #[cfg(feature = "compat")]
216      inner: self.inner.clone(),
217      _brand: self._brand,
218    }
219  }
220}
221
222unsafe impl<T> Send for LocalRef<'_, T> where T: Sync {}
223unsafe impl<T> Sync for LocalRef<'_, T> where T: Sync {}
224/// LocalKey extension for creating thread-safe pointers to thread-local [`Context`]
225pub trait AsyncLocal<T>
226where
227  T: AsContext,
228{
229  /// A wrapper around [`tokio::task::spawn_blocking`](https://docs.rs/tokio/latest/tokio/task/fn.spawn_blocking.html) that safely constrains the lifetime of [`LocalRef`]
230  #[cfg(all(not(loom), feature = "rt"))]
231  #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
232  fn with_blocking<F, R>(&'static self, f: F) -> JoinHandle<R>
233  where
234    F: for<'id> FnOnce(LocalRef<'id, T::Target>) -> R + Send + 'static,
235    R: Send + 'static;
236
237  /// Acquire a reference to the value in this TLS key.
238  fn with_async<F, R>(&'static self, f: F) -> impl Future<Output = R>
239  where
240    F: for<'a> AsyncFnMut(LocalRef<'a, T::Target>) -> R;
241
242  /// Create a pointer to a thread local [`Context`] using a trusted lifetime carrier.
243  ///
244  /// # Usage
245  ///
246  /// Use [`generativity::make_guard`] to generate a unique [`invariant`](https://doc.rust-lang.org/nomicon/subtyping.html#variance) lifetime brand
247  ///
248  /// # Panic
249  ///
250  /// [`LocalRef`] must be created within the async context of the runtime.
251  fn local_ref<'id>(&'static self, guard: Guard<'id>) -> LocalRef<'id, T::Target>;
252}
253
254impl<T> AsyncLocal<T> for LocalKey<T>
255where
256  T: AsContext,
257{
258  #[cfg(all(not(loom), feature = "rt"))]
259  #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
260  fn with_blocking<F, R>(&'static self, f: F) -> JoinHandle<R>
261  where
262    F: for<'id> FnOnce(LocalRef<'id, T::Target>) -> R + Send + 'static,
263    R: Send + 'static,
264  {
265    let guard = unsafe { Guard::new(Id::new()) };
266    let local_ref = self.local_ref(guard);
267    spawn_blocking(move || f(local_ref))
268  }
269
270  async fn with_async<F, R>(&'static self, mut f: F) -> R
271  where
272    F: for<'a> AsyncFnMut(LocalRef<'a, T::Target>) -> R,
273  {
274    make_guard!(guard);
275    let local_ref = self.local_ref(guard);
276    f(local_ref).await
277  }
278
279  #[track_caller]
280  #[inline(always)]
281  fn local_ref<'id>(&'static self, guard: Guard<'id>) -> LocalRef<'id, T::Target> {
282    #[cfg(not(feature = "compat"))]
283    {
284      if CONTEXT
285        .with(|context| matches!(&*context.borrow(), None | Some(BarrierContext::PoolWorker)))
286      {
287        panic!(
288          "LocalRef can only be created within the async context of a Tokio Runtime configured by `#[async_local::main]` or `#[async_local::test]`"
289        );
290      }
291    }
292
293    self.with(|value| unsafe { LocalRef::new(value.as_ref(), guard) })
294  }
295}
296
297#[cfg(test)]
298mod tests {
299  use std::sync::atomic::{AtomicUsize, Ordering};
300
301  use generativity::make_guard;
302  use tokio::task::yield_now;
303
304  use super::*;
305
306  thread_local! {
307      static COUNTER: Context<AtomicUsize> = Context::new(AtomicUsize::new(0));
308  }
309
310  #[async_local::test]
311  async fn with_blocking() {
312    COUNTER
313      .with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed))
314      .await
315      .unwrap();
316
317    make_guard!(guard);
318    let local_ref = COUNTER.local_ref(guard);
319
320    local_ref
321      .with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed))
322      .await
323      .unwrap();
324  }
325
326  #[async_local::test]
327  async fn ref_spans_await() {
328    make_guard!(guard);
329    let counter = COUNTER.local_ref(guard);
330    yield_now().await;
331    counter.fetch_add(1, Ordering::SeqCst);
332  }
333
334  #[async_local::test]
335  async fn with_async_trait() {
336    struct Counter;
337
338    trait Countable {
339      async fn add_one(ref_guard: LocalRef<'_, AtomicUsize>) -> usize;
340    }
341
342    impl Countable for Counter {
343      async fn add_one(counter: LocalRef<'_, AtomicUsize>) -> usize {
344        yield_now().await;
345        counter.fetch_add(1, Ordering::Release)
346      }
347    }
348
349    make_guard!(guard);
350    let counter = COUNTER.local_ref(guard);
351
352    Counter::add_one(counter).await;
353  }
354}