1#![cfg_attr(test, feature(exit_status_error))]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4extern crate self as async_local;
5
6#[cfg(all(not(loom), feature = "tokio-runtime"))]
8#[cfg_attr(docsrs, doc(cfg(any(feature = "barrier-protected-runtime"))))]
9pub mod runtime;
10
11use std::ops::Deref;
12#[cfg(feature = "barrier-protected-runtime")]
13use std::ptr::addr_of;
14#[cfg(not(feature = "barrier-protected-runtime"))]
15use std::sync::Arc;
16#[cfg(not(loom))]
17use std::thread::LocalKey;
18
19pub use derive_async_local::AsContext;
20use generativity::{Guard, Id, make_guard};
21#[cfg(loom)]
22use loom::thread::LocalKey;
23#[doc(hidden)]
24#[cfg(all(not(loom), feature = "tokio-runtime"))]
25pub use tokio::pin;
26#[cfg(all(not(loom), feature = "tokio-runtime"))]
27use tokio::task::{JoinHandle, spawn_blocking};
28
29#[cfg(not(feature = "barrier-protected-runtime"))]
31pub struct Context<T: Sync + 'static>(Arc<T>);
32
33#[cfg(feature = "barrier-protected-runtime")]
35pub struct Context<T: Sync + 'static>(T);
36
37impl<T> Context<T>
38where
39 T: Sync,
40{
41 #[cfg(feature = "barrier-protected-runtime")]
61 pub fn new(inner: T) -> Context<T> {
62 Context(inner)
63 }
64
65 #[cfg(not(feature = "barrier-protected-runtime"))]
66 pub fn new(inner: T) -> Context<T> {
67 Context(Arc::new(inner))
68 }
69
70 pub unsafe fn local_ref<'a>(&self) -> LocalRef<'a, T> {
76 unsafe { LocalRef::new(self, Guard::new(Id::new())) }
77 }
78}
79
80impl<T> AsRef<Context<T>> for Context<T>
81where
82 T: Sync,
83{
84 fn as_ref(&self) -> &Context<T> {
85 self
86 }
87}
88
89#[cfg(not(feature = "barrier-protected-runtime"))]
90impl<T> Deref for Context<T>
91where
92 T: Sync,
93{
94 type Target = T;
95 fn deref(&self) -> &Self::Target {
96 self.0.as_ref()
97 }
98}
99
100#[cfg(feature = "barrier-protected-runtime")]
101impl<T> Deref for Context<T>
102where
103 T: Sync,
104{
105 type Target = T;
106 fn deref(&self) -> &Self::Target {
107 &self.0
108 }
109}
110
111pub unsafe trait AsContext: AsRef<Context<Self::Target>> {
117 type Target: Sync + 'static;
118}
119
120unsafe impl<T> AsContext for Context<T>
121where
122 T: Sync,
123{
124 type Target = T;
125}
126
127#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
129pub struct LocalRef<'id, T: Sync + 'static> {
130 #[cfg(feature = "barrier-protected-runtime")]
131 inner: *const T,
132 #[cfg(not(feature = "barrier-protected-runtime"))]
133 inner: Arc<T>,
134 _brand: Id<'id>,
136}
137
138impl<'id, T> LocalRef<'id, T>
139where
140 T: Sync + 'static,
141{
142 #[cfg(not(feature = "barrier-protected-runtime"))]
143 unsafe fn new(context: &Context<T>, guard: Guard<'id>) -> Self {
144 LocalRef {
145 inner: context.0.clone(),
146 _brand: guard.into(),
147 }
148 }
149
150 #[cfg(feature = "barrier-protected-runtime")]
151 unsafe fn new(context: &Context<T>, guard: Guard<'id>) -> Self {
152 LocalRef {
153 inner: addr_of!(context.0),
154 _brand: guard.into(),
155 }
156 }
157
158 #[cfg(all(not(loom), feature = "tokio-runtime"))]
160 #[cfg_attr(
161 docsrs,
162 doc(cfg(any(feature = "tokio-runtime", feature = "barrier-protected-runtime")))
163 )]
164 pub fn with_blocking<F, R>(self, f: F) -> JoinHandle<R>
165 where
166 F: for<'a> FnOnce(LocalRef<'a, T>) -> R + Send + 'static,
167 R: Send + 'static,
168 {
169 use std::mem::transmute;
170
171 let local_ref = unsafe { transmute::<LocalRef<'_, T>, LocalRef<'_, T>>(self) };
172
173 spawn_blocking(move || f(local_ref))
174 }
175}
176
177#[cfg(feature = "barrier-protected-runtime")]
178impl<T> Deref for LocalRef<'_, T>
179where
180 T: Sync,
181{
182 type Target = T;
183 fn deref(&self) -> &Self::Target {
184 unsafe { &*self.inner }
185 }
186}
187#[cfg(not(feature = "barrier-protected-runtime"))]
188impl<T> Deref for LocalRef<'_, T>
189where
190 T: Sync,
191{
192 type Target = T;
193 fn deref(&self) -> &Self::Target {
194 self.inner.deref()
195 }
196}
197
198#[cfg(feature = "barrier-protected-runtime")]
199impl<T> Clone for LocalRef<'_, T>
200where
201 T: Sync + 'static,
202{
203 fn clone(&self) -> Self {
204 LocalRef {
205 inner: self.inner,
206 _brand: self._brand,
207 }
208 }
209}
210
211#[cfg(not(feature = "barrier-protected-runtime"))]
212impl<T> Clone for LocalRef<'_, T>
213where
214 T: Sync + 'static,
215{
216 fn clone(&self) -> Self {
217 LocalRef {
218 inner: self.inner.clone(),
219 _brand: self._brand,
220 }
221 }
222}
223
224unsafe impl<T> Send for LocalRef<'_, T> where T: Sync {}
225unsafe impl<T> Sync for LocalRef<'_, T> where T: Sync {}
226pub trait AsyncLocal<T>
228where
229 T: AsContext,
230{
231 #[cfg(all(not(loom), feature = "tokio-runtime"))]
233 #[cfg_attr(
234 docsrs,
235 doc(cfg(any(feature = "tokio-runtime", feature = "barrier-protected-runtime")))
236 )]
237 fn with_blocking<F, R>(&'static self, f: F) -> JoinHandle<R>
238 where
239 F: for<'id> FnOnce(LocalRef<'id, T::Target>) -> R + Send + 'static,
240 R: Send + 'static;
241
242 fn with_async<F, R>(&'static self, f: F) -> impl Future<Output = R>
244 where
245 F: for<'a> AsyncFnMut(LocalRef<'a, T::Target>) -> R;
246
247 fn local_ref<'id>(&'static self, guard: Guard<'id>) -> LocalRef<'id, T::Target>;
257}
258
259impl<T> AsyncLocal<T> for LocalKey<T>
260where
261 T: AsContext,
262{
263 #[cfg(all(not(loom), feature = "tokio-runtime"))]
264 #[cfg_attr(
265 docsrs,
266 doc(cfg(any(feature = "tokio-runtime", feature = "barrier-protected-runtime")))
267 )]
268 fn with_blocking<F, R>(&'static self, f: F) -> JoinHandle<R>
269 where
270 F: for<'id> FnOnce(LocalRef<'id, T::Target>) -> R + Send + 'static,
271 R: Send + 'static,
272 {
273 let guard = unsafe { Guard::new(Id::new()) };
274 let local_ref = self.local_ref(guard);
275 spawn_blocking(move || f(local_ref))
276 }
277
278 async fn with_async<F, R>(&'static self, mut f: F) -> R
279 where
280 F: for<'a> AsyncFnMut(LocalRef<'a, T::Target>) -> R,
281 {
282 make_guard!(guard);
283 f(self.local_ref(guard)).await
284 }
285
286 fn local_ref<'id>(&'static self, guard: Guard<'id>) -> LocalRef<'id, T::Target> {
287 self.with(|value| unsafe { LocalRef::new(value.as_ref(), guard) })
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use std::sync::atomic::{AtomicUsize, Ordering};
294
295 use generativity::make_guard;
296 use tokio::task::yield_now;
297
298 use super::*;
299
300 thread_local! {
301 static COUNTER: Context<AtomicUsize> = Context::new(AtomicUsize::new(0));
302 }
303
304 #[tokio::test(crate = "async_local", flavor = "multi_thread")]
305 async fn with_blocking() {
306 COUNTER
307 .with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed))
308 .await
309 .unwrap();
310
311 make_guard!(guard);
312 let local_ref = COUNTER.local_ref(guard);
313
314 local_ref
315 .with_blocking(|counter| counter.fetch_add(1, Ordering::Relaxed))
316 .await
317 .unwrap();
318 }
319
320 #[tokio::test(crate = "async_local", flavor = "multi_thread")]
321 async fn ref_spans_await() {
322 make_guard!(guard);
323 let counter = COUNTER.local_ref(guard);
324 yield_now().await;
325 counter.fetch_add(1, Ordering::SeqCst);
326 }
327
328 #[tokio::test(crate = "async_local", flavor = "multi_thread")]
329 async fn with_async_trait() {
330 struct Counter;
331
332 trait Countable {
333 async fn add_one(ref_guard: LocalRef<'_, AtomicUsize>) -> usize;
334 }
335
336 impl Countable for Counter {
337 async fn add_one(counter: LocalRef<'_, AtomicUsize>) -> usize {
338 yield_now().await;
339 counter.fetch_add(1, Ordering::Release)
340 }
341 }
342
343 make_guard!(guard);
344 let counter = COUNTER.local_ref(guard);
345
346 Counter::add_one(counter).await;
347 }
348}