os_thread_local/
lib.rs

1// Parts of tests and documentation that were copied from the rust code-base
2// are Copyright The Rust Project Developers, and licensed under the MIT or
3// Apache-2.0, license, like the rest of this project. See the LICENSE-MIT and
4// LICENSE-APACHE files at the root of this crate.
5
6//! OS-backed thread-local storage
7//!
8//! This crate provides a [`ThreadLocal`] type as an alternative to
9//! `std::thread_local!` that allows per-object thread-local storage, while
10//! providing a similar API. It always uses the thread-local storage primitives
11//! provided by the OS.
12//!
13//! On Unix systems, pthread-based thread-local storage is used.
14//!
15//! On Windows, fiber-local storage is used. This acts like thread-local
16//! storage when fibers are unused, but also provides per-fiber values
17//! after fibers are created with e.g. `winapi::um::winbase::CreateFiber`.
18//!
19//! See [`ThreadLocal`] for more details.
20//!
21//!   [`ThreadLocal`]: struct.ThreadLocal.html
22//!
23//! The [`thread_local`](https://crates.io/crates/thread_local) crate also
24//! provides per-object thread-local storage, with a different API, and
25//! different features, but with more performance overhead than this one.
26
27#![deny(missing_docs)]
28
29use core::fmt;
30use core::ptr::NonNull;
31use std::boxed::Box;
32use std::error::Error;
33
34#[cfg(windows)]
35mod oskey {
36    use winapi::um::fibersapi;
37
38    pub(crate) type Key = winapi::shared::minwindef::DWORD;
39    #[allow(non_camel_case_types)]
40    pub(crate) type c_void = winapi::ctypes::c_void;
41
42    #[inline]
43    pub(crate) unsafe fn create(dtor: Option<unsafe extern "system" fn(*mut c_void)>) -> Key {
44        fibersapi::FlsAlloc(dtor)
45    }
46
47    #[inline]
48    pub(crate) unsafe fn set(key: Key, value: *mut c_void) {
49        let r = fibersapi::FlsSetValue(key, value);
50        debug_assert_ne!(r, 0);
51    }
52
53    #[inline]
54    pub(crate) unsafe fn get(key: Key) -> *mut c_void {
55        fibersapi::FlsGetValue(key)
56    }
57
58    #[inline]
59    pub(crate) unsafe fn destroy(key: Key) {
60        let r = fibersapi::FlsFree(key);
61        debug_assert_ne!(r, 0);
62    }
63}
64
65#[cfg(not(windows))]
66mod oskey {
67    use core::mem::{self, MaybeUninit};
68
69    pub(crate) type Key = libc::pthread_key_t;
70    #[allow(non_camel_case_types)]
71    pub(crate) type c_void = core::ffi::c_void;
72
73    #[inline]
74    pub(crate) unsafe fn create(dtor: Option<unsafe extern "system" fn(*mut c_void)>) -> Key {
75        let mut key = MaybeUninit::uninit();
76        assert_eq!(
77            libc::pthread_key_create(key.as_mut_ptr(), mem::transmute(dtor)),
78            0
79        );
80        key.assume_init()
81    }
82
83    #[inline]
84    pub(crate) unsafe fn set(key: Key, value: *mut c_void) {
85        let r = libc::pthread_setspecific(key, value);
86        debug_assert_eq!(r, 0);
87    }
88
89    #[inline]
90    pub(crate) unsafe fn get(key: Key) -> *mut c_void {
91        libc::pthread_getspecific(key)
92    }
93
94    #[inline]
95    pub(crate) unsafe fn destroy(key: Key) {
96        let r = libc::pthread_key_delete(key);
97        debug_assert_eq!(r, 0);
98    }
99}
100
101use oskey::c_void;
102
103/// A thread-local storage handle.
104///
105/// In many ways, this struct works similarly to [`std::thread::LocalKey`], but
106/// always relies on OS primitives (see [module-level documentation](index.html)).
107///
108/// The [`with`] method yields a reference to the contained value which cannot
109/// be sent across threads or escape the given closure.
110///
111/// # Initialization and Destruction
112///
113/// Initialization is dynamically performed on the first call to [`with`]
114/// within a thread, and values that implement [`Drop`] get destructed when a
115/// thread exits. Some caveats apply, which are explained below.
116///
117/// A `ThreadLocal`'s initializer cannot recursively depend on itself, and
118/// using a `ThreadLocal` in this way will cause the initializer to infinitely
119/// recurse on the first call to [`with`].
120///
121///   [`std::thread::LocalKey`]: https://doc.rust-lang.org/std/thread/struct.LocalKey.html
122///   [`with`]: #method.with
123///   [`Drop`]: https://doc.rust-lang.org/std/ops/trait.Drop.html
124///
125/// # Examples
126///
127/// This is the same as the example in the [`std::thread::LocalKey`] documentation,
128/// but adjusted to use `ThreadLocal` instead. To use it in a `static` context, a
129/// lazy initializer, such as [`once_cell::sync::Lazy`] or [`lazy_static!`] is
130/// required.
131///
132///   [`once_cell::sync::Lazy`]: https://docs.rs/once_cell/1.2.0/once_cell/sync/struct.Lazy.html
133///   [`lazy_static!`]: https://docs.rs/lazy_static/1.4.0/lazy_static/
134///
135/// ```rust
136/// use std::cell::RefCell;
137/// use std::thread;
138/// use once_cell::sync::Lazy;
139/// use os_thread_local::ThreadLocal;
140///
141/// static FOO: Lazy<ThreadLocal<RefCell<u32>>> =
142///     Lazy::new(|| ThreadLocal::new(|| RefCell::new(1)));
143///
144/// FOO.with(|f| {
145///     assert_eq!(*f.borrow(), 1);
146///     *f.borrow_mut() = 2;
147/// });
148///
149/// // each thread starts out with the initial value of 1
150/// let t = thread::spawn(move || {
151///     FOO.with(|f| {
152///         assert_eq!(*f.borrow(), 1);
153///         *f.borrow_mut() = 3;
154///     });
155/// });
156///
157/// // wait for the thread to complete and bail out on panic
158/// t.join().unwrap();
159///
160/// // we retain our original value of 2 despite the child thread
161/// FOO.with(|f| {
162///     assert_eq!(*f.borrow(), 2);
163/// });
164/// ```
165///
166/// A variation of the same with scoped threads and per-object thread-local
167/// storage:
168///
169/// ```rust
170/// use std::cell::RefCell;
171/// use crossbeam_utils::thread::scope;
172/// use os_thread_local::ThreadLocal;
173///
174/// struct Foo {
175///     data: u32,
176///     tls: ThreadLocal<RefCell<u32>>,
177/// }
178///
179/// let foo = Foo {
180///     data: 0,
181///     tls: ThreadLocal::new(|| RefCell::new(1)),
182/// };
183///
184/// foo.tls.with(|f| {
185///     assert_eq!(*f.borrow(), 1);
186///     *f.borrow_mut() = 2;
187/// });
188///
189/// scope(|s| {
190///     // each thread starts out with the initial value of 1
191///     let foo2 = &foo;
192///     let t = s.spawn(move |_| {
193///         foo2.tls.with(|f| {
194///             assert_eq!(*f.borrow(), 1);
195///             *f.borrow_mut() = 3;
196///         });
197///     });
198///
199///     // wait for the thread to complete and bail out on panic
200///     t.join().unwrap();
201///
202///     // we retain our original value of 2 despite the child thread
203///     foo.tls.with(|f| {
204///         assert_eq!(*f.borrow(), 2);
205///     });
206/// }).unwrap();
207/// ```
208///
209/// # Platform-specific behavior and caveats
210///
211/// Note that a "best effort" is made to ensure that destructors for types
212/// stored in thread-local storage are run, but it is not guaranteed that
213/// destructors will be run for all types in thread-local storage.
214///
215/// - Destructors may not run on the main thread when it exits.
216/// - Destructors will not run if the corresponding `ThreadLocal` is dropped
217///   in a child thread (this can happen e.g. if the object or binding holding
218///   it is moved into a child thread ; or when the `ThreadLocal` is created
219///   in a child thread).
220/// - Destructors may not run if a `ThreadLocal` is initialized during the `Drop`
221///   impl of a type held by another `ThreadLocal`.
222/// - The order in which destructors may run when using multiple `ThreadLocal`
223///   is not guaranteed.
224///
225/// On Windows, `ThreadLocal` provides per-thread storage as long as fibers
226/// are unused. When fibers are used, it provides per-fiber storage, which
227/// is similar but more fine-grained.
228pub struct ThreadLocal<T> {
229    key: oskey::Key,
230    init: fn() -> T,
231}
232
233impl<T: Default> Default for ThreadLocal<T> {
234    fn default() -> Self {
235        ThreadLocal::new(Default::default)
236    }
237}
238
239impl<T> fmt::Debug for ThreadLocal<T> {
240    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
241        f.pad("ThreadLocal {{ .. }}")
242    }
243}
244
245/// An error returned by [`ThreadLocal::try_with`](struct.ThreadLocal.html#method.try_with).
246#[derive(Clone, Copy, Eq, PartialEq)]
247pub struct AccessError {
248    _private: (),
249}
250
251impl fmt::Debug for AccessError {
252    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
253        f.debug_struct("AccessError").finish()
254    }
255}
256
257impl fmt::Display for AccessError {
258    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259        fmt::Display::fmt("already destroyed", f)
260    }
261}
262
263impl Error for AccessError {}
264
265/// A wrapper holding values stored in TLS. We store a `Box<ThreadLocalValue<T>>`,
266/// turned into raw pointers.
267struct ThreadLocalValue<T> {
268    inner: T,
269    key: oskey::Key,
270}
271
272const GUARD: NonNull<c_void> = NonNull::dangling();
273
274unsafe extern "system" fn thread_local_drop<T>(ptr: *mut c_void) {
275    let ptr = NonNull::new_unchecked(ptr as *mut ThreadLocalValue<T>);
276    if ptr != GUARD.cast() {
277        let value = Box::from_raw(ptr.as_ptr());
278        oskey::set(value.key, GUARD.as_ptr());
279        // value is dropped here, and the `Box` destroyed.
280    }
281}
282
283impl<T> ThreadLocal<T> {
284    /// Creates a new thread-local storage handle.
285    ///
286    /// The provided function is used to initialize the value on the first use in
287    /// each thread.
288    ///
289    /// ```rust
290    /// use os_thread_local::ThreadLocal;
291    ///
292    /// let tls = ThreadLocal::new(|| 42);
293    /// ```
294    pub fn new(f: fn() -> T) -> Self {
295        ThreadLocal {
296            key: unsafe { oskey::create(Some(thread_local_drop::<T>)) },
297            init: f,
298        }
299    }
300
301    /// Acquires a reference to the value in this thread-local storage.
302    ///
303    /// This will lazily initialize the value if this thread has not accessed it
304    /// yet.
305    ///
306    /// ```rust
307    /// use os_thread_local::ThreadLocal;
308    /// use std::cell::Cell;
309    ///
310    /// let tls = ThreadLocal::new(|| Cell::new(42));
311    /// tls.with(|v| v.set(21));
312    /// ```
313    ///
314    /// # Panics
315    ///
316    /// This function will `panic!()` if the handle currently has its destructor
317    /// running, and it **may** panic if the destructor has previously been run for
318    /// this thread.
319    /// This function can also `panic!()` if the storage is uninitialized and there
320    /// is not enough available memory to allocate a new thread local storage for
321    /// the current thread, or if the OS primitives fail.
322    pub fn with<R, F: FnOnce(&T) -> R>(&self, f: F) -> R {
323        self.try_with(f)
324            .expect("cannot access a TLS value during or after it is destroyed")
325    }
326
327    /// Acquires a reference to the value in this thread-local storage.
328    ///
329    /// This will lazily initialize the value if this thread has not accessed it
330    /// yet. If the storage has been destroyed, this function will return an
331    /// `AccessError`.
332    ///
333    /// ```rust
334    /// use os_thread_local::ThreadLocal;
335    /// use std::cell::Cell;
336    ///
337    /// let tls = ThreadLocal::new(|| Cell::new(42));
338    /// tls.try_with(|v| v.set(21)).expect("storage destroyed");
339    /// ```
340    ///
341    /// # Panics
342    ///
343    /// This function will `panic!()` if the storage is uninitialized and the
344    /// initializer given to [`ThreadLocal::new`](#method.new) panics.
345    /// This function can also `panic!()` if the storage is uninitialized and there
346    /// is not enough available memory to allocate a new thread local storage for
347    /// the current thread, or if the OS primitives fail.
348    pub fn try_with<R, F: FnOnce(&T) -> R>(&self, f: F) -> Result<R, AccessError> {
349        let ptr = unsafe { oskey::get(self.key) as *mut ThreadLocalValue<T> };
350        let value = NonNull::new(ptr).unwrap_or_else(|| unsafe {
351            // Equivalent to currently unstable Box::into_raw_non_null.
352            // https://github.com/rust-lang/rust/issues/47336#issuecomment-373941458
353            let result = NonNull::new_unchecked(Box::into_raw(Box::new(ThreadLocalValue {
354                inner: (self.init)(),
355                key: self.key,
356            })));
357            oskey::set(self.key, result.as_ptr() as *mut _);
358            result
359        });
360        // Avoid reinitializing a TLS that was destroyed.
361        if value != GUARD.cast() {
362            Ok(f(&unsafe { value.as_ref() }.inner))
363        } else {
364            Err(AccessError { _private: () })
365        }
366    }
367}
368
369impl<T> Drop for ThreadLocal<T> {
370    fn drop(&mut self) {
371        unsafe {
372            oskey::destroy(self.key);
373        }
374    }
375}
376
377#[cfg(test)]
378pub(crate) mod tests {
379    use super::ThreadLocal;
380    use core::cell::{Cell, UnsafeCell};
381    use crossbeam_utils::thread::scope;
382    use once_cell::sync::Lazy;
383    use std::sync::mpsc::{channel, Sender};
384    use std::sync::RwLock;
385    use std::thread;
386
387    // Some tests use multiple ThreadLocal handles and rely on them having a
388    // deterministic-ish order, which is not guaranteed when multiple tests
389    // run in parallel. So we make all tests using a single handle hold a
390    // read lock, and all tests using multiple handles hold a write lock,
391    // avoiding race conditions between tests.
392    pub static LOCK: Lazy<RwLock<()>> = Lazy::new(|| RwLock::new(()));
393
394    #[test]
395    fn assumptions() {
396        use super::oskey;
397        use core::ptr::{self, NonNull};
398        use core::sync::atomic::{AtomicBool, Ordering};
399
400        let _l = LOCK.write().unwrap();
401        static CALLED: AtomicBool = AtomicBool::new(false);
402        unsafe extern "system" fn call(_: *mut oskey::c_void) {
403            CALLED.store(true, Ordering::Release);
404        }
405        unsafe {
406            // Test our assumptions wrt the OS TLS implementation.
407            let key = oskey::create(None);
408            // A newly created handle returns a NULL value.
409            assert_eq!(oskey::get(key), ptr::null_mut());
410            oskey::set(key, NonNull::dangling().as_ptr());
411            assert_eq!(oskey::get(key), NonNull::dangling().as_ptr());
412            oskey::destroy(key);
413            let key2 = oskey::create(None);
414            // Destroying a handle and creating a new one right after gives
415            // the same handle.
416            assert_eq!(key, key2);
417            // A re-create handle with the same number still returns a NULL
418            // value.
419            assert_eq!(oskey::get(key), ptr::null_mut());
420            oskey::destroy(key2);
421
422            let key = oskey::create(Some(call));
423            scope(|s| {
424                s.spawn(|_| {
425                    oskey::get(key);
426                })
427                .join()
428                .unwrap();
429                // The destructor of a handle that hasn't been set is not called.
430                assert_eq!(CALLED.load(Ordering::Acquire), false);
431                s.spawn(|_| {
432                    oskey::set(key, NonNull::dangling().as_ptr());
433                })
434                .join()
435                .unwrap();
436                // The destructor is called if the handle has been set.
437                assert_eq!(CALLED.load(Ordering::Acquire), true);
438                CALLED.store(false, Ordering::Release);
439                s.spawn(|_| {
440                    oskey::set(key, NonNull::dangling().as_ptr());
441                    oskey::set(key, ptr::null_mut());
442                })
443                .join()
444                .unwrap();
445                // The destructor of a handle explicitly (re)set to NULL is not called.
446                assert_eq!(CALLED.load(Ordering::Acquire), false);
447            })
448            .unwrap();
449        }
450    }
451
452    // The tests below were adapted from the tests in libstd/thread/local.rs
453    // in the rust source.
454    struct Foo(Sender<()>);
455
456    impl Drop for Foo {
457        fn drop(&mut self) {
458            let Foo(ref s) = *self;
459            s.send(()).unwrap();
460        }
461    }
462
463    #[test]
464    fn smoke_dtor() {
465        let _l = LOCK.read().unwrap();
466        let foo = ThreadLocal::new(|| UnsafeCell::new(None));
467        scope(|s| {
468            let foo = &foo;
469            let (tx, rx) = channel();
470            let _t = s.spawn(move |_| unsafe {
471                let mut tx = Some(tx);
472                foo.with(|f| {
473                    *f.get() = Some(Foo(tx.take().unwrap()));
474                });
475            });
476            rx.recv().unwrap();
477        })
478        .unwrap();
479    }
480
481    #[test]
482    fn smoke_no_dtor() {
483        let _l = LOCK.read().unwrap();
484        let foo = ThreadLocal::new(|| Cell::new(1));
485        scope(|s| {
486            let foo = &foo;
487            foo.with(|f| {
488                assert_eq!(f.get(), 1);
489                f.set(2);
490            });
491            let (tx, rx) = channel();
492            let _t = s.spawn(move |_| {
493                foo.with(|f| {
494                    assert_eq!(f.get(), 1);
495                });
496                tx.send(()).unwrap();
497            });
498            rx.recv().unwrap();
499            foo.with(|f| {
500                assert_eq!(f.get(), 2);
501            });
502        })
503        .unwrap();
504    }
505
506    #[test]
507    fn states() {
508        let _l = LOCK.read().unwrap();
509        struct Foo;
510        impl Drop for Foo {
511            fn drop(&mut self) {
512                assert!(FOO.try_with(|_| ()).is_err());
513            }
514        }
515        static FOO: Lazy<ThreadLocal<Foo>> = Lazy::new(|| ThreadLocal::new(|| Foo));
516
517        thread::spawn(|| {
518            assert!(FOO.try_with(|_| ()).is_ok());
519        })
520        .join()
521        .ok()
522        .expect("thread panicked");
523    }
524
525    #[test]
526    fn circular() {
527        let _l = LOCK.read().unwrap();
528        struct S1;
529        struct S2;
530        static K1: Lazy<ThreadLocal<UnsafeCell<Option<S1>>>> =
531            Lazy::new(|| ThreadLocal::new(|| UnsafeCell::new(None)));
532        static K2: Lazy<ThreadLocal<UnsafeCell<Option<S2>>>> =
533            Lazy::new(|| ThreadLocal::new(|| UnsafeCell::new(None)));
534        static mut HITS: u32 = 0;
535
536        impl Drop for S1 {
537            fn drop(&mut self) {
538                unsafe {
539                    HITS += 1;
540                    if K2.try_with(|_| ()).is_err() {
541                        assert_eq!(HITS, 3);
542                    } else {
543                        if HITS == 1 {
544                            K2.with(|s| *s.get() = Some(S2));
545                        } else {
546                            assert_eq!(HITS, 3);
547                        }
548                    }
549                }
550            }
551        }
552        impl Drop for S2 {
553            fn drop(&mut self) {
554                unsafe {
555                    HITS += 1;
556                    assert!(K1.try_with(|_| ()).is_ok());
557                    assert_eq!(HITS, 2);
558                    K1.with(|s| *s.get() = Some(S1));
559                }
560            }
561        }
562
563        thread::spawn(move || {
564            drop(S1);
565        })
566        .join()
567        .ok()
568        .expect("thread panicked");
569    }
570
571    #[test]
572    fn self_referential() {
573        let _l = LOCK.read().unwrap();
574        struct S1;
575        static K1: Lazy<ThreadLocal<UnsafeCell<Option<S1>>>> =
576            Lazy::new(|| ThreadLocal::new(|| UnsafeCell::new(None)));
577
578        impl Drop for S1 {
579            fn drop(&mut self) {
580                assert!(K1.try_with(|_| ()).is_err());
581            }
582        }
583
584        thread::spawn(move || unsafe {
585            K1.with(|s| *s.get() = Some(S1));
586        })
587        .join()
588        .ok()
589        .expect("thread panicked");
590    }
591
592    #[test]
593    fn dtors_in_dtors_in_dtors() {
594        let _l = LOCK.write().unwrap();
595        struct S1(Sender<()>);
596        static K: Lazy<(
597            ThreadLocal<UnsafeCell<Option<S1>>>,
598            ThreadLocal<UnsafeCell<Option<Foo>>>,
599        )> = Lazy::new(|| {
600            (
601                ThreadLocal::new(|| UnsafeCell::new(None)),
602                ThreadLocal::new(|| UnsafeCell::new(None)),
603            )
604        });
605
606        impl Drop for S1 {
607            fn drop(&mut self) {
608                let S1(ref tx) = *self;
609                unsafe {
610                    let _ = K.1.try_with(|s| *s.get() = Some(Foo(tx.clone())));
611                }
612            }
613        }
614
615        let (tx, rx) = channel();
616        let _t = thread::spawn(move || unsafe {
617            let mut tx = Some(tx);
618            K.0.with(|s| *s.get() = Some(S1(tx.take().unwrap())));
619        });
620        rx.recv().unwrap();
621    }
622}
623
624#[cfg(test)]
625mod dynamic_tests {
626    use super::tests::LOCK;
627    use super::ThreadLocal;
628    use core::cell::RefCell;
629    use std::collections::HashMap;
630    use std::vec;
631
632    #[test]
633    fn smoke() {
634        let _l = LOCK.read().unwrap();
635        fn square(i: i32) -> i32 {
636            i * i
637        }
638        let foo = ThreadLocal::new(|| square(3));
639
640        foo.with(|f| {
641            assert_eq!(*f, 9);
642        });
643    }
644
645    #[test]
646    fn hashmap() {
647        let _l = LOCK.read().unwrap();
648        fn map() -> RefCell<HashMap<i32, i32>> {
649            let mut m = HashMap::new();
650            m.insert(1, 2);
651            RefCell::new(m)
652        }
653        let foo = ThreadLocal::new(|| map());
654
655        foo.with(|map| {
656            assert_eq!(map.borrow()[&1], 2);
657        });
658    }
659
660    #[test]
661    fn refcell_vec() {
662        let _l = LOCK.read().unwrap();
663        let foo = ThreadLocal::new(|| RefCell::new(vec![1, 2, 3]));
664
665        foo.with(|vec| {
666            assert_eq!(vec.borrow().len(), 3);
667            vec.borrow_mut().push(4);
668            assert_eq!(vec.borrow()[3], 4);
669        });
670    }
671}