hipthread/
local.rs

1use super::{pthread, LazyLock};
2use core::alloc::Layout;
3use core::cell::{Cell, RefCell};
4use core::marker::PhantomData;
5use core::ops::Deref;
6use hierr::{Error, Result};
7
8/// 封装pthread_key_t, 提供线程局部存储机制.
9pub struct ThrdLocal<T: Sized> {
10    key: pthread::pthread_key_t,
11    mark: PhantomData<*const T>,
12}
13pub type ThrdLocalDtor = unsafe extern "C" fn(_: *const core::ffi::c_void);
14
15unsafe impl<T> Send for ThrdLocal<T> {}
16unsafe impl<T> Sync for ThrdLocal<T> {}
17
18impl<T> Drop for ThrdLocal<T> {
19    fn drop(&mut self) {
20        unsafe { pthread::pthread_key_delete(self.key) };
21    }
22}
23
24impl<T> ThrdLocal<T> {
25    pub fn new() -> Result<Self> {
26        Self::new_with(None)
27    }
28    pub fn new_with(dtor: Option<ThrdLocalDtor>) -> Result<Self> {
29        let mut key: pthread::pthread_key_t = 0;
30        let ret = unsafe { pthread::pthread_key_create(&mut key, dtor) };
31        if ret == 0 {
32            Ok(Self {
33                key,
34                mark: PhantomData,
35            })
36        } else {
37            Err(Error::new(ret))
38        }
39    }
40
41    pub fn set(&self, val: *const T) {
42        unsafe {
43            pthread::pthread_setspecific(self.key, val as *const _ as *const pthread::c_void);
44        }
45    }
46
47    pub fn get(&self) -> *const T {
48        unsafe {
49            pthread::pthread_getspecific(self.key)
50                .cast::<T>()
51                .cast_mut()
52        }
53    }
54}
55
56/// 封装ThrdLocal方便使用 
57pub struct LocalKey<T> {
58    key: LazyLock<ThrdLocal<T>>,
59}
60
61unsafe impl<T> Send for LocalKey<T> {}
62unsafe impl<T> Sync for LocalKey<T> {}
63
64impl<T> LocalKey<T> {
65    pub const fn new() -> Self {
66        Self {
67            key: LazyLock::new(|| ThrdLocal::new().unwrap()),
68        }
69    }
70
71    /// 获取当前值
72    pub fn get(&self) -> *const T {
73        self.key.get()
74    }
75
76    /// 设置为新值
77    pub fn set(&self, val: *const T) {
78        self.key.set(val);
79    }
80
81    /// 替换为新值,返回原来的值.
82    pub fn replace(&self, val: *const T) -> *const T {
83        let old = self.key.get();
84        self.key.set(val);
85        old
86    }
87}
88
89/// # Safety
90/// TLS变量的动态分配和释放接口.
91pub unsafe trait ThrdLocalAlloc {
92    /// # Safety
93    /// TLS变量动态分配接口,如果返回空指针,会导致panic
94    unsafe fn alloc(layout: Layout) -> *mut u8;
95    /// # Safety
96    /// TLS变量动态释放接口, 线程退出时是否调用依赖具体平台的实现.
97    unsafe fn dealloc(p: *mut u8, layout: Layout);
98}
99
100/// 缺省利用aligned_alloc/aligned_free来分配和释放TLS变量的内存空间.
101pub struct NativeAlloc;
102
103unsafe impl ThrdLocalAlloc for NativeAlloc {
104    unsafe fn alloc(layout: Layout) -> *mut u8 {
105        crate::stdlib::aligned_alloc(layout.align(), layout.size())
106    }
107    unsafe fn dealloc(p: *mut u8, _: Layout) {
108        crate::stdlib::aligned_free(p);
109    }
110}
111
112/// 提供RefCell<T>类型的TLS变量操作接口.
113pub struct LocalRefCell<T, A = NativeAlloc, F = fn() -> T> {
114    local: LazyLock<ThrdLocal<RefCell<T>>>,
115    f: F,
116    _alloc: A,
117}
118
119unsafe impl<T, A: Send, F: Send> Send for LocalRefCell<T, A, F> {}
120unsafe impl<T, A: Sync, F: Sync> Sync for LocalRefCell<T, A, F> {}
121
122impl<T, F: Fn() -> T> LocalRefCell<T, NativeAlloc, F> {
123    pub const fn new(f: F) -> Self {
124        Self::new_with(NativeAlloc, f)
125    }
126}
127
128impl<T, A: ThrdLocalAlloc, F: Fn() -> T> LocalRefCell<T, A, F> {
129    unsafe extern "C" fn dealloc(val: *const core::ffi::c_void) {
130        let p = val.cast_mut().cast::<RefCell<T>>();
131        unsafe { core::ptr::drop_in_place(p) };
132        A::dealloc(p.cast::<u8>(), Layout::new::<RefCell<T>>());
133    }
134
135    pub const fn new_with(_alloc: A, f: F) -> Self {
136        Self {
137            local: LazyLock::new(|| ThrdLocal::new_with(Some(Self::dealloc)).unwrap()),
138            f,
139            _alloc,
140        }
141    }
142
143    fn get_refcell(&self) -> &RefCell<T> {
144        let mut p = self.local.get().cast_mut();
145        if p.is_null() {
146            p = unsafe { A::alloc(Layout::new::<RefCell<T>>()).cast::<RefCell<T>>() };
147            unsafe { p.write(RefCell::new((self.f)())) };
148            self.local.set(p);
149        }
150        unsafe { &*p }
151    }
152}
153
154impl<T, A: ThrdLocalAlloc, F: Fn() -> T> Deref for LocalRefCell<T, A, F> {
155    type Target = RefCell<T>;
156    fn deref(&self) -> &Self::Target {
157        self.get_refcell()
158    }
159}
160
161/// 提供Cell<T>类型的TLS变量操作接口.
162pub struct LocalCell<T, A = NativeAlloc, F = fn() -> T> {
163    local: LazyLock<ThrdLocal<Cell<T>>>,
164    f: F,
165    _alloc: A,
166}
167
168unsafe impl<T, A: Send, F: Send> Send for LocalCell<T, A, F> {}
169unsafe impl<T, A: Sync, F: Sync> Sync for LocalCell<T, A, F> {}
170
171impl<T, F: Fn() -> T> LocalCell<T, NativeAlloc, F> {
172    pub const fn new(f: F) -> Self {
173        Self::new_with(NativeAlloc, f)
174    }
175}
176
177impl<T, A: ThrdLocalAlloc, F: Fn() -> T> LocalCell<T, A, F> {
178    unsafe extern "C" fn dealloc(val: *const core::ffi::c_void) {
179        let p = val.cast_mut().cast::<Cell<T>>();
180        unsafe { core::ptr::drop_in_place(p) };
181        A::dealloc(p.cast::<u8>(), Layout::new::<Cell<T>>());
182    }
183
184    pub const fn new_with(_alloc: A, f: F) -> Self {
185        Self {
186            local: LazyLock::new(|| ThrdLocal::new_with(Some(Self::dealloc)).unwrap()),
187            f,
188            _alloc,
189        }
190    }
191
192    fn get_cell(&self) -> &Cell<T> {
193        let mut p = self.local.get().cast_mut();
194        if p.is_null() {
195            p = unsafe { A::alloc(Layout::new::<Cell<T>>()).cast::<Cell<T>>() };
196            unsafe { p.write(Cell::new((self.f)())) };
197            self.local.set(p);
198        }
199        unsafe { &*p }
200    }
201}
202
203impl<T, A: ThrdLocalAlloc, F: Fn() -> T> Deref for LocalCell<T, A, F> {
204    type Target = Cell<T>;
205    fn deref(&self) -> &Self::Target {
206        self.get_cell()
207    }
208}
209
210#[cfg(test)]
211mod test {
212    use crate::*;
213
214    #[test]
215    fn test_local() {
216        static KEY: LocalKey<i32> = LocalKey::new();
217        KEY.set(&100);
218        let h = spawn(|| {
219            let addr = KEY.replace(&101);
220            assert!(addr.is_null());
221            let addr = KEY.get();
222            assert_eq!(addr, &101);
223            101
224        })
225        .unwrap();
226        assert_eq!(h.join().unwrap(), 101);
227        let addr = KEY.get();
228        assert_eq!(addr.is_null(), false);
229        assert_eq!(addr, &100);
230    }
231
232    #[test]
233    fn test_refcell() {
234        static KEY: LocalRefCell<i32> = LocalRefCell::new(|| 100);
235        assert_eq!(*KEY.borrow(), 100);
236        KEY.replace(200);
237        let h = spawn(|| {
238            KEY.replace(101);
239            *KEY.borrow()
240        })
241        .unwrap();
242        assert_eq!(h.join().unwrap(), 101);
243        let h = spawn(|| {
244            KEY.replace(102);
245            *KEY.borrow()
246        })
247        .unwrap();
248        assert_eq!(h.join().unwrap(), 102);
249        assert_eq!(*KEY.borrow(), 200);
250    }
251
252    #[test]
253    fn test_cell() {
254        static KEY: LocalCell<i32> = LocalCell::new(|| 100);
255        assert_eq!(KEY.get(), 100);
256        KEY.replace(200);
257        let h = spawn(|| {
258            KEY.replace(101);
259            KEY.get()
260        })
261        .unwrap();
262        assert_eq!(h.join().unwrap(), 101);
263        let h = spawn(|| {
264            KEY.replace(102);
265            KEY.get()
266        })
267        .unwrap();
268        assert_eq!(h.join().unwrap(), 102);
269        assert_eq!(KEY.get(), 200);
270    }
271}