1use super::pthread;
2use core::cell::UnsafeCell;
3use core::marker::PhantomData;
4use core::ops::{Deref, DerefMut};
5
6pub struct Mutex<T: ?Sized> {
8 lock: pthread::pthread_mutex_t,
9 val: UnsafeCell<T>,
10}
11
12unsafe impl<T: Send + ?Sized> Send for Mutex<T> {}
13unsafe impl<T: ?Sized> Sync for Mutex<T> {}
14
15pub struct MutexGuard<'a, T: ?Sized + 'a> {
16 mutex: &'a Mutex<T>,
17 mark: PhantomData<*const T>,
20}
21
22unsafe impl<T: ?Sized + Sync> Sync for MutexGuard<'_, T> {}
23
24impl<T: ?Sized> Drop for Mutex<T> {
25 fn drop(&mut self) {
26 unsafe {
27 pthread::pthread_mutex_destroy(&mut self.lock);
28 }
29 }
30}
31
32impl<T> Mutex<T> {
33 pub const fn new(val: T) -> Self {
34 Self {
35 lock: pthread::PTHREAD_MUTEX_INITIALIZER,
36 val: UnsafeCell::new(val),
37 }
38 }
39}
40
41impl<T: ?Sized> Mutex<T> {
42 fn get_lock(&self) -> *mut pthread::pthread_mutex_t {
43 &self.lock as *const _ as *mut pthread::pthread_mutex_t
44 }
45 pub fn lock(&self) -> MutexGuard<'_, T> {
46 let ret = unsafe { pthread::pthread_mutex_lock(self.get_lock()) };
47 assert_eq!(ret, 0);
48 MutexGuard::new(self)
49 }
50
51 pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
52 let ret = unsafe { pthread::pthread_mutex_trylock(self.get_lock()) };
53 if ret == 0 {
54 Some(MutexGuard::new(self))
55 } else {
56 None
57 }
58 }
59}
60
61impl<'a, T: ?Sized> MutexGuard<'a, T> {
62 fn new(mutex: &'a Mutex<T>) -> Self {
63 Self {
64 mutex,
65 mark: PhantomData,
66 }
67 }
68}
69
70impl<T: ?Sized> Drop for MutexGuard<'_, T> {
71 fn drop(&mut self) {
72 unsafe {
73 pthread::pthread_mutex_unlock(self.mutex.get_lock());
74 }
75 }
76}
77
78impl<T: ?Sized> Deref for MutexGuard<'_, T> {
79 type Target = T;
80 fn deref(&self) -> &Self::Target {
81 unsafe { &*self.mutex.val.get() }
82 }
83}
84
85impl<T: ?Sized> DerefMut for MutexGuard<'_, T> {
86 fn deref_mut(&mut self) -> &mut Self::Target {
87 unsafe { &mut *self.mutex.val.get() }
88 }
89}
90
91#[cfg(test)]
92mod test {
93 use crate::*;
94 use hipool::*;
95
96 #[test]
97 fn test_mutex() {
98 let lock = Arc::new(Mutex::new(1)).unwrap();
99 let cloned = lock.clone();
100 let mut val = lock.lock();
101 let handle = spawn(move || {
102 let val = cloned.lock();
103 *val
104 })
105 .unwrap();
106 *val = 100;
107 core::mem::drop(val);
108 let val = handle.join().unwrap();
109 assert_eq!(val, 100);
110 }
111}