dropping_thread_local/lib.rs
1#![deny(
2 // currently there is no unsafe code
3 unsafe_code,
4 // public library should have docs
5 missing_docs,
6)]
7//! Dynamically allocated thread locals that properly run destructors when a thread is destroyed.
8//!
9//! This is in contrast to the [`thread_local`] crate, which has similar functionality,
10//! but only runs destructors when the `ThreadLocal` object is dropped.
11//! This crate guarantees that one thread will never see the thread-local data of another,
12//! which can happen in the `thread_local` crate due to internal storage reuse.
13//!
14//! This crate attempts to implement "true" thread locals,
15//! mirroring [`std::thread_local!`] as closely as possible.
16//! I would say the `thread_local` crate is good for functionality like reusing allocations
17//! or for having local caches that can be sensibly reused once a thread dies.
18//!
19//! This crate will attempt to run destructors as promptly as possible,
20//! but taking snapshots may interfere with this (see below).
21//! Panics in thread destructors will cause aborts, just like they do with [`std::thread_local!`].
22//!
23//! Right now, this crate has no unsafe code.
24//! This may change if it can bring a significant performance improvement.
25//!
26//! # Snapshots
27//! The most complicated feature of this library is snapshots.
28//! It allows anyone who has access to a [`DroppingThreadLocal`] to iterate over all currently live
29//! values using the [`DroppingThreadLocal::snapshot_iter`] method.
30//!
31//! This will return a snapshot of the live values at the time the method is called,
32//! although if a thread dies during iteration, it may not show up.
33//! See the method documentation for more details.
34//!
35//! # Performance
36//! I expect the current implementation to be noticeably slower than either
37//! [`std::thread_local!`] or the [`thread_local`] crate.
38//! I have not done any benchmarks to compare the performance.
39//!
40//! ## Locking
41//! The implementation needs to acquire a global lock to initialize/deinitialize threads and create new locals.
42//! Accessing thread-local data is also protected by a per-thread lock.
43//! This lock should be uncontended, and [`parking_lot::Mutex`] should make this relatively fast.
44//! I have been careful to make sure that locks are not held while user code is being executed.
45//! This includes releasing locks before any destructors are executed.
46//!
47//! # Limitations
48//! The type that is stored must be `Send + Sync + 'static`.
49//! The `Send` bound is necessary because the [`DroppingThreadLocal`] may be dropped from any thread.
50//! The `Sync` bound is necessary to support snapshots,
51//! and the `'static` bound is due to internal implementation chooses (use of safe code).
52//!
53//! A Mutex can be used to work around the `Sync` limitation.
54//! (I recommend [`parking_lot::Mutex`], which is optimized for uncontented locks)
55//! You can attempt to use the [`fragile`] crate to work around the `Send` limitation,
56//! but this will cause panics if the value is dropped from another thead.
57//! Some ways a value can be dropped from another thread if a snapshot keeps the value alive,
58//! or if the [`DroppingThreadLocal`] itself is dropped.
59//!
60//! [`thread_local`]: https://docs.rs/thread_local/1.1/thread_local/
61//! [`fragile`]: https://docs.rs/fragile/2/fragile/
62
63extern crate alloc;
64extern crate core;
65
66use alloc::rc::Rc;
67use alloc::sync::{Arc, Weak};
68use core::any::Any;
69use core::fmt::{Debug, Formatter};
70use core::hash::{Hash, Hasher};
71use core::marker::PhantomData;
72use core::num::NonZero;
73use core::ops::Deref;
74use core::sync::atomic::Ordering;
75use std::thread::ThreadId;
76
77use parking_lot::Mutex;
78use portable_atomic::AtomicU64;
79
80/// A thread local that drops its value when the thread is destroyed.
81///
82/// See module-level documentation for more details.
83///
84/// Dropping this value will free all the associated values.
85pub struct DroppingThreadLocal<T: Send + Sync + 'static> {
86 id: UniqueLocalId,
87 marker: PhantomData<Arc<T>>,
88}
89impl<T: Send + Sync + 'static + Debug> Debug for DroppingThreadLocal<T> {
90 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
91 let value = self.get();
92 f.debug_struct("DroppingThreadLocal")
93 .field("local_data", &value.as_ref().map(|value| value.as_ref()))
94 .finish()
95 }
96}
97impl<T: Send + Sync + 'static> Default for DroppingThreadLocal<T> {
98 #[inline]
99 fn default() -> Self {
100 DroppingThreadLocal::new()
101 }
102}
103impl<T: Send + Sync + 'static> DroppingThreadLocal<T> {
104 /// Create a new thread-local value.
105 #[inline]
106 pub fn new() -> Self {
107 DroppingThreadLocal {
108 id: UniqueLocalId::alloc(),
109 marker: PhantomData,
110 }
111 }
112 /// Get the value associated with the current thread,
113 /// or `None` if not initialized.
114 #[inline]
115 pub fn get(&self) -> Option<SharedRef<T>> {
116 THREAD_STATE.with(|thread| {
117 Some(SharedRef {
118 thread_id: thread.id,
119 value: thread.get(self.id)?.downcast::<T>().expect("unexpected type"),
120 })
121 })
122 }
123 /// Get the value associated with the current thread,
124 /// initializing it if not yet defined.
125 #[inline]
126 pub fn get_or_init(&self, func: impl FnOnce() -> T) -> SharedRef<T> {
127 THREAD_STATE.with(|thread| {
128 let value = thread
129 .get(self.id)
130 .unwrap_or_else(|| {
131 let mut func = Some(func);
132 thread.init(self.id, &mut || {
133 let func = func.take().unwrap();
134 Arc::new(func()) as Arc<dyn Any + Send + Sync + 'static>
135 })
136 })
137 .downcast::<T>()
138 .expect("unexpected type");
139 SharedRef {
140 thread_id: thread.id,
141 value,
142 }
143 })
144 }
145 /// Iterate over currently live values and their associated thread ids.
146 ///
147 /// New threads that have been spanned after the snapshot was taken will not be present
148 /// in the iterator.
149 /// Threads that die after the snapshot is taken may or may not be present.
150 /// Values from threads that die before the snapshot will not be present.
151 ///
152 /// The order of the iteration is undefined.
153 pub fn snapshot_iter(&self) -> SnapshotIter<T> {
154 let Some(snapshot) = snapshot_live_threads() else {
155 return SnapshotIter {
156 local_id: self.id,
157 iter: None,
158 marker: PhantomData,
159 };
160 };
161 SnapshotIter {
162 local_id: self.id,
163 iter: Some(snapshot.into_iter()),
164 marker: PhantomData,
165 }
166 }
167}
168impl<T: Send + Sync + 'static> Drop for DroppingThreadLocal<T> {
169 fn drop(&mut self) {
170 // want to drop without holding the lock
171 let Some(snapshot) = snapshot_live_threads() else {
172 // no live threads -> nothing to free
173 return;
174 };
175 // panics won't cause aborts here, there is no need for them
176 for (thread_id, thread) in snapshot {
177 if let Some(thread) = Weak::upgrade(&thread) {
178 assert_eq!(thread.id, thread_id);
179 let value: Option<DynArc> = {
180 let mut lock = thread.values.lock();
181 lock.remove(&self.id)
182 };
183 // drop value once lock no longer held
184 drop(value);
185 }
186 }
187 }
188}
189/// Iterates over a snapshot of the values,
190/// given by [`DroppingThreadLocal::snapshot_iter`].`
191///
192/// Due to thread death, it is not possible to know the exact size of the iterator.
193pub struct SnapshotIter<T: Send + Sync + 'static> {
194 local_id: UniqueLocalId,
195 iter: Option<imbl::hashmap::ConsumingIter<(ThreadId, Weak<LiveThreadState>), imbl::shared_ptr::DefaultSharedPtr>>,
196 // do not make Send+Sync, for flexibility in the future
197 marker: PhantomData<Rc<T>>,
198}
199impl<T: Send + Sync + 'static> Iterator for SnapshotIter<T> {
200 type Item = SharedRef<T>;
201
202 #[inline]
203 fn next(&mut self) -> Option<Self::Item> {
204 loop {
205 // None from either of these means we have nothing left to iterate
206 let (thread_id, thread) = self.iter.as_mut()?.next()?;
207 let Some(thread) = Weak::upgrade(&thread) else { continue };
208 let Some(arc) = ({
209 let lock = thread.values.lock();
210 lock.get(&self.local_id).cloned()
211 }) else {
212 continue;
213 };
214 return Some(SharedRef {
215 thread_id,
216 value: arc.downcast::<T>().expect("mismatched type"),
217 });
218 }
219 }
220
221 #[inline]
222 fn size_hint(&self) -> (usize, Option<usize>) {
223 match self.iter {
224 Some(ref iter) => {
225 // may be zero if all threads die
226 (0, Some(iter.len()))
227 }
228 None => (0, Some(0)),
229 }
230 }
231}
232impl<T: Send + Sync + 'static> core::iter::FusedIterator for SnapshotIter<T> {}
233
234/// A shared reference to a thread local value.
235///
236/// This may be cloned and sent across threads.
237/// May delay destruction of value past thread death.
238#[derive(Clone, Debug)]
239pub struct SharedRef<T> {
240 thread_id: ThreadId,
241 value: Arc<T>,
242}
243impl<T> SharedRef<T> {
244 /// The thread id the value was
245 #[inline]
246 pub fn thread_id(this: &Self) -> ThreadId {
247 this.thread_id
248 }
249}
250
251impl<T> Deref for SharedRef<T> {
252 type Target = T;
253
254 #[inline]
255 fn deref(&self) -> &Self::Target {
256 &self.value
257 }
258}
259impl<T> AsRef<T> for SharedRef<T> {
260 #[inline]
261 fn as_ref(&self) -> &T {
262 &self.value
263 }
264}
265impl<T: Hash> Hash for SharedRef<T> {
266 #[inline]
267 fn hash<H: Hasher>(&self, state: &mut H) {
268 self.value.hash(state);
269 }
270}
271impl<T: Eq> Eq for SharedRef<T> {}
272impl<T: PartialEq> PartialEq for SharedRef<T> {
273 #[inline]
274 fn eq(&self, other: &Self) -> bool {
275 self.value == other.value
276 }
277}
278impl<T: PartialOrd> PartialOrd for SharedRef<T> {
279 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
280 T::partial_cmp(&self.value, &other.value)
281 }
282}
283impl<T: Ord> Ord for SharedRef<T> {
284 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
285 T::cmp(&self.value, &other.value)
286 }
287}
288
289struct UniqueIdAllocator {
290 next_id: AtomicU64,
291}
292impl UniqueIdAllocator {
293 const fn new() -> Self {
294 UniqueIdAllocator {
295 next_id: AtomicU64::new(1),
296 }
297 }
298 fn alloc(&self) -> NonZero<u64> {
299 NonZero::new(
300 self.next_id
301 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |x| x.checked_add(1))
302 .expect("id overflow"),
303 )
304 .unwrap()
305 }
306}
307#[derive(Copy, Clone, Debug, Eq, PartialOrd, PartialEq, Hash)]
308struct UniqueLocalId(NonZero<u64>);
309impl UniqueLocalId {
310 fn alloc() -> Self {
311 static ALLOCATOR: UniqueIdAllocator = UniqueIdAllocator::new();
312 UniqueLocalId(ALLOCATOR.alloc())
313 }
314}
315
316type LiveThreadMap = imbl::GenericHashMap<
317 ThreadId,
318 Weak<LiveThreadState>,
319 foldhash::fast::RandomState,
320 imbl::shared_ptr::DefaultSharedPtr,
321>;
322/// Map of currently live threads.
323///
324/// This is a persistent map to allow quick snapshots to be taken by drop function and iteration/
325/// I use `imbl` instead of `rpds`, because `rpds` doesn't support an owned iterator,
326/// only a borrowed iterator which would require an extra allocation.
327/// I use a hashmap instead of a `BTreeMap` because that would require `ThreadId: Ord`,
328/// which the stdlib doesn't have.
329static LIVE_THREADS: Mutex<Option<LiveThreadMap>> = Mutex::new(None);
330fn snapshot_live_threads() -> Option<LiveThreadMap> {
331 let lock = LIVE_THREADS.lock();
332 lock.as_ref().cloned()
333}
334
335thread_local! {
336 static THREAD_STATE: Arc<LiveThreadState> = {
337 let id = std::thread::current().id();
338 let state =Arc::new(LiveThreadState {
339 id,
340 values: Mutex::new(foldhash::HashMap::default()),
341 });
342 let mut live_threads = LIVE_THREADS.lock();
343 let live_threads = live_threads.get_or_insert_default();
344 use imbl::hashmap::Entry;
345 match live_threads.entry(id) {
346 Entry::Occupied(_) => panic!("reinitialized thread"),
347 Entry::Vacant(entry) => {
348 entry.insert(Arc::downgrade(&state));
349 }
350 }
351 state
352 };
353}
354type DynArc = Arc<dyn Any + Send + Sync + 'static>;
355struct LiveThreadState {
356 id: ThreadId,
357 /// Maps from local ids to values.
358 ///
359 /// ## Performance
360 /// This is noticeably slower than what `thread_local`.
361 ///
362 /// We could make it faster if we used a vector
363 /// A `boxcar::Vec` is essentially the same data structure as `thread_local` uses,
364 /// and would allow us to get rid of the lock.
365 /// The only difference is we would be mapping from local ids -> values,
366 /// However, the lock is uncontented and parking_lot makes that relatively cheap,
367 /// so a `Mutex<Vec<T>> might also work and be simpler.
368 /// To avoid unbounded memory usage if locals are constantly being allocated/drops,
369 /// this would require reusing indexes.
370 values: Mutex<foldhash::HashMap<UniqueLocalId, DynArc>>,
371}
372impl LiveThreadState {
373 #[inline]
374 fn get(&self, id: UniqueLocalId) -> Option<DynArc> {
375 let lock = self.values.lock();
376 Some(Arc::clone(lock.get(&id)?))
377 }
378 // the arc has dynamic type to avoid monomorphization
379 #[cold]
380 #[inline(never)]
381 fn init(&self, id: UniqueLocalId, init: &mut dyn FnMut() -> DynArc) -> DynArc {
382 let new_value = init();
383 let mut lock = self.values.lock();
384 use std::collections::hash_map::Entry;
385 match lock.entry(id) {
386 Entry::Occupied(_) => {
387 panic!("unexpected double initialization of thread-local value")
388 }
389 Entry::Vacant(entry) => {
390 entry.insert(Arc::clone(&new_value));
391 }
392 }
393 new_value
394 }
395}
396impl Drop for LiveThreadState {
397 fn drop(&mut self) {
398 // clear all our values
399 self.values.get_mut().clear();
400 // remove from the list of live threads
401 {
402 let mut threads = LIVE_THREADS.lock();
403 if let Some(threads) = threads.as_mut() {
404 // no fear of dropping while locked because we control the type
405 let state: Option<Weak<LiveThreadState>> = threads.remove(&self.id);
406 drop(state)
407 }
408 }
409 }
410}