hipthread/
mutex.rs

1use super::pthread;
2use core::cell::UnsafeCell;
3use core::marker::PhantomData;
4use core::ops::{Deref, DerefMut};
5
6/// 封装pthread_mutex_t提供互斥锁机制.
7pub struct Mutex<T: ?Sized> {
8    lock: pthread::pthread_mutex_t,
9    val: UnsafeCell<T>,
10}
11
12unsafe impl<T: Send + ?Sized> Send for Mutex<T> {}
13unsafe impl<T: ?Sized> Sync for Mutex<T> {}
14
15pub struct MutexGuard<'a, T: ?Sized + 'a> {
16    mutex: &'a Mutex<T>,
17    // impl !Send for MutexGuard
18    // impl !Sync for MutexGuard
19    mark: PhantomData<*const T>,
20}
21
22unsafe impl<T: ?Sized + Sync> Sync for MutexGuard<'_, T> {}
23
24impl<T: ?Sized> Drop for Mutex<T> {
25    fn drop(&mut self) {
26        unsafe {
27            pthread::pthread_mutex_destroy(&mut self.lock);
28        }
29    }
30}
31
32impl<T> Mutex<T> {
33    pub const fn new(val: T) -> Self {
34        Self {
35            lock: pthread::PTHREAD_MUTEX_INITIALIZER,
36            val: UnsafeCell::new(val),
37        }
38    }
39}
40
41impl<T: ?Sized> Mutex<T> {
42    fn get_lock(&self) -> *mut pthread::pthread_mutex_t {
43        &self.lock as *const _ as *mut pthread::pthread_mutex_t
44    }
45    pub fn lock(&self) -> MutexGuard<'_, T> {
46        let ret = unsafe { pthread::pthread_mutex_lock(self.get_lock()) };
47        assert_eq!(ret, 0);
48        MutexGuard::new(self)
49    }
50
51    pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
52        let ret = unsafe { pthread::pthread_mutex_trylock(self.get_lock()) };
53        if ret == 0 {
54            Some(MutexGuard::new(self))
55        } else {
56            None
57        }
58    }
59}
60
61impl<'a, T: ?Sized> MutexGuard<'a, T> {
62    fn new(mutex: &'a Mutex<T>) -> Self {
63        Self {
64            mutex,
65            mark: PhantomData,
66        }
67    }
68}
69
70impl<T: ?Sized> Drop for MutexGuard<'_, T> {
71    fn drop(&mut self) {
72        unsafe {
73            pthread::pthread_mutex_unlock(self.mutex.get_lock());
74        }
75    }
76}
77
78impl<T: ?Sized> Deref for MutexGuard<'_, T> {
79    type Target = T;
80    fn deref(&self) -> &Self::Target {
81        unsafe { &*self.mutex.val.get() }
82    }
83}
84
85impl<T: ?Sized> DerefMut for MutexGuard<'_, T> {
86    fn deref_mut(&mut self) -> &mut Self::Target {
87        unsafe { &mut *self.mutex.val.get() }
88    }
89}
90
91#[cfg(test)]
92mod test {
93    use crate::*;
94    use hipool::*;
95
96    #[test]
97    fn test_mutex() {
98        let lock = Arc::new(Mutex::new(1)).unwrap();
99        let cloned = lock.clone();
100        let mut val = lock.lock();
101        let handle = spawn(move || {
102            let val = cloned.lock();
103            *val
104        })
105        .unwrap();
106        *val = 100;
107        core::mem::drop(val);
108        let val = handle.join().unwrap();
109        assert_eq!(val, 100);
110    }
111}