1use super::{pthread, LazyLock};
2use core::alloc::Layout;
3use core::cell::{Cell, RefCell};
4use core::marker::PhantomData;
5use core::ops::Deref;
6use hierr::{Error, Result};
7
8pub struct ThrdLocal<T: Sized> {
10 key: pthread::pthread_key_t,
11 mark: PhantomData<*const T>,
12}
13pub type ThrdLocalDtor = unsafe extern "C" fn(_: *const core::ffi::c_void);
14
15unsafe impl<T> Send for ThrdLocal<T> {}
16unsafe impl<T> Sync for ThrdLocal<T> {}
17
18impl<T> Drop for ThrdLocal<T> {
19 fn drop(&mut self) {
20 unsafe { pthread::pthread_key_delete(self.key) };
21 }
22}
23
24impl<T> ThrdLocal<T> {
25 pub fn new() -> Result<Self> {
26 Self::new_with(None)
27 }
28 pub fn new_with(dtor: Option<ThrdLocalDtor>) -> Result<Self> {
29 let mut key: pthread::pthread_key_t = 0;
30 let ret = unsafe { pthread::pthread_key_create(&mut key, dtor) };
31 if ret == 0 {
32 Ok(Self {
33 key,
34 mark: PhantomData,
35 })
36 } else {
37 Err(Error::new(ret))
38 }
39 }
40
41 pub fn set(&self, val: *const T) {
42 unsafe {
43 pthread::pthread_setspecific(self.key, val as *const _ as *const pthread::c_void);
44 }
45 }
46
47 pub fn get(&self) -> *const T {
48 unsafe {
49 pthread::pthread_getspecific(self.key)
50 .cast::<T>()
51 .cast_mut()
52 }
53 }
54}
55
56pub struct LocalKey<T> {
58 key: LazyLock<ThrdLocal<T>>,
59}
60
61unsafe impl<T> Send for LocalKey<T> {}
62unsafe impl<T> Sync for LocalKey<T> {}
63
64impl<T> LocalKey<T> {
65 pub const fn new() -> Self {
66 Self {
67 key: LazyLock::new(|| ThrdLocal::new().unwrap()),
68 }
69 }
70
71 pub fn get(&self) -> *const T {
73 self.key.get()
74 }
75
76 pub fn set(&self, val: *const T) {
78 self.key.set(val);
79 }
80
81 pub fn replace(&self, val: *const T) -> *const T {
83 let old = self.key.get();
84 self.key.set(val);
85 old
86 }
87}
88
89pub unsafe trait ThrdLocalAlloc {
92 unsafe fn alloc(layout: Layout) -> *mut u8;
95 unsafe fn dealloc(p: *mut u8, layout: Layout);
98}
99
100pub struct NativeAlloc;
102
103unsafe impl ThrdLocalAlloc for NativeAlloc {
104 unsafe fn alloc(layout: Layout) -> *mut u8 {
105 crate::stdlib::aligned_alloc(layout.align(), layout.size())
106 }
107 unsafe fn dealloc(p: *mut u8, _: Layout) {
108 crate::stdlib::aligned_free(p);
109 }
110}
111
112pub struct LocalRefCell<T, A = NativeAlloc, F = fn() -> T> {
114 local: LazyLock<ThrdLocal<RefCell<T>>>,
115 f: F,
116 _alloc: A,
117}
118
119unsafe impl<T, A: Send, F: Send> Send for LocalRefCell<T, A, F> {}
120unsafe impl<T, A: Sync, F: Sync> Sync for LocalRefCell<T, A, F> {}
121
122impl<T, F: Fn() -> T> LocalRefCell<T, NativeAlloc, F> {
123 pub const fn new(f: F) -> Self {
124 Self::new_with(NativeAlloc, f)
125 }
126}
127
128impl<T, A: ThrdLocalAlloc, F: Fn() -> T> LocalRefCell<T, A, F> {
129 unsafe extern "C" fn dealloc(val: *const core::ffi::c_void) {
130 let p = val.cast_mut().cast::<RefCell<T>>();
131 unsafe { core::ptr::drop_in_place(p) };
132 A::dealloc(p.cast::<u8>(), Layout::new::<RefCell<T>>());
133 }
134
135 pub const fn new_with(_alloc: A, f: F) -> Self {
136 Self {
137 local: LazyLock::new(|| ThrdLocal::new_with(Some(Self::dealloc)).unwrap()),
138 f,
139 _alloc,
140 }
141 }
142
143 fn get_refcell(&self) -> &RefCell<T> {
144 let mut p = self.local.get().cast_mut();
145 if p.is_null() {
146 p = unsafe { A::alloc(Layout::new::<RefCell<T>>()).cast::<RefCell<T>>() };
147 unsafe { p.write(RefCell::new((self.f)())) };
148 self.local.set(p);
149 }
150 unsafe { &*p }
151 }
152}
153
154impl<T, A: ThrdLocalAlloc, F: Fn() -> T> Deref for LocalRefCell<T, A, F> {
155 type Target = RefCell<T>;
156 fn deref(&self) -> &Self::Target {
157 self.get_refcell()
158 }
159}
160
161pub struct LocalCell<T, A = NativeAlloc, F = fn() -> T> {
163 local: LazyLock<ThrdLocal<Cell<T>>>,
164 f: F,
165 _alloc: A,
166}
167
168unsafe impl<T, A: Send, F: Send> Send for LocalCell<T, A, F> {}
169unsafe impl<T, A: Sync, F: Sync> Sync for LocalCell<T, A, F> {}
170
171impl<T, F: Fn() -> T> LocalCell<T, NativeAlloc, F> {
172 pub const fn new(f: F) -> Self {
173 Self::new_with(NativeAlloc, f)
174 }
175}
176
177impl<T, A: ThrdLocalAlloc, F: Fn() -> T> LocalCell<T, A, F> {
178 unsafe extern "C" fn dealloc(val: *const core::ffi::c_void) {
179 let p = val.cast_mut().cast::<Cell<T>>();
180 unsafe { core::ptr::drop_in_place(p) };
181 A::dealloc(p.cast::<u8>(), Layout::new::<Cell<T>>());
182 }
183
184 pub const fn new_with(_alloc: A, f: F) -> Self {
185 Self {
186 local: LazyLock::new(|| ThrdLocal::new_with(Some(Self::dealloc)).unwrap()),
187 f,
188 _alloc,
189 }
190 }
191
192 fn get_cell(&self) -> &Cell<T> {
193 let mut p = self.local.get().cast_mut();
194 if p.is_null() {
195 p = unsafe { A::alloc(Layout::new::<Cell<T>>()).cast::<Cell<T>>() };
196 unsafe { p.write(Cell::new((self.f)())) };
197 self.local.set(p);
198 }
199 unsafe { &*p }
200 }
201}
202
203impl<T, A: ThrdLocalAlloc, F: Fn() -> T> Deref for LocalCell<T, A, F> {
204 type Target = Cell<T>;
205 fn deref(&self) -> &Self::Target {
206 self.get_cell()
207 }
208}
209
210#[cfg(test)]
211mod test {
212 use crate::*;
213
214 #[test]
215 fn test_local() {
216 static KEY: LocalKey<i32> = LocalKey::new();
217 KEY.set(&100);
218 let h = spawn(|| {
219 let addr = KEY.replace(&101);
220 assert!(addr.is_null());
221 let addr = KEY.get();
222 assert_eq!(addr, &101);
223 101
224 })
225 .unwrap();
226 assert_eq!(h.join().unwrap(), 101);
227 let addr = KEY.get();
228 assert_eq!(addr.is_null(), false);
229 assert_eq!(addr, &100);
230 }
231
232 #[test]
233 fn test_refcell() {
234 static KEY: LocalRefCell<i32> = LocalRefCell::new(|| 100);
235 assert_eq!(*KEY.borrow(), 100);
236 KEY.replace(200);
237 let h = spawn(|| {
238 KEY.replace(101);
239 *KEY.borrow()
240 })
241 .unwrap();
242 assert_eq!(h.join().unwrap(), 101);
243 let h = spawn(|| {
244 KEY.replace(102);
245 *KEY.borrow()
246 })
247 .unwrap();
248 assert_eq!(h.join().unwrap(), 102);
249 assert_eq!(*KEY.borrow(), 200);
250 }
251
252 #[test]
253 fn test_cell() {
254 static KEY: LocalCell<i32> = LocalCell::new(|| 100);
255 assert_eq!(KEY.get(), 100);
256 KEY.replace(200);
257 let h = spawn(|| {
258 KEY.replace(101);
259 KEY.get()
260 })
261 .unwrap();
262 assert_eq!(h.join().unwrap(), 101);
263 let h = spawn(|| {
264 KEY.replace(102);
265 KEY.get()
266 })
267 .unwrap();
268 assert_eq!(h.join().unwrap(), 102);
269 assert_eq!(KEY.get(), 200);
270 }
271}