1#[cfg(not(feature = "loom"))]
33mod loom;
34
35#[cfg(feature = "loom")]
36use loom;
37
38mod thread;
39mod page;
40
41use std::ptr::NonNull;
42use loom::cell::UnsafeCell;
43use page::Storage;
44
45pub use page::DEFAULT_PAGE_CAP;
46
47
48pub struct ThreadLocal<T: Send + 'static> {
64 pool: Storage<T>
65}
66
67pub struct StackToken {
68 _marker: std::marker::PhantomData<*const ()>,
69}
70
71impl StackToken {
72 #[doc(hidden)]
73 pub unsafe fn __private_new() -> StackToken {
74 StackToken {
75 _marker: std::marker::PhantomData,
76 }
77 }
78}
79
80#[macro_export]
81macro_rules! stack_token {
82 ($name:ident) => {
83 #[allow(unsafe_code)]
84 let $name = &unsafe { $crate::StackToken::__private_new() };
85 };
86}
87
88impl<T: Send + 'static> ThreadLocal<T> {
89 pub fn new() -> ThreadLocal<T> {
90 ThreadLocal {
91 pool: Storage::new()
92 }
93 }
94
95 #[inline]
96 pub fn get<'stack>(&'stack self, _token: &'stack StackToken) -> Option<&'stack T> {
97 unsafe {
98 self.pool.get(thread::get())
99 }
100 }
101
102 #[inline]
103 pub fn get_or_init<'stack, F>(&'stack self, token: &'stack StackToken, init: F)
104 -> &'stack T
105 where
106 F: FnOnce() -> T
107 {
108 use std::convert::Infallible;
109
110 match self.get_or_try_init::<_, Infallible>(token, || Ok(init())) {
111 Ok(val) => val,
112 Err(err) => match err {}
113 }
114 }
115
116 #[inline]
117 pub fn get_or_try_init<'stack, F, E>(&'stack self, _token: &'stack StackToken, init: F)
118 -> Result<&'stack T, E>
119 where
120 F: FnOnce() -> Result<T, E>
121 {
122 let id = thread::get();
123 let ptr = unsafe { self.pool.get_or_new(id) };
124
125 let obj = unsafe { &*ptr.as_ptr() };
126 let val = if let Some(val) = obj.with(|val| unsafe { &*val }) {
127 val
128 } else {
129 let val = obj.with_mut(|val| {
130 let val = unsafe { &mut *val }.get_or_insert(init()?);
131 Ok(val)
132 })?;
133
134 ThreadLocal::or_try(&self.pool, id, ptr);
135
136 val
137 };
138
139 Ok(val)
140 }
141
142 #[cold]
143 fn or_try(pool: &Storage<T>, id: usize, ptr: NonNull<UnsafeCell<Option<T>>>) {
144 let thread_handle = unsafe {
145 thread::push(pool.as_threads_ref(), ptr)
146 };
147
148 pool.insert_thread_handle(id, thread_handle);
149 }
150}
151
152impl<T: Send + 'static> Default for ThreadLocal<T> {
153 #[inline]
154 fn default() -> ThreadLocal<T> {
155 ThreadLocal::new()
156 }
157}
158
159unsafe impl<T: Send> Send for ThreadLocal<T> {}
160unsafe impl<T: Send> Sync for ThreadLocal<T> {}