async_local/
lib.rs

1#![cfg_attr(test, feature(exit_status_error))]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4extern crate self as async_local;
5
6/// A Tokio Runtime builder that configures a barrier to rendezvous worker threads during shutdown to ensure tasks never outlive local data owned by worker threads
7#[cfg(all(not(loom), feature = "tokio-runtime"))]
8#[cfg_attr(docsrs, doc(cfg(any(feature = "barrier-protected-runtime"))))]
9pub mod runtime;
10
11use std::ops::Deref;
12#[cfg(feature = "barrier-protected-runtime")]
13use std::ptr::addr_of;
14#[cfg(not(feature = "barrier-protected-runtime"))]
15use std::sync::Arc;
16#[cfg(not(loom))]
17use std::thread::LocalKey;
18
19pub use derive_async_local::AsContext;
20use generativity::{Guard, Id, make_guard};
21#[cfg(loom)]
22use loom::thread::LocalKey;
23#[doc(hidden)]
24#[cfg(all(not(loom), feature = "tokio-runtime"))]
25pub use tokio::pin;
26#[cfg(all(not(loom), feature = "tokio-runtime"))]
27use tokio::task::{JoinHandle, spawn_blocking};
28
29/// A wrapper type used for creating pointers to thread-locals
30#[cfg(not(feature = "barrier-protected-runtime"))]
31pub struct Context<T: Sync + 'static>(Arc<T>);
32
33/// A wrapper type used for creating pointers to thread-locals
34#[cfg(feature = "barrier-protected-runtime")]
35pub struct Context<T: Sync + 'static>(T);
36
37impl<T> Context<T>
38where
39  T: Sync,
40{
41  /// Create a new thread-local context
42  ///
43  /// If the `barrier-protected-runtime` feature flag isn't enabled, [`Context`] will use [`std::sync::Arc`] to ensure the validity of `T`
44  ///
45  /// # Usage
46  ///
47  /// Either wrap a type with [`Context`] and assign to a thread-local, or use as an unwrapped field in a struct that derives [`AsContext`]
48  ///
49  /// # Example
50  ///
51  /// ```rust
52  /// use std::sync::atomic::AtomicUsize;
53  ///
54  /// use async_local::Context;
55  ///
56  /// thread_local! {
57  ///   static COUNTER: Context<AtomicUsize> = Context::new(AtomicUsize::new(0));
58  /// }
59  /// ```
60  #[cfg(feature = "barrier-protected-runtime")]
61  pub fn new(inner: T) -> Context<T> {
62    Context(inner)
63  }
64
65  #[cfg(not(feature = "barrier-protected-runtime"))]
66  pub fn new(inner: T) -> Context<T> {
67    Context(Arc::new(inner))
68  }
69
70  /// Construct [`LocalRef`] with an unbounded lifetime.
71  ///
72  /// # Safety
73  ///
74  /// This lifetime must be restricted to avoid unsoundness
75  pub unsafe fn local_ref<'a>(&self) -> LocalRef<'a, T> {
76    unsafe { LocalRef::new(self, Guard::new(Id::new())) }
77  }
78}
79
80impl<T> AsRef<Context<T>> for Context<T>
81where
82  T: Sync,
83{
84  fn as_ref(&self) -> &Context<T> {
85    self
86  }
87}
88
89#[cfg(not(feature = "barrier-protected-runtime"))]
90impl<T> Deref for Context<T>
91where
92  T: Sync,
93{
94  type Target = T;
95  fn deref(&self) -> &Self::Target {
96    self.0.as_ref()
97  }
98}
99
100#[cfg(feature = "barrier-protected-runtime")]
101impl<T> Deref for Context<T>
102where
103  T: Sync,
104{
105  type Target = T;
106  fn deref(&self) -> &Self::Target {
107    &self.0
108  }
109}
110
111/// A marker trait promising that [AsRef](https://doc.rust-lang.org/std/convert/trait.AsRef.html)<[`Context<T>`]> is implemented in a way that can't be invalidated
112///
113/// # Safety
114///
115/// [`Context`] must not be invalidated as references may exist for the lifetime of the runtime.
116pub unsafe trait AsContext: AsRef<Context<Self::Target>> {
117  type Target: Sync + 'static;
118}
119
120unsafe impl<T> AsContext for Context<T>
121where
122  T: Sync,
123{
124  type Target = T;
125}
126
127/// A thread-safe pointer to a thread-local [`Context`] constrained by a "[generative](https://crates.io/crates/generativity)" lifetime brand that is [invariant](https://doc.rust-lang.org/nomicon/subtyping.html#variance) over the lifetime parameter and cannot be coerced into `'static`
128#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
129pub struct LocalRef<'id, T: Sync + 'static> {
130  #[cfg(feature = "barrier-protected-runtime")]
131  inner: *const T,
132  #[cfg(not(feature = "barrier-protected-runtime"))]
133  inner: Arc<T>,
134  /// Lifetime carrier
135  _brand: Id<'id>,
136}
137
138impl<'id, T> LocalRef<'id, T>
139where
140  T: Sync + 'static,
141{
142  #[cfg(not(feature = "barrier-protected-runtime"))]
143  unsafe fn new(context: &Context<T>, guard: Guard<'id>) -> Self {
144    LocalRef {
145      inner: context.0.clone(),
146      _brand: guard.into(),
147    }
148  }
149
150  #[cfg(feature = "barrier-protected-runtime")]
151  unsafe fn new(context: &Context<T>, guard: Guard<'id>) -> Self {
152    LocalRef {
153      inner: addr_of!(context.0),
154      _brand: guard.into(),
155    }
156  }
157
158  /// A wrapper around [`tokio::task::spawn_blocking`](https://docs.rs/tokio/latest/tokio/task/fn.spawn_blocking.html) that safely constrains the lifetime of [`LocalRef`]
159  #[cfg(all(not(loom), feature = "tokio-runtime"))]
160  #[cfg_attr(
161    docsrs,
162    doc(cfg(any(feature = "tokio-runtime", feature = "barrier-protected-runtime")))
163  )]
164  pub fn with_blocking<F, R>(self, f: F) -> JoinHandle<R>
165  where
166    F: for<'a> FnOnce(LocalRef<'a, T>) -> R + Send + 'static,
167    R: Send + 'static,
168  {
169    use std::mem::transmute;
170
171    let local_ref = unsafe { transmute::<LocalRef<'_, T>, LocalRef<'_, T>>(self) };
172
173    spawn_blocking(move || f(local_ref))
174  }
175}
176
177#[cfg(feature = "barrier-protected-runtime")]
178impl<T> Deref for LocalRef<'_, T>
179where
180  T: Sync,
181{
182  type Target = T;
183  fn deref(&self) -> &Self::Target {
184    unsafe { &*self.inner }
185  }
186}
187#[cfg(not(feature = "barrier-protected-runtime"))]
188impl<T> Deref for LocalRef<'_, T>
189where
190  T: Sync,
191{
192  type Target = T;
193  fn deref(&self) -> &Self::Target {
194    self.inner.deref()
195  }
196}
197
198#[cfg(feature = "barrier-protected-runtime")]
199impl<T> Clone for LocalRef<'_, T>
200where
201  T: Sync + 'static,
202{
203  fn clone(&self) -> Self {
204    LocalRef {
205      inner: self.inner,
206      _brand: self._brand,
207    }
208  }
209}
210
211#[cfg(not(feature = "barrier-protected-runtime"))]
212impl<T> Clone for LocalRef<'_, T>
213where
214  T: Sync + 'static,
215{
216  fn clone(&self) -> Self {
217    LocalRef {
218      inner: self.inner.clone(),
219      _brand: self._brand,
220    }
221  }
222}
223
224unsafe impl<T> Send for LocalRef<'_, T> where T: Sync {}
225unsafe impl<T> Sync for LocalRef<'_, T> where T: Sync {}
226/// LocalKey extension for creating thread-safe pointers to thread-local [`Context`]
227pub trait AsyncLocal<T>
228where
229  T: AsContext,
230{
231  /// A wrapper around [`tokio::task::spawn_blocking`](https://docs.rs/tokio/latest/tokio/task/fn.spawn_blocking.html) that safely constrains the lifetime of [`LocalRef`]
232  #[cfg(all(not(loom), feature = "tokio-runtime"))]
233  #[cfg_attr(
234    docsrs,
235    doc(cfg(any(feature = "tokio-runtime", feature = "barrier-protected-runtime")))
236  )]
237  fn with_blocking<F, R>(&'static self, f: F) -> JoinHandle<R>
238  where
239    F: for<'id> FnOnce(LocalRef<'id, T::Target>) -> R + Send + 'static,
240    R: Send + 'static;
241
242  /// Acquire a reference to the value in this TLS key.
243  fn with_async<F, R>(&'static self, f: F) -> impl Future<Output = R>
244  where
245    F: for<'a> AsyncFnMut(LocalRef<'a, T::Target>) -> R;
246
247  /// Create a pointer to a thread local [`Context`] using a trusted lifetime carrier.
248  ///
249  /// # Usage
250  ///
251  /// Use [`generativity::make_guard`] to generate a unique [`invariant`](https://doc.rust-lang.org/nomicon/subtyping.html#variance) lifetime brand
252  ///
253  /// # Safety
254  ///
255  /// When `barrier-protected-runtime` is enabled, [`tokio::main`](https://docs.rs/tokio/1/tokio/attr.test.html) and [`tokio::test`](https://docs.rs/tokio/1/tokio/attr.test.html) must be used with `crate = "async_local"` set to configure the runtime to synchronize shutdown. This ensures the validity of all invariant lifetimes
256  fn local_ref<'id>(&'static self, guard: Guard<'id>) -> LocalRef<'id, T::Target>;
257}
258
259impl<T> AsyncLocal<T> for LocalKey<T>
260where
261  T: AsContext,
262{
263  #[cfg(all(not(loom), feature = "tokio-runtime"))]
264  #[cfg_attr(
265    docsrs,
266    doc(cfg(any(feature = "tokio-runtime", feature = "barrier-protected-runtime")))
267  )]
268  fn with_blocking<F, R>(&'static self, f: F) -> JoinHandle<R>
269  where
270    F: for<'id> FnOnce(LocalRef<'id, T::Target>) -> R + Send + 'static,
271    R: Send + 'static,
272  {
273    let guard = unsafe { Guard::new(Id::new()) };
274    let local_ref = self.local_ref(guard);
275    spawn_blocking(move || f(local_ref))
276  }
277
278  async fn with_async<F, R>(&'static self, mut f: F) -> R
279  where
280    F: for<'a> AsyncFnMut(LocalRef<'a, T::Target>) -> R,
281  {
282    make_guard!(guard);
283    f(self.local_ref(guard)).await
284  }
285
286  fn local_ref<'id>(&'static self, guard: Guard<'id>) -> LocalRef<'id, T::Target> {
287    self.with(|value| unsafe { LocalRef::new(value.as_ref(), guard) })
288  }
289}
290
291#[cfg(test)]
292mod tests {
293  use std::sync::atomic::{AtomicUsize, Ordering};
294
295  use generativity::make_guard;
296  use tokio::task::yield_now;
297
298  use super::*;
299
300  thread_local! {
301      static COUNTER: Context<AtomicUsize> = Context::new(AtomicUsize::new(0));
302  }
303
304  #[tokio::test(crate = "async_local", flavor = "multi_thread")]
305  async fn with_blocking() {
306    COUNTER
307      .with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed))
308      .await
309      .unwrap();
310
311    make_guard!(guard);
312    let local_ref = COUNTER.local_ref(guard);
313
314    local_ref
315      .with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed))
316      .await
317      .unwrap();
318  }
319
320  #[tokio::test(crate = "async_local", flavor = "multi_thread")]
321  async fn ref_spans_await() {
322    make_guard!(guard);
323    let counter = COUNTER.local_ref(guard);
324    yield_now().await;
325    counter.fetch_add(1, Ordering::SeqCst);
326  }
327
328  #[tokio::test(crate = "async_local", flavor = "multi_thread")]
329  async fn with_async_trait() {
330    struct Counter;
331
332    trait Countable {
333      async fn add_one(ref_guard: LocalRef<'_, AtomicUsize>) -> usize;
334    }
335
336    impl Countable for Counter {
337      async fn add_one(counter: LocalRef<'_, AtomicUsize>) -> usize {
338        yield_now().await;
339        counter.fetch_add(1, Ordering::Release)
340      }
341    }
342
343    make_guard!(guard);
344    let counter = COUNTER.local_ref(guard);
345
346    Counter::add_one(counter).await;
347  }
348}