dropping_thread_local/
lib.rs

1#![deny(
2    // currently there is no unsafe code
3    unsafe_code,
4    // public library should have docs
5    missing_docs,
6)]
7//! Dynamically allocated thread locals that properly run destructors when a thread is destroyed.
8//!
9//! This is in contrast to the [`thread_local`] crate, which has similar functionality,
10//! but only runs destructors when the `ThreadLocal` object is dropped.
11//! This crate guarantees that one thread will never see the thread-local data of another,
12//! which can happen in the `thread_local` crate due to internal storage reuse.
13//!
14//! This crate attempts to implement "true" thread locals,
15//! mirroring [`std::thread_local!`] as closely as possible.
16//! I would say the `thread_local` crate is good for functionality like reusing allocations
17//! or for having local caches that can be sensibly reused once a thread dies.
18//!
19//! This crate will attempt to run destructors as promptly as possible,
20//! but taking snapshots may interfere with this (see below).
21//! Panics in thread destructors will cause aborts, just like they do with [`std::thread_local!`].
22//!
23//! Right now, this crate has no unsafe code.
24//! This may change if it can bring a significant performance improvement.
25//!
26//! # Snapshots
27//! The most complicated feature of this library is snapshots.
28//! It allows anyone who has access to a [`DroppingThreadLocal`] to iterate over all currently live
29//! values using the [`DroppingThreadLocal::snapshot_iter`] method.
30//!
31//! This will return a snapshot of the live values at the time the method is called,
32//! although if a thread dies during iteration, it may not show up.
33//! See the method documentation for more details.
34//!
35//! # Performance
36//! Lookup is based around a hashmap, and I expect the current implementation to be noticeably slower than either [`std::thread_local!`] or the [`thread_local`] crate.
37//! The former is a low-cost abstraction over native thread local storage and the latter is written by the author of`hashbrown` and `parking_lot`.
38//!
39//! A very basic benchmark on my M1 Mac and Linux Laptop (`i5-7200U` circa 2017) gives the following results:
40//!
41//! | library                 | does `Arc::clone` | time (M1 Mac) | time (i5 ~2017 Laptop)    |
42//! |-------------------------|-------------------|---------------|---------------------------|
43//! | `std`                   | no                |  0.42 ns      |  0.69 ns                  |
44//! | `std`                   | *yes*             | 11.49 ns      | 14.01 ns                  |
45//! | `thread_local`          | no                |  1.38 ns      |  1.38 ns                  |
46//! | `thread_local`          | *yes*             | 11.43 ns      | 14.02 ns                  |
47//! | `dropping_thread_local` | *yes*             | 13.14 ns      | 31.14 ns                  |
48//!
49//! Every lookup in the current implementation of `dropping_thread_local` requires calling `Arc::clone`.
50//! This has significant overhead in its own right, so I benchmarked the other libraries both storing their data an regular `Box` vs. storing data in an `Arc` and doing `Arc::clone`.
51//!
52//! On my Mac, the library ranges between 30% slower than calling `thread_local::ThreadLocal::get` + `Arc::clone` and 30x slower than a plain `std::thread_local!`. On my older Linux laptop, this library ranges between 3x slower than `thread_local::ThreadLocal::get` + `Arc::clone` and 60x slower than a plain `std::thread_local`.
53//!
54//! This performance is a lot better than I expected (at least on the macbook). I am also disappointed by the performance of `Arc::clone`. Further improvements beyond this will almost certainly require amount of `unsafe` code. I have three ideas for improvement:
55//!
56//! - Avoid requiring `Arc::clone` by using a [`LocalKey::with`] style API, and making `drop(DroppingThreadLocal)` delay freeing values from live threads until after that live thread dies.
57//! - Use [biased reference counting] instead of an `Arc`. This would not require `unsafe` code directly in this crate. Unfortunately, biased reference counting can delay destruction or even leak if the heartbeat function is not called. The [`trc` crate] will not work, as `trc::Trc` is `!Send`.
58//! - Using [`boxcar::Vec`] instead of a `HashMap` for lookup. This is essentially the same data structure that the `thread_local` crate uses, so should make the lookup performance similar.
59//!
60//! *NOTE*: Simply removing `Arc::clone` doesn't help that much. On the M1 mac it reduces the time to 9 ns. (I did this unsoundly, so it cant be published)
61//!
62//! [biased reference counting]: https://dl.acm.org/doi/10.1145/3243176.3243195
63//! [`LocalKey::with`]: https://doc.rust-lang.org/std/thread/struct.LocalKey.html#method.with
64//! [`boxcar::Vec`]: https://docs.rs/boxcar/0.2.13/boxcar/struct.Vec.html
65//! [`trc` crate]: https://github.com/ericlbuehler/trc
66//!
67//! ## Locking
68//! The implementation needs to acquire a global lock to initialize/deinitialize threads and create new locals.
69//! Accessing thread-local data is also protected by a per-thread lock.
70//! This lock should be uncontended, and [`parking_lot::Mutex`] should make this relatively fast.
71//! I have been careful to make sure that locks are not held while user code is being executed.
72//! This includes releasing locks before any destructors are executed.
73//!
74//! # Limitations
75//! The type that is stored must be `Send + Sync + 'static`.
76//! The `Send` bound is necessary because the [`DroppingThreadLocal`] may be dropped from any thread.
77//! The `Sync` bound is necessary to support snapshots,
78//! and the `'static` bound is due to internal implementation chooses (use of safe code).
79//!
80//! A Mutex can be used to work around the `Sync` limitation.
81//! (I recommend [`parking_lot::Mutex`], which is optimized for uncontented locks)
82//! You can attempt to use the [`fragile`] crate to work around the `Send` limitation,
83//! but this will cause panics if the value is dropped from another thead.
84//! Some ways a value can be dropped from another thread if a snapshot keeps the value alive,
85//! or if the [`DroppingThreadLocal`] itself is dropped.
86//!
87//! [`thread_local`]: https://docs.rs/thread_local/1.1/thread_local/
88//! [`fragile`]: https://docs.rs/fragile/2/fragile/
89
90extern crate alloc;
91extern crate core;
92
93use alloc::rc::Rc;
94use alloc::sync::{Arc, Weak};
95use core::any::Any;
96use core::fmt::{Debug, Formatter};
97use core::hash::{Hash, Hasher};
98use core::marker::PhantomData;
99use core::num::NonZero;
100use core::ops::Deref;
101use core::sync::atomic::Ordering;
102use std::thread::ThreadId;
103
104use parking_lot::Mutex;
105use portable_atomic::AtomicU64;
106
107/// A thread local that drops its value when the thread is destroyed.
108///
109/// See module-level documentation for more details.
110///
111/// Dropping this value will free all the associated values.
112pub struct DroppingThreadLocal<T: Send + Sync + 'static> {
113    id: UniqueLocalId,
114    marker: PhantomData<Arc<T>>,
115}
116impl<T: Send + Sync + 'static + Debug> Debug for DroppingThreadLocal<T> {
117    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
118        let value = self.get();
119        f.debug_struct("DroppingThreadLocal")
120            .field("local_data", &value.as_ref().map(|value| value.as_ref()))
121            .finish()
122    }
123}
124impl<T: Send + Sync + 'static> Default for DroppingThreadLocal<T> {
125    #[inline]
126    fn default() -> Self {
127        DroppingThreadLocal::new()
128    }
129}
130impl<T: Send + Sync + 'static> DroppingThreadLocal<T> {
131    /// Create a new thread-local value.
132    #[inline]
133    pub fn new() -> Self {
134        DroppingThreadLocal {
135            id: UniqueLocalId::alloc(),
136            marker: PhantomData,
137        }
138    }
139    /// Get the value associated with the current thread,
140    /// or `None` if not initialized.
141    #[inline]
142    pub fn get(&self) -> Option<SharedRef<T>> {
143        THREAD_STATE.with(|thread| {
144            Some(SharedRef {
145                thread_id: thread.id,
146                value: thread.get(self.id)?.downcast::<T>().expect("unexpected type"),
147            })
148        })
149    }
150    /// Set the value associated with this thread,
151    /// returning an `Err(existing_val)` if already initialized,
152    /// and a reference to the new value if successful.
153    pub fn set(&self, value: T) -> Result<SharedRef<T>, SharedRef<T>> {
154        if let Some(existing) = self.get() {
155            Err(existing)
156        } else {
157            THREAD_STATE.with(|thread| {
158                let new_value = Arc::new(value) as DynArc;
159                thread.init(self.id, &new_value);
160                Ok(SharedRef {
161                    thread_id: thread.id,
162                    value: new_value.downcast::<T>().unwrap(),
163                })
164            })
165        }
166    }
167    /// Get the value associated with the current thread,
168    /// initializing it if not yet defined.
169    ///
170    /// Panics if double initialization is detected.
171    pub fn get_or_init(&self, func: impl FnOnce() -> T) -> SharedRef<T> {
172        match self.get_or_try_init::<core::convert::Infallible>(|| Ok(func())) {
173            Ok(success) => success,
174        }
175    }
176
177    /// Get the value associated with the current thread,
178    /// attempting to initialize it if not yet defined.
179    ///
180    /// Panics if double initialization is detected.
181    pub fn get_or_try_init<E>(&self, func: impl FnOnce() -> Result<T, E>) -> Result<SharedRef<T>, E> {
182        THREAD_STATE.with(|thread| {
183            let value = match thread.get(self.id) {
184                Some(existing) => existing,
185                None => {
186                    let new_value = Arc::new(func()?) as DynArc;
187                    thread.init(self.id, &new_value);
188                    new_value
189                }
190            };
191            let value = value.downcast::<T>().expect("unexpected type");
192            Ok(SharedRef {
193                thread_id: thread.id,
194                value,
195            })
196        })
197    }
198    /// Iterate over currently live values and their associated thread ids.
199    ///
200    /// New threads that have been spanned after the snapshot was taken will not be present
201    /// in the iterator.
202    /// Threads that die after the snapshot is taken may or may not be present.
203    /// Values from threads that die before the snapshot will not be present.
204    ///
205    /// The order of the iteration is undefined.
206    pub fn snapshot_iter(&self) -> SnapshotIter<T> {
207        let Some(snapshot) = snapshot_live_threads() else {
208            return SnapshotIter {
209                local_id: self.id,
210                iter: None,
211                marker: PhantomData,
212            };
213        };
214        SnapshotIter {
215            local_id: self.id,
216            iter: Some(snapshot.into_iter()),
217            marker: PhantomData,
218        }
219    }
220}
221impl<T: Send + Sync + 'static> Drop for DroppingThreadLocal<T> {
222    fn drop(&mut self) {
223        // want to drop without holding the lock
224        let Some(snapshot) = snapshot_live_threads() else {
225            // no live threads -> nothing to free
226            return;
227        };
228        // panics won't cause aborts here, there is no need for them
229        for (thread_id, thread) in snapshot {
230            if let Some(thread) = Weak::upgrade(&thread) {
231                assert_eq!(thread.id, thread_id);
232                let value: Option<DynArc> = {
233                    let mut lock = thread.values.lock();
234                    lock.remove(&self.id)
235                };
236                // drop value once lock no longer held
237                drop(value);
238            }
239        }
240    }
241}
242/// Iterates over a snapshot of the values,
243/// given by [`DroppingThreadLocal::snapshot_iter`].`
244///
245/// Due to thread death, it is not possible to know the exact size of the iterator.
246pub struct SnapshotIter<T: Send + Sync + 'static> {
247    local_id: UniqueLocalId,
248    iter: Option<imbl::hashmap::ConsumingIter<(ThreadId, Weak<LiveThreadState>), imbl::shared_ptr::DefaultSharedPtr>>,
249    // do not make Send+Sync, for flexibility in the future
250    marker: PhantomData<Rc<T>>,
251}
252impl<T: Send + Sync + 'static> Iterator for SnapshotIter<T> {
253    type Item = SharedRef<T>;
254
255    #[inline]
256    fn next(&mut self) -> Option<Self::Item> {
257        loop {
258            // None from either of these means we have nothing left to iterate
259            let (thread_id, thread) = self.iter.as_mut()?.next()?;
260            let Some(thread) = Weak::upgrade(&thread) else { continue };
261            let Some(arc) = ({
262                let lock = thread.values.lock();
263                lock.get(&self.local_id).cloned()
264            }) else {
265                continue;
266            };
267            return Some(SharedRef {
268                thread_id,
269                value: arc.downcast::<T>().expect("mismatched type"),
270            });
271        }
272    }
273
274    #[inline]
275    fn size_hint(&self) -> (usize, Option<usize>) {
276        match self.iter {
277            Some(ref iter) => {
278                // may be zero if all threads die
279                (0, Some(iter.len()))
280            }
281            None => (0, Some(0)),
282        }
283    }
284}
285impl<T: Send + Sync + 'static> core::iter::FusedIterator for SnapshotIter<T> {}
286
287/// A shared reference to a thread local value.
288///
289/// This may be cloned and sent across threads.
290/// May delay destruction of value past thread death.
291#[derive(Clone, Debug)]
292pub struct SharedRef<T> {
293    thread_id: ThreadId,
294    value: Arc<T>,
295}
296impl<T> SharedRef<T> {
297    /// The thread id the value was
298    #[inline]
299    pub fn thread_id(this: &Self) -> ThreadId {
300        this.thread_id
301    }
302}
303
304impl<T> Deref for SharedRef<T> {
305    type Target = T;
306
307    #[inline]
308    fn deref(&self) -> &Self::Target {
309        &self.value
310    }
311}
312impl<T> AsRef<T> for SharedRef<T> {
313    #[inline]
314    fn as_ref(&self) -> &T {
315        &self.value
316    }
317}
318impl<T: Hash> Hash for SharedRef<T> {
319    #[inline]
320    fn hash<H: Hasher>(&self, state: &mut H) {
321        self.value.hash(state);
322    }
323}
324impl<T: Eq> Eq for SharedRef<T> {}
325impl<T: PartialEq> PartialEq for SharedRef<T> {
326    #[inline]
327    fn eq(&self, other: &Self) -> bool {
328        self.value == other.value
329    }
330}
331impl<T: PartialOrd> PartialOrd for SharedRef<T> {
332    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
333        T::partial_cmp(&self.value, &other.value)
334    }
335}
336impl<T: Ord> Ord for SharedRef<T> {
337    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
338        T::cmp(&self.value, &other.value)
339    }
340}
341
342struct UniqueIdAllocator {
343    next_id: AtomicU64,
344}
345impl UniqueIdAllocator {
346    const fn new() -> Self {
347        UniqueIdAllocator {
348            next_id: AtomicU64::new(1),
349        }
350    }
351    fn alloc(&self) -> NonZero<u64> {
352        NonZero::new(
353            self.next_id
354                .fetch_update(Ordering::AcqRel, Ordering::Acquire, |x| x.checked_add(1))
355                .expect("id overflow"),
356        )
357        .unwrap()
358    }
359}
360#[derive(Copy, Clone, Debug, Eq, PartialOrd, PartialEq, Hash)]
361struct UniqueLocalId(NonZero<u64>);
362impl UniqueLocalId {
363    fn alloc() -> Self {
364        static ALLOCATOR: UniqueIdAllocator = UniqueIdAllocator::new();
365        UniqueLocalId(ALLOCATOR.alloc())
366    }
367}
368
369type LiveThreadMap = imbl::GenericHashMap<
370    ThreadId,
371    Weak<LiveThreadState>,
372    foldhash::fast::RandomState,
373    imbl::shared_ptr::DefaultSharedPtr,
374>;
375/// Map of currently live threads.
376///
377/// This is a persistent map to allow quick snapshots to be taken by drop function and iteration/
378/// I use `imbl` instead of `rpds`, because `rpds` doesn't support an owned iterator,
379/// only a borrowed iterator which would require an extra allocation.
380/// I use a hashmap instead of a `BTreeMap` because that would require `ThreadId: Ord`,
381/// which the stdlib doesn't have.
382static LIVE_THREADS: Mutex<Option<LiveThreadMap>> = Mutex::new(None);
383fn snapshot_live_threads() -> Option<LiveThreadMap> {
384    let lock = LIVE_THREADS.lock();
385    lock.as_ref().cloned()
386}
387
388thread_local! {
389    static THREAD_STATE: Arc<LiveThreadState> = {
390        let id = std::thread::current().id();
391        let state =Arc::new(LiveThreadState {
392            id,
393            values: Mutex::new(foldhash::HashMap::default()),
394        });
395        let mut live_threads = LIVE_THREADS.lock();
396        let live_threads = live_threads.get_or_insert_default();
397        use imbl::hashmap::Entry;
398        match live_threads.entry(id) {
399            Entry::Occupied(_) => panic!("reinitialized thread"),
400            Entry::Vacant(entry) => {
401                entry.insert(Arc::downgrade(&state));
402            }
403        }
404        state
405    };
406}
407type DynArc = Arc<dyn Any + Send + Sync + 'static>;
408struct LiveThreadState {
409    id: ThreadId,
410    /// Maps from local ids to values.
411    ///
412    /// ## Performance
413    /// This is noticeably slower than what `thread_local` offers.
414    ///
415    /// We could make it faster if we used a vector
416    /// A `boxcar::Vec` is essentially the same data structure as `thread_local` uses,
417    /// and would allow us to get rid of the lock.
418    /// The only difference is we would be mapping from local ids -> values,
419    /// However, the lock is uncontented and parking_lot makes that relatively cheap,
420    /// so a `Mutex<Vec<T>> might also work and be simpler.
421    /// To avoid unbounded memory usage if locals are constantly being allocated/drops,
422    /// this would require reusing indexes.
423    values: Mutex<foldhash::HashMap<UniqueLocalId, DynArc>>,
424}
425impl LiveThreadState {
426    #[inline]
427    fn get(&self, id: UniqueLocalId) -> Option<DynArc> {
428        let lock = self.values.lock();
429        Some(Arc::clone(lock.get(&id)?))
430    }
431    // the arc has dynamic type to avoid monomorphization
432    #[cold]
433    #[inline(never)]
434    fn init(&self, id: UniqueLocalId, new_value: &DynArc) {
435        let mut lock = self.values.lock();
436        use std::collections::hash_map::Entry;
437        match lock.entry(id) {
438            Entry::Occupied(_) => {
439                panic!("unexpected double initialization of thread-local value")
440            }
441            Entry::Vacant(entry) => {
442                entry.insert(Arc::clone(new_value));
443            }
444        }
445    }
446}
447impl Drop for LiveThreadState {
448    fn drop(&mut self) {
449        // clear all our values
450        self.values.get_mut().clear();
451        // remove from the list of live threads
452        {
453            let mut threads = LIVE_THREADS.lock();
454            if let Some(threads) = threads.as_mut() {
455                // no fear of dropping while locked because we control the type
456                let state: Option<Weak<LiveThreadState>> = threads.remove(&self.id);
457                drop(state)
458            }
459        }
460    }
461}