guarded_tls/
lib.rs

1//! Thread-local storage with guarded scopes.
2//!
3//! This crate provides thread-local variables whose values can be temporarily
4//! overridden within a scope. Each time you call [set](GuardedKey::set), a new
5//! value is pushed onto the thread-local stack, and a [Guard] is returned.
6//! When the guard is dropped, the associated value is removed from the stack.
7//! This enables safe, nested overrides of thread-local state.
8//!
9//! # Usage
10//!
11//! Use the [guarded_thread_local] macro to define a thread-local key. Call
12//! [set](GuardedKey::set) to override the value for the current thread and
13//! receive a guard. The value is accessible via [get](GuardedKey::get) while
14//! the guard is alive.
15//!
16//! ```
17//! use guarded_tls::guarded_thread_local;
18//!
19//! guarded_thread_local!(static FOO: String);
20//!
21//! let _guard1 = FOO.set("abc".into());
22//! assert_eq!(FOO.get(), "abc");
23//!
24//! let guard2 = FOO.set("def".into());
25//! assert_eq!(FOO.get(), "def");
26//!
27//! drop(guard2);
28//! assert_eq!(FOO.get(), "abc");
29//! ```
30//!
31//! # Notes
32//!
33//! - [get](GuardedKey::get) requires the value type to implement [Clone].
34//! - Accessing the value without having a guard will panic.
35//! - Guards dropped out of order have well-defined behavior.
36//!
37//! # See Also
38//!
39//! - [scoped-tls](https://docs.rs/scoped-tls/): a similar crate for scoped
40//!   thread-local values.
41//!
42//! The main difference between this crate and `scoped-tls` is that this crate
43//! doesn't require the nesting of functions, making it some application easier
44//! to manage. For instance creating a test fixture that holds a [Guard].
45//!
46//! ```
47//! guarded_tls::guarded_thread_local!(static FOO: u32);
48//!
49//! # use guarded_tls::Guard;
50//! # struct MyFixture { foo_guard: Guard<u32> }
51//! fn create_fixture() -> MyFixture {
52//!     MyFixture { foo_guard: FOO.set(123) }
53//! }
54//!
55//! fn my_test() {
56//!     let fixture = create_fixture();
57//!
58//!     // Test code here that assumes `FOO` is set.
59//!     assert_eq!(FOO.get(), 123);
60//! }
61//!
62//! my_test();
63//! ```
64use std::{cell::RefCell, thread::LocalKey};
65
66#[macro_export]
67macro_rules! guarded_thread_local {
68    ($(#[$attrs:meta])* $vis:vis static $name:ident: $ty:ty) => (
69        $(#[$attrs])*
70        $vis static $name: $crate::GuardedKey<$ty> = {
71            ::std::thread_local!(static FOO: ::std::cell::RefCell<$crate::Inner<$ty>> = const {
72                ::std::cell::RefCell::new($crate::Inner::new())
73            });
74            $crate::GuardedKey::new(&FOO)
75        };
76    )
77}
78
79/// A nested thread-local that spawns a [Guard] for each [set](GuardedKey::set).
80pub struct GuardedKey<T: 'static> {
81    inner: &'static LocalKey<RefCell<Inner<T>>>,
82}
83
84impl<T: 'static> GuardedKey<T> {
85    #[doc(hidden)]
86    pub const fn new(inner: &'static LocalKey<RefCell<Inner<T>>>) -> Self {
87        Self { inner }
88    }
89
90    /// Sets the value of this thread-local and returns a [Guard].
91    ///
92    /// After this call, [get](GuardedKey::get) will return the value that was
93    /// provided here.
94    #[must_use]
95    pub fn set(&'static self, t: T) -> Guard<T> {
96        self.inner.with_borrow_mut(move |inner| {
97            inner.item.push(Some(t));
98            Guard {
99                inner: self.inner,
100                index: inner.item.len() - 1,
101            }
102        })
103    }
104}
105
106impl<T: Clone + 'static> GuardedKey<T> {
107    /// Clones and returns the last value of thread-local stack.
108    ///
109    /// # Panics
110    ///
111    /// Panics if this thread-local has not previously been
112    /// [set](GuardedKey::set).
113    ///
114    /// Panics if the [Clone] implementation of `T` accesses this same thread
115    /// local.
116    pub fn get(&'static self) -> T {
117        let Some(val) = self.inner.with_borrow(|inner| inner.item.last().cloned()) else {
118            panic!("cannot access a guarded thread local variable without calling `set` first")
119        };
120
121        // The top of the stack cannot be None, as Guard::drop will pop from the stack
122        // until it finds a non-None entry.
123        val.expect("internal error: top of item list is none")
124    }
125}
126
127#[doc(hidden)]
128pub struct Inner<T: 'static> {
129    item: Vec<Option<T>>,
130}
131
132impl<T: 'static> Inner<T> {
133    #[doc(hidden)]
134    pub const fn new() -> Self {
135        Self { item: Vec::new() }
136    }
137}
138
139/// Keeps a thread local value alive. Removes its associated value from the
140/// stack upon being dropped.
141pub struct Guard<T: 'static> {
142    inner: &'static LocalKey<RefCell<Inner<T>>>,
143    index: usize,
144}
145
146impl<T> Drop for Guard<T> {
147    /// Removes associated value from the thread-local stack. If this is the
148    /// last existing guard for this thread-local, then any
149    /// subsequent [get](GuardedKey::get) will panic unless the thread-local
150    /// is [set](GuardedKey::set) again.
151    fn drop(&mut self) {
152        self.inner.with_borrow_mut(|inner| {
153            *inner.item.get_mut(self.index).unwrap() = None;
154
155            while let Some(item) = inner.item.last() {
156                if item.is_none() {
157                    let _ = inner.item.pop();
158                } else {
159                    break;
160                }
161            }
162        });
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    #[test]
169    fn smoke() {
170        guarded_thread_local!(static FOO: u32);
171        let _foo_guard_1 = FOO.set(3);
172        assert_eq!(FOO.get(), 3);
173        assert_eq!(FOO.get(), 3);
174
175        let foo_guard_2 = FOO.set(123);
176        assert_eq!(FOO.get(), 123);
177
178        drop(foo_guard_2);
179        assert_eq!(FOO.get(), 3);
180    }
181
182    #[test]
183    #[should_panic(
184        expected = "cannot access a guarded thread local variable without calling `set` first"
185    )]
186    fn get_without_set() {
187        guarded_thread_local!(static FOO: u32);
188        let _ = FOO.get();
189    }
190
191    #[test]
192    fn out_of_order_guard_drop() {
193        guarded_thread_local!(static FOO: u32);
194        let guard_1 = FOO.set(1);
195        let guard_2 = FOO.set(2);
196        let guard_3 = FOO.set(3);
197        assert_eq!(FOO.get(), 3);
198
199        drop(guard_1);
200        assert_eq!(FOO.get(), 3);
201
202        drop(guard_3);
203        assert_eq!(FOO.get(), 2);
204
205        drop(guard_2);
206    }
207
208    #[test]
209    fn non_copy_type() {
210        guarded_thread_local!(static FOO: String);
211        let _guard_1 = FOO.set("x".into());
212        let guard_2 = FOO.set("y".into());
213
214        assert_eq!(FOO.get(), "y");
215        drop(guard_2);
216        assert_eq!(FOO.get(), "x");
217    }
218
219    #[test]
220    #[should_panic(expected = "already borrowed: BorrowMutError")]
221    fn clone_access_same_thread_local() {
222        guarded_thread_local!(static FOO: X);
223
224        struct X;
225
226        impl Clone for X {
227            fn clone(&self) -> Self {
228                let _ = FOO.set(X);
229                X
230            }
231        }
232
233        let _guard = FOO.set(X);
234        let _ = FOO.get();
235    }
236}