1#![cfg_attr(test, feature(exit_status_error))]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4extern crate self as async_local;
5
6#[doc(hidden)]
8#[cfg(all(not(loom), feature = "rt"))]
9#[path = "runtime.rs"]
10pub mod __runtime;
11
12#[cfg(not(feature = "compat"))]
13use std::ptr::addr_of;
14#[cfg(feature = "compat")]
15use std::sync::Arc;
16#[cfg(not(loom))]
17use std::thread::LocalKey;
18use std::{cell::RefCell, ops::Deref};
19
20pub use derive_async_local::{AsContext, main, test};
21use generativity::{Guard, Id, make_guard};
22#[doc(hidden)]
23pub use linkme;
24#[cfg(loom)]
25use loom::thread::LocalKey;
26#[doc(hidden)]
27#[cfg(all(not(loom), feature = "rt"))]
28pub use tokio::pin;
29#[cfg(all(not(loom), feature = "rt"))]
30use tokio::task::{JoinHandle, spawn_blocking};
31
32#[derive(PartialEq, Eq, Debug)]
33pub(crate) enum BarrierContext {
34 Owner,
35 RuntimeWorker,
37 PoolWorker,
39}
40
41thread_local! {
42 pub(crate) static CONTEXT: RefCell<Option<BarrierContext>> = const { RefCell::new(None) };
43}
44
45pub struct Context<T: Sync + 'static>(
47 #[cfg(not(feature = "compat"))] T,
48 #[cfg(feature = "compat")] Arc<T>,
49);
50
51impl<T> Context<T>
52where
53 T: Sync,
54{
55 pub fn new(inner: T) -> Context<T> {
87 #[cfg(not(feature = "compat"))]
88 {
89 Context(inner)
90 }
91 #[cfg(feature = "compat")]
92 {
93 Context(Arc::new(inner))
94 }
95 }
96
97 pub unsafe fn local_ref<'a>(&self) -> LocalRef<'a, T> {
103 unsafe { LocalRef::new(self, Guard::new(Id::new())) }
104 }
105}
106
107impl<T> AsRef<Context<T>> for Context<T>
108where
109 T: Sync,
110{
111 fn as_ref(&self) -> &Context<T> {
112 self
113 }
114}
115
116impl<T> Deref for Context<T>
117where
118 T: Sync,
119{
120 type Target = T;
121 fn deref(&self) -> &Self::Target {
122 #[cfg(not(feature = "compat"))]
123 {
124 &self.0
125 }
126 #[cfg(feature = "compat")]
127 {
128 self.0.as_ref()
129 }
130 }
131}
132
133pub unsafe trait AsContext: AsRef<Context<Self::Target>> {
139 type Target: Sync + 'static;
140}
141
142unsafe impl<T> AsContext for Context<T>
143where
144 T: Sync,
145{
146 type Target = T;
147}
148
149#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
151pub struct LocalRef<'id, T: Sync + 'static> {
152 #[cfg(not(feature = "compat"))]
153 inner: *const T,
154 #[cfg(feature = "compat")]
155 inner: Arc<T>,
156 _brand: Id<'id>,
158}
159
160impl<'id, T> LocalRef<'id, T>
161where
162 T: Sync + 'static,
163{
164 unsafe fn new(context: &Context<T>, guard: Guard<'id>) -> Self {
165 LocalRef {
166 #[cfg(not(feature = "compat"))]
167 inner: addr_of!(context.0),
168 #[cfg(feature = "compat")]
169 inner: context.0.clone(),
170 _brand: guard.into(),
171 }
172 }
173
174 #[cfg(all(not(loom), feature = "rt"))]
176 #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
177 pub fn with_blocking<F, R>(self, f: F) -> JoinHandle<R>
178 where
179 F: for<'a> FnOnce(LocalRef<'a, T>) -> R + Send + 'static,
180 R: Send + 'static,
181 {
182 use std::mem::transmute;
183
184 let local_ref = unsafe { transmute::<LocalRef<'_, T>, LocalRef<'_, T>>(self) };
185
186 spawn_blocking(move || f(local_ref))
187 }
188}
189
190impl<T> Deref for LocalRef<'_, T>
191where
192 T: Sync,
193{
194 type Target = T;
195 fn deref(&self) -> &Self::Target {
196 #[cfg(not(feature = "compat"))]
197 {
198 unsafe { &*self.inner }
199 }
200 #[cfg(feature = "compat")]
201 {
202 self.inner.deref()
203 }
204 }
205}
206
207impl<T> Clone for LocalRef<'_, T>
208where
209 T: Sync + 'static,
210{
211 fn clone(&self) -> Self {
212 LocalRef {
213 #[cfg(not(feature = "compat"))]
214 inner: self.inner,
215 #[cfg(feature = "compat")]
216 inner: self.inner.clone(),
217 _brand: self._brand,
218 }
219 }
220}
221
222unsafe impl<T> Send for LocalRef<'_, T> where T: Sync {}
223unsafe impl<T> Sync for LocalRef<'_, T> where T: Sync {}
224pub trait AsyncLocal<T>
226where
227 T: AsContext,
228{
229 #[cfg(all(not(loom), feature = "rt"))]
231 #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
232 fn with_blocking<F, R>(&'static self, f: F) -> JoinHandle<R>
233 where
234 F: for<'id> FnOnce(LocalRef<'id, T::Target>) -> R + Send + 'static,
235 R: Send + 'static;
236
237 fn with_async<F, R>(&'static self, f: F) -> impl Future<Output = R>
239 where
240 F: for<'a> AsyncFnMut(LocalRef<'a, T::Target>) -> R;
241
242 fn local_ref<'id>(&'static self, guard: Guard<'id>) -> LocalRef<'id, T::Target>;
252}
253
254impl<T> AsyncLocal<T> for LocalKey<T>
255where
256 T: AsContext,
257{
258 #[cfg(all(not(loom), feature = "rt"))]
259 #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
260 fn with_blocking<F, R>(&'static self, f: F) -> JoinHandle<R>
261 where
262 F: for<'id> FnOnce(LocalRef<'id, T::Target>) -> R + Send + 'static,
263 R: Send + 'static,
264 {
265 let guard = unsafe { Guard::new(Id::new()) };
266 let local_ref = self.local_ref(guard);
267 spawn_blocking(move || f(local_ref))
268 }
269
270 async fn with_async<F, R>(&'static self, mut f: F) -> R
271 where
272 F: for<'a> AsyncFnMut(LocalRef<'a, T::Target>) -> R,
273 {
274 make_guard!(guard);
275 let local_ref = self.local_ref(guard);
276 f(local_ref).await
277 }
278
279 #[track_caller]
280 #[inline(always)]
281 fn local_ref<'id>(&'static self, guard: Guard<'id>) -> LocalRef<'id, T::Target> {
282 #[cfg(not(feature = "compat"))]
283 {
284 if CONTEXT
285 .with(|context| matches!(&*context.borrow(), None | Some(BarrierContext::PoolWorker)))
286 {
287 panic!(
288 "LocalRef can only be created within the async context of a Tokio Runtime configured by `#[async_local::main]` or `#[async_local::test]`"
289 );
290 }
291 }
292
293 self.with(|value| unsafe { LocalRef::new(value.as_ref(), guard) })
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use std::sync::atomic::{AtomicUsize, Ordering};
300
301 use generativity::make_guard;
302 use tokio::task::yield_now;
303
304 use super::*;
305
306 thread_local! {
307 static COUNTER: Context<AtomicUsize> = Context::new(AtomicUsize::new(0));
308 }
309
310 #[async_local::test]
311 async fn with_blocking() {
312 COUNTER
313 .with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed))
314 .await
315 .unwrap();
316
317 make_guard!(guard);
318 let local_ref = COUNTER.local_ref(guard);
319
320 local_ref
321 .with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed))
322 .await
323 .unwrap();
324 }
325
326 #[async_local::test]
327 async fn ref_spans_await() {
328 make_guard!(guard);
329 let counter = COUNTER.local_ref(guard);
330 yield_now().await;
331 counter.fetch_add(1, Ordering::SeqCst);
332 }
333
334 #[async_local::test]
335 async fn with_async_trait() {
336 struct Counter;
337
338 trait Countable {
339 async fn add_one(ref_guard: LocalRef<'_, AtomicUsize>) -> usize;
340 }
341
342 impl Countable for Counter {
343 async fn add_one(counter: LocalRef<'_, AtomicUsize>) -> usize {
344 yield_now().await;
345 counter.fetch_add(1, Ordering::Release)
346 }
347 }
348
349 make_guard!(guard);
350 let counter = COUNTER.local_ref(guard);
351
352 Counter::add_one(counter).await;
353 }
354}