1use std::collections::LinkedList;
2
3use super::*;
4
5struct Waiter {
6 waker: Waker,
7 state: *mut PollState,
8}
9
10unsafe impl Send for Waiter {}
11unsafe impl Sync for Waiter {}
12
13#[derive(Default)]
16pub struct Mutex<T> {
17 map: Map<T, LinkedList<Waiter>>,
18}
19
20impl<T: Eq + Hash + Copy + Unpin> Mutex<T> {
21 #[inline]
22 pub fn new() -> Self {
23 Self {
24 map: new_map(),
25 }
26 }
27
28 #[inline]
29 pub fn is_locked(&self, num: &T) -> bool {
30 self.map.lock().unwrap().contains_key(num)
31 }
32
33 pub fn unlock(&self, num: &T) {
34 if let Some(list) = self.map.lock().unwrap().remove(num) {
35 for waiter in list {
36 unsafe { *waiter.state = PollState::Ready };
38 waiter.waker.wake();
39 }
40 }
41 }
42
43 #[inline]
44 pub fn until_unlocked(&self, num: T) -> UntilUnlocked<T> {
45 UntilUnlocked {
46 num,
47 inner: self,
48 state: PollState::Init,
49 }
50 }
51
52 unsafe fn _wake_next(&self, num: &T) {
54 let mut map = self.map.lock().unwrap_unchecked();
55 let cell = map.get_mut(num);
56 match cell.unwrap_unchecked().pop_front() {
57 Some(waiter) => {
58 *waiter.state = PollState::Ready;
59 waiter.waker.wake();
60 }
61 None => {
62 map.remove(num);
63 }
64 }
65 }
66
67 pub async fn lock(&self, num: T) -> WriteGuard<'_, T> {
68 let mutex_guard = WriteGuard { num, inner: self };
69 WaitForUnlock {
70 num,
71 map: &self.map,
72 state: PollState::Init.into(),
73 }
74 .await;
75 mutex_guard
76 }
77}
78
79pub struct WriteGuard<'a, T: Eq + Hash + Copy + Unpin> {
80 num: T,
81 inner: &'a Mutex<T>,
82}
83
84impl<'a, T: Eq + Hash + Copy + Unpin> Drop for WriteGuard<'a, T> {
88 fn drop(&mut self) {
89 unsafe { self.inner._wake_next(&self.num) };
91 }
92}
93
94struct WaitForUnlock<'a, T> {
95 num: T,
96 map: &'a Map<T, LinkedList<Waiter>>,
97 state: UnsafeCell<PollState>,
98}
99
100impl<'a, T: Eq + Hash + Copy + Unpin> Future for WaitForUnlock<'a, T> {
101 type Output = ();
102 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
103 let this = self.get_mut();
104 ret_fut!(*this.state.get(), {
105 let mut map = this.map.lock().unwrap();
106 match map.get_mut(&this.num) {
107 Some(w) => w.push_back(Waiter {
108 waker: cx.waker().clone(),
109 state: this.state.get(),
110 }),
111 None => {
112 map.insert(this.num, LinkedList::new());
113 return Poll::Ready(());
114 }
115 }
116 });
117 }
118}
119
120pub struct UntilUnlocked<'a, T> {
121 num: T,
122 inner: &'a Mutex<T>,
123 state: PollState,
124}
125
126impl<'a, T: Eq + Hash + Copy + Unpin> Future for UntilUnlocked<'a, T> {
127 type Output = ();
128 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
129 let this = self.get_mut();
130 match this.state {
131 PollState::Init => {
132 let mut map = this.inner.map.lock().unwrap();
133 match map.get_mut(&this.num) {
134 Some(w) => w.push_back(Waiter {
135 waker: cx.waker().clone(),
136 state: &mut this.state,
137 }),
138 None => return Poll::Ready(()),
139 }
140 this.state = PollState::Pending;
141 }
142 PollState::Ready => {
143 unsafe { this.inner._wake_next(&this.num) };
145 return Poll::Ready(());
146 }
147 _ => {}
148 }
149 Poll::Pending
150 }
151}