guarded_tls/lib.rs
1//! Thread-local storage with guarded scopes.
2//!
3//! This crate provides thread-local variables whose values can be temporarily
4//! overridden within a scope. Each time you call [set](GuardedKey::set), a new
5//! value is pushed onto the thread-local stack, and a [Guard] is returned.
6//! When the guard is dropped, the associated value is removed from the stack.
7//! This enables safe, nested overrides of thread-local state.
8//!
9//! # Usage
10//!
11//! Use the [guarded_thread_local] macro to define a thread-local key. Call
12//! [set](GuardedKey::set) to override the value for the current thread and
13//! receive a guard. The value is accessible via [get](GuardedKey::get) while
14//! the guard is alive.
15//!
16//! ```
17//! use guarded_tls::guarded_thread_local;
18//!
19//! guarded_thread_local!(static FOO: String);
20//!
21//! let _guard1 = FOO.set("abc".into());
22//! assert_eq!(FOO.get(), "abc");
23//!
24//! let guard2 = FOO.set("def".into());
25//! assert_eq!(FOO.get(), "def");
26//!
27//! drop(guard2);
28//! assert_eq!(FOO.get(), "abc");
29//! ```
30//!
31//! # Notes
32//!
33//! - [get](GuardedKey::get) requires the value type to implement [Clone].
34//! - Accessing the value without having a guard will panic.
35//! - Guards dropped out of order have well-defined behavior.
36//!
37//! # See Also
38//!
39//! - [scoped-tls](https://docs.rs/scoped-tls/): a similar crate for scoped
40//! thread-local values.
41//!
42//! The main difference between this crate and `scoped-tls` is that this crate
43//! doesn't require the nesting of functions, making it some application easier
44//! to manage. For instance creating a test fixture that holds a [Guard].
45//!
46//! ```
47//! guarded_tls::guarded_thread_local!(static FOO: u32);
48//!
49//! # use guarded_tls::Guard;
50//! # struct MyFixture { foo_guard: Guard<u32> }
51//! fn create_fixture() -> MyFixture {
52//! MyFixture { foo_guard: FOO.set(123) }
53//! }
54//!
55//! fn my_test() {
56//! let fixture = create_fixture();
57//!
58//! // Test code here that assumes `FOO` is set.
59//! assert_eq!(FOO.get(), 123);
60//! }
61//!
62//! my_test();
63//! ```
64use std::{cell::RefCell, thread::LocalKey};
65
66#[macro_export]
67macro_rules! guarded_thread_local {
68 ($(#[$attrs:meta])* $vis:vis static $name:ident: $ty:ty) => (
69 $(#[$attrs])*
70 $vis static $name: $crate::GuardedKey<$ty> = {
71 ::std::thread_local!(static FOO: ::std::cell::RefCell<$crate::Inner<$ty>> = const {
72 ::std::cell::RefCell::new($crate::Inner::new())
73 });
74 $crate::GuardedKey::new(&FOO)
75 };
76 )
77}
78
79/// A nested thread-local that spawns a [Guard] for each [set](GuardedKey::set).
80pub struct GuardedKey<T: 'static> {
81 inner: &'static LocalKey<RefCell<Inner<T>>>,
82}
83
84impl<T: 'static> GuardedKey<T> {
85 #[doc(hidden)]
86 pub const fn new(inner: &'static LocalKey<RefCell<Inner<T>>>) -> Self {
87 Self { inner }
88 }
89
90 /// Sets the value of this thread-local and returns a [Guard].
91 ///
92 /// After this call, [get](GuardedKey::get) will return the value that was
93 /// provided here.
94 #[must_use]
95 pub fn set(&'static self, t: T) -> Guard<T> {
96 self.inner.with_borrow_mut(move |inner| {
97 inner.item.push(Some(t));
98 Guard {
99 inner: self.inner,
100 index: inner.item.len() - 1,
101 }
102 })
103 }
104}
105
106impl<T: Clone + 'static> GuardedKey<T> {
107 /// Clones and returns the last value of thread-local stack.
108 ///
109 /// # Panics
110 ///
111 /// Panics if this thread-local has not previously been
112 /// [set](GuardedKey::set).
113 ///
114 /// Panics if the [Clone] implementation of `T` accesses this same thread
115 /// local.
116 pub fn get(&'static self) -> T {
117 let Some(val) = self.inner.with_borrow(|inner| inner.item.last().cloned()) else {
118 panic!("cannot access a guarded thread local variable without calling `set` first")
119 };
120
121 // The top of the stack cannot be None, as Guard::drop will pop from the stack
122 // until it finds a non-None entry.
123 val.expect("internal error: top of item list is none")
124 }
125}
126
127#[doc(hidden)]
128pub struct Inner<T: 'static> {
129 item: Vec<Option<T>>,
130}
131
132impl<T: 'static> Inner<T> {
133 #[doc(hidden)]
134 pub const fn new() -> Self {
135 Self { item: Vec::new() }
136 }
137}
138
139/// Keeps a thread local value alive. Removes its associated value from the
140/// stack upon being dropped.
141pub struct Guard<T: 'static> {
142 inner: &'static LocalKey<RefCell<Inner<T>>>,
143 index: usize,
144}
145
146impl<T> Drop for Guard<T> {
147 /// Removes associated value from the thread-local stack. If this is the
148 /// last existing guard for this thread-local, then any
149 /// subsequent [get](GuardedKey::get) will panic unless the thread-local
150 /// is [set](GuardedKey::set) again.
151 fn drop(&mut self) {
152 self.inner.with_borrow_mut(|inner| {
153 *inner.item.get_mut(self.index).unwrap() = None;
154
155 while let Some(item) = inner.item.last() {
156 if item.is_none() {
157 let _ = inner.item.pop();
158 } else {
159 break;
160 }
161 }
162 });
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 #[test]
169 fn smoke() {
170 guarded_thread_local!(static FOO: u32);
171 let _foo_guard_1 = FOO.set(3);
172 assert_eq!(FOO.get(), 3);
173 assert_eq!(FOO.get(), 3);
174
175 let foo_guard_2 = FOO.set(123);
176 assert_eq!(FOO.get(), 123);
177
178 drop(foo_guard_2);
179 assert_eq!(FOO.get(), 3);
180 }
181
182 #[test]
183 #[should_panic(
184 expected = "cannot access a guarded thread local variable without calling `set` first"
185 )]
186 fn get_without_set() {
187 guarded_thread_local!(static FOO: u32);
188 let _ = FOO.get();
189 }
190
191 #[test]
192 fn out_of_order_guard_drop() {
193 guarded_thread_local!(static FOO: u32);
194 let guard_1 = FOO.set(1);
195 let guard_2 = FOO.set(2);
196 let guard_3 = FOO.set(3);
197 assert_eq!(FOO.get(), 3);
198
199 drop(guard_1);
200 assert_eq!(FOO.get(), 3);
201
202 drop(guard_3);
203 assert_eq!(FOO.get(), 2);
204
205 drop(guard_2);
206 }
207
208 #[test]
209 fn non_copy_type() {
210 guarded_thread_local!(static FOO: String);
211 let _guard_1 = FOO.set("x".into());
212 let guard_2 = FOO.set("y".into());
213
214 assert_eq!(FOO.get(), "y");
215 drop(guard_2);
216 assert_eq!(FOO.get(), "x");
217 }
218
219 #[test]
220 #[should_panic(expected = "already borrowed: BorrowMutError")]
221 fn clone_access_same_thread_local() {
222 guarded_thread_local!(static FOO: X);
223
224 struct X;
225
226 impl Clone for X {
227 fn clone(&self) -> Self {
228 let _ = FOO.set(X);
229 X
230 }
231 }
232
233 let _guard = FOO.set(X);
234 let _ = FOO.get();
235 }
236}