page_lock/
mutex.rs

1use std::collections::LinkedList;
2
3use super::*;
4
5struct Waiter {
6    waker: Waker,
7    state: *mut PollState,
8}
9
10unsafe impl Send for Waiter {}
11unsafe impl Sync for Waiter {}
12
13/// Super-fast asynchronous mutex implementation.
14/// Implementation is based on the Binary [Semaphore](https://en.wikipedia.org/wiki/Semaphore_(programming)) algorithm.
15#[derive(Default)]
16pub struct Mutex<T> {
17    map: Map<T, LinkedList<Waiter>>,
18}
19
20impl<T: Eq + Hash + Copy + Unpin> Mutex<T> {
21    #[inline]
22    pub fn new() -> Self {
23        Self {
24            map: new_map(),
25        }
26    }
27
28    #[inline]
29    pub fn is_locked(&self, num: &T) -> bool {
30        self.map.lock().unwrap().contains_key(num)
31    }
32
33    pub fn unlock(&self, num: &T) {
34        if let Some(list) = self.map.lock().unwrap().remove(num) {
35            for waiter in list {
36                // SAFETY: We have exclusive access to the `state`, so it is safe to mutate it.
37                unsafe { *waiter.state = PollState::Ready };
38                waiter.waker.wake();
39            }
40        }
41    }
42
43    #[inline]
44    pub fn until_unlocked(&self, num: T) -> UntilUnlocked<T> {
45        UntilUnlocked {
46            num,
47            inner: self,
48            state: PollState::Init,
49        }
50    }
51
52    /// SAFETY: Make sure that, `LinkedList<(...)>` is properly initialized in `HashMap`.
53    unsafe fn _wake_next(&self, num: &T) {
54        let mut map = self.map.lock().unwrap_unchecked();
55        let cell = map.get_mut(num);
56        match cell.unwrap_unchecked().pop_front() {
57            Some(waiter) => {
58                *waiter.state = PollState::Ready;
59                waiter.waker.wake();
60            }
61            None => {
62                map.remove(num);
63            }
64        }
65    }
66
67    pub async fn lock(&self, num: T) -> WriteGuard<'_, T> {
68        let mutex_guard = WriteGuard { num, inner: self };
69        WaitForUnlock {
70            num,
71            map: &self.map,
72            state: PollState::Init.into(),
73        }
74        .await;
75        mutex_guard
76    }
77}
78
79pub struct WriteGuard<'a, T: Eq + Hash + Copy + Unpin> {
80    num: T,
81    inner: &'a Mutex<T>,
82}
83
84// impl<T: ?Sized> !Send for MutexGuard<'_, T> {}
85// impl<T: ?Sized + Sync> Sync for MutexGuard<'_, T> {}
86
87impl<'a, T: Eq + Hash + Copy + Unpin> Drop for WriteGuard<'a, T> {
88    fn drop(&mut self) {
89        // SAFETY: LinkedList is properly initialized.
90        unsafe { self.inner._wake_next(&self.num) };
91    }
92}
93
94struct WaitForUnlock<'a, T> {
95    num: T,
96    map: &'a Map<T, LinkedList<Waiter>>,
97    state: UnsafeCell<PollState>,
98}
99
100impl<'a, T: Eq + Hash + Copy + Unpin> Future for WaitForUnlock<'a, T> {
101    type Output = ();
102    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
103        let this = self.get_mut();
104        ret_fut!(*this.state.get(), {
105            let mut map = this.map.lock().unwrap();
106            match map.get_mut(&this.num) {
107                Some(w) => w.push_back(Waiter {
108                    waker: cx.waker().clone(),
109                    state: this.state.get(),
110                }),
111                None => {
112                    map.insert(this.num, LinkedList::new());
113                    return Poll::Ready(());
114                }
115            }
116        });
117    }
118}
119
120pub struct UntilUnlocked<'a, T> {
121    num: T,
122    inner: &'a Mutex<T>,
123    state: PollState,
124}
125
126impl<'a, T: Eq + Hash + Copy + Unpin> Future for UntilUnlocked<'a, T> {
127    type Output = ();
128    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
129        let this = self.get_mut();
130        match this.state {
131            PollState::Init => {
132                let mut map = this.inner.map.lock().unwrap();
133                match map.get_mut(&this.num) {
134                    Some(w) => w.push_back(Waiter {
135                        waker: cx.waker().clone(),
136                        state: &mut this.state,
137                    }),
138                    None => return Poll::Ready(()),
139                }
140                this.state = PollState::Pending;
141            }
142            PollState::Ready => {
143                // SAFETY: LinkedList is properly initialized.
144                unsafe { this.inner._wake_next(&this.num) };
145                return Poll::Ready(());
146            }
147            _ => {}
148        }
149        Poll::Pending
150    }
151}