use super::pthread;
use core::fmt;
use core::mem::MaybeUninit;
use core::ptr;
use hierr::{Error, Result};
use hipool::Arc;
#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[repr(i32)]
#[allow(non_camel_case_types)]
pub enum ThrdSchedPolicy {
SCHED_OTHER = pthread::SCHED_OTHER as i32,
SCHED_FIFO = pthread::SCHED_FIFO as i32,
SCHED_RR = pthread::SCHED_RR as i32,
}
#[derive(Debug)]
pub struct ThrdSchedParam {
pub priority: i32,
}
pub struct ThrdAttr {
attr: MaybeUninit<pthread::pthread_attr_t>,
}
impl Default for ThrdAttr {
fn default() -> Self {
Self::new()
}
}
impl Drop for ThrdAttr {
fn drop(&mut self) {
let _ = unsafe { pthread::pthread_attr_destroy(self.as_mut_ptr()) };
}
}
impl ThrdAttr {
pub fn new() -> Self {
let mut this = Self {
attr: MaybeUninit::uninit(),
};
let _ = unsafe { pthread::pthread_attr_init(this.as_mut_ptr()) };
#[cfg(windows)]
let this = this.set_inheritsched(true);
this
}
pub fn set_detachstate(mut self, detached: bool) -> Self {
let state = if detached {
pthread::PTHREAD_CREATE_DETACHED
} else {
pthread::PTHREAD_CREATE_JOINABLE
};
let _ = unsafe { pthread::pthread_attr_setdetachstate(self.as_mut_ptr(), state) };
self
}
pub fn get_detachstate(&self) -> bool {
let mut state = 0_i32;
let _ = unsafe { pthread::pthread_attr_getdetachstate(self.as_ptr(), &mut state) };
state == pthread::PTHREAD_CREATE_DETACHED
}
pub fn set_stacksize(mut self, stacksize: usize) -> Self {
let _ = unsafe { pthread::pthread_attr_setstacksize(self.as_mut_ptr(), stacksize) };
self
}
pub fn get_stacksize(&self) -> usize {
let mut stacksize = 0;
let _ = unsafe { pthread::pthread_attr_getstacksize(self.as_ptr(), &mut stacksize) };
stacksize
}
pub fn set_inheritsched(mut self, inherit: bool) -> Self {
let sched = if inherit {
pthread::PTHREAD_INHERIT_SCHED
} else {
pthread::PTHREAD_EXPLICIT_SCHED
};
let _ = unsafe { pthread::pthread_attr_setinheritsched(self.as_mut_ptr(), sched) };
self
}
pub fn get_inheritsched(&self) -> bool {
let mut sched: pthread::c_int = 0;
let _ = unsafe { pthread::pthread_attr_getinheritsched(self.as_ptr(), &mut sched) };
sched != pthread::PTHREAD_EXPLICIT_SCHED
}
pub fn set_schedpolicy(mut self, policy: ThrdSchedPolicy) -> Self {
let _ = unsafe {
pthread::pthread_attr_setschedpolicy(self.as_mut_ptr(), policy as pthread::c_int)
};
self
}
pub fn get_schedpolicy(&self) -> ThrdSchedPolicy {
let mut policy: pthread::c_int = 0;
let _ = unsafe { pthread::pthread_attr_getschedpolicy(self.as_ptr(), &mut policy) };
match policy {
pthread::SCHED_FIFO => ThrdSchedPolicy::SCHED_FIFO,
pthread::SCHED_RR => ThrdSchedPolicy::SCHED_RR,
_ => ThrdSchedPolicy::SCHED_OTHER,
}
}
pub fn set_schedparam(mut self, param: ThrdSchedParam) -> Self {
let mut param = pthread::sched_param {
sched_priority: param.priority as pthread::c_int,
};
let _ = unsafe { pthread::pthread_attr_setschedparam(self.as_mut_ptr(), &mut param) };
self
}
pub fn get_schedparam(&self) -> ThrdSchedParam {
let mut param = pthread::sched_param { sched_priority: 0 };
let _ = unsafe { pthread::pthread_attr_getschedparam(self.as_ptr(), &mut param) };
ThrdSchedParam {
priority: param.sched_priority as i32,
}
}
fn as_ptr(&self) -> *const pthread::pthread_attr_t {
self.attr.as_ptr().cast_mut()
}
fn as_mut_ptr(&mut self) -> *mut pthread::pthread_attr_t {
self.attr.as_mut_ptr()
}
}
trait JoinContext {
type ReturnType;
fn output(&mut self) -> Self::ReturnType;
}
#[repr(C)]
struct ThrdContext<T: 'static, F: 'static> {
this: Option<Arc<'static, ThrdContext<T, F>>>,
routine: Option<F>,
output: Option<T>,
}
impl<T: 'static, F: 'static> ThrdContext<T, F> {
fn new(routine: F) -> Self {
Self {
routine: Some(routine),
output: None,
this: None,
}
}
}
unsafe impl<T: Send, F: Send> Send for ThrdContext<T, F> {}
impl<T: 'static, F: 'static> JoinContext for ThrdContext<T, F> {
type ReturnType = T;
fn output(&mut self) -> Self::ReturnType {
self.output.take().unwrap()
}
}
impl<T, F> fmt::Debug for ThrdContext<T, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "output.is_some() = {}", self.output.is_some())
}
}
pub struct JoinHandle<T: 'static> {
tid: pthread::pthread_t,
ctx: Arc<'static, dyn JoinContext<ReturnType = T>>,
}
unsafe impl<T: 'static> Send for JoinHandle<T> {}
impl<T: 'static> JoinHandle<T> {
fn new<F: 'static>(tid: pthread::pthread_t, ctx: Arc<'static, ThrdContext<T, F>>) -> Self {
Self {
tid,
ctx: ctx
.upcast::<dyn JoinContext<ReturnType = T>>(|val| val)
.unwrap(),
}
}
}
impl<T: 'static> JoinHandle<T> {
pub fn join(mut self) -> Result<T> {
let err = unsafe { pthread::pthread_join(self.tid, ptr::null_mut()) };
if err == 0 {
let ctx = unsafe { self.ctx.get_mut_unchecked() };
Ok(ctx.output())
} else {
Err(Error::new(err))
}
}
}
type LibcRunnable = extern "C" fn(*mut pthread::c_void) -> *mut pthread::c_void;
pub fn spawn<F, T>(routine: F) -> Result<JoinHandle<T>>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
do_spawn(routine, &ThrdAttr::new())
}
pub fn spawn_with<F, T>(routine: F, attr: &ThrdAttr) -> Result<JoinHandle<T>>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
do_spawn(routine, attr)
}
fn do_spawn<F, T>(routine: F, attr: &ThrdAttr) -> Result<JoinHandle<T>>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
extern "C" fn runnable<T, F>(data: *mut pthread::c_void) -> *mut pthread::c_void
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
let ctx = unsafe { &mut *(data as *mut _ as *mut ThrdContext<T, F>) };
let _guard = ctx.this.take();
let routine = ctx.routine.take().unwrap();
ctx.output = Some(routine());
ptr::null_mut()
}
let mut ctx = Arc::new(ThrdContext::new(routine))?;
let cloned = ctx.clone();
unsafe { ctx.get_mut_unchecked() }.this = Some(cloned);
let mut tid: pthread::pthread_t = 0;
let errno = unsafe {
pthread::pthread_create(
&mut tid,
attr.as_ptr(),
runnable::<T, F> as LibcRunnable,
(&*ctx) as *const _ as *mut pthread::c_void,
)
};
if errno == 0 {
Ok(JoinHandle::new(tid, ctx))
} else {
unsafe { ctx.get_mut_unchecked() }.this = None;
Err(Error::new(errno))
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_thrdattr_default() {
let attr = ThrdAttr::new();
assert_eq!(attr.get_detachstate(), false);
assert_eq!(attr.get_schedpolicy(), ThrdSchedPolicy::SCHED_OTHER);
assert_eq!(attr.get_inheritsched(), true);
assert_eq!(attr.get_schedparam().priority, 0);
}
#[test]
fn test_create() {
let handle = spawn(|| 100);
assert!(handle.is_ok());
let retn = handle.unwrap().join();
assert!(retn.is_ok());
assert_eq!(retn.unwrap(), 100);
}
#[test]
fn test_create_deatached() {
let handle = spawn_with(|| 100, &ThrdAttr::new().set_detachstate(true));
assert!(handle.is_ok());
let retn = handle.unwrap().join();
assert!(retn.is_err());
}
}