hipthread 0.1.3

no-std thread library based on pthread
Documentation
use super::pthread;
use core::fmt;
use core::mem::MaybeUninit;
use core::ptr;
use hierr::{Error, Result};
use hipool::Arc;

#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[repr(i32)]
#[allow(non_camel_case_types)]
pub enum ThrdSchedPolicy {
    SCHED_OTHER = pthread::SCHED_OTHER as i32,
    SCHED_FIFO = pthread::SCHED_FIFO as i32,
    SCHED_RR = pthread::SCHED_RR as i32,
}

#[derive(Debug)]
pub struct ThrdSchedParam {
    pub priority: i32,
}

/// 封装pthread_attr_t
pub struct ThrdAttr {
    attr: MaybeUninit<pthread::pthread_attr_t>,
}

impl Default for ThrdAttr {
    fn default() -> Self {
        Self::new()
    }
}

impl Drop for ThrdAttr {
    fn drop(&mut self) {
        let _ = unsafe { pthread::pthread_attr_destroy(self.as_mut_ptr()) };
    }
}

impl ThrdAttr {
    pub fn new() -> Self {
        let mut this = Self {
            attr: MaybeUninit::uninit(),
        };
        let _ = unsafe { pthread::pthread_attr_init(this.as_mut_ptr()) };
        // mingw中,缺省是PTHREAD_EXPLICIT_SCHED, 修改为PTHREAD_INHERIT_SCHED.
        #[cfg(windows)]
        let this = this.set_inheritsched(true);
        this
    }
    pub fn set_detachstate(mut self, detached: bool) -> Self {
        let state = if detached {
            pthread::PTHREAD_CREATE_DETACHED
        } else {
            pthread::PTHREAD_CREATE_JOINABLE
        };
        let _ = unsafe { pthread::pthread_attr_setdetachstate(self.as_mut_ptr(), state) };
        self
    }
    pub fn get_detachstate(&self) -> bool {
        let mut state = 0_i32;
        let _ = unsafe { pthread::pthread_attr_getdetachstate(self.as_ptr(), &mut state) };
        state == pthread::PTHREAD_CREATE_DETACHED
    }

    pub fn set_stacksize(mut self, stacksize: usize) -> Self {
        let _ = unsafe { pthread::pthread_attr_setstacksize(self.as_mut_ptr(), stacksize) };
        self
    }
    pub fn get_stacksize(&self) -> usize {
        let mut stacksize = 0;
        let _ = unsafe { pthread::pthread_attr_getstacksize(self.as_ptr(), &mut stacksize) };
        stacksize
    }

    pub fn set_inheritsched(mut self, inherit: bool) -> Self {
        let sched = if inherit {
            pthread::PTHREAD_INHERIT_SCHED
        } else {
            pthread::PTHREAD_EXPLICIT_SCHED
        };
        let _ = unsafe { pthread::pthread_attr_setinheritsched(self.as_mut_ptr(), sched) };
        self
    }
    pub fn get_inheritsched(&self) -> bool {
        let mut sched: pthread::c_int = 0;
        let _ = unsafe { pthread::pthread_attr_getinheritsched(self.as_ptr(), &mut sched) };
        sched != pthread::PTHREAD_EXPLICIT_SCHED
    }

    pub fn set_schedpolicy(mut self, policy: ThrdSchedPolicy) -> Self {
        let _ = unsafe {
            pthread::pthread_attr_setschedpolicy(self.as_mut_ptr(), policy as pthread::c_int)
        };
        self
    }
    pub fn get_schedpolicy(&self) -> ThrdSchedPolicy {
        let mut policy: pthread::c_int = 0;
        let _ = unsafe { pthread::pthread_attr_getschedpolicy(self.as_ptr(), &mut policy) };
        match policy {
            pthread::SCHED_FIFO => ThrdSchedPolicy::SCHED_FIFO,
            pthread::SCHED_RR => ThrdSchedPolicy::SCHED_RR,
            _ => ThrdSchedPolicy::SCHED_OTHER,
        }
    }

    pub fn set_schedparam(mut self, param: ThrdSchedParam) -> Self {
        let mut param = pthread::sched_param {
            sched_priority: param.priority as pthread::c_int,
        };
        let _ = unsafe { pthread::pthread_attr_setschedparam(self.as_mut_ptr(), &mut param) };
        self
    }
    pub fn get_schedparam(&self) -> ThrdSchedParam {
        let mut param = pthread::sched_param { sched_priority: 0 };
        let _ = unsafe { pthread::pthread_attr_getschedparam(self.as_ptr(), &mut param) };
        ThrdSchedParam {
            priority: param.sched_priority as i32,
        }
    }

    fn as_ptr(&self) -> *const pthread::pthread_attr_t {
        self.attr.as_ptr().cast_mut()
    }
    fn as_mut_ptr(&mut self) -> *mut pthread::pthread_attr_t {
        self.attr.as_mut_ptr()
    }
}

trait JoinContext {
    type ReturnType;
    fn output(&mut self) -> Self::ReturnType;
}

#[repr(C)]
struct ThrdContext<T: 'static, F: 'static> {
    this: Option<Arc<'static, ThrdContext<T, F>>>,
    routine: Option<F>,
    output: Option<T>,
}

impl<T: 'static, F: 'static> ThrdContext<T, F> {
    fn new(routine: F) -> Self {
        Self {
            routine: Some(routine),
            output: None,
            this: None,
        }
    }
}

unsafe impl<T: Send, F: Send> Send for ThrdContext<T, F> {}

impl<T: 'static, F: 'static> JoinContext for ThrdContext<T, F> {
    type ReturnType = T;
    fn output(&mut self) -> Self::ReturnType {
        self.output.take().unwrap()
    }
}

impl<T, F> fmt::Debug for ThrdContext<T, F> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "output.is_some() = {}", self.output.is_some())
    }
}

pub struct JoinHandle<T: 'static> {
    tid: pthread::pthread_t,
    ctx: Arc<'static, dyn JoinContext<ReturnType = T>>,
}

unsafe impl<T: 'static> Send for JoinHandle<T> {}

impl<T: 'static> JoinHandle<T> {
    fn new<F: 'static>(tid: pthread::pthread_t, ctx: Arc<'static, ThrdContext<T, F>>) -> Self {
        Self {
            tid,
            ctx: ctx
                .upcast::<dyn JoinContext<ReturnType = T>>(|val| val)
                .unwrap(),
        }
    }
}

impl<T: 'static> JoinHandle<T> {
    pub fn join(mut self) -> Result<T> {
        let err = unsafe { pthread::pthread_join(self.tid, ptr::null_mut()) };
        if err == 0 {
            // SAFETY: 线程已经退出,可以安全读取线程函数的返回值.
            let ctx = unsafe { self.ctx.get_mut_unchecked() };
            Ok(ctx.output())
        } else {
            Err(Error::new(err))
        }
    }
}

type LibcRunnable = extern "C" fn(*mut pthread::c_void) -> *mut pthread::c_void;

/// 创建一个线程.
pub fn spawn<F, T>(routine: F) -> Result<JoinHandle<T>>
where
    F: FnOnce() -> T + Send + 'static,
    T: Send + 'static,
{
    do_spawn(routine, &ThrdAttr::new())
}

/// 创建线程,可指定线程属性.
pub fn spawn_with<F, T>(routine: F, attr: &ThrdAttr) -> Result<JoinHandle<T>>
where
    F: FnOnce() -> T + Send + 'static,
    T: Send + 'static,
{
    do_spawn(routine, attr)
}

fn do_spawn<F, T>(routine: F, attr: &ThrdAttr) -> Result<JoinHandle<T>>
where
    F: FnOnce() -> T + Send + 'static,
    T: Send + 'static,
{
    extern "C" fn runnable<T, F>(data: *mut pthread::c_void) -> *mut pthread::c_void
    where
        F: FnOnce() -> T + Send + 'static,
        T: Send + 'static,
    {
        // SAFETY: JoinHandle只有在线程退出后才能访问,这里获取可写引用没有任何并发冲突.
        let ctx = unsafe { &mut *(data as *mut _ as *mut ThrdContext<T, F>) };
        // 获取Arc, 需要确保线程退出时释放资源.
        let _guard = ctx.this.take();
        let routine = ctx.routine.take().unwrap();
        ctx.output = Some(routine());
        ptr::null_mut()
    }

    let mut ctx = Arc::new(ThrdContext::new(routine))?;
    let cloned = ctx.clone();
    // SAFETY: cloned并未传递到线程,可以安全的获取可写引用.
    unsafe { ctx.get_mut_unchecked() }.this = Some(cloned);

    let mut tid: pthread::pthread_t = 0;
    let errno = unsafe {
        pthread::pthread_create(
            &mut tid,
            attr.as_ptr(),
            runnable::<T, F> as LibcRunnable,
            (&*ctx) as *const _ as *mut pthread::c_void,
        )
    };
    if errno == 0 {
        Ok(JoinHandle::new(tid, ctx))
    } else {
        // SAFETY: 线程创建失败,可以安全获取可写引用. 需要释放cloned.
        unsafe { ctx.get_mut_unchecked() }.this = None;
        Err(Error::new(errno))
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_thrdattr_default() {
        let attr = ThrdAttr::new();
        assert_eq!(attr.get_detachstate(), false);
        assert_eq!(attr.get_schedpolicy(), ThrdSchedPolicy::SCHED_OTHER);
        assert_eq!(attr.get_inheritsched(), true);
        assert_eq!(attr.get_schedparam().priority, 0);
    }

    #[test]
    fn test_create() {
        let handle = spawn(|| 100);
        assert!(handle.is_ok());
        let retn = handle.unwrap().join();
        assert!(retn.is_ok());
        assert_eq!(retn.unwrap(), 100);
    }

    #[test]
    fn test_create_deatached() {
        let handle = spawn_with(|| 100, &ThrdAttr::new().set_detachstate(true));
        assert!(handle.is_ok());
        let retn = handle.unwrap().join();
        assert!(retn.is_err());
    }
}