per_thread_object/
lib.rs

1//! Efficient per-object thread-local storage implementation.
2//!
3//! ```rust
4//! # if cfg!(feature = "loom") || cfg!(feature = "shuttle") { return }
5//! use std::thread;
6//! use std::sync::Arc;
7//! use std::cell::RefCell;
8//! use per_thread_object::ThreadLocal;
9//!
10//! fn default() -> RefCell<u32> {
11//!     RefCell::new(0x0)
12//! }
13//!
14//! let tl: Arc<ThreadLocal<RefCell<u32>>> = Arc::new(ThreadLocal::new());
15//! let tl2 = tl.clone();
16//!
17//! thread::spawn(move || {
18//!     per_thread_object::stack_token!(token);
19//!
20//!     *tl2.get_or_init(token, default).borrow_mut() += 1;
21//!     let val = *tl2.get(token).unwrap().borrow();
22//!     assert_eq!(0x1, val);
23//! })
24//!     .join()
25//!     .unwrap();
26//!
27//! per_thread_object::stack_token!(token);
28//!
29//! *tl.get_or_init(token, default).borrow_mut() += 2;
30//! assert_eq!(0x2, *tl.get_or_init(token, default).borrow());
31//! ```
32
33#[cfg(not(feature = "loom"))]
34mod loom;
35
36#[cfg(feature = "loom")]
37use loom;
38
39mod util;
40mod thread;
41mod page;
42
43use std::ptr::NonNull;
44use loom::cell::UnsafeCell;
45use page::Storage;
46
47
48/// Per-object thread-local storage
49///
50/// ## Capacity
51///
52/// `per-thread-object` has no max capacity limit,
53/// each `ThreadLocal` instance will create its own memory space
54/// instead of using global space.
55///
56/// this crate supports any number of threads,
57/// but only the specified number of threads are lock-free.
58///
59/// ## Panic when dropping
60///
61/// `ThreadLocal` will release object at the end of thread.
62/// If panic occurs during this process, it may cause a memory leak.
63pub struct ThreadLocal<T: Send + 'static> {
64    pool: Storage<T>
65}
66
67pub struct StackToken {
68    _marker: std::marker::PhantomData<*const ()>,
69}
70
71impl StackToken {
72    #[doc(hidden)]
73    pub unsafe fn __private_new() -> StackToken {
74        StackToken {
75            _marker: std::marker::PhantomData,
76        }
77    }
78}
79
80#[macro_export]
81macro_rules! stack_token {
82    ($name:ident) => {
83        #[allow(unsafe_code)]
84        let $name = &unsafe { $crate::StackToken::__private_new() };
85    };
86}
87
88impl<T: Send + 'static> ThreadLocal<T> {
89    pub fn new() -> ThreadLocal<T> {
90        #[cfg(not(feature = "loom"))]
91        #[cfg(not(feature = "shuttle"))]
92        let default = 16;
93
94        #[cfg(any(feature = "loom", feature = "shuttle"))]
95        let default = 3;
96
97        ThreadLocal::with_threads(default)
98    }
99
100    pub fn with_threads(num: usize) -> ThreadLocal<T> {
101        ThreadLocal {
102            pool: Storage::with_threads(num)
103        }
104    }
105
106    #[inline]
107    pub fn get<'stack>(&'stack self, _token: &'stack StackToken) -> Option<&'stack T> {
108        unsafe {
109            self.pool.get(thread::get())
110        }
111    }
112
113    #[inline]
114    pub fn get_or_init<'stack, F>(&'stack self, token: &'stack StackToken, init: F)
115        -> &'stack T
116    where
117        F: FnOnce() -> T
118    {
119        use std::convert::Infallible;
120
121        match self.get_or_try_init::<_, Infallible>(token, || Ok(init())) {
122            Ok(val) => val,
123            Err(err) => match err {}
124        }
125    }
126
127    #[inline]
128    pub fn get_or_try_init<'stack, F, E>(&'stack self, _token: &'stack StackToken, init: F)
129        -> Result<&'stack T, E>
130    where
131        F: FnOnce() -> Result<T, E>
132    {
133        let id = thread::get();
134        let ptr = unsafe { self.pool.get_or_new(id) };
135
136        let obj = unsafe { &*ptr.as_ptr() };
137        let val = if let Some(val) = obj.with(|val| unsafe { &*val }) {
138            val
139        } else {
140            let newval = init()?;
141            let val = obj.with_mut(|val| unsafe { &mut *val }.get_or_insert(newval));
142
143            ThreadLocal::or_try(&self.pool, id, ptr);
144
145            val
146        };
147
148        Ok(val)
149    }
150
151    #[cold]
152    fn or_try(pool: &Storage<T>, id: usize, ptr: NonNull<UnsafeCell<Option<T>>>) {
153        let thread_handle = unsafe {
154            thread::push(pool.as_threads_ref(), ptr)
155        };
156
157        pool.insert_thread_handle(id, thread_handle);
158    }
159}
160
161impl<T: Send + 'static> Default for ThreadLocal<T> {
162    #[inline]
163    fn default() -> ThreadLocal<T> {
164        ThreadLocal::new()
165    }
166}
167
168unsafe impl<T: Send> Send for ThreadLocal<T> {}
169unsafe impl<T: Send> Sync for ThreadLocal<T> {}