roam_task_local/lib.rs
1//! Task-local storage for asynchronous tasks.
2//!
3//! This crate provides a way to store task-local values across `.await` points.
4//! It was extracted from the `tokio::task_local` module and can be used independently
5//! of the Tokio runtime.
6//!
7//! Vendored from <https://github.com/BugenZhao/task-local> with minor modifications.
8
9use pin_project_lite::pin_project;
10use std::cell::RefCell;
11use std::error::Error;
12use std::future::Future;
13use std::marker::PhantomPinned;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16use std::{fmt, mem, thread};
17
18/// Declares a new task-local key of type [`LocalKey`].
19///
20/// # Syntax
21///
22/// The macro wraps any number of static declarations and makes them local to the current task.
23/// Publicity and attributes for each static is preserved. For example:
24///
25/// # Examples
26///
27/// ```
28/// # use roam_task_local::task_local;
29/// task_local! {
30/// pub static ONE: u32;
31///
32/// #[allow(unused)]
33/// static TWO: f32;
34/// }
35/// # fn main() {}
36/// ```
37///
38/// See [`LocalKey` documentation][`LocalKey`] for more information.
39#[macro_export]
40macro_rules! task_local {
41 // empty (base case for the recursion)
42 () => {};
43
44 ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => {
45 $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
46 $crate::task_local!($($rest)*);
47 };
48
49 ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => {
50 $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
51 }
52}
53
54#[doc(hidden)]
55#[macro_export]
56macro_rules! __task_local_inner {
57 ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => {
58 $(#[$attr])*
59 $vis static $name: $crate::LocalKey<$t> = {
60 std::thread_local! {
61 static __KEY: std::cell::RefCell<Option<$t>> = const { std::cell::RefCell::new(None) };
62 }
63
64 $crate::LocalKey { inner: __KEY }
65 };
66 };
67}
68
69/// A key for task-local data.
70///
71/// This type is generated by the [`task_local!`] macro.
72///
73/// Unlike [`std::thread::LocalKey`], `LocalKey` will
74/// _not_ lazily initialize the value on first access. Instead, the
75/// value is first initialized when the future containing
76/// the task-local is first polled by a futures executor.
77///
78/// # Examples
79///
80/// ```
81/// # async fn dox() {
82/// roam_task_local::task_local! {
83/// static NUMBER: u32;
84/// }
85///
86/// NUMBER.scope(1, async move {
87/// assert_eq!(NUMBER.get(), 1);
88/// }).await;
89///
90/// NUMBER.scope(2, async move {
91/// assert_eq!(NUMBER.get(), 2);
92///
93/// NUMBER.scope(3, async move {
94/// assert_eq!(NUMBER.get(), 3);
95/// }).await;
96/// }).await;
97/// # }
98/// ```
99///
100/// [`std::thread::LocalKey`]: struct@std::thread::LocalKey
101pub struct LocalKey<T: 'static> {
102 #[doc(hidden)]
103 pub inner: thread::LocalKey<RefCell<Option<T>>>,
104}
105
106impl<T: 'static> LocalKey<T> {
107 /// Sets a value `T` as the task-local value for the future `F`.
108 ///
109 /// On completion of `scope`, the task-local will be dropped.
110 ///
111 /// ### Panics
112 ///
113 /// If you poll the returned future inside a call to [`with`] or
114 /// [`try_with`] on the same `LocalKey`, then the call to `poll` will panic.
115 ///
116 /// ### Examples
117 ///
118 /// ```
119 /// # async fn dox() {
120 /// roam_task_local::task_local! {
121 /// static NUMBER: u32;
122 /// }
123 ///
124 /// NUMBER.scope(1, async move {
125 /// println!("task local value: {}", NUMBER.get());
126 /// }).await;
127 /// # }
128 /// ```
129 ///
130 /// [`with`]: fn@Self::with
131 /// [`try_with`]: fn@Self::try_with
132 pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F>
133 where
134 F: Future,
135 {
136 TaskLocalFuture {
137 local: self,
138 slot: Some(value),
139 future: Some(f),
140 _pinned: PhantomPinned,
141 }
142 }
143
144 /// Sets a value `T` as the task-local value for the closure `F`.
145 ///
146 /// On completion of `sync_scope`, the task-local will be dropped.
147 ///
148 /// ### Panics
149 ///
150 /// This method panics if called inside a call to [`with`] or [`try_with`]
151 /// on the same `LocalKey`.
152 ///
153 /// ### Examples
154 ///
155 /// ```
156 /// # async fn dox() {
157 /// roam_task_local::task_local! {
158 /// static NUMBER: u32;
159 /// }
160 ///
161 /// NUMBER.sync_scope(1, || {
162 /// println!("task local value: {}", NUMBER.get());
163 /// });
164 /// # }
165 /// ```
166 ///
167 /// [`with`]: fn@Self::with
168 /// [`try_with`]: fn@Self::try_with
169 #[track_caller]
170 pub fn sync_scope<F, R>(&'static self, value: T, f: F) -> R
171 where
172 F: FnOnce() -> R,
173 {
174 let mut value = Some(value);
175 match self.scope_inner(&mut value, f) {
176 Ok(res) => res,
177 Err(err) => err.panic(),
178 }
179 }
180
181 fn scope_inner<F, R>(&'static self, slot: &mut Option<T>, f: F) -> Result<R, ScopeInnerErr>
182 where
183 F: FnOnce() -> R,
184 {
185 struct Guard<'a, T: 'static> {
186 local: &'static LocalKey<T>,
187 slot: &'a mut Option<T>,
188 }
189
190 impl<T: 'static> Drop for Guard<'_, T> {
191 fn drop(&mut self) {
192 // This should not panic.
193 //
194 // We know that the RefCell was not borrowed before the call to
195 // `scope_inner`, so the only way for this to panic is if the
196 // closure has created but not destroyed a RefCell guard.
197 // However, we never give user-code access to the guards, so
198 // there's no way for user-code to forget to destroy a guard.
199 //
200 // The call to `with` also should not panic, since the
201 // thread-local wasn't destroyed when we first called
202 // `scope_inner`, and it shouldn't have gotten destroyed since
203 // then.
204 self.local.inner.with(|inner| {
205 let mut ref_mut = inner.borrow_mut();
206 mem::swap(self.slot, &mut *ref_mut);
207 });
208 }
209 }
210
211 self.inner.try_with(|inner| {
212 inner
213 .try_borrow_mut()
214 .map(|mut ref_mut| mem::swap(slot, &mut *ref_mut))
215 })??;
216
217 let guard = Guard { local: self, slot };
218
219 let res = f();
220
221 drop(guard);
222
223 Ok(res)
224 }
225
226 /// Accesses the current task-local and runs the provided closure.
227 ///
228 /// # Panics
229 ///
230 /// This function will panic if the task local doesn't have a value set.
231 #[track_caller]
232 pub fn with<F, R>(&'static self, f: F) -> R
233 where
234 F: FnOnce(&T) -> R,
235 {
236 match self.try_with(f) {
237 Ok(res) => res,
238 Err(_) => panic!("cannot access a task-local storage value without setting it first"),
239 }
240 }
241
242 /// Accesses the current task-local and runs the provided closure.
243 ///
244 /// If the task-local with the associated key is not present, this
245 /// method will return an `AccessError`. For a panicking variant,
246 /// see `with`.
247 pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
248 where
249 F: FnOnce(&T) -> R,
250 {
251 // If called after the thread-local storing the task-local is destroyed,
252 // then we are outside of a closure where the task-local is set.
253 //
254 // Therefore, it is correct to return an AccessError if `try_with`
255 // returns an error.
256 let try_with_res = self.inner.try_with(|v| {
257 // This call to `borrow` cannot panic because no user-defined code
258 // runs while a `borrow_mut` call is active.
259 v.borrow().as_ref().map(f)
260 });
261
262 match try_with_res {
263 Ok(Some(res)) => Ok(res),
264 Ok(None) | Err(_) => Err(AccessError { _private: () }),
265 }
266 }
267}
268
269impl<T: Clone + 'static> LocalKey<T> {
270 /// Returns a copy of the task-local value
271 /// if the task-local value implements `Clone`.
272 ///
273 /// # Panics
274 ///
275 /// This function will panic if the task local doesn't have a value set.
276 #[track_caller]
277 pub fn get(&'static self) -> T {
278 self.with(|v| v.clone())
279 }
280
281 /// Returns a copy of the task-local value
282 /// if the task-local value implements `Clone`.
283 ///
284 /// If the task-local with the associated key is not present, this
285 /// method will return an [AccessError]. For a panicking variant,
286 /// see [get][Self::get].
287 #[track_caller]
288 pub fn try_get(&'static self) -> Result<T, AccessError> {
289 self.try_with(|v| v.clone())
290 }
291}
292
293impl<T: 'static> fmt::Debug for LocalKey<T> {
294 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
295 f.pad("LocalKey { .. }")
296 }
297}
298
299pin_project! {
300 /// A future that sets a value `T` of a task local for the future `F` during
301 /// its execution.
302 ///
303 /// The value of the task-local must be `'static` and will be dropped on the
304 /// completion of the future.
305 ///
306 /// Created by the function [`LocalKey::scope`](self::LocalKey::scope).
307 ///
308 /// ### Examples
309 ///
310 /// ```
311 /// # async fn dox() {
312 /// roam_task_local::task_local! {
313 /// static NUMBER: u32;
314 /// }
315 ///
316 /// NUMBER.scope(1, async move {
317 /// println!("task local value: {}", NUMBER.get());
318 /// }).await;
319 /// # }
320 /// ```
321 pub struct TaskLocalFuture<T, F>
322 where
323 T: 'static,
324 {
325 local: &'static LocalKey<T>,
326 slot: Option<T>,
327 #[pin]
328 future: Option<F>,
329 #[pin]
330 _pinned: PhantomPinned,
331 }
332
333 impl<T: 'static, F> PinnedDrop for TaskLocalFuture<T, F> {
334 fn drop(this: Pin<&mut Self>) {
335 let this = this.project();
336 if mem::needs_drop::<F>() && this.future.is_some() {
337 // Drop the future while the task-local is set, if possible. Otherwise
338 // the future is dropped normally when the `Option<F>` field drops.
339 let mut future = this.future;
340 let _ = this.local.scope_inner(this.slot, || {
341 future.set(None);
342 });
343 }
344 }
345 }
346}
347
348impl<T, F> TaskLocalFuture<T, F>
349where
350 T: 'static,
351{
352 /// Returns the value stored in the task local by this `TaskLocalFuture`.
353 ///
354 /// The function returns:
355 ///
356 /// * `Some(T)` if the task local value exists.
357 /// * `None` if the task local value has already been taken.
358 ///
359 /// Note that this function attempts to take the task local value even if
360 /// the future has not yet completed. In that case, the value will no longer
361 /// be available via the task local after the call to `take_value`.
362 ///
363 /// # Examples
364 ///
365 /// ```
366 /// # async fn dox() {
367 /// roam_task_local::task_local! {
368 /// static KEY: u32;
369 /// }
370 ///
371 /// let fut = KEY.scope(42, async {
372 /// // Do some async work
373 /// });
374 ///
375 /// let mut pinned = Box::pin(fut);
376 ///
377 /// // Complete the TaskLocalFuture
378 /// let _ = pinned.as_mut().await;
379 ///
380 /// // And here, we can take task local value
381 /// let value = pinned.as_mut().take_value();
382 ///
383 /// assert_eq!(value, Some(42));
384 /// # }
385 /// ```
386 pub fn take_value(self: Pin<&mut Self>) -> Option<T> {
387 let this = self.project();
388 this.slot.take()
389 }
390}
391
392impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> {
393 type Output = F::Output;
394
395 #[track_caller]
396 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
397 let this = self.project();
398 let mut future_opt = this.future;
399
400 let res = this
401 .local
402 .scope_inner(this.slot, || match future_opt.as_mut().as_pin_mut() {
403 Some(fut) => {
404 let res = fut.poll(cx);
405 if res.is_ready() {
406 future_opt.set(None);
407 }
408 Some(res)
409 }
410 None => None,
411 });
412
413 match res {
414 Ok(Some(res)) => res,
415 Ok(None) => panic!("`TaskLocalFuture` polled after completion"),
416 Err(err) => err.panic(),
417 }
418 }
419}
420
421impl<T: 'static, F> fmt::Debug for TaskLocalFuture<T, F>
422where
423 T: fmt::Debug,
424{
425 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
426 /// Format the Option without Some.
427 struct TransparentOption<'a, T> {
428 value: &'a Option<T>,
429 }
430 impl<T: fmt::Debug> fmt::Debug for TransparentOption<'_, T> {
431 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
432 match self.value.as_ref() {
433 Some(value) => value.fmt(f),
434 // Hitting the None branch should not be possible.
435 None => f.pad("<missing>"),
436 }
437 }
438 }
439
440 f.debug_struct("TaskLocalFuture")
441 .field("value", &TransparentOption { value: &self.slot })
442 .finish()
443 }
444}
445
446/// An error returned by [`LocalKey::try_with`](method@LocalKey::try_with).
447#[derive(Clone, Copy, Eq, PartialEq)]
448pub struct AccessError {
449 _private: (),
450}
451
452impl fmt::Debug for AccessError {
453 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
454 f.debug_struct("AccessError").finish()
455 }
456}
457
458impl fmt::Display for AccessError {
459 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
460 fmt::Display::fmt("task-local value not set", f)
461 }
462}
463
464impl Error for AccessError {}
465
466enum ScopeInnerErr {
467 BorrowError,
468 AccessError,
469}
470
471impl ScopeInnerErr {
472 #[track_caller]
473 fn panic(&self) -> ! {
474 match self {
475 Self::BorrowError => {
476 panic!("cannot enter a task-local scope while the task-local storage is borrowed")
477 }
478 Self::AccessError => panic!(
479 "cannot enter a task-local scope during or after destruction of the underlying thread-local"
480 ),
481 }
482 }
483}
484
485impl From<std::cell::BorrowMutError> for ScopeInnerErr {
486 fn from(_: std::cell::BorrowMutError) -> Self {
487 Self::BorrowError
488 }
489}
490
491impl From<std::thread::AccessError> for ScopeInnerErr {
492 fn from(_: std::thread::AccessError) -> Self {
493 Self::AccessError
494 }
495}