dropping_thread_local/
lib.rs

1#![deny(
2    // currently there is no unsafe code
3    unsafe_code,
4    // public library should have docs
5    missing_docs,
6)]
7//! Dynamically allocated thread locals that properly run destructors when a thread is destroyed.
8//!
9//! This is in contrast to the [`thread_local`] crate, which has similar functionality,
10//! but only runs destructors when the `ThreadLocal` object is dropped.
11//! This crate guarantees that one thread will never see the thread-local data of another,
12//! which can happen in the `thread_local` crate due to internal storage reuse.
13//!
14//! This crate attempts to implement "true" thread locals,
15//! mirroring [`std::thread_local!`] as closely as possible.
16//! I would say the `thread_local` crate is good for functionality like reusing allocations
17//! or for having local caches that can be sensibly reused once a thread dies.
18//!
19//! This crate will attempt to run destructors as promptly as possible,
20//! but taking snapshots may interfere with this (see below).
21//! Panics in thread destructors will cause aborts, just like they do with [`std::thread_local!`].
22//!
23//! Right now, this crate has no unsafe code.
24//! This may change if it can bring a significant performance improvement.
25//!
26//! # Snapshots
27//! The most complicated feature of this library is snapshots.
28//! It allows anyone who has access to a [`DroppingThreadLocal`] to iterate over all currently live
29//! values using the [`DroppingThreadLocal::snapshot_iter`] method.
30//!
31//! This will return a snapshot of the live values at the time the method is called,
32//! although if a thread dies during iteration, it may not show up.
33//! See the method documentation for more details.
34//!
35//! # Performance
36//! Benchmarks show that lookup is 10x and 30x slower than the [`thread_local`] crate, which is in turn about 2x slower than [`std::thread_local!`].
37//!
38//! Keep in mind that using a `std::thread_local` is a very fast operation.
39//! It takes about 0.5 nanoseconds on both my M1 Mac and a Intel i5 from 2017.
40//! For reference, calling [`Arc::clone`] takes about 11 ns on both machines.
41//!
42//! See `performance.md` in the repository root for benchmarks results and more detailed
43//! performance notes.
44//!
45//! ## Locking
46//! The implementation needs to acquire a global lock to initialize/deinitialize threads and create new locals.
47//! Accessing thread-local data is also protected by a per-thread lock.
48//! This lock should be uncontended, and [`parking_lot::Mutex`] should make this relatively fast.
49//! I have been careful to make sure that locks are not held while user code is being executed.
50//! This includes releasing locks before any destructors are executed.
51//!
52//! # Limitations
53//! The type that is stored must be `Send + Sync + 'static`.
54//! The `Send` bound is necessary because the [`DroppingThreadLocal`] may be dropped from any thread.
55//! The `Sync` bound is necessary to support snapshots,
56//! and the `'static` bound is due to internal implementation chooses (use of safe code).
57//!
58//! A Mutex can be used to work around the `Sync` limitation.
59//! (I recommend [`parking_lot::Mutex`], which is optimized for uncontented locks)
60//! You can attempt to use the [`fragile`] crate to work around the `Send` limitation,
61//! but this will cause panics if the value is dropped from another thead.
62//! Some ways a value can be dropped from another thread if a snapshot keeps the value alive,
63//! or if the [`DroppingThreadLocal`] itself is dropped.
64//!
65//! [`thread_local`]: https://docs.rs/thread_local/1.1/thread_local/
66//! [`fragile`]: https://docs.rs/fragile/2/fragile/
67
68extern crate alloc;
69
70use alloc::rc::Rc;
71use alloc::sync::{Arc, Weak};
72use core::any::Any;
73use core::fmt::{Debug, Formatter};
74use core::hash::{Hash, Hasher};
75use core::marker::PhantomData;
76use core::ops::Deref;
77use std::thread::ThreadId;
78
79use intid::IntegerId;
80use parking_lot::Mutex;
81
82use crate::local_ids::{LiveLocalId, OwnedLocalId};
83
84mod local_ids;
85
86/// A thread local that drops its value when the thread is destroyed.
87///
88/// See module-level documentation for more details.
89///
90/// Dropping this value will free all the associated values.
91pub struct DroppingThreadLocal<T: Send + Sync + 'static> {
92    id: OwnedLocalId,
93    marker: PhantomData<Arc<T>>,
94}
95impl<T: Send + Sync + 'static + Debug> Debug for DroppingThreadLocal<T> {
96    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
97        let value = self.get();
98        f.debug_struct("DroppingThreadLocal")
99            .field("local_data", &value.as_ref().map(|value| value.as_ref()))
100            .finish()
101    }
102}
103impl<T: Send + Sync + 'static> Default for DroppingThreadLocal<T> {
104    #[inline]
105    fn default() -> Self {
106        DroppingThreadLocal::new()
107    }
108}
109impl<T: Send + Sync + 'static> DroppingThreadLocal<T> {
110    /// Create a new thread-local value.
111    #[inline]
112    pub fn new() -> Self {
113        DroppingThreadLocal {
114            id: OwnedLocalId::alloc(),
115            marker: PhantomData,
116        }
117    }
118    /// Get the value associated with the current thread,
119    /// or `None` if not initialized.
120    #[inline]
121    pub fn get(&self) -> Option<SharedRef<T>> {
122        THREAD_STATE.with(|thread| {
123            Some(SharedRef {
124                thread_id: thread.id,
125                value: thread.get(self.id.id())?.downcast::<T>().expect("unexpected type"),
126            })
127        })
128    }
129    /// Set the value associated with this thread,
130    /// returning an `Err(existing_val)` if already initialized,
131    /// and a reference to the new value if successful.
132    pub fn set(&self, value: T) -> Result<SharedRef<T>, SharedRef<T>> {
133        if let Some(existing) = self.get() {
134            Err(existing)
135        } else {
136            THREAD_STATE.with(|thread| {
137                let new_value = Arc::new(value) as DynArc;
138                thread.init(&self.id, &new_value);
139                Ok(SharedRef {
140                    thread_id: thread.id,
141                    value: new_value.downcast::<T>().unwrap(),
142                })
143            })
144        }
145    }
146    /// Get the value associated with the current thread,
147    /// initializing it if not yet defined.
148    ///
149    /// Panics if double initialization is detected.
150    pub fn get_or_init(&self, func: impl FnOnce() -> T) -> SharedRef<T> {
151        match self.get_or_try_init::<core::convert::Infallible>(|| Ok(func())) {
152            Ok(success) => success,
153        }
154    }
155
156    /// Get the value associated with the current thread,
157    /// attempting to initialize it if not yet defined.
158    ///
159    /// Panics if double initialization is detected.
160    pub fn get_or_try_init<E>(&self, func: impl FnOnce() -> Result<T, E>) -> Result<SharedRef<T>, E> {
161        THREAD_STATE.with(|thread| {
162            let value = match thread.get(self.id.id()) {
163                Some(existing) => existing,
164                None => {
165                    let new_value = Arc::new(func()?) as DynArc;
166                    thread.init(&self.id, &new_value);
167                    new_value
168                }
169            };
170            let value = value.downcast::<T>().expect("unexpected type");
171            Ok(SharedRef {
172                thread_id: thread.id,
173                value,
174            })
175        })
176    }
177
178    /// Iterate over currently live values and their associated thread ids.
179    ///
180    /// New threads that have been spanned after the snapshot was taken will not be present
181    /// in the iterator.
182    /// Threads that die after the snapshot is taken may or may not be present.
183    /// Values from threads that die before the snapshot will not be present.
184    ///
185    /// The order of the iteration is undefined.
186    pub fn snapshot_iter(&self) -> SnapshotIter<T> {
187        let Some(snapshot) = snapshot_live_threads() else {
188            return SnapshotIter {
189                local_id: self.id.clone(),
190                iter: None,
191                marker: PhantomData,
192            };
193        };
194        SnapshotIter {
195            local_id: self.id.clone(),
196            iter: Some(snapshot.into_iter()),
197            marker: PhantomData,
198        }
199    }
200}
201impl<T: Send + Sync + 'static> Drop for DroppingThreadLocal<T> {
202    fn drop(&mut self) {
203        // want to drop without holding the lock
204        let Some(snapshot) = snapshot_live_threads() else {
205            // no live threads -> nothing to free
206            return;
207        };
208        // panics won't cause aborts here, there is no need for them
209        for (thread_id, thread) in snapshot {
210            if let Some(thread) = Weak::upgrade(&thread) {
211                assert_eq!(thread.id, thread_id);
212                let value: Option<DynArc> = {
213                    let mut lock = thread.values.lock();
214                    // there is an opportunity to shrink the Vec here,
215                    // but we don't do that because it could waste time growing it again
216                    match lock.get_mut(self.id.index()) {
217                        None => None, // out of bounds
218                        Some(inner) => inner.take().map(|(_, value)| value),
219                    }
220                };
221                // drop value once lock no longer held
222                drop(value);
223            }
224        }
225    }
226}
227/// Iterates over a snapshot of the values,
228/// given by [`DroppingThreadLocal::snapshot_iter`].`
229///
230/// Due to thread death, it is not possible to know the exact size of the iterator.
231pub struct SnapshotIter<T: Send + Sync + 'static> {
232    local_id: OwnedLocalId,
233    iter: Option<imbl::hashmap::ConsumingIter<(ThreadId, Weak<LiveThreadState>), imbl::shared_ptr::DefaultSharedPtr>>,
234    // do not make Send+Sync, for flexibility in the future
235    marker: PhantomData<Rc<T>>,
236}
237impl<T: Send + Sync + 'static> Iterator for SnapshotIter<T> {
238    type Item = SharedRef<T>;
239
240    #[inline]
241    fn next(&mut self) -> Option<Self::Item> {
242        loop {
243            // None from either of these means we have nothing left to iterate
244            let (thread_id, thread) = self.iter.as_mut()?.next()?;
245            let Some(thread) = Weak::upgrade(&thread) else { continue };
246            let Some(arc) = ({
247                let lock = thread.values.lock();
248                lock.get(self.local_id.index())
249                    .and_then(Option::as_ref)
250                    .map(|(_id, value)| Arc::clone(value))
251            }) else {
252                continue;
253            };
254            return Some(SharedRef {
255                thread_id,
256                value: arc.downcast::<T>().expect("mismatched type"),
257            });
258        }
259    }
260
261    #[inline]
262    fn size_hint(&self) -> (usize, Option<usize>) {
263        match self.iter {
264            Some(ref iter) => {
265                // may be zero if all threads die
266                (0, Some(iter.len()))
267            }
268            None => (0, Some(0)),
269        }
270    }
271}
272impl<T: Send + Sync + 'static> core::iter::FusedIterator for SnapshotIter<T> {}
273
274/// A shared reference to a thread local value.
275///
276/// This may be cloned and sent across threads.
277/// May delay destruction of value past thread death.
278#[derive(Clone, Debug)]
279pub struct SharedRef<T> {
280    thread_id: ThreadId,
281    value: Arc<T>,
282}
283impl<T> SharedRef<T> {
284    /// The thread id the value was
285    #[inline]
286    pub fn thread_id(this: &Self) -> ThreadId {
287        this.thread_id
288    }
289}
290
291impl<T> Deref for SharedRef<T> {
292    type Target = T;
293
294    #[inline]
295    fn deref(&self) -> &Self::Target {
296        &self.value
297    }
298}
299impl<T> AsRef<T> for SharedRef<T> {
300    #[inline]
301    fn as_ref(&self) -> &T {
302        &self.value
303    }
304}
305impl<T: Hash> Hash for SharedRef<T> {
306    #[inline]
307    fn hash<H: Hasher>(&self, state: &mut H) {
308        self.value.hash(state);
309    }
310}
311impl<T: Eq> Eq for SharedRef<T> {}
312impl<T: PartialEq> PartialEq for SharedRef<T> {
313    #[inline]
314    fn eq(&self, other: &Self) -> bool {
315        self.value == other.value
316    }
317}
318impl<T: PartialOrd> PartialOrd for SharedRef<T> {
319    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
320        T::partial_cmp(&self.value, &other.value)
321    }
322}
323impl<T: Ord> Ord for SharedRef<T> {
324    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
325        T::cmp(&self.value, &other.value)
326    }
327}
328
329type LiveThreadMap = imbl::GenericHashMap<
330    ThreadId,
331    Weak<LiveThreadState>,
332    foldhash::fast::RandomState,
333    imbl::shared_ptr::DefaultSharedPtr,
334>;
335/// Map of currently live threads.
336///
337/// This is a persistent map to allow quick snapshots to be taken by drop function and iteration/
338/// I use `imbl` instead of `rpds`, because `rpds` doesn't support an owned iterator,
339/// only a borrowed iterator which would require an extra allocation.
340/// I use a hashmap instead of a `BTreeMap` because that would require `ThreadId: Ord`,
341/// which the stdlib doesn't have.
342static LIVE_THREADS: Mutex<Option<LiveThreadMap>> = Mutex::new(None);
343fn snapshot_live_threads() -> Option<LiveThreadMap> {
344    let lock = LIVE_THREADS.lock();
345    lock.as_ref().cloned()
346}
347
348thread_local! {
349    static THREAD_STATE: Arc<LiveThreadState> = {
350        let id = std::thread::current().id();
351        let state = Arc::new(LiveThreadState {
352            id,
353            values: Mutex::new(Vec::new()),
354        });
355        let mut live_threads = LIVE_THREADS.lock();
356        let live_threads = live_threads.get_or_insert_default();
357        use imbl::hashmap::Entry;
358        match live_threads.entry(id) {
359            Entry::Occupied(_) => panic!("reinitialized thread"),
360            Entry::Vacant(entry) => {
361                entry.insert(Arc::downgrade(&state));
362            }
363        }
364        state
365    };
366}
367type DynArc = Arc<dyn Any + Send + Sync + 'static>;
368struct LiveThreadState {
369    id: ThreadId,
370    /// Maps from local ids to values.
371    ///
372    /// TODO: Cannot use the [idmap crate], because it doesn't support owned keys (issue [DuckLogic/intid#2]).
373    ///
374    /// ## Performance
375    /// Surprisingly, using a vector is not much faster than using a hashmap.
376    /// Reusing ids should avoid growing the vector too much,
377    /// so memory is only proportional to the peak number of live ids.
378    ///
379    /// A [`boxcar::Vec`] is essentially the same data structure as `thread_local` uses.
380    /// However it does not allow modifying entries that already exist,
381    /// making it unusable for our purposes without per-element synchronization.
382    /// Since the lock is thread-local, it should be uncontented.
383    /// Using parking_lot makes that relatively cheap,
384    /// only requiring an acquire CAS and release store in the common case.
385    /// By comparison `boxcar::Vec::get` only reduces this to a single an acquire load.
386    /// So even if we could use it, performance would not be much better.
387    ///
388    /// [`boxcar::Vec`]: https://docs.rs/boxcar/latest/boxcar/struct.Vec.html
389    values: Mutex<Vec<Option<(OwnedLocalId, DynArc)>>>,
390}
391impl LiveThreadState {
392    #[inline]
393    fn get(&self, id: LiveLocalId) -> Option<DynArc> {
394        let lock = self.values.lock();
395        Some(Arc::clone(&lock.get(id.to_int())?.as_ref()?.1))
396    }
397    // the arc has dynamic type to avoid monomorphization
398    #[cold]
399    #[inline(never)]
400    fn init(&self, id: &OwnedLocalId, new_value: &DynArc) {
401        let mut lock = self.values.lock();
402        let index = id.index();
403        match lock.get(index).and_then(Option::as_ref) {
404            Some(_existing) => {
405                panic!("unexpected double initialization of thread-local value")
406            }
407            None => {
408                // Unlike DirectIdMap::insert, I don't care whether growth is amortized here,
409                // because this is the cold path
410                while lock.len() <= index {
411                    lock.push(None)
412                }
413                lock[index] = Some((id.clone(), Arc::clone(new_value)));
414            }
415        }
416    }
417}
418impl Drop for LiveThreadState {
419    fn drop(&mut self) {
420        // clear all our values
421        self.values.get_mut().clear();
422        // remove from the list of live threads
423        {
424            let mut threads = LIVE_THREADS.lock();
425            if let Some(threads) = threads.as_mut() {
426                // no fear of dropping while locked because we control the type
427                let state: Option<Weak<LiveThreadState>> = threads.remove(&self.id);
428                drop(state)
429            }
430        }
431    }
432}