hipthread 0.1.3

no-std thread library based on pthread
Documentation
use super::{pthread, LazyLock};
use core::alloc::Layout;
use core::cell::{Cell, RefCell};
use core::marker::PhantomData;
use core::ops::Deref;
use hierr::{Error, Result};

/// 封装pthread_key_t, 提供线程局部存储机制.
pub struct ThrdLocal<T: Sized> {
    key: pthread::pthread_key_t,
    mark: PhantomData<*const T>,
}
pub type ThrdLocalDtor = unsafe extern "C" fn(_: *const core::ffi::c_void);

unsafe impl<T> Send for ThrdLocal<T> {}
unsafe impl<T> Sync for ThrdLocal<T> {}

impl<T> Drop for ThrdLocal<T> {
    fn drop(&mut self) {
        unsafe { pthread::pthread_key_delete(self.key) };
    }
}

impl<T> ThrdLocal<T> {
    pub fn new() -> Result<Self> {
        Self::new_with(None)
    }
    pub fn new_with(dtor: Option<ThrdLocalDtor>) -> Result<Self> {
        let mut key: pthread::pthread_key_t = 0;
        let ret = unsafe { pthread::pthread_key_create(&mut key, dtor) };
        if ret == 0 {
            Ok(Self {
                key,
                mark: PhantomData,
            })
        } else {
            Err(Error::new(ret))
        }
    }

    pub fn set(&self, val: *const T) {
        unsafe {
            pthread::pthread_setspecific(self.key, val as *const _ as *const pthread::c_void);
        }
    }

    pub fn get(&self) -> *const T {
        unsafe {
            pthread::pthread_getspecific(self.key)
                .cast::<T>()
                .cast_mut()
        }
    }
}

/// 封装ThrdLocal方便使用 
pub struct LocalKey<T> {
    key: LazyLock<ThrdLocal<T>>,
}

unsafe impl<T> Send for LocalKey<T> {}
unsafe impl<T> Sync for LocalKey<T> {}

impl<T> LocalKey<T> {
    pub const fn new() -> Self {
        Self {
            key: LazyLock::new(|| ThrdLocal::new().unwrap()),
        }
    }

    /// 获取当前值
    pub fn get(&self) -> *const T {
        self.key.get()
    }

    /// 设置为新值
    pub fn set(&self, val: *const T) {
        self.key.set(val);
    }

    /// 替换为新值,返回原来的值.
    pub fn replace(&self, val: *const T) -> *const T {
        let old = self.key.get();
        self.key.set(val);
        old
    }
}

/// # Safety
/// TLS变量的动态分配和释放接口.
pub unsafe trait ThrdLocalAlloc {
    /// # Safety
    /// TLS变量动态分配接口,如果返回空指针,会导致panic
    unsafe fn alloc(layout: Layout) -> *mut u8;
    /// # Safety
    /// TLS变量动态释放接口, 线程退出时是否调用依赖具体平台的实现.
    unsafe fn dealloc(p: *mut u8, layout: Layout);
}

/// 缺省利用aligned_alloc/aligned_free来分配和释放TLS变量的内存空间.
pub struct NativeAlloc;

unsafe impl ThrdLocalAlloc for NativeAlloc {
    unsafe fn alloc(layout: Layout) -> *mut u8 {
        crate::stdlib::aligned_alloc(layout.align(), layout.size())
    }
    unsafe fn dealloc(p: *mut u8, _: Layout) {
        crate::stdlib::aligned_free(p);
    }
}

/// 提供RefCell<T>类型的TLS变量操作接口.
pub struct LocalRefCell<T, A = NativeAlloc, F = fn() -> T> {
    local: LazyLock<ThrdLocal<RefCell<T>>>,
    f: F,
    _alloc: A,
}

unsafe impl<T, A: Send, F: Send> Send for LocalRefCell<T, A, F> {}
unsafe impl<T, A: Sync, F: Sync> Sync for LocalRefCell<T, A, F> {}

impl<T, F: Fn() -> T> LocalRefCell<T, NativeAlloc, F> {
    pub const fn new(f: F) -> Self {
        Self::new_with(NativeAlloc, f)
    }
}

impl<T, A: ThrdLocalAlloc, F: Fn() -> T> LocalRefCell<T, A, F> {
    unsafe extern "C" fn dealloc(val: *const core::ffi::c_void) {
        let p = val.cast_mut().cast::<RefCell<T>>();
        unsafe { core::ptr::drop_in_place(p) };
        A::dealloc(p.cast::<u8>(), Layout::new::<RefCell<T>>());
    }

    pub const fn new_with(_alloc: A, f: F) -> Self {
        Self {
            local: LazyLock::new(|| ThrdLocal::new_with(Some(Self::dealloc)).unwrap()),
            f,
            _alloc,
        }
    }

    fn get_refcell(&self) -> &RefCell<T> {
        let mut p = self.local.get().cast_mut();
        if p.is_null() {
            p = unsafe { A::alloc(Layout::new::<RefCell<T>>()).cast::<RefCell<T>>() };
            unsafe { p.write(RefCell::new((self.f)())) };
            self.local.set(p);
        }
        unsafe { &*p }
    }
}

impl<T, A: ThrdLocalAlloc, F: Fn() -> T> Deref for LocalRefCell<T, A, F> {
    type Target = RefCell<T>;
    fn deref(&self) -> &Self::Target {
        self.get_refcell()
    }
}

/// 提供Cell<T>类型的TLS变量操作接口.
pub struct LocalCell<T, A = NativeAlloc, F = fn() -> T> {
    local: LazyLock<ThrdLocal<Cell<T>>>,
    f: F,
    _alloc: A,
}

unsafe impl<T, A: Send, F: Send> Send for LocalCell<T, A, F> {}
unsafe impl<T, A: Sync, F: Sync> Sync for LocalCell<T, A, F> {}

impl<T, F: Fn() -> T> LocalCell<T, NativeAlloc, F> {
    pub const fn new(f: F) -> Self {
        Self::new_with(NativeAlloc, f)
    }
}

impl<T, A: ThrdLocalAlloc, F: Fn() -> T> LocalCell<T, A, F> {
    unsafe extern "C" fn dealloc(val: *const core::ffi::c_void) {
        let p = val.cast_mut().cast::<Cell<T>>();
        unsafe { core::ptr::drop_in_place(p) };
        A::dealloc(p.cast::<u8>(), Layout::new::<Cell<T>>());
    }

    pub const fn new_with(_alloc: A, f: F) -> Self {
        Self {
            local: LazyLock::new(|| ThrdLocal::new_with(Some(Self::dealloc)).unwrap()),
            f,
            _alloc,
        }
    }

    fn get_cell(&self) -> &Cell<T> {
        let mut p = self.local.get().cast_mut();
        if p.is_null() {
            p = unsafe { A::alloc(Layout::new::<Cell<T>>()).cast::<Cell<T>>() };
            unsafe { p.write(Cell::new((self.f)())) };
            self.local.set(p);
        }
        unsafe { &*p }
    }
}

impl<T, A: ThrdLocalAlloc, F: Fn() -> T> Deref for LocalCell<T, A, F> {
    type Target = Cell<T>;
    fn deref(&self) -> &Self::Target {
        self.get_cell()
    }
}

#[cfg(test)]
mod test {
    use crate::*;

    #[test]
    fn test_local() {
        static KEY: LocalKey<i32> = LocalKey::new();
        KEY.set(&100);
        let h = spawn(|| {
            let addr = KEY.replace(&101);
            assert!(addr.is_null());
            let addr = KEY.get();
            assert_eq!(addr, &101);
            101
        })
        .unwrap();
        assert_eq!(h.join().unwrap(), 101);
        let addr = KEY.get();
        assert_eq!(addr.is_null(), false);
        assert_eq!(addr, &100);
    }

    #[test]
    fn test_refcell() {
        static KEY: LocalRefCell<i32> = LocalRefCell::new(|| 100);
        assert_eq!(*KEY.borrow(), 100);
        KEY.replace(200);
        let h = spawn(|| {
            KEY.replace(101);
            *KEY.borrow()
        })
        .unwrap();
        assert_eq!(h.join().unwrap(), 101);
        let h = spawn(|| {
            KEY.replace(102);
            *KEY.borrow()
        })
        .unwrap();
        assert_eq!(h.join().unwrap(), 102);
        assert_eq!(*KEY.borrow(), 200);
    }

    #[test]
    fn test_cell() {
        static KEY: LocalCell<i32> = LocalCell::new(|| 100);
        assert_eq!(KEY.get(), 100);
        KEY.replace(200);
        let h = spawn(|| {
            KEY.replace(101);
            KEY.get()
        })
        .unwrap();
        assert_eq!(h.join().unwrap(), 101);
        let h = spawn(|| {
            KEY.replace(102);
            KEY.get()
        })
        .unwrap();
        assert_eq!(h.join().unwrap(), 102);
        assert_eq!(KEY.get(), 200);
    }
}