dropping_thread_local/
lib.rs

1#![deny(
2    // currently there is no unsafe code
3    unsafe_code,
4    // public library should have docs
5    missing_docs,
6)]
7//! Dynamically allocated thread locals that properly run destructors when a thread is destroyed.
8//!
9//! This is in contrast to the [`thread_local`] crate, which has similar functionality,
10//! but only runs destructors when the `ThreadLocal` object is dropped.
11//! This crate guarantees that one thread will never see the thread-local data of another,
12//! which can happen in the `thread_local` crate due to internal storage reuse.
13//!
14//! This crate attempts to implement "true" thread locals,
15//! mirroring [`std::thread_local!`] as closely as possible.
16//! I would say the `thread_local` crate is good for functionality like reusing allocations
17//! or for having local caches that can be sensibly reused once a thread dies.
18//!
19//! This crate will attempt to run destructors as promptly as possible,
20//! but taking snapshots may interfere with this (see below).
21//! Panics in thread destructors will cause aborts, just like they do with [`std::thread_local!`].
22//!
23//! Right now, this crate has no unsafe code.
24//! This may change if it can bring a significant performance improvement.
25//!
26//! # Snapshots
27//! The most complicated feature of this library is snapshots.
28//! It allows anyone who has access to a [`DroppingThreadLocal`] to iterate over all currently live
29//! values using the [`DroppingThreadLocal::snapshot_iter`] method.
30//!
31//! This will return a snapshot of the live values at the time the method is called,
32//! although if a thread dies during iteration, it may not show up.
33//! See the method documentation for more details.
34//!
35//! # Performance
36//! I expect the current implementation to be noticeably slower than either
37//! [`std::thread_local!`] or the [`thread_local`] crate.
38//! I have not done any benchmarks to compare the performance.
39//!
40//! ## Locking
41//! The implementation needs to acquire a global lock to initialize/deinitialize threads and create new locals.
42//! Accessing thread-local data is also protected by a per-thread lock.
43//! This lock should be uncontended, and [`parking_lot::Mutex`] should make this relatively fast.
44//! I have been careful to make sure that locks are not held while user code is being executed.
45//! This includes releasing locks before any destructors are executed.
46//!
47//! # Limitations
48//! The type that is stored must be `Send + Sync + 'static`.
49//! The `Send` bound is necessary because the [`DroppingThreadLocal`] may be dropped from any thread.
50//! The `Sync` bound is necessary to support snapshots,
51//! and the `'static` bound is due to internal implementation chooses (use of safe code).
52//!
53//! A Mutex can be used to work around the `Sync` limitation.
54//! (I recommend [`parking_lot::Mutex`], which is optimized for uncontented locks)
55//! You can attempt to use the [`fragile`] crate to work around the `Send` limitation,
56//! but this will cause panics if the value is dropped from another thead.
57//! Some ways a value can be dropped from another thread if a snapshot keeps the value alive,
58//! or if the [`DroppingThreadLocal`] itself is dropped.
59//!
60//! [`thread_local`]: https://docs.rs/thread_local/1.1/thread_local/
61//! [`fragile`]: https://docs.rs/fragile/2/fragile/
62
63extern crate alloc;
64extern crate core;
65
66use alloc::rc::Rc;
67use alloc::sync::{Arc, Weak};
68use core::any::Any;
69use core::fmt::{Debug, Formatter};
70use core::hash::{Hash, Hasher};
71use core::marker::PhantomData;
72use core::num::NonZero;
73use core::ops::Deref;
74use core::sync::atomic::Ordering;
75use std::thread::ThreadId;
76
77use parking_lot::Mutex;
78use portable_atomic::AtomicU64;
79
80/// A thread local that drops its value when the thread is destroyed.
81///
82/// See module-level documentation for more details.
83///
84/// Dropping this value will free all the associated values.
85pub struct DroppingThreadLocal<T: Send + Sync + 'static> {
86    id: UniqueLocalId,
87    marker: PhantomData<Arc<T>>,
88}
89impl<T: Send + Sync + 'static + Debug> Debug for DroppingThreadLocal<T> {
90    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
91        let value = self.get();
92        f.debug_struct("DroppingThreadLocal")
93            .field("local_data", &value.as_ref().map(|value| value.as_ref()))
94            .finish()
95    }
96}
97impl<T: Send + Sync + 'static> Default for DroppingThreadLocal<T> {
98    #[inline]
99    fn default() -> Self {
100        DroppingThreadLocal::new()
101    }
102}
103impl<T: Send + Sync + 'static> DroppingThreadLocal<T> {
104    /// Create a new thread-local value.
105    #[inline]
106    pub fn new() -> Self {
107        DroppingThreadLocal {
108            id: UniqueLocalId::alloc(),
109            marker: PhantomData,
110        }
111    }
112    /// Get the value associated with the current thread,
113    /// or `None` if not initialized.
114    #[inline]
115    pub fn get(&self) -> Option<SharedRef<T>> {
116        THREAD_STATE.with(|thread| {
117            Some(SharedRef {
118                thread_id: thread.id,
119                value: thread.get(self.id)?.downcast::<T>().expect("unexpected type"),
120            })
121        })
122    }
123    /// Get the value associated with the current thread,
124    /// initializing it if not yet defined.
125    #[inline]
126    pub fn get_or_init(&self, func: impl FnOnce() -> T) -> SharedRef<T> {
127        THREAD_STATE.with(|thread| {
128            let value = thread
129                .get(self.id)
130                .unwrap_or_else(|| {
131                    let mut func = Some(func);
132                    thread.init(self.id, &mut || {
133                        let func = func.take().unwrap();
134                        Arc::new(func()) as Arc<dyn Any + Send + Sync + 'static>
135                    })
136                })
137                .downcast::<T>()
138                .expect("unexpected type");
139            SharedRef {
140                thread_id: thread.id,
141                value,
142            }
143        })
144    }
145    /// Iterate over currently live values and their associated thread ids.
146    ///
147    /// New threads that have been spanned after the snapshot was taken will not be present
148    /// in the iterator.
149    /// Threads that die after the snapshot is taken may or may not be present.
150    /// Values from threads that die before the snapshot will not be present.
151    ///
152    /// The order of the iteration is undefined.
153    pub fn snapshot_iter(&self) -> SnapshotIter<T> {
154        let Some(snapshot) = snapshot_live_threads() else {
155            return SnapshotIter {
156                local_id: self.id,
157                iter: None,
158                marker: PhantomData,
159            };
160        };
161        SnapshotIter {
162            local_id: self.id,
163            iter: Some(snapshot.into_iter()),
164            marker: PhantomData,
165        }
166    }
167}
168impl<T: Send + Sync + 'static> Drop for DroppingThreadLocal<T> {
169    fn drop(&mut self) {
170        // want to drop without holding the lock
171        let Some(snapshot) = snapshot_live_threads() else {
172            // no live threads -> nothing to free
173            return;
174        };
175        // panics won't cause aborts here, there is no need for them
176        for (thread_id, thread) in snapshot {
177            if let Some(thread) = Weak::upgrade(&thread) {
178                assert_eq!(thread.id, thread_id);
179                let value: Option<DynArc> = {
180                    let mut lock = thread.values.lock();
181                    lock.remove(&self.id)
182                };
183                // drop value once lock no longer held
184                drop(value);
185            }
186        }
187    }
188}
189/// Iterates over a snapshot of the values,
190/// given by [`DroppingThreadLocal::snapshot_iter`].`
191///
192/// Due to thread death, it is not possible to know the exact size of the iterator.
193pub struct SnapshotIter<T: Send + Sync + 'static> {
194    local_id: UniqueLocalId,
195    iter: Option<imbl::hashmap::ConsumingIter<(ThreadId, Weak<LiveThreadState>), imbl::shared_ptr::DefaultSharedPtr>>,
196    // do not make Send+Sync, for flexibility in the future
197    marker: PhantomData<Rc<T>>,
198}
199impl<T: Send + Sync + 'static> Iterator for SnapshotIter<T> {
200    type Item = SharedRef<T>;
201
202    #[inline]
203    fn next(&mut self) -> Option<Self::Item> {
204        loop {
205            // None from either of these means we have nothing left to iterate
206            let (thread_id, thread) = self.iter.as_mut()?.next()?;
207            let Some(thread) = Weak::upgrade(&thread) else { continue };
208            let Some(arc) = ({
209                let lock = thread.values.lock();
210                lock.get(&self.local_id).cloned()
211            }) else {
212                continue;
213            };
214            return Some(SharedRef {
215                thread_id,
216                value: arc.downcast::<T>().expect("mismatched type"),
217            });
218        }
219    }
220
221    #[inline]
222    fn size_hint(&self) -> (usize, Option<usize>) {
223        match self.iter {
224            Some(ref iter) => {
225                // may be zero if all threads die
226                (0, Some(iter.len()))
227            }
228            None => (0, Some(0)),
229        }
230    }
231}
232impl<T: Send + Sync + 'static> core::iter::FusedIterator for SnapshotIter<T> {}
233
234/// A shared reference to a thread local value.
235///
236/// This may be cloned and sent across threads.
237/// May delay destruction of value past thread death.
238#[derive(Clone, Debug)]
239pub struct SharedRef<T> {
240    thread_id: ThreadId,
241    value: Arc<T>,
242}
243impl<T> SharedRef<T> {
244    /// The thread id the value was
245    #[inline]
246    pub fn thread_id(this: &Self) -> ThreadId {
247        this.thread_id
248    }
249}
250
251impl<T> Deref for SharedRef<T> {
252    type Target = T;
253
254    #[inline]
255    fn deref(&self) -> &Self::Target {
256        &self.value
257    }
258}
259impl<T> AsRef<T> for SharedRef<T> {
260    #[inline]
261    fn as_ref(&self) -> &T {
262        &self.value
263    }
264}
265impl<T: Hash> Hash for SharedRef<T> {
266    #[inline]
267    fn hash<H: Hasher>(&self, state: &mut H) {
268        self.value.hash(state);
269    }
270}
271impl<T: Eq> Eq for SharedRef<T> {}
272impl<T: PartialEq> PartialEq for SharedRef<T> {
273    #[inline]
274    fn eq(&self, other: &Self) -> bool {
275        self.value == other.value
276    }
277}
278impl<T: PartialOrd> PartialOrd for SharedRef<T> {
279    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
280        T::partial_cmp(&self.value, &other.value)
281    }
282}
283impl<T: Ord> Ord for SharedRef<T> {
284    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
285        T::cmp(&self.value, &other.value)
286    }
287}
288
289struct UniqueIdAllocator {
290    next_id: AtomicU64,
291}
292impl UniqueIdAllocator {
293    const fn new() -> Self {
294        UniqueIdAllocator {
295            next_id: AtomicU64::new(1),
296        }
297    }
298    fn alloc(&self) -> NonZero<u64> {
299        NonZero::new(
300            self.next_id
301                .fetch_update(Ordering::AcqRel, Ordering::Acquire, |x| x.checked_add(1))
302                .expect("id overflow"),
303        )
304        .unwrap()
305    }
306}
307#[derive(Copy, Clone, Debug, Eq, PartialOrd, PartialEq, Hash)]
308struct UniqueLocalId(NonZero<u64>);
309impl UniqueLocalId {
310    fn alloc() -> Self {
311        static ALLOCATOR: UniqueIdAllocator = UniqueIdAllocator::new();
312        UniqueLocalId(ALLOCATOR.alloc())
313    }
314}
315
316type LiveThreadMap = imbl::GenericHashMap<
317    ThreadId,
318    Weak<LiveThreadState>,
319    foldhash::fast::RandomState,
320    imbl::shared_ptr::DefaultSharedPtr,
321>;
322/// Map of currently live threads.
323///
324/// This is a persistent map to allow quick snapshots to be taken by drop function and iteration/
325/// I use `imbl` instead of `rpds`, because `rpds` doesn't support an owned iterator,
326/// only a borrowed iterator which would require an extra allocation.
327/// I use a hashmap instead of a `BTreeMap` because that would require `ThreadId: Ord`,
328/// which the stdlib doesn't have.
329static LIVE_THREADS: Mutex<Option<LiveThreadMap>> = Mutex::new(None);
330fn snapshot_live_threads() -> Option<LiveThreadMap> {
331    let lock = LIVE_THREADS.lock();
332    lock.as_ref().cloned()
333}
334
335thread_local! {
336    static THREAD_STATE: Arc<LiveThreadState> = {
337        let id = std::thread::current().id();
338        let state =Arc::new(LiveThreadState {
339            id,
340            values: Mutex::new(foldhash::HashMap::default()),
341        });
342        let mut live_threads = LIVE_THREADS.lock();
343        let live_threads = live_threads.get_or_insert_default();
344        use imbl::hashmap::Entry;
345        match live_threads.entry(id) {
346            Entry::Occupied(_) => panic!("reinitialized thread"),
347            Entry::Vacant(entry) => {
348                entry.insert(Arc::downgrade(&state));
349            }
350        }
351        state
352    };
353}
354type DynArc = Arc<dyn Any + Send + Sync + 'static>;
355struct LiveThreadState {
356    id: ThreadId,
357    /// Maps from local ids to values.
358    ///
359    /// ## Performance
360    /// This is noticeably slower than what `thread_local`.
361    ///
362    /// We could make it faster if we used a vector
363    /// A `boxcar::Vec` is essentially the same data structure as `thread_local` uses,
364    /// and would allow us to get rid of the lock.
365    /// The only difference is we would be mapping from local ids -> values,
366    /// However, the lock is uncontented and parking_lot makes that relatively cheap,
367    /// so a `Mutex<Vec<T>> might also work and be simpler.
368    /// To avoid unbounded memory usage if locals are constantly being allocated/drops,
369    /// this would require reusing indexes.
370    values: Mutex<foldhash::HashMap<UniqueLocalId, DynArc>>,
371}
372impl LiveThreadState {
373    #[inline]
374    fn get(&self, id: UniqueLocalId) -> Option<DynArc> {
375        let lock = self.values.lock();
376        Some(Arc::clone(lock.get(&id)?))
377    }
378    // the arc has dynamic type to avoid monomorphization
379    #[cold]
380    #[inline(never)]
381    fn init(&self, id: UniqueLocalId, init: &mut dyn FnMut() -> DynArc) -> DynArc {
382        let new_value = init();
383        let mut lock = self.values.lock();
384        use std::collections::hash_map::Entry;
385        match lock.entry(id) {
386            Entry::Occupied(_) => {
387                panic!("unexpected double initialization of thread-local value")
388            }
389            Entry::Vacant(entry) => {
390                entry.insert(Arc::clone(&new_value));
391            }
392        }
393        new_value
394    }
395}
396impl Drop for LiveThreadState {
397    fn drop(&mut self) {
398        // clear all our values
399        self.values.get_mut().clear();
400        // remove from the list of live threads
401        {
402            let mut threads = LIVE_THREADS.lock();
403            if let Some(threads) = threads.as_mut() {
404                // no fear of dropping while locked because we control the type
405                let state: Option<Weak<LiveThreadState>> = threads.remove(&self.id);
406                drop(state)
407            }
408        }
409    }
410}