dropping_thread_local/lib.rs
1#![deny(
2 // currently there is no unsafe code
3 unsafe_code,
4 // public library should have docs
5 missing_docs,
6)]
7//! Dynamically allocated thread locals that properly run destructors when a thread is destroyed.
8//!
9//! This is in contrast to the [`thread_local`] crate, which has similar functionality,
10//! but only runs destructors when the `ThreadLocal` object is dropped.
11//! This crate guarantees that one thread will never see the thread-local data of another,
12//! which can happen in the `thread_local` crate due to internal storage reuse.
13//!
14//! This crate attempts to implement "true" thread locals,
15//! mirroring [`std::thread_local!`] as closely as possible.
16//! I would say the `thread_local` crate is good for functionality like reusing allocations
17//! or for having local caches that can be sensibly reused once a thread dies.
18//!
19//! This crate will attempt to run destructors as promptly as possible,
20//! but taking snapshots may interfere with this (see below).
21//! Panics in thread destructors will cause aborts, just like they do with [`std::thread_local!`].
22//!
23//! Right now, this crate has no unsafe code.
24//! This may change if it can bring a significant performance improvement.
25//!
26//! # Snapshots
27//! The most complicated feature of this library is snapshots.
28//! It allows anyone who has access to a [`DroppingThreadLocal`] to iterate over all currently live
29//! values using the [`DroppingThreadLocal::snapshot_iter`] method.
30//!
31//! This will return a snapshot of the live values at the time the method is called,
32//! although if a thread dies during iteration, it may not show up.
33//! See the method documentation for more details.
34//!
35//! # Performance
36//! Lookup is based around a hashmap, and I expect the current implementation to be noticeably slower than either [`std::thread_local!`] or the [`thread_local`] crate.
37//! The former is a low-cost abstraction over native thread local storage and the latter is written by the author of`hashbrown` and `parking_lot`.
38//!
39//! A very basic benchmark on my M1 Mac and Linux Laptop (`i5-7200U` circa 2017) gives the following results:
40//!
41//! | library | does `Arc::clone` | time (M1 Mac) | time (i5 ~2017 Laptop) |
42//! |-------------------------|-------------------|---------------|---------------------------|
43//! | `std` | no | 0.42 ns | 0.69 ns |
44//! | `std` | *yes* | 11.49 ns | 14.01 ns |
45//! | `thread_local` | no | 1.38 ns | 1.38 ns |
46//! | `thread_local` | *yes* | 11.43 ns | 14.02 ns |
47//! | `dropping_thread_local` | *yes* | 13.14 ns | 31.14 ns |
48//!
49//! Every lookup in the current implementation of `dropping_thread_local` requires calling `Arc::clone`.
50//! This has significant overhead in its own right, so I benchmarked the other libraries both storing their data an regular `Box` vs. storing data in an `Arc` and doing `Arc::clone`.
51//!
52//! On my Mac, the library ranges between 30% slower than calling `thread_local::ThreadLocal::get` + `Arc::clone` and 30x slower than a plain `std::thread_local!`. On my older Linux laptop, this library ranges between 3x slower than `thread_local::ThreadLocal::get` + `Arc::clone` and 60x slower than a plain `std::thread_local`.
53//!
54//! This performance is a lot better than I expected (at least on the macbook). I am also disappointed by the performance of `Arc::clone`. Further improvements beyond this will almost certainly require amount of `unsafe` code. I have three ideas for improvement:
55//!
56//! - Avoid requiring `Arc::clone` by using a [`LocalKey::with`] style API, and making `drop(DroppingThreadLocal)` delay freeing values from live threads until after that live thread dies.
57//! - Use [biased reference counting] instead of an `Arc`. This would not require `unsafe` code directly in this crate. Unfortunately, biased reference counting can delay destruction or even leak if the heartbeat function is not called. The [`trc` crate] will not work, as `trc::Trc` is `!Send`.
58//! - Using [`boxcar::Vec`] instead of a `HashMap` for lookup. This is essentially the same data structure that the `thread_local` crate uses, so should make the lookup performance similar.
59//!
60//! *NOTE*: Simply removing `Arc::clone` doesn't help that much. On the M1 mac it reduces the time to 9 ns. (I did this unsoundly, so it cant be published)
61//!
62//! [biased reference counting]: https://dl.acm.org/doi/10.1145/3243176.3243195
63//! [`LocalKey::with`]: https://doc.rust-lang.org/std/thread/struct.LocalKey.html#method.with
64//! [`boxcar::Vec`]: https://docs.rs/boxcar/0.2.13/boxcar/struct.Vec.html
65//! [`trc` crate]: https://github.com/ericlbuehler/trc
66//!
67//! ## Locking
68//! The implementation needs to acquire a global lock to initialize/deinitialize threads and create new locals.
69//! Accessing thread-local data is also protected by a per-thread lock.
70//! This lock should be uncontended, and [`parking_lot::Mutex`] should make this relatively fast.
71//! I have been careful to make sure that locks are not held while user code is being executed.
72//! This includes releasing locks before any destructors are executed.
73//!
74//! # Limitations
75//! The type that is stored must be `Send + Sync + 'static`.
76//! The `Send` bound is necessary because the [`DroppingThreadLocal`] may be dropped from any thread.
77//! The `Sync` bound is necessary to support snapshots,
78//! and the `'static` bound is due to internal implementation chooses (use of safe code).
79//!
80//! A Mutex can be used to work around the `Sync` limitation.
81//! (I recommend [`parking_lot::Mutex`], which is optimized for uncontented locks)
82//! You can attempt to use the [`fragile`] crate to work around the `Send` limitation,
83//! but this will cause panics if the value is dropped from another thead.
84//! Some ways a value can be dropped from another thread if a snapshot keeps the value alive,
85//! or if the [`DroppingThreadLocal`] itself is dropped.
86//!
87//! [`thread_local`]: https://docs.rs/thread_local/1.1/thread_local/
88//! [`fragile`]: https://docs.rs/fragile/2/fragile/
89
90extern crate alloc;
91extern crate core;
92
93use alloc::rc::Rc;
94use alloc::sync::{Arc, Weak};
95use core::any::Any;
96use core::fmt::{Debug, Formatter};
97use core::hash::{Hash, Hasher};
98use core::marker::PhantomData;
99use core::num::NonZero;
100use core::ops::Deref;
101use core::sync::atomic::Ordering;
102use std::thread::ThreadId;
103
104use parking_lot::Mutex;
105use portable_atomic::AtomicU64;
106
107/// A thread local that drops its value when the thread is destroyed.
108///
109/// See module-level documentation for more details.
110///
111/// Dropping this value will free all the associated values.
112pub struct DroppingThreadLocal<T: Send + Sync + 'static> {
113 id: UniqueLocalId,
114 marker: PhantomData<Arc<T>>,
115}
116impl<T: Send + Sync + 'static + Debug> Debug for DroppingThreadLocal<T> {
117 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
118 let value = self.get();
119 f.debug_struct("DroppingThreadLocal")
120 .field("local_data", &value.as_ref().map(|value| value.as_ref()))
121 .finish()
122 }
123}
124impl<T: Send + Sync + 'static> Default for DroppingThreadLocal<T> {
125 #[inline]
126 fn default() -> Self {
127 DroppingThreadLocal::new()
128 }
129}
130impl<T: Send + Sync + 'static> DroppingThreadLocal<T> {
131 /// Create a new thread-local value.
132 #[inline]
133 pub fn new() -> Self {
134 DroppingThreadLocal {
135 id: UniqueLocalId::alloc(),
136 marker: PhantomData,
137 }
138 }
139 /// Get the value associated with the current thread,
140 /// or `None` if not initialized.
141 #[inline]
142 pub fn get(&self) -> Option<SharedRef<T>> {
143 THREAD_STATE.with(|thread| {
144 Some(SharedRef {
145 thread_id: thread.id,
146 value: thread.get(self.id)?.downcast::<T>().expect("unexpected type"),
147 })
148 })
149 }
150 /// Set the value associated with this thread,
151 /// returning an `Err(existing_val)` if already initialized,
152 /// and a reference to the new value if successful.
153 pub fn set(&self, value: T) -> Result<SharedRef<T>, SharedRef<T>> {
154 if let Some(existing) = self.get() {
155 Err(existing)
156 } else {
157 THREAD_STATE.with(|thread| {
158 let new_value = Arc::new(value) as DynArc;
159 thread.init(self.id, &new_value);
160 Ok(SharedRef {
161 thread_id: thread.id,
162 value: new_value.downcast::<T>().unwrap(),
163 })
164 })
165 }
166 }
167 /// Get the value associated with the current thread,
168 /// initializing it if not yet defined.
169 ///
170 /// Panics if double initialization is detected.
171 pub fn get_or_init(&self, func: impl FnOnce() -> T) -> SharedRef<T> {
172 match self.get_or_try_init::<core::convert::Infallible>(|| Ok(func())) {
173 Ok(success) => success,
174 }
175 }
176
177 /// Get the value associated with the current thread,
178 /// attempting to initialize it if not yet defined.
179 ///
180 /// Panics if double initialization is detected.
181 pub fn get_or_try_init<E>(&self, func: impl FnOnce() -> Result<T, E>) -> Result<SharedRef<T>, E> {
182 THREAD_STATE.with(|thread| {
183 let value = match thread.get(self.id) {
184 Some(existing) => existing,
185 None => {
186 let new_value = Arc::new(func()?) as DynArc;
187 thread.init(self.id, &new_value);
188 new_value
189 }
190 };
191 let value = value.downcast::<T>().expect("unexpected type");
192 Ok(SharedRef {
193 thread_id: thread.id,
194 value,
195 })
196 })
197 }
198 /// Iterate over currently live values and their associated thread ids.
199 ///
200 /// New threads that have been spanned after the snapshot was taken will not be present
201 /// in the iterator.
202 /// Threads that die after the snapshot is taken may or may not be present.
203 /// Values from threads that die before the snapshot will not be present.
204 ///
205 /// The order of the iteration is undefined.
206 pub fn snapshot_iter(&self) -> SnapshotIter<T> {
207 let Some(snapshot) = snapshot_live_threads() else {
208 return SnapshotIter {
209 local_id: self.id,
210 iter: None,
211 marker: PhantomData,
212 };
213 };
214 SnapshotIter {
215 local_id: self.id,
216 iter: Some(snapshot.into_iter()),
217 marker: PhantomData,
218 }
219 }
220}
221impl<T: Send + Sync + 'static> Drop for DroppingThreadLocal<T> {
222 fn drop(&mut self) {
223 // want to drop without holding the lock
224 let Some(snapshot) = snapshot_live_threads() else {
225 // no live threads -> nothing to free
226 return;
227 };
228 // panics won't cause aborts here, there is no need for them
229 for (thread_id, thread) in snapshot {
230 if let Some(thread) = Weak::upgrade(&thread) {
231 assert_eq!(thread.id, thread_id);
232 let value: Option<DynArc> = {
233 let mut lock = thread.values.lock();
234 lock.remove(&self.id)
235 };
236 // drop value once lock no longer held
237 drop(value);
238 }
239 }
240 }
241}
242/// Iterates over a snapshot of the values,
243/// given by [`DroppingThreadLocal::snapshot_iter`].`
244///
245/// Due to thread death, it is not possible to know the exact size of the iterator.
246pub struct SnapshotIter<T: Send + Sync + 'static> {
247 local_id: UniqueLocalId,
248 iter: Option<imbl::hashmap::ConsumingIter<(ThreadId, Weak<LiveThreadState>), imbl::shared_ptr::DefaultSharedPtr>>,
249 // do not make Send+Sync, for flexibility in the future
250 marker: PhantomData<Rc<T>>,
251}
252impl<T: Send + Sync + 'static> Iterator for SnapshotIter<T> {
253 type Item = SharedRef<T>;
254
255 #[inline]
256 fn next(&mut self) -> Option<Self::Item> {
257 loop {
258 // None from either of these means we have nothing left to iterate
259 let (thread_id, thread) = self.iter.as_mut()?.next()?;
260 let Some(thread) = Weak::upgrade(&thread) else { continue };
261 let Some(arc) = ({
262 let lock = thread.values.lock();
263 lock.get(&self.local_id).cloned()
264 }) else {
265 continue;
266 };
267 return Some(SharedRef {
268 thread_id,
269 value: arc.downcast::<T>().expect("mismatched type"),
270 });
271 }
272 }
273
274 #[inline]
275 fn size_hint(&self) -> (usize, Option<usize>) {
276 match self.iter {
277 Some(ref iter) => {
278 // may be zero if all threads die
279 (0, Some(iter.len()))
280 }
281 None => (0, Some(0)),
282 }
283 }
284}
285impl<T: Send + Sync + 'static> core::iter::FusedIterator for SnapshotIter<T> {}
286
287/// A shared reference to a thread local value.
288///
289/// This may be cloned and sent across threads.
290/// May delay destruction of value past thread death.
291#[derive(Clone, Debug)]
292pub struct SharedRef<T> {
293 thread_id: ThreadId,
294 value: Arc<T>,
295}
296impl<T> SharedRef<T> {
297 /// The thread id the value was
298 #[inline]
299 pub fn thread_id(this: &Self) -> ThreadId {
300 this.thread_id
301 }
302}
303
304impl<T> Deref for SharedRef<T> {
305 type Target = T;
306
307 #[inline]
308 fn deref(&self) -> &Self::Target {
309 &self.value
310 }
311}
312impl<T> AsRef<T> for SharedRef<T> {
313 #[inline]
314 fn as_ref(&self) -> &T {
315 &self.value
316 }
317}
318impl<T: Hash> Hash for SharedRef<T> {
319 #[inline]
320 fn hash<H: Hasher>(&self, state: &mut H) {
321 self.value.hash(state);
322 }
323}
324impl<T: Eq> Eq for SharedRef<T> {}
325impl<T: PartialEq> PartialEq for SharedRef<T> {
326 #[inline]
327 fn eq(&self, other: &Self) -> bool {
328 self.value == other.value
329 }
330}
331impl<T: PartialOrd> PartialOrd for SharedRef<T> {
332 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
333 T::partial_cmp(&self.value, &other.value)
334 }
335}
336impl<T: Ord> Ord for SharedRef<T> {
337 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
338 T::cmp(&self.value, &other.value)
339 }
340}
341
342struct UniqueIdAllocator {
343 next_id: AtomicU64,
344}
345impl UniqueIdAllocator {
346 const fn new() -> Self {
347 UniqueIdAllocator {
348 next_id: AtomicU64::new(1),
349 }
350 }
351 fn alloc(&self) -> NonZero<u64> {
352 NonZero::new(
353 self.next_id
354 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |x| x.checked_add(1))
355 .expect("id overflow"),
356 )
357 .unwrap()
358 }
359}
360#[derive(Copy, Clone, Debug, Eq, PartialOrd, PartialEq, Hash)]
361struct UniqueLocalId(NonZero<u64>);
362impl UniqueLocalId {
363 fn alloc() -> Self {
364 static ALLOCATOR: UniqueIdAllocator = UniqueIdAllocator::new();
365 UniqueLocalId(ALLOCATOR.alloc())
366 }
367}
368
369type LiveThreadMap = imbl::GenericHashMap<
370 ThreadId,
371 Weak<LiveThreadState>,
372 foldhash::fast::RandomState,
373 imbl::shared_ptr::DefaultSharedPtr,
374>;
375/// Map of currently live threads.
376///
377/// This is a persistent map to allow quick snapshots to be taken by drop function and iteration/
378/// I use `imbl` instead of `rpds`, because `rpds` doesn't support an owned iterator,
379/// only a borrowed iterator which would require an extra allocation.
380/// I use a hashmap instead of a `BTreeMap` because that would require `ThreadId: Ord`,
381/// which the stdlib doesn't have.
382static LIVE_THREADS: Mutex<Option<LiveThreadMap>> = Mutex::new(None);
383fn snapshot_live_threads() -> Option<LiveThreadMap> {
384 let lock = LIVE_THREADS.lock();
385 lock.as_ref().cloned()
386}
387
388thread_local! {
389 static THREAD_STATE: Arc<LiveThreadState> = {
390 let id = std::thread::current().id();
391 let state =Arc::new(LiveThreadState {
392 id,
393 values: Mutex::new(foldhash::HashMap::default()),
394 });
395 let mut live_threads = LIVE_THREADS.lock();
396 let live_threads = live_threads.get_or_insert_default();
397 use imbl::hashmap::Entry;
398 match live_threads.entry(id) {
399 Entry::Occupied(_) => panic!("reinitialized thread"),
400 Entry::Vacant(entry) => {
401 entry.insert(Arc::downgrade(&state));
402 }
403 }
404 state
405 };
406}
407type DynArc = Arc<dyn Any + Send + Sync + 'static>;
408struct LiveThreadState {
409 id: ThreadId,
410 /// Maps from local ids to values.
411 ///
412 /// ## Performance
413 /// This is noticeably slower than what `thread_local` offers.
414 ///
415 /// We could make it faster if we used a vector
416 /// A `boxcar::Vec` is essentially the same data structure as `thread_local` uses,
417 /// and would allow us to get rid of the lock.
418 /// The only difference is we would be mapping from local ids -> values,
419 /// However, the lock is uncontented and parking_lot makes that relatively cheap,
420 /// so a `Mutex<Vec<T>> might also work and be simpler.
421 /// To avoid unbounded memory usage if locals are constantly being allocated/drops,
422 /// this would require reusing indexes.
423 values: Mutex<foldhash::HashMap<UniqueLocalId, DynArc>>,
424}
425impl LiveThreadState {
426 #[inline]
427 fn get(&self, id: UniqueLocalId) -> Option<DynArc> {
428 let lock = self.values.lock();
429 Some(Arc::clone(lock.get(&id)?))
430 }
431 // the arc has dynamic type to avoid monomorphization
432 #[cold]
433 #[inline(never)]
434 fn init(&self, id: UniqueLocalId, new_value: &DynArc) {
435 let mut lock = self.values.lock();
436 use std::collections::hash_map::Entry;
437 match lock.entry(id) {
438 Entry::Occupied(_) => {
439 panic!("unexpected double initialization of thread-local value")
440 }
441 Entry::Vacant(entry) => {
442 entry.insert(Arc::clone(new_value));
443 }
444 }
445 }
446}
447impl Drop for LiveThreadState {
448 fn drop(&mut self) {
449 // clear all our values
450 self.values.get_mut().clear();
451 // remove from the list of live threads
452 {
453 let mut threads = LIVE_THREADS.lock();
454 if let Some(threads) = threads.as_mut() {
455 // no fear of dropping while locked because we control the type
456 let state: Option<Weak<LiveThreadState>> = threads.remove(&self.id);
457 drop(state)
458 }
459 }
460 }
461}