per_thread_object/
lib.rs

1//! Efficient per-object thread-local storage implementation.
2//!
3//! ```rust
4//! use std::thread;
5//! use std::sync::Arc;
6//! use std::cell::RefCell;
7//! use per_thread_object::ThreadLocal;
8//!
9//! fn default() -> RefCell<u32> {
10//!     RefCell::new(0x0)
11//! }
12//!
13//! let tl: Arc<ThreadLocal<RefCell<u32>>> = Arc::new(ThreadLocal::new());
14//! let tl2 = tl.clone();
15//!
16//! thread::spawn(move || {
17//!     per_thread_object::stack_token!(token);
18//!
19//!     *tl2.get_or_init(token, default).borrow_mut() += 1;
20//!     let val = *tl2.get(token).unwrap().borrow();
21//!     assert_eq!(0x1, val);
22//! })
23//!     .join()
24//!     .unwrap();
25//!
26//! per_thread_object::stack_token!(token);
27//!
28//! *tl.get_or_init(token, default).borrow_mut() += 2;
29//! assert_eq!(0x2, *tl.get_or_init(token, default).borrow());
30//! ```
31
32#[cfg(not(feature = "loom"))]
33mod loom;
34
35#[cfg(feature = "loom")]
36use loom;
37
38mod thread;
39mod page;
40
41use std::ptr::NonNull;
42use loom::cell::UnsafeCell;
43use page::Storage;
44
45pub use page::DEFAULT_PAGE_CAP;
46
47
48/// Per-object thread-local storage
49///
50/// ## Capacity
51///
52/// `per-thread-object` has no max capacity limit,
53/// each `ThreadLocal` instance will create its own memory space
54/// instead of using global space.
55///
56/// this crate supports any number of threads,
57/// but only the [DEFAULT_PAGE_CAP] threads are lock-free.
58///
59/// ## Panic when dropping
60///
61/// `ThreadLocal` will release object at the end of thread.
62/// If panic occurs during this process, it may cause a memory leak.
63pub struct ThreadLocal<T: Send + 'static> {
64    pool: Storage<T>
65}
66
67pub struct StackToken {
68    _marker: std::marker::PhantomData<*const ()>,
69}
70
71impl StackToken {
72    #[doc(hidden)]
73    pub unsafe fn __private_new() -> StackToken {
74        StackToken {
75            _marker: std::marker::PhantomData,
76        }
77    }
78}
79
80#[macro_export]
81macro_rules! stack_token {
82    ($name:ident) => {
83        #[allow(unsafe_code)]
84        let $name = &unsafe { $crate::StackToken::__private_new() };
85    };
86}
87
88impl<T: Send + 'static> ThreadLocal<T> {
89    pub fn new() -> ThreadLocal<T> {
90        ThreadLocal {
91            pool: Storage::new()
92        }
93    }
94
95    #[inline]
96    pub fn get<'stack>(&'stack self, _token: &'stack StackToken) -> Option<&'stack T> {
97        unsafe {
98            self.pool.get(thread::get())
99        }
100    }
101
102    #[inline]
103    pub fn get_or_init<'stack, F>(&'stack self, token: &'stack StackToken, init: F)
104        -> &'stack T
105    where
106        F: FnOnce() -> T
107    {
108        use std::convert::Infallible;
109
110        match self.get_or_try_init::<_, Infallible>(token, || Ok(init())) {
111            Ok(val) => val,
112            Err(err) => match err {}
113        }
114    }
115
116    #[inline]
117    pub fn get_or_try_init<'stack, F, E>(&'stack self, _token: &'stack StackToken, init: F)
118        -> Result<&'stack T, E>
119    where
120        F: FnOnce() -> Result<T, E>
121    {
122        let id = thread::get();
123        let ptr = unsafe { self.pool.get_or_new(id) };
124
125        let obj = unsafe { &*ptr.as_ptr() };
126        let val = if let Some(val) = obj.with(|val| unsafe { &*val }) {
127            val
128        } else {
129            let val = obj.with_mut(|val| {
130                let val = unsafe { &mut *val }.get_or_insert(init()?);
131                Ok(val)
132            })?;
133
134            ThreadLocal::or_try(&self.pool, id, ptr);
135
136            val
137        };
138
139        Ok(val)
140    }
141
142    #[cold]
143    fn or_try(pool: &Storage<T>, id: usize, ptr: NonNull<UnsafeCell<Option<T>>>) {
144        let thread_handle = unsafe {
145            thread::push(pool.as_threads_ref(), ptr)
146        };
147
148        pool.insert_thread_handle(id, thread_handle);
149    }
150}
151
152impl<T: Send + 'static> Default for ThreadLocal<T> {
153    #[inline]
154    fn default() -> ThreadLocal<T> {
155        ThreadLocal::new()
156    }
157}
158
159unsafe impl<T: Send> Send for ThreadLocal<T> {}
160unsafe impl<T: Send> Sync for ThreadLocal<T> {}