hipthread/
thrd.rs

1use super::pthread;
2use core::fmt;
3use core::mem::MaybeUninit;
4use core::ptr;
5use hierr::{Error, Result};
6use hipool::Arc;
7
8#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
9#[repr(i32)]
10#[allow(non_camel_case_types)]
11pub enum ThrdSchedPolicy {
12    SCHED_OTHER = pthread::SCHED_OTHER as i32,
13    SCHED_FIFO = pthread::SCHED_FIFO as i32,
14    SCHED_RR = pthread::SCHED_RR as i32,
15}
16
17#[derive(Debug)]
18pub struct ThrdSchedParam {
19    pub priority: i32,
20}
21
22/// 封装pthread_attr_t
23pub struct ThrdAttr {
24    attr: MaybeUninit<pthread::pthread_attr_t>,
25}
26
27impl Default for ThrdAttr {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl Drop for ThrdAttr {
34    fn drop(&mut self) {
35        let _ = unsafe { pthread::pthread_attr_destroy(self.as_mut_ptr()) };
36    }
37}
38
39impl ThrdAttr {
40    pub fn new() -> Self {
41        let mut this = Self {
42            attr: MaybeUninit::uninit(),
43        };
44        let _ = unsafe { pthread::pthread_attr_init(this.as_mut_ptr()) };
45        // mingw中,缺省是PTHREAD_EXPLICIT_SCHED, 修改为PTHREAD_INHERIT_SCHED.
46        #[cfg(windows)]
47        let this = this.set_inheritsched(true);
48        this
49    }
50    pub fn set_detachstate(mut self, detached: bool) -> Self {
51        let state = if detached {
52            pthread::PTHREAD_CREATE_DETACHED
53        } else {
54            pthread::PTHREAD_CREATE_JOINABLE
55        };
56        let _ = unsafe { pthread::pthread_attr_setdetachstate(self.as_mut_ptr(), state) };
57        self
58    }
59    pub fn get_detachstate(&self) -> bool {
60        let mut state = 0_i32;
61        let _ = unsafe { pthread::pthread_attr_getdetachstate(self.as_ptr(), &mut state) };
62        state == pthread::PTHREAD_CREATE_DETACHED
63    }
64
65    pub fn set_stacksize(mut self, stacksize: usize) -> Self {
66        let _ = unsafe { pthread::pthread_attr_setstacksize(self.as_mut_ptr(), stacksize) };
67        self
68    }
69    pub fn get_stacksize(&self) -> usize {
70        let mut stacksize = 0;
71        let _ = unsafe { pthread::pthread_attr_getstacksize(self.as_ptr(), &mut stacksize) };
72        stacksize
73    }
74
75    pub fn set_inheritsched(mut self, inherit: bool) -> Self {
76        let sched = if inherit {
77            pthread::PTHREAD_INHERIT_SCHED
78        } else {
79            pthread::PTHREAD_EXPLICIT_SCHED
80        };
81        let _ = unsafe { pthread::pthread_attr_setinheritsched(self.as_mut_ptr(), sched) };
82        self
83    }
84    pub fn get_inheritsched(&self) -> bool {
85        let mut sched: pthread::c_int = 0;
86        let _ = unsafe { pthread::pthread_attr_getinheritsched(self.as_ptr(), &mut sched) };
87        sched != pthread::PTHREAD_EXPLICIT_SCHED
88    }
89
90    pub fn set_schedpolicy(mut self, policy: ThrdSchedPolicy) -> Self {
91        let _ = unsafe {
92            pthread::pthread_attr_setschedpolicy(self.as_mut_ptr(), policy as pthread::c_int)
93        };
94        self
95    }
96    pub fn get_schedpolicy(&self) -> ThrdSchedPolicy {
97        let mut policy: pthread::c_int = 0;
98        let _ = unsafe { pthread::pthread_attr_getschedpolicy(self.as_ptr(), &mut policy) };
99        match policy {
100            pthread::SCHED_FIFO => ThrdSchedPolicy::SCHED_FIFO,
101            pthread::SCHED_RR => ThrdSchedPolicy::SCHED_RR,
102            _ => ThrdSchedPolicy::SCHED_OTHER,
103        }
104    }
105
106    pub fn set_schedparam(mut self, param: ThrdSchedParam) -> Self {
107        let mut param = pthread::sched_param {
108            sched_priority: param.priority as pthread::c_int,
109        };
110        let _ = unsafe { pthread::pthread_attr_setschedparam(self.as_mut_ptr(), &mut param) };
111        self
112    }
113    pub fn get_schedparam(&self) -> ThrdSchedParam {
114        let mut param = pthread::sched_param { sched_priority: 0 };
115        let _ = unsafe { pthread::pthread_attr_getschedparam(self.as_ptr(), &mut param) };
116        ThrdSchedParam {
117            priority: param.sched_priority as i32,
118        }
119    }
120
121    fn as_ptr(&self) -> *const pthread::pthread_attr_t {
122        self.attr.as_ptr().cast_mut()
123    }
124    fn as_mut_ptr(&mut self) -> *mut pthread::pthread_attr_t {
125        self.attr.as_mut_ptr()
126    }
127}
128
129trait JoinContext {
130    type ReturnType;
131    fn output(&mut self) -> Self::ReturnType;
132}
133
134#[repr(C)]
135struct ThrdContext<T: 'static, F: 'static> {
136    this: Option<Arc<'static, ThrdContext<T, F>>>,
137    routine: Option<F>,
138    output: Option<T>,
139}
140
141impl<T: 'static, F: 'static> ThrdContext<T, F> {
142    fn new(routine: F) -> Self {
143        Self {
144            routine: Some(routine),
145            output: None,
146            this: None,
147        }
148    }
149}
150
151unsafe impl<T: Send, F: Send> Send for ThrdContext<T, F> {}
152
153impl<T: 'static, F: 'static> JoinContext for ThrdContext<T, F> {
154    type ReturnType = T;
155    fn output(&mut self) -> Self::ReturnType {
156        self.output.take().unwrap()
157    }
158}
159
160impl<T, F> fmt::Debug for ThrdContext<T, F> {
161    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162        write!(f, "output.is_some() = {}", self.output.is_some())
163    }
164}
165
166pub struct JoinHandle<T: 'static> {
167    tid: pthread::pthread_t,
168    ctx: Arc<'static, dyn JoinContext<ReturnType = T>>,
169}
170
171unsafe impl<T: 'static> Send for JoinHandle<T> {}
172
173impl<T: 'static> JoinHandle<T> {
174    fn new<F: 'static>(tid: pthread::pthread_t, ctx: Arc<'static, ThrdContext<T, F>>) -> Self {
175        Self {
176            tid,
177            ctx: ctx
178                .upcast::<dyn JoinContext<ReturnType = T>>(|val| val)
179                .unwrap(),
180        }
181    }
182}
183
184impl<T: 'static> JoinHandle<T> {
185    pub fn join(mut self) -> Result<T> {
186        let err = unsafe { pthread::pthread_join(self.tid, ptr::null_mut()) };
187        if err == 0 {
188            // SAFETY: 线程已经退出,可以安全读取线程函数的返回值.
189            let ctx = unsafe { self.ctx.get_mut_unchecked() };
190            Ok(ctx.output())
191        } else {
192            Err(Error::new(err))
193        }
194    }
195}
196
197type LibcRunnable = extern "C" fn(*mut pthread::c_void) -> *mut pthread::c_void;
198
199/// 创建一个线程.
200pub fn spawn<F, T>(routine: F) -> Result<JoinHandle<T>>
201where
202    F: FnOnce() -> T + Send + 'static,
203    T: Send + 'static,
204{
205    do_spawn(routine, &ThrdAttr::new())
206}
207
208/// 创建线程,可指定线程属性.
209pub fn spawn_with<F, T>(routine: F, attr: &ThrdAttr) -> Result<JoinHandle<T>>
210where
211    F: FnOnce() -> T + Send + 'static,
212    T: Send + 'static,
213{
214    do_spawn(routine, attr)
215}
216
217fn do_spawn<F, T>(routine: F, attr: &ThrdAttr) -> Result<JoinHandle<T>>
218where
219    F: FnOnce() -> T + Send + 'static,
220    T: Send + 'static,
221{
222    extern "C" fn runnable<T, F>(data: *mut pthread::c_void) -> *mut pthread::c_void
223    where
224        F: FnOnce() -> T + Send + 'static,
225        T: Send + 'static,
226    {
227        // SAFETY: JoinHandle只有在线程退出后才能访问,这里获取可写引用没有任何并发冲突.
228        let ctx = unsafe { &mut *(data as *mut _ as *mut ThrdContext<T, F>) };
229        // 获取Arc, 需要确保线程退出时释放资源.
230        let _guard = ctx.this.take();
231        let routine = ctx.routine.take().unwrap();
232        ctx.output = Some(routine());
233        ptr::null_mut()
234    }
235
236    let mut ctx = Arc::new(ThrdContext::new(routine))?;
237    let cloned = ctx.clone();
238    // SAFETY: cloned并未传递到线程,可以安全的获取可写引用.
239    unsafe { ctx.get_mut_unchecked() }.this = Some(cloned);
240
241    let mut tid: pthread::pthread_t = 0;
242    let errno = unsafe {
243        pthread::pthread_create(
244            &mut tid,
245            attr.as_ptr(),
246            runnable::<T, F> as LibcRunnable,
247            (&*ctx) as *const _ as *mut pthread::c_void,
248        )
249    };
250    if errno == 0 {
251        Ok(JoinHandle::new(tid, ctx))
252    } else {
253        // SAFETY: 线程创建失败,可以安全获取可写引用. 需要释放cloned.
254        unsafe { ctx.get_mut_unchecked() }.this = None;
255        Err(Error::new(errno))
256    }
257}
258
259#[cfg(test)]
260mod test {
261    use super::*;
262
263    #[test]
264    fn test_thrdattr_default() {
265        let attr = ThrdAttr::new();
266        assert_eq!(attr.get_detachstate(), false);
267        assert_eq!(attr.get_schedpolicy(), ThrdSchedPolicy::SCHED_OTHER);
268        assert_eq!(attr.get_inheritsched(), true);
269        assert_eq!(attr.get_schedparam().priority, 0);
270    }
271
272    #[test]
273    fn test_create() {
274        let handle = spawn(|| 100);
275        assert!(handle.is_ok());
276        let retn = handle.unwrap().join();
277        assert!(retn.is_ok());
278        assert_eq!(retn.unwrap(), 100);
279    }
280
281    #[test]
282    fn test_create_deatached() {
283        let handle = spawn_with(|| 100, &ThrdAttr::new().set_detachstate(true));
284        assert!(handle.is_ok());
285        let retn = handle.unwrap().join();
286        assert!(retn.is_err());
287    }
288}