1use super::WaitNode;
2use crate::ThreadEvent;
3use core::{
4 marker::PhantomData,
5 sync::atomic::{fence, spin_loop_hint, AtomicUsize, Ordering},
6};
7
8#[cfg(feature = "os")]
9pub use self::if_os::*;
10#[cfg(feature = "os")]
11mod if_os {
12 use super::*;
13 use crate::OsThreadEvent;
14
15 #[cfg_attr(feature = "nightly", doc(cfg(feature = "os")))]
17 pub type Mutex<T> = RawMutex<T, OsThreadEvent>;
18
19 #[cfg_attr(feature = "nightly", doc(cfg(feature = "os")))]
21 pub type MutexGuard<'a, T> = RawMutexGuard<'a, T, OsThreadEvent>;
22}
23
24pub type RawMutex<T, E> = lock_api::Mutex<WordLock<E>, T>;
27
28pub type RawMutexGuard<'a, T, E> = lock_api::MutexGuard<'a, WordLock<E>, T>;
35
36const MUTEX_LOCK: usize = 1;
37const QUEUE_LOCK: usize = 2;
38const QUEUE_MASK: usize = !(QUEUE_LOCK | MUTEX_LOCK);
39
40pub struct WordLock<E> {
44 state: AtomicUsize,
45 phantom: PhantomData<E>,
46}
47
48unsafe impl<E: Send> Send for WordLock<E> {}
49unsafe impl<E: Sync> Sync for WordLock<E> {}
50
51unsafe impl<E: ThreadEvent> lock_api::RawMutex for WordLock<E> {
52 const INIT: Self = Self {
53 state: AtomicUsize::new(0),
54 phantom: PhantomData,
55 };
56
57 type GuardMarker = lock_api::GuardSend;
58
59 fn try_lock(&self) -> bool {
60 self.state
61 .compare_exchange_weak(0, MUTEX_LOCK, Ordering::Acquire, Ordering::Relaxed)
62 .is_ok()
63 }
64
65 fn lock(&self) {
66 if !self.try_lock() {
67 let node = WaitNode::<E>::default();
68 self.lock_slow(&node);
69 }
70 }
71
72 fn unlock(&self) {
73 let state = self.state.fetch_sub(MUTEX_LOCK, Ordering::Release);
74 if (state & QUEUE_MASK != 0) && (state & QUEUE_LOCK == 0) {
75 self.unlock_slow();
76 }
77 }
78}
79
80impl<E: ThreadEvent> WordLock<E> {
81 #[cold]
82 fn lock_slow(&self, wait_node: &WaitNode<E>) {
83 const MAX_SPIN_DOUBLING: usize = 4;
84
85 let mut spin = 0;
86 let mut state = self.state.load(Ordering::Relaxed);
87 loop {
88 if state & MUTEX_LOCK == 0 {
90 match self.state.compare_exchange_weak(
91 state,
92 state | MUTEX_LOCK,
93 Ordering::Acquire,
94 Ordering::Relaxed,
95 ) {
96 Ok(_) => return,
97 Err(s) => state = s,
98 }
99 continue;
100 }
101
102 let head = (state & QUEUE_MASK) as *const WaitNode<E>;
104 if head.is_null() && spin < MAX_SPIN_DOUBLING {
105 spin += 1;
106 (0..(1 << spin)).for_each(|_| spin_loop_hint());
107 state = self.state.load(Ordering::Relaxed);
108 continue;
109 }
110
111 let head = wait_node.enqueue(head);
113 if let Err(s) = self.state.compare_exchange_weak(
114 state,
115 (head as usize) | (state & !QUEUE_MASK),
116 Ordering::Release,
117 Ordering::Relaxed,
118 ) {
119 state = s;
120 continue;
121 }
122
123 if wait_node.wait() {
125 return;
126 } else {
127 spin = 0;
128 wait_node.reset();
129 state = self.state.load(Ordering::Relaxed);
130 }
131 }
132 }
133
134 #[cold]
135 fn unlock_slow(&self) {
136 let mut state = self.state.load(Ordering::Relaxed);
138 loop {
139 if (state & QUEUE_MASK == 0) || (state & QUEUE_LOCK != 0) {
141 return;
142 }
143
144 match self.state.compare_exchange_weak(
147 state,
148 state | QUEUE_LOCK,
149 Ordering::Acquire,
150 Ordering::Relaxed,
151 ) {
152 Ok(_) => break,
153 Err(s) => state = s,
154 }
155 }
156
157 'outer: loop {
162 if state & MUTEX_LOCK != 0 {
165 match self.state.compare_exchange_weak(
166 state,
167 state & !QUEUE_LOCK,
168 Ordering::Relaxed,
169 Ordering::Relaxed,
170 ) {
171 Ok(_) => return,
172 Err(s) => state = s,
173 }
174 fence(Ordering::Acquire);
175 continue;
176 }
177
178 let head = unsafe { &*((state & QUEUE_MASK) as *const WaitNode<E>) };
181 let (new_tail, tail) = head.dequeue();
182 if new_tail.is_null() {
183 loop {
184 match self.state.compare_exchange_weak(
186 state,
187 state & MUTEX_LOCK,
188 Ordering::Release,
189 Ordering::Relaxed,
190 ) {
191 Ok(_) => break,
192 Err(s) => state = s,
193 }
194
195 if state & QUEUE_MASK != 0 {
197 fence(Ordering::Acquire);
198 continue 'outer;
199 }
200 }
201 } else {
202 self.state.fetch_and(!QUEUE_LOCK, Ordering::Release);
203 }
204
205 tail.notify(false);
207 return;
208 }
209 }
210}
211
212unsafe impl<E: ThreadEvent> lock_api::RawMutexFair for WordLock<E> {
213 fn unlock_fair(&self) {
214 let mut state = self.state.load(Ordering::Relaxed);
215 loop {
216 if (state & QUEUE_MASK == 0) || (state & QUEUE_LOCK != 0) {
219 match self.state.compare_exchange_weak(
220 state,
221 state & QUEUE_LOCK,
222 Ordering::Relaxed,
223 Ordering::Relaxed,
224 ) {
225 Ok(_) => return,
226 Err(s) => state = s,
227 }
228 } else {
231 match self.state.compare_exchange_weak(
232 state,
233 state | QUEUE_LOCK,
234 Ordering::Acquire,
235 Ordering::Relaxed,
236 ) {
237 Ok(_) => break,
238 Err(s) => state = s,
239 }
240 }
241 }
242
243 'outer: loop {
244 let head = unsafe { &*((state & QUEUE_MASK) as *const WaitNode<E>) };
247 let (new_tail, tail) = head.dequeue();
248
249 if new_tail.is_null() {
252 loop {
253 match self.state.compare_exchange_weak(
255 state,
256 MUTEX_LOCK,
257 Ordering::Release,
258 Ordering::Relaxed,
259 ) {
260 Ok(_) => break,
261 Err(s) => state = s,
262 }
263
264 if state & QUEUE_MASK != 0 {
267 fence(Ordering::Acquire);
268 continue 'outer;
269 }
270 }
271 } else {
272 self.state.fetch_and(!QUEUE_LOCK, Ordering::Release);
273 }
274
275 tail.notify(true);
277 return;
278 }
279 }
280
281 }
283
284#[cfg(test)]
285#[test]
286fn test_mutex() {
287 use std::{
288 sync::{atomic::AtomicBool, Arc, Barrier, Mutex},
289 thread,
290 };
291 const NUM_THREADS: usize = 10;
292 const NUM_ITERS: usize = 10_000;
293
294 #[derive(Debug)]
295 struct Context {
296 is_exclusive: AtomicBool,
299 count: u128,
302 }
303
304 let start_barrier = Arc::new(Barrier::new(NUM_THREADS + 1));
305 let context = Arc::new(Mutex::new(Context {
306 is_exclusive: AtomicBool::new(false),
307 count: 0,
308 }));
309
310 let workers = (0..NUM_THREADS)
311 .map(|_| {
312 let context = context.clone();
313 let start_barrier = start_barrier.clone();
314 thread::spawn(move || {
315 start_barrier.wait();
316 for _ in 0..NUM_ITERS {
317 let mut ctx = context.lock().unwrap();
318 assert_eq!(ctx.is_exclusive.swap(true, Ordering::SeqCst), false);
319 ctx.count += 1;
320 ctx.is_exclusive.store(false, Ordering::SeqCst);
321 }
322 })
323 })
324 .collect::<Vec<_>>();
325 start_barrier.wait();
326 workers.into_iter().for_each(|t| t.join().unwrap());
327 assert_eq!(
328 context.lock().unwrap().count,
329 (NUM_ITERS * NUM_THREADS) as u128
330 );
331}