Skip to main content

roam_task_local/
lib.rs

1//! Task-local storage for asynchronous tasks.
2//!
3//! This crate provides a way to store task-local values across `.await` points.
4//! It was extracted from the `tokio::task_local` module and can be used independently
5//! of the Tokio runtime.
6//!
7//! Vendored from <https://github.com/BugenZhao/task-local> with minor modifications.
8
9use pin_project_lite::pin_project;
10use std::cell::RefCell;
11use std::error::Error;
12use std::future::Future;
13use std::marker::PhantomPinned;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16use std::{fmt, mem, thread};
17
18/// Declares a new task-local key of type [`LocalKey`].
19///
20/// # Syntax
21///
22/// The macro wraps any number of static declarations and makes them local to the current task.
23/// Publicity and attributes for each static is preserved. For example:
24///
25/// # Examples
26///
27/// ```
28/// # use roam_task_local::task_local;
29/// task_local! {
30///     pub static ONE: u32;
31///
32///     #[allow(unused)]
33///     static TWO: f32;
34/// }
35/// # fn main() {}
36/// ```
37///
38/// See [`LocalKey` documentation][`LocalKey`] for more information.
39#[macro_export]
40macro_rules! task_local {
41     // empty (base case for the recursion)
42    () => {};
43
44    ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => {
45        $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
46        $crate::task_local!($($rest)*);
47    };
48
49    ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => {
50        $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
51    }
52}
53
54#[doc(hidden)]
55#[macro_export]
56macro_rules! __task_local_inner {
57    ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => {
58        $(#[$attr])*
59        $vis static $name: $crate::LocalKey<$t> = {
60            std::thread_local! {
61                static __KEY: std::cell::RefCell<Option<$t>> = const { std::cell::RefCell::new(None) };
62            }
63
64            $crate::LocalKey { inner: __KEY }
65        };
66    };
67}
68
69/// A key for task-local data.
70///
71/// This type is generated by the [`task_local!`] macro.
72///
73/// Unlike [`std::thread::LocalKey`], `LocalKey` will
74/// _not_ lazily initialize the value on first access. Instead, the
75/// value is first initialized when the future containing
76/// the task-local is first polled by a futures executor.
77///
78/// # Examples
79///
80/// ```
81/// # async fn dox() {
82/// roam_task_local::task_local! {
83///     static NUMBER: u32;
84/// }
85///
86/// NUMBER.scope(1, async move {
87///     assert_eq!(NUMBER.get(), 1);
88/// }).await;
89///
90/// NUMBER.scope(2, async move {
91///     assert_eq!(NUMBER.get(), 2);
92///
93///     NUMBER.scope(3, async move {
94///         assert_eq!(NUMBER.get(), 3);
95///     }).await;
96/// }).await;
97/// # }
98/// ```
99///
100/// [`std::thread::LocalKey`]: struct@std::thread::LocalKey
101pub struct LocalKey<T: 'static> {
102    #[doc(hidden)]
103    pub inner: thread::LocalKey<RefCell<Option<T>>>,
104}
105
106impl<T: 'static> LocalKey<T> {
107    /// Sets a value `T` as the task-local value for the future `F`.
108    ///
109    /// On completion of `scope`, the task-local will be dropped.
110    ///
111    /// ### Panics
112    ///
113    /// If you poll the returned future inside a call to [`with`] or
114    /// [`try_with`] on the same `LocalKey`, then the call to `poll` will panic.
115    ///
116    /// ### Examples
117    ///
118    /// ```
119    /// # async fn dox() {
120    /// roam_task_local::task_local! {
121    ///     static NUMBER: u32;
122    /// }
123    ///
124    /// NUMBER.scope(1, async move {
125    ///     println!("task local value: {}", NUMBER.get());
126    /// }).await;
127    /// # }
128    /// ```
129    ///
130    /// [`with`]: fn@Self::with
131    /// [`try_with`]: fn@Self::try_with
132    pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F>
133    where
134        F: Future,
135    {
136        TaskLocalFuture {
137            local: self,
138            slot: Some(value),
139            future: Some(f),
140            _pinned: PhantomPinned,
141        }
142    }
143
144    /// Sets a value `T` as the task-local value for the closure `F`.
145    ///
146    /// On completion of `sync_scope`, the task-local will be dropped.
147    ///
148    /// ### Panics
149    ///
150    /// This method panics if called inside a call to [`with`] or [`try_with`]
151    /// on the same `LocalKey`.
152    ///
153    /// ### Examples
154    ///
155    /// ```
156    /// # async fn dox() {
157    /// roam_task_local::task_local! {
158    ///     static NUMBER: u32;
159    /// }
160    ///
161    /// NUMBER.sync_scope(1, || {
162    ///     println!("task local value: {}", NUMBER.get());
163    /// });
164    /// # }
165    /// ```
166    ///
167    /// [`with`]: fn@Self::with
168    /// [`try_with`]: fn@Self::try_with
169    #[track_caller]
170    pub fn sync_scope<F, R>(&'static self, value: T, f: F) -> R
171    where
172        F: FnOnce() -> R,
173    {
174        let mut value = Some(value);
175        match self.scope_inner(&mut value, f) {
176            Ok(res) => res,
177            Err(err) => err.panic(),
178        }
179    }
180
181    fn scope_inner<F, R>(&'static self, slot: &mut Option<T>, f: F) -> Result<R, ScopeInnerErr>
182    where
183        F: FnOnce() -> R,
184    {
185        struct Guard<'a, T: 'static> {
186            local: &'static LocalKey<T>,
187            slot: &'a mut Option<T>,
188        }
189
190        impl<T: 'static> Drop for Guard<'_, T> {
191            fn drop(&mut self) {
192                // This should not panic.
193                //
194                // We know that the RefCell was not borrowed before the call to
195                // `scope_inner`, so the only way for this to panic is if the
196                // closure has created but not destroyed a RefCell guard.
197                // However, we never give user-code access to the guards, so
198                // there's no way for user-code to forget to destroy a guard.
199                //
200                // The call to `with` also should not panic, since the
201                // thread-local wasn't destroyed when we first called
202                // `scope_inner`, and it shouldn't have gotten destroyed since
203                // then.
204                self.local.inner.with(|inner| {
205                    let mut ref_mut = inner.borrow_mut();
206                    mem::swap(self.slot, &mut *ref_mut);
207                });
208            }
209        }
210
211        self.inner.try_with(|inner| {
212            inner
213                .try_borrow_mut()
214                .map(|mut ref_mut| mem::swap(slot, &mut *ref_mut))
215        })??;
216
217        let guard = Guard { local: self, slot };
218
219        let res = f();
220
221        drop(guard);
222
223        Ok(res)
224    }
225
226    /// Accesses the current task-local and runs the provided closure.
227    ///
228    /// # Panics
229    ///
230    /// This function will panic if the task local doesn't have a value set.
231    #[track_caller]
232    pub fn with<F, R>(&'static self, f: F) -> R
233    where
234        F: FnOnce(&T) -> R,
235    {
236        match self.try_with(f) {
237            Ok(res) => res,
238            Err(_) => panic!("cannot access a task-local storage value without setting it first"),
239        }
240    }
241
242    /// Accesses the current task-local and runs the provided closure.
243    ///
244    /// If the task-local with the associated key is not present, this
245    /// method will return an `AccessError`. For a panicking variant,
246    /// see `with`.
247    pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
248    where
249        F: FnOnce(&T) -> R,
250    {
251        // If called after the thread-local storing the task-local is destroyed,
252        // then we are outside of a closure where the task-local is set.
253        //
254        // Therefore, it is correct to return an AccessError if `try_with`
255        // returns an error.
256        let try_with_res = self.inner.try_with(|v| {
257            // This call to `borrow` cannot panic because no user-defined code
258            // runs while a `borrow_mut` call is active.
259            v.borrow().as_ref().map(f)
260        });
261
262        match try_with_res {
263            Ok(Some(res)) => Ok(res),
264            Ok(None) | Err(_) => Err(AccessError { _private: () }),
265        }
266    }
267}
268
269impl<T: Clone + 'static> LocalKey<T> {
270    /// Returns a copy of the task-local value
271    /// if the task-local value implements `Clone`.
272    ///
273    /// # Panics
274    ///
275    /// This function will panic if the task local doesn't have a value set.
276    #[track_caller]
277    pub fn get(&'static self) -> T {
278        self.with(|v| v.clone())
279    }
280
281    /// Returns a copy of the task-local value
282    /// if the task-local value implements `Clone`.
283    ///
284    /// If the task-local with the associated key is not present, this
285    /// method will return an [AccessError]. For a panicking variant,
286    /// see [get][Self::get].
287    #[track_caller]
288    pub fn try_get(&'static self) -> Result<T, AccessError> {
289        self.try_with(|v| v.clone())
290    }
291}
292
293impl<T: 'static> fmt::Debug for LocalKey<T> {
294    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
295        f.pad("LocalKey { .. }")
296    }
297}
298
299pin_project! {
300    /// A future that sets a value `T` of a task local for the future `F` during
301    /// its execution.
302    ///
303    /// The value of the task-local must be `'static` and will be dropped on the
304    /// completion of the future.
305    ///
306    /// Created by the function [`LocalKey::scope`](self::LocalKey::scope).
307    ///
308    /// ### Examples
309    ///
310    /// ```
311    /// # async fn dox() {
312    /// roam_task_local::task_local! {
313    ///     static NUMBER: u32;
314    /// }
315    ///
316    /// NUMBER.scope(1, async move {
317    ///     println!("task local value: {}", NUMBER.get());
318    /// }).await;
319    /// # }
320    /// ```
321    pub struct TaskLocalFuture<T, F>
322    where
323        T: 'static,
324    {
325        local: &'static LocalKey<T>,
326        slot: Option<T>,
327        #[pin]
328        future: Option<F>,
329        #[pin]
330        _pinned: PhantomPinned,
331    }
332
333    impl<T: 'static, F> PinnedDrop for TaskLocalFuture<T, F> {
334        fn drop(this: Pin<&mut Self>) {
335            let this = this.project();
336            if mem::needs_drop::<F>() && this.future.is_some() {
337                // Drop the future while the task-local is set, if possible. Otherwise
338                // the future is dropped normally when the `Option<F>` field drops.
339                let mut future = this.future;
340                let _ = this.local.scope_inner(this.slot, || {
341                    future.set(None);
342                });
343            }
344        }
345    }
346}
347
348impl<T, F> TaskLocalFuture<T, F>
349where
350    T: 'static,
351{
352    /// Returns the value stored in the task local by this `TaskLocalFuture`.
353    ///
354    /// The function returns:
355    ///
356    /// * `Some(T)` if the task local value exists.
357    /// * `None` if the task local value has already been taken.
358    ///
359    /// Note that this function attempts to take the task local value even if
360    /// the future has not yet completed. In that case, the value will no longer
361    /// be available via the task local after the call to `take_value`.
362    ///
363    /// # Examples
364    ///
365    /// ```
366    /// # async fn dox() {
367    /// roam_task_local::task_local! {
368    ///     static KEY: u32;
369    /// }
370    ///
371    /// let fut = KEY.scope(42, async {
372    ///     // Do some async work
373    /// });
374    ///
375    /// let mut pinned = Box::pin(fut);
376    ///
377    /// // Complete the TaskLocalFuture
378    /// let _ = pinned.as_mut().await;
379    ///
380    /// // And here, we can take task local value
381    /// let value = pinned.as_mut().take_value();
382    ///
383    /// assert_eq!(value, Some(42));
384    /// # }
385    /// ```
386    pub fn take_value(self: Pin<&mut Self>) -> Option<T> {
387        let this = self.project();
388        this.slot.take()
389    }
390}
391
392impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> {
393    type Output = F::Output;
394
395    #[track_caller]
396    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
397        let this = self.project();
398        let mut future_opt = this.future;
399
400        let res = this
401            .local
402            .scope_inner(this.slot, || match future_opt.as_mut().as_pin_mut() {
403                Some(fut) => {
404                    let res = fut.poll(cx);
405                    if res.is_ready() {
406                        future_opt.set(None);
407                    }
408                    Some(res)
409                }
410                None => None,
411            });
412
413        match res {
414            Ok(Some(res)) => res,
415            Ok(None) => panic!("`TaskLocalFuture` polled after completion"),
416            Err(err) => err.panic(),
417        }
418    }
419}
420
421impl<T: 'static, F> fmt::Debug for TaskLocalFuture<T, F>
422where
423    T: fmt::Debug,
424{
425    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
426        /// Format the Option without Some.
427        struct TransparentOption<'a, T> {
428            value: &'a Option<T>,
429        }
430        impl<T: fmt::Debug> fmt::Debug for TransparentOption<'_, T> {
431            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
432                match self.value.as_ref() {
433                    Some(value) => value.fmt(f),
434                    // Hitting the None branch should not be possible.
435                    None => f.pad("<missing>"),
436                }
437            }
438        }
439
440        f.debug_struct("TaskLocalFuture")
441            .field("value", &TransparentOption { value: &self.slot })
442            .finish()
443    }
444}
445
446/// An error returned by [`LocalKey::try_with`](method@LocalKey::try_with).
447#[derive(Clone, Copy, Eq, PartialEq)]
448pub struct AccessError {
449    _private: (),
450}
451
452impl fmt::Debug for AccessError {
453    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
454        f.debug_struct("AccessError").finish()
455    }
456}
457
458impl fmt::Display for AccessError {
459    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
460        fmt::Display::fmt("task-local value not set", f)
461    }
462}
463
464impl Error for AccessError {}
465
466enum ScopeInnerErr {
467    BorrowError,
468    AccessError,
469}
470
471impl ScopeInnerErr {
472    #[track_caller]
473    fn panic(&self) -> ! {
474        match self {
475            Self::BorrowError => {
476                panic!("cannot enter a task-local scope while the task-local storage is borrowed")
477            }
478            Self::AccessError => panic!(
479                "cannot enter a task-local scope during or after destruction of the underlying thread-local"
480            ),
481        }
482    }
483}
484
485impl From<std::cell::BorrowMutError> for ScopeInnerErr {
486    fn from(_: std::cell::BorrowMutError) -> Self {
487        Self::BorrowError
488    }
489}
490
491impl From<std::thread::AccessError> for ScopeInnerErr {
492    fn from(_: std::thread::AccessError) -> Self {
493        Self::AccessError
494    }
495}