dropping_thread_local/
lib.rs

1#![deny(
2    // currently there is no unsafe code
3    unsafe_code,
4    // public library should have docs
5    missing_docs,
6)]
7//! Dynamically allocated thread locals that properly run destructors when a thread is destroyed.
8//!
9//! This is in contrast to the [`thread_local`] crate, which has similar functionality,
10//! but only runs destructors when the `ThreadLocal` object is dropped.
11//! This crate guarantees that one thread will never see the thread-local data of another,
12//! which can happen in the `thread_local` crate due to internal storage reuse.
13//!
14//! This crate attempts to implement "true" thread locals,
15//! mirroring [`std::thread_local!`] as closely as possible.
16//! I would say the `thread_local` crate is good for functionality like reusing allocations
17//! or for having local caches that can be sensibly reused once a thread dies.
18//!
19//! This crate will attempt to run destructors as promptly as possible,
20//! but taking snapshots may interfere with this (see below).
21//! Panics in thread destructors will cause aborts, just like they do with [`std::thread_local!`].
22//!
23//! Right now, this crate has no unsafe code.
24//! This may change if it can bring a significant performance improvement.
25//!
26//! # Snapshots
27//! The most complicated feature of this library is snapshots.
28//! It allows anyone who has access to a [`DroppingThreadLocal`] to iterate over all currently live
29//! values using the [`DroppingThreadLocal::snapshot_iter`] method.
30//!
31//! This will return a snapshot of the live values at the time the method is called,
32//! although if a thread dies during iteration, it may not show up.
33//! See the method documentation for more details.
34//!
35//! # Performance
36//! Lookup is based around a hashmap, and I expect the current implementation to be noticeably slower than either [`std::thread_local!`] or the [`thread_local`] crate.
37//! The former is a low-cost abstraction over native thread local storage and the latter is written by the author of`hashbrown` and `parking_lot`.
38//!
39//! A very basic benchmark on my M1 Mac and Linux Laptop (`i5-7200U` circa 2017) gives the following results:
40//!
41//! | library                 | does `Arc::clone` | time (M1 Mac) | time (i5 ~2017 Laptop)    |
42//! |-------------------------|-------------------|---------------|---------------------------|
43//! | `std`                   | no                |  0.42 ns      |  0.69 ns                  |
44//! | `std`                   | *yes*             | 11.49 ns      | 14.01 ns                  |
45//! | `thread_local`          | no                |  1.38 ns      |  1.38 ns                  |
46//! | `thread_local`          | *yes*             | 11.43 ns      | 14.02 ns                  |
47//! | `dropping_thread_local` | *yes*             | 13.14 ns      | 31.14 ns                  |
48//!
49//! Every lookup in the current implementation of `dropping_thread_local` requires calling `Arc::clone`.
50//! This has significant overhead in its own right, so I benchmarked the other libraries both storing their data an regular `Box` vs. storing data in an `Arc` and doing `Arc::clone`.
51//!
52//! On my Mac, the library ranges between 30% slower than calling `thread_local::ThreadLocal::get` + `Arc::clone` and 30x slower than a plain `std::thread_local!`. On my older Linux laptop, this library ranges between 3x slower than `thread_local::ThreadLocal::get` + `Arc::clone` and 60x slower than a plain `std::thread_local`.
53//!
54//! This performance is a lot better than I expected (at least on the macbook). I am also disappointed by the performance of `Arc::clone`. Further improvements beyond this will almost certainly require amount of `unsafe` code. I have three ideas for improvement:
55//!
56//! - Avoid requiring `Arc::clone` by using a [`LocalKey::with`] style API, and making `drop(DroppingThreadLocal)` delay freeing values from live threads until after that live thread dies.
57//! - Use [biased reference counting] instead of an `Arc`. This would not require `unsafe` code directly in this crate. Unfortunately, biased reference counting can delay destruction or even leak if the heartbeat function is not called. The [`trc` crate] will not work, as `trc::Trc` is `!Send`.
58//! - Using [`boxcar::Vec`] instead of a `HashMap` for lookup. This is essentially the same data structure that the `thread_local` crate uses, so should make the lookup performance similar.
59//!
60//! [biased reference counting]: https://dl.acm.org/doi/10.1145/3243176.3243195
61//! [`LocalKey::with`]: https://doc.rust-lang.org/std/thread/struct.LocalKey.html#method.with
62//! [`boxcar::Vec`]: https://docs.rs/boxcar/0.2.13/boxcar/struct.Vec.html
63//! [`trc` crate]: https://github.com/ericlbuehler/trc
64//!
65//! ## Locking
66//! The implementation needs to acquire a global lock to initialize/deinitialize threads and create new locals.
67//! Accessing thread-local data is also protected by a per-thread lock.
68//! This lock should be uncontended, and [`parking_lot::Mutex`] should make this relatively fast.
69//! I have been careful to make sure that locks are not held while user code is being executed.
70//! This includes releasing locks before any destructors are executed.
71//!
72//! # Limitations
73//! The type that is stored must be `Send + Sync + 'static`.
74//! The `Send` bound is necessary because the [`DroppingThreadLocal`] may be dropped from any thread.
75//! The `Sync` bound is necessary to support snapshots,
76//! and the `'static` bound is due to internal implementation chooses (use of safe code).
77//!
78//! A Mutex can be used to work around the `Sync` limitation.
79//! (I recommend [`parking_lot::Mutex`], which is optimized for uncontented locks)
80//! You can attempt to use the [`fragile`] crate to work around the `Send` limitation,
81//! but this will cause panics if the value is dropped from another thead.
82//! Some ways a value can be dropped from another thread if a snapshot keeps the value alive,
83//! or if the [`DroppingThreadLocal`] itself is dropped.
84//!
85//! [`thread_local`]: https://docs.rs/thread_local/1.1/thread_local/
86//! [`fragile`]: https://docs.rs/fragile/2/fragile/
87
88extern crate alloc;
89extern crate core;
90
91use alloc::rc::Rc;
92use alloc::sync::{Arc, Weak};
93use core::any::Any;
94use core::fmt::{Debug, Formatter};
95use core::hash::{Hash, Hasher};
96use core::marker::PhantomData;
97use core::num::NonZero;
98use core::ops::Deref;
99use core::sync::atomic::Ordering;
100use std::thread::ThreadId;
101
102use parking_lot::Mutex;
103use portable_atomic::AtomicU64;
104
105/// A thread local that drops its value when the thread is destroyed.
106///
107/// See module-level documentation for more details.
108///
109/// Dropping this value will free all the associated values.
110pub struct DroppingThreadLocal<T: Send + Sync + 'static> {
111    id: UniqueLocalId,
112    marker: PhantomData<Arc<T>>,
113}
114impl<T: Send + Sync + 'static + Debug> Debug for DroppingThreadLocal<T> {
115    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
116        let value = self.get();
117        f.debug_struct("DroppingThreadLocal")
118            .field("local_data", &value.as_ref().map(|value| value.as_ref()))
119            .finish()
120    }
121}
122impl<T: Send + Sync + 'static> Default for DroppingThreadLocal<T> {
123    #[inline]
124    fn default() -> Self {
125        DroppingThreadLocal::new()
126    }
127}
128impl<T: Send + Sync + 'static> DroppingThreadLocal<T> {
129    /// Create a new thread-local value.
130    #[inline]
131    pub fn new() -> Self {
132        DroppingThreadLocal {
133            id: UniqueLocalId::alloc(),
134            marker: PhantomData,
135        }
136    }
137    /// Get the value associated with the current thread,
138    /// or `None` if not initialized.
139    #[inline]
140    pub fn get(&self) -> Option<SharedRef<T>> {
141        THREAD_STATE.with(|thread| {
142            Some(SharedRef {
143                thread_id: thread.id,
144                value: thread.get(self.id)?.downcast::<T>().expect("unexpected type"),
145            })
146        })
147    }
148    /// Get the value associated with the current thread,
149    /// initializing it if not yet defined.
150    ///
151    /// Panics if double initialization is detected.
152    pub fn get_or_init(&self, func: impl FnOnce() -> T) -> SharedRef<T> {
153        match self.get_or_try_init::<core::convert::Infallible>(|| Ok(func())) {
154            Ok(success) => success,
155        }
156    }
157
158    /// Get the value associated with the current thread,
159    /// attempting to initialize it if not yet defined.
160    ///
161    /// Panics if double initialization is detected.
162    pub fn get_or_try_init<E>(&self, func: impl FnOnce() -> Result<T, E>) -> Result<SharedRef<T>, E> {
163        THREAD_STATE.with(|thread| {
164            let value = match thread.get(self.id) {
165                Some(existing) => existing,
166                None => {
167                    let new_value = Arc::new(func()?) as DynArc;
168                    thread.init(self.id, &new_value);
169                    new_value
170                }
171            };
172            let value = value.downcast::<T>().expect("unexpected type");
173            Ok(SharedRef {
174                thread_id: thread.id,
175                value,
176            })
177        })
178    }
179    /// Iterate over currently live values and their associated thread ids.
180    ///
181    /// New threads that have been spanned after the snapshot was taken will not be present
182    /// in the iterator.
183    /// Threads that die after the snapshot is taken may or may not be present.
184    /// Values from threads that die before the snapshot will not be present.
185    ///
186    /// The order of the iteration is undefined.
187    pub fn snapshot_iter(&self) -> SnapshotIter<T> {
188        let Some(snapshot) = snapshot_live_threads() else {
189            return SnapshotIter {
190                local_id: self.id,
191                iter: None,
192                marker: PhantomData,
193            };
194        };
195        SnapshotIter {
196            local_id: self.id,
197            iter: Some(snapshot.into_iter()),
198            marker: PhantomData,
199        }
200    }
201}
202impl<T: Send + Sync + 'static> Drop for DroppingThreadLocal<T> {
203    fn drop(&mut self) {
204        // want to drop without holding the lock
205        let Some(snapshot) = snapshot_live_threads() else {
206            // no live threads -> nothing to free
207            return;
208        };
209        // panics won't cause aborts here, there is no need for them
210        for (thread_id, thread) in snapshot {
211            if let Some(thread) = Weak::upgrade(&thread) {
212                assert_eq!(thread.id, thread_id);
213                let value: Option<DynArc> = {
214                    let mut lock = thread.values.lock();
215                    lock.remove(&self.id)
216                };
217                // drop value once lock no longer held
218                drop(value);
219            }
220        }
221    }
222}
223/// Iterates over a snapshot of the values,
224/// given by [`DroppingThreadLocal::snapshot_iter`].`
225///
226/// Due to thread death, it is not possible to know the exact size of the iterator.
227pub struct SnapshotIter<T: Send + Sync + 'static> {
228    local_id: UniqueLocalId,
229    iter: Option<imbl::hashmap::ConsumingIter<(ThreadId, Weak<LiveThreadState>), imbl::shared_ptr::DefaultSharedPtr>>,
230    // do not make Send+Sync, for flexibility in the future
231    marker: PhantomData<Rc<T>>,
232}
233impl<T: Send + Sync + 'static> Iterator for SnapshotIter<T> {
234    type Item = SharedRef<T>;
235
236    #[inline]
237    fn next(&mut self) -> Option<Self::Item> {
238        loop {
239            // None from either of these means we have nothing left to iterate
240            let (thread_id, thread) = self.iter.as_mut()?.next()?;
241            let Some(thread) = Weak::upgrade(&thread) else { continue };
242            let Some(arc) = ({
243                let lock = thread.values.lock();
244                lock.get(&self.local_id).cloned()
245            }) else {
246                continue;
247            };
248            return Some(SharedRef {
249                thread_id,
250                value: arc.downcast::<T>().expect("mismatched type"),
251            });
252        }
253    }
254
255    #[inline]
256    fn size_hint(&self) -> (usize, Option<usize>) {
257        match self.iter {
258            Some(ref iter) => {
259                // may be zero if all threads die
260                (0, Some(iter.len()))
261            }
262            None => (0, Some(0)),
263        }
264    }
265}
266impl<T: Send + Sync + 'static> core::iter::FusedIterator for SnapshotIter<T> {}
267
268/// A shared reference to a thread local value.
269///
270/// This may be cloned and sent across threads.
271/// May delay destruction of value past thread death.
272#[derive(Clone, Debug)]
273pub struct SharedRef<T> {
274    thread_id: ThreadId,
275    value: Arc<T>,
276}
277impl<T> SharedRef<T> {
278    /// The thread id the value was
279    #[inline]
280    pub fn thread_id(this: &Self) -> ThreadId {
281        this.thread_id
282    }
283}
284
285impl<T> Deref for SharedRef<T> {
286    type Target = T;
287
288    #[inline]
289    fn deref(&self) -> &Self::Target {
290        &self.value
291    }
292}
293impl<T> AsRef<T> for SharedRef<T> {
294    #[inline]
295    fn as_ref(&self) -> &T {
296        &self.value
297    }
298}
299impl<T: Hash> Hash for SharedRef<T> {
300    #[inline]
301    fn hash<H: Hasher>(&self, state: &mut H) {
302        self.value.hash(state);
303    }
304}
305impl<T: Eq> Eq for SharedRef<T> {}
306impl<T: PartialEq> PartialEq for SharedRef<T> {
307    #[inline]
308    fn eq(&self, other: &Self) -> bool {
309        self.value == other.value
310    }
311}
312impl<T: PartialOrd> PartialOrd for SharedRef<T> {
313    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
314        T::partial_cmp(&self.value, &other.value)
315    }
316}
317impl<T: Ord> Ord for SharedRef<T> {
318    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
319        T::cmp(&self.value, &other.value)
320    }
321}
322
323struct UniqueIdAllocator {
324    next_id: AtomicU64,
325}
326impl UniqueIdAllocator {
327    const fn new() -> Self {
328        UniqueIdAllocator {
329            next_id: AtomicU64::new(1),
330        }
331    }
332    fn alloc(&self) -> NonZero<u64> {
333        NonZero::new(
334            self.next_id
335                .fetch_update(Ordering::AcqRel, Ordering::Acquire, |x| x.checked_add(1))
336                .expect("id overflow"),
337        )
338        .unwrap()
339    }
340}
341#[derive(Copy, Clone, Debug, Eq, PartialOrd, PartialEq, Hash)]
342struct UniqueLocalId(NonZero<u64>);
343impl UniqueLocalId {
344    fn alloc() -> Self {
345        static ALLOCATOR: UniqueIdAllocator = UniqueIdAllocator::new();
346        UniqueLocalId(ALLOCATOR.alloc())
347    }
348}
349
350type LiveThreadMap = imbl::GenericHashMap<
351    ThreadId,
352    Weak<LiveThreadState>,
353    foldhash::fast::RandomState,
354    imbl::shared_ptr::DefaultSharedPtr,
355>;
356/// Map of currently live threads.
357///
358/// This is a persistent map to allow quick snapshots to be taken by drop function and iteration/
359/// I use `imbl` instead of `rpds`, because `rpds` doesn't support an owned iterator,
360/// only a borrowed iterator which would require an extra allocation.
361/// I use a hashmap instead of a `BTreeMap` because that would require `ThreadId: Ord`,
362/// which the stdlib doesn't have.
363static LIVE_THREADS: Mutex<Option<LiveThreadMap>> = Mutex::new(None);
364fn snapshot_live_threads() -> Option<LiveThreadMap> {
365    let lock = LIVE_THREADS.lock();
366    lock.as_ref().cloned()
367}
368
369thread_local! {
370    static THREAD_STATE: Arc<LiveThreadState> = {
371        let id = std::thread::current().id();
372        let state =Arc::new(LiveThreadState {
373            id,
374            values: Mutex::new(foldhash::HashMap::default()),
375        });
376        let mut live_threads = LIVE_THREADS.lock();
377        let live_threads = live_threads.get_or_insert_default();
378        use imbl::hashmap::Entry;
379        match live_threads.entry(id) {
380            Entry::Occupied(_) => panic!("reinitialized thread"),
381            Entry::Vacant(entry) => {
382                entry.insert(Arc::downgrade(&state));
383            }
384        }
385        state
386    };
387}
388type DynArc = Arc<dyn Any + Send + Sync + 'static>;
389struct LiveThreadState {
390    id: ThreadId,
391    /// Maps from local ids to values.
392    ///
393    /// ## Performance
394    /// This is noticeably slower than what `thread_local` offers.
395    ///
396    /// We could make it faster if we used a vector
397    /// A `boxcar::Vec` is essentially the same data structure as `thread_local` uses,
398    /// and would allow us to get rid of the lock.
399    /// The only difference is we would be mapping from local ids -> values,
400    /// However, the lock is uncontented and parking_lot makes that relatively cheap,
401    /// so a `Mutex<Vec<T>> might also work and be simpler.
402    /// To avoid unbounded memory usage if locals are constantly being allocated/drops,
403    /// this would require reusing indexes.
404    values: Mutex<foldhash::HashMap<UniqueLocalId, DynArc>>,
405}
406impl LiveThreadState {
407    #[inline]
408    fn get(&self, id: UniqueLocalId) -> Option<DynArc> {
409        let lock = self.values.lock();
410        Some(Arc::clone(lock.get(&id)?))
411    }
412    // the arc has dynamic type to avoid monomorphization
413    #[cold]
414    #[inline(never)]
415    fn init(&self, id: UniqueLocalId, new_value: &DynArc) {
416        let mut lock = self.values.lock();
417        use std::collections::hash_map::Entry;
418        match lock.entry(id) {
419            Entry::Occupied(_) => {
420                panic!("unexpected double initialization of thread-local value")
421            }
422            Entry::Vacant(entry) => {
423                entry.insert(Arc::clone(new_value));
424            }
425        }
426    }
427}
428impl Drop for LiveThreadState {
429    fn drop(&mut self) {
430        // clear all our values
431        self.values.get_mut().clear();
432        // remove from the list of live threads
433        {
434            let mut threads = LIVE_THREADS.lock();
435            if let Some(threads) = threads.as_mut() {
436                // no fear of dropping while locked because we control the type
437                let state: Option<Weak<LiveThreadState>> = threads.remove(&self.id);
438                drop(state)
439            }
440        }
441    }
442}