Skip to main content

commonware_utils/
thread_local.rs

1//! RAII guard for thread-local caching of expensive-to-construct objects.
2//!
3//! # Overview
4//!
5//! When an object is expensive to construct but cheap to reset and must be
6//! used within a stateless function, keeping one instance per thread avoids
7//! repeated allocation. The manual take-then-return pattern is fragile:
8//! forgetting the return silently degrades to constructing a new instance.
9//!
10//! [`Cached`] is an RAII guard whose [`Drop`] automatically returns the
11//! value to the thread-local slot, so forgetting the return is impossible.
12//!
13//! # Synchronization
14//!
15//! This cache provides no synchronization guarantees across threads.
16//! Each thread has an independent slot.
17//!
18//! Within one thread, only one guard per cache can be held at a time.
19//! Attempting to acquire a second guard before dropping the first will panic.
20//!
21//! # Examples
22//!
23//! ```
24//! use commonware_utils::{thread_local_cache, Cached};
25//!
26//! thread_local_cache!(static POOL: String);
27//!
28//! let guard = Cached::take(&POOL, || Ok::<_, ()>(String::new()), |s| { s.clear(); Ok(()) }).unwrap();
29//! assert_eq!(&*guard, "");
30//! drop(guard);
31//!
32//! // Second take reuses the cached instance.
33//! let guard = Cached::take(&POOL, || Ok::<_, ()>(String::new()), |s| { s.clear(); Ok(()) }).unwrap();
34//! drop(guard);
35//! ```
36
37use std::{
38    cell::RefCell,
39    marker::PhantomData,
40    ops::{Deref, DerefMut},
41    thread::LocalKey,
42};
43
44/// Restores slot state if `take` returns early or unwinds.
45///
46/// While active, drop clears the "held" bit and puts any cached value
47/// back into the TLS slot.
48struct TakeCleanup<T: 'static> {
49    cache: &'static LocalKey<RefCell<(bool, Option<T>)>>,
50    cached: Option<T>,
51    armed: bool,
52}
53
54impl<T: 'static> TakeCleanup<T> {
55    const fn disarm(&mut self) {
56        self.armed = false;
57    }
58}
59
60impl<T: 'static> Drop for TakeCleanup<T> {
61    fn drop(&mut self) {
62        if !self.armed {
63            return;
64        }
65        self.cache.with(|cell| {
66            let mut slot = cell.borrow_mut();
67            debug_assert!(slot.0, "cache expected to be held");
68            slot.0 = false;
69            slot.1 = self.cached.take();
70        });
71    }
72}
73
74/// RAII guard that borrows a value from a thread-local cache and returns it
75/// on drop.
76///
77/// Guards are thread-affine and must be dropped on the same thread where
78/// they were created.
79pub struct Cached<T: 'static> {
80    value: Option<T>,
81    cache: &'static LocalKey<RefCell<(bool, Option<T>)>>,
82    _not_send: PhantomData<*const ()>,
83}
84
85impl<T: 'static> Cached<T> {
86    /// Take a value from the thread-local `cache`.
87    ///
88    /// On a cache hit the `reset` closure reconfigures the existing instance.
89    /// On a miss the `create` closure constructs a new one. Both closures may
90    /// fail with `E`.
91    ///
92    /// This cache provides no synchronization guarantees.
93    /// Attempting to take a second guard from the same cache on the same
94    /// thread while one is already held will panic.
95    pub fn take<E>(
96        cache: &'static LocalKey<RefCell<(bool, Option<T>)>>,
97        create: impl FnOnce() -> Result<T, E>,
98        reset: impl FnOnce(&mut T) -> Result<(), E>,
99    ) -> Result<Self, E> {
100        let cached = cache.with(|cell| {
101            let mut slot = cell.borrow_mut();
102            assert!(!slot.0, "cache already held on this thread");
103            slot.0 = true;
104            slot.1.take()
105        });
106        let mut cleanup = TakeCleanup {
107            cache,
108            cached,
109            armed: true,
110        };
111        let value = match cleanup.cached.take() {
112            Some(mut v) => {
113                if let Err(err) = reset(&mut v) {
114                    cleanup.cached = Some(v);
115                    return Err(err);
116                }
117                v
118            }
119            None => create()?,
120        };
121        cleanup.disarm();
122        Ok(Self {
123            value: Some(value),
124            cache,
125            _not_send: PhantomData,
126        })
127    }
128}
129
130impl<T: 'static> Deref for Cached<T> {
131    type Target = T;
132
133    fn deref(&self) -> &T {
134        self.value.as_ref().expect("value taken after drop")
135    }
136}
137
138impl<T: 'static> DerefMut for Cached<T> {
139    fn deref_mut(&mut self) -> &mut T {
140        self.value.as_mut().expect("value taken after drop")
141    }
142}
143
144impl<T: 'static> Drop for Cached<T> {
145    fn drop(&mut self) {
146        if let Some(v) = self.value.take() {
147            self.cache.with(|cell| {
148                let mut slot = cell.borrow_mut();
149                debug_assert!(slot.0, "cache expected to be held");
150                slot.0 = false;
151                slot.1 = Some(v);
152            });
153        }
154    }
155}
156
157/// Declare a thread-local slot for use with [`Cached`].
158///
159/// ```ignore
160/// thread_local_cache!(static SLOT: MyType);
161/// ```
162///
163/// Expands to a `thread_local!` declaration wrapping
164/// `RefCell<(bool, Option<MyType>)>` where:
165/// - `(false, None)` means uninitialized
166/// - `(false, Some(_))` means available
167/// - `(true, None)` means held
168#[macro_export]
169macro_rules! thread_local_cache {
170    (static $name:ident : $ty:ty) => {
171        ::std::thread_local! {
172            static $name: ::std::cell::RefCell<(bool, ::core::option::Option<$ty>)> =
173                const { ::std::cell::RefCell::new((false, ::core::option::Option::None)) };
174        }
175    };
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    thread_local_cache!(static TEST_CACHE: Vec<u8>);
183
184    #[test]
185    fn test_take_creates_on_miss() {
186        let guard = Cached::take(&TEST_CACHE, || Ok::<_, ()>(vec![1, 2, 3]), |_v| Ok(())).unwrap();
187        assert_eq!(&*guard, &[1, 2, 3]);
188    }
189
190    thread_local_cache!(static REUSE_CACHE: Vec<u8>);
191
192    #[test]
193    fn test_take_reuses_on_hit() {
194        // First take: create
195        let mut guard = Cached::take(
196            &REUSE_CACHE,
197            || Ok::<_, ()>(vec![1, 2, 3]),
198            |v| {
199                v.clear();
200                Ok(())
201            },
202        )
203        .unwrap();
204        guard.push(4);
205        drop(guard);
206
207        // Second take: reuse (reset clears, so we get an empty vec)
208        let guard = Cached::take(
209            &REUSE_CACHE,
210            || Ok::<_, ()>(vec![99]),
211            |v| {
212                v.clear();
213                Ok(())
214            },
215        )
216        .unwrap();
217        assert!(guard.is_empty(), "reset should have cleared the vec");
218    }
219
220    thread_local_cache!(static DROP_CACHE: String);
221
222    #[test]
223    fn test_drop_returns_to_cache() {
224        {
225            let _guard = Cached::take(
226                &DROP_CACHE,
227                || Ok::<_, ()>(String::from("hello")),
228                |_| Ok(()),
229            )
230            .unwrap();
231            // guard drops here
232        }
233
234        // Cache should now hold the value
235        let has_value = DROP_CACHE.with(|cell| cell.borrow().1.is_some());
236        assert!(has_value, "drop should return value to cache");
237    }
238
239    thread_local_cache!(static ERR_CACHE: u32);
240
241    #[test]
242    fn test_create_error_propagates() {
243        let result = Cached::take(&ERR_CACHE, || Err::<u32, &str>("create failed"), |_| Ok(()));
244        assert!(result.is_err());
245
246        // A failed create should not leave the cache marked as held.
247        let guard = Cached::take(&ERR_CACHE, || Ok::<u32, &str>(7), |_| Ok(())).unwrap();
248        assert_eq!(*guard, 7);
249    }
250
251    thread_local_cache!(static RESET_ERR_CACHE: u32);
252
253    #[test]
254    fn test_reset_error_propagates() {
255        // Seed the cache
256        {
257            let _guard = Cached::take(&RESET_ERR_CACHE, || Ok::<_, &str>(42), |_| Ok(())).unwrap();
258        }
259
260        // Now take again; reset should fail
261        let result = Cached::take(
262            &RESET_ERR_CACHE,
263            || Ok::<_, &str>(0),
264            |_| Err("reset failed"),
265        );
266        assert!(result.is_err());
267
268        // Failed reset should not evict the cached value.
269        let cached = RESET_ERR_CACHE.with(|cell| cell.borrow().1);
270        assert_eq!(cached, Some(42));
271    }
272
273    thread_local_cache!(static NESTED_CACHE: Vec<u8>);
274
275    #[test]
276    fn test_nested_guards_rejected() {
277        NESTED_CACHE.with(|cell| *cell.borrow_mut() = (false, None));
278
279        let result = std::panic::catch_unwind(|| {
280            let mut outer =
281                Cached::take(&NESTED_CACHE, || Ok::<_, ()>(vec![1]), |_| Ok(())).unwrap();
282            outer.push(10);
283            let _inner = Cached::take(&NESTED_CACHE, || Ok::<_, ()>(vec![2]), |_| Ok(())).unwrap();
284        });
285        assert!(result.is_err(), "nested take on same thread should panic");
286
287        // Outer guard should have returned its value while unwinding.
288        let cached = NESTED_CACHE.with(|cell| cell.borrow().1.clone());
289        assert_eq!(cached, Some(vec![1, 10]));
290    }
291
292    thread_local_cache!(static PANIC_CREATE_CACHE: u32);
293
294    #[test]
295    fn test_create_panic_does_not_poison_held_flag() {
296        let result = std::panic::catch_unwind(|| {
297            let _ = Cached::take(
298                &PANIC_CREATE_CACHE,
299                || -> Result<u32, ()> { panic!("create panic") },
300                |_| Ok(()),
301            );
302        });
303        assert!(result.is_err());
304
305        let guard = Cached::take(&PANIC_CREATE_CACHE, || Ok::<_, ()>(7), |_| Ok(())).unwrap();
306        assert_eq!(*guard, 7);
307    }
308
309    thread_local_cache!(static PANIC_RESET_CACHE: u32);
310
311    #[test]
312    fn test_reset_panic_does_not_poison_held_flag() {
313        {
314            let _guard = Cached::take(&PANIC_RESET_CACHE, || Ok::<_, ()>(42), |_| Ok(())).unwrap();
315        }
316
317        let result = std::panic::catch_unwind(|| {
318            let _ = Cached::take(
319                &PANIC_RESET_CACHE,
320                || Ok::<_, ()>(0),
321                |_| -> Result<(), ()> { panic!("reset panic") },
322            );
323        });
324        assert!(result.is_err());
325
326        let guard = Cached::take(&PANIC_RESET_CACHE, || Ok::<_, ()>(9), |_| Ok(())).unwrap();
327        assert_eq!(*guard, 9);
328    }
329}