use super::{pthread, LazyLock};
use core::alloc::Layout;
use core::cell::{Cell, RefCell};
use core::marker::PhantomData;
use core::ops::Deref;
use hierr::{Error, Result};
pub struct ThrdLocal<T: Sized> {
key: pthread::pthread_key_t,
mark: PhantomData<*const T>,
}
pub type ThrdLocalDtor = unsafe extern "C" fn(_: *const core::ffi::c_void);
unsafe impl<T> Send for ThrdLocal<T> {}
unsafe impl<T> Sync for ThrdLocal<T> {}
impl<T> Drop for ThrdLocal<T> {
fn drop(&mut self) {
unsafe { pthread::pthread_key_delete(self.key) };
}
}
impl<T> ThrdLocal<T> {
pub fn new() -> Result<Self> {
Self::new_with(None)
}
pub fn new_with(dtor: Option<ThrdLocalDtor>) -> Result<Self> {
let mut key: pthread::pthread_key_t = 0;
let ret = unsafe { pthread::pthread_key_create(&mut key, dtor) };
if ret == 0 {
Ok(Self {
key,
mark: PhantomData,
})
} else {
Err(Error::new(ret))
}
}
pub fn set(&self, val: *const T) {
unsafe {
pthread::pthread_setspecific(self.key, val as *const _ as *const pthread::c_void);
}
}
pub fn get(&self) -> *const T {
unsafe {
pthread::pthread_getspecific(self.key)
.cast::<T>()
.cast_mut()
}
}
}
pub struct LocalKey<T> {
key: LazyLock<ThrdLocal<T>>,
}
unsafe impl<T> Send for LocalKey<T> {}
unsafe impl<T> Sync for LocalKey<T> {}
impl<T> LocalKey<T> {
pub const fn new() -> Self {
Self {
key: LazyLock::new(|| ThrdLocal::new().unwrap()),
}
}
pub fn get(&self) -> *const T {
self.key.get()
}
pub fn set(&self, val: *const T) {
self.key.set(val);
}
pub fn replace(&self, val: *const T) -> *const T {
let old = self.key.get();
self.key.set(val);
old
}
}
pub unsafe trait ThrdLocalAlloc {
unsafe fn alloc(layout: Layout) -> *mut u8;
unsafe fn dealloc(p: *mut u8, layout: Layout);
}
pub struct NativeAlloc;
unsafe impl ThrdLocalAlloc for NativeAlloc {
unsafe fn alloc(layout: Layout) -> *mut u8 {
crate::stdlib::aligned_alloc(layout.align(), layout.size())
}
unsafe fn dealloc(p: *mut u8, _: Layout) {
crate::stdlib::aligned_free(p);
}
}
pub struct LocalRefCell<T, A = NativeAlloc, F = fn() -> T> {
local: LazyLock<ThrdLocal<RefCell<T>>>,
f: F,
_alloc: A,
}
unsafe impl<T, A: Send, F: Send> Send for LocalRefCell<T, A, F> {}
unsafe impl<T, A: Sync, F: Sync> Sync for LocalRefCell<T, A, F> {}
impl<T, F: Fn() -> T> LocalRefCell<T, NativeAlloc, F> {
pub const fn new(f: F) -> Self {
Self::new_with(NativeAlloc, f)
}
}
impl<T, A: ThrdLocalAlloc, F: Fn() -> T> LocalRefCell<T, A, F> {
unsafe extern "C" fn dealloc(val: *const core::ffi::c_void) {
let p = val.cast_mut().cast::<RefCell<T>>();
unsafe { core::ptr::drop_in_place(p) };
A::dealloc(p.cast::<u8>(), Layout::new::<RefCell<T>>());
}
pub const fn new_with(_alloc: A, f: F) -> Self {
Self {
local: LazyLock::new(|| ThrdLocal::new_with(Some(Self::dealloc)).unwrap()),
f,
_alloc,
}
}
fn get_refcell(&self) -> &RefCell<T> {
let mut p = self.local.get().cast_mut();
if p.is_null() {
p = unsafe { A::alloc(Layout::new::<RefCell<T>>()).cast::<RefCell<T>>() };
unsafe { p.write(RefCell::new((self.f)())) };
self.local.set(p);
}
unsafe { &*p }
}
}
impl<T, A: ThrdLocalAlloc, F: Fn() -> T> Deref for LocalRefCell<T, A, F> {
type Target = RefCell<T>;
fn deref(&self) -> &Self::Target {
self.get_refcell()
}
}
pub struct LocalCell<T, A = NativeAlloc, F = fn() -> T> {
local: LazyLock<ThrdLocal<Cell<T>>>,
f: F,
_alloc: A,
}
unsafe impl<T, A: Send, F: Send> Send for LocalCell<T, A, F> {}
unsafe impl<T, A: Sync, F: Sync> Sync for LocalCell<T, A, F> {}
impl<T, F: Fn() -> T> LocalCell<T, NativeAlloc, F> {
pub const fn new(f: F) -> Self {
Self::new_with(NativeAlloc, f)
}
}
impl<T, A: ThrdLocalAlloc, F: Fn() -> T> LocalCell<T, A, F> {
unsafe extern "C" fn dealloc(val: *const core::ffi::c_void) {
let p = val.cast_mut().cast::<Cell<T>>();
unsafe { core::ptr::drop_in_place(p) };
A::dealloc(p.cast::<u8>(), Layout::new::<Cell<T>>());
}
pub const fn new_with(_alloc: A, f: F) -> Self {
Self {
local: LazyLock::new(|| ThrdLocal::new_with(Some(Self::dealloc)).unwrap()),
f,
_alloc,
}
}
fn get_cell(&self) -> &Cell<T> {
let mut p = self.local.get().cast_mut();
if p.is_null() {
p = unsafe { A::alloc(Layout::new::<Cell<T>>()).cast::<Cell<T>>() };
unsafe { p.write(Cell::new((self.f)())) };
self.local.set(p);
}
unsafe { &*p }
}
}
impl<T, A: ThrdLocalAlloc, F: Fn() -> T> Deref for LocalCell<T, A, F> {
type Target = Cell<T>;
fn deref(&self) -> &Self::Target {
self.get_cell()
}
}
#[cfg(test)]
mod test {
use crate::*;
#[test]
fn test_local() {
static KEY: LocalKey<i32> = LocalKey::new();
KEY.set(&100);
let h = spawn(|| {
let addr = KEY.replace(&101);
assert!(addr.is_null());
let addr = KEY.get();
assert_eq!(addr, &101);
101
})
.unwrap();
assert_eq!(h.join().unwrap(), 101);
let addr = KEY.get();
assert_eq!(addr.is_null(), false);
assert_eq!(addr, &100);
}
#[test]
fn test_refcell() {
static KEY: LocalRefCell<i32> = LocalRefCell::new(|| 100);
assert_eq!(*KEY.borrow(), 100);
KEY.replace(200);
let h = spawn(|| {
KEY.replace(101);
*KEY.borrow()
})
.unwrap();
assert_eq!(h.join().unwrap(), 101);
let h = spawn(|| {
KEY.replace(102);
*KEY.borrow()
})
.unwrap();
assert_eq!(h.join().unwrap(), 102);
assert_eq!(*KEY.borrow(), 200);
}
#[test]
fn test_cell() {
static KEY: LocalCell<i32> = LocalCell::new(|| 100);
assert_eq!(KEY.get(), 100);
KEY.replace(200);
let h = spawn(|| {
KEY.replace(101);
KEY.get()
})
.unwrap();
assert_eq!(h.join().unwrap(), 101);
let h = spawn(|| {
KEY.replace(102);
KEY.get()
})
.unwrap();
assert_eq!(h.join().unwrap(), 102);
assert_eq!(KEY.get(), 200);
}
}