yaar_lock/sync/
mutex.rs

1use super::WaitNode;
2use crate::ThreadEvent;
3use core::{
4    marker::PhantomData,
5    sync::atomic::{fence, spin_loop_hint, AtomicUsize, Ordering},
6};
7
8#[cfg(feature = "os")]
9pub use self::if_os::*;
10#[cfg(feature = "os")]
11mod if_os {
12    use super::*;
13    use crate::OsThreadEvent;
14
15    /// A [`WordLock`] backed by [`OsThreadEvent`].
16    #[cfg_attr(feature = "nightly", doc(cfg(feature = "os")))]
17    pub type Mutex<T> = RawMutex<T, OsThreadEvent>;
18
19    /// A [`RawMutexGuard`] for [`Mutex`].
20    #[cfg_attr(feature = "nightly", doc(cfg(feature = "os")))]
21    pub type MutexGuard<'a, T> = RawMutexGuard<'a, T, OsThreadEvent>;
22}
23
24/// A mutual exclusion primitive useful for protecting shared data using
25/// [`ThreadEvent`] for thread blocking.
26pub type RawMutex<T, E> = lock_api::Mutex<WordLock<E>, T>;
27
28/// An RAII implementation of a "scoped lock" of a [`RawMutex`].
29/// When this structure is dropped (falls out of scope), the lock will be
30/// unlocked.
31///
32/// The data protected by the mutex can be accessed through this guard via its
33/// `Deref` and `DerefMut` implementations.
34pub type RawMutexGuard<'a, T, E> = lock_api::MutexGuard<'a, WordLock<E>, T>;
35
36const MUTEX_LOCK: usize = 1;
37const QUEUE_LOCK: usize = 2;
38const QUEUE_MASK: usize = !(QUEUE_LOCK | MUTEX_LOCK);
39
40/// [`lock_api::RawMutex`] implementation of parking_lot's [`WordLock`].
41///
42/// [`WordLock`]: https://github.com/Amanieu/parking_lot/blob/master/core/src/word_lock.rs
43pub struct WordLock<E> {
44    state: AtomicUsize,
45    phantom: PhantomData<E>,
46}
47
48unsafe impl<E: Send> Send for WordLock<E> {}
49unsafe impl<E: Sync> Sync for WordLock<E> {}
50
51unsafe impl<E: ThreadEvent> lock_api::RawMutex for WordLock<E> {
52    const INIT: Self = Self {
53        state: AtomicUsize::new(0),
54        phantom: PhantomData,
55    };
56
57    type GuardMarker = lock_api::GuardSend;
58
59    fn try_lock(&self) -> bool {
60        self.state
61            .compare_exchange_weak(0, MUTEX_LOCK, Ordering::Acquire, Ordering::Relaxed)
62            .is_ok()
63    }
64
65    fn lock(&self) {
66        if !self.try_lock() {
67            let node = WaitNode::<E>::default();
68            self.lock_slow(&node);
69        }
70    }
71
72    fn unlock(&self) {
73        let state = self.state.fetch_sub(MUTEX_LOCK, Ordering::Release);
74        if (state & QUEUE_MASK != 0) && (state & QUEUE_LOCK == 0) {
75            self.unlock_slow();
76        }
77    }
78}
79
80impl<E: ThreadEvent> WordLock<E> {
81    #[cold]
82    fn lock_slow(&self, wait_node: &WaitNode<E>) {
83        const MAX_SPIN_DOUBLING: usize = 4;
84
85        let mut spin = 0;
86        let mut state = self.state.load(Ordering::Relaxed);
87        loop {
88            // try to acquire the mutex if its unlocked
89            if state & MUTEX_LOCK == 0 {
90                match self.state.compare_exchange_weak(
91                    state,
92                    state | MUTEX_LOCK,
93                    Ordering::Acquire,
94                    Ordering::Relaxed,
95                ) {
96                    Ok(_) => return,
97                    Err(s) => state = s,
98                }
99                continue;
100            }
101
102            // spin if theres no waiting nodes & havent spun too much
103            let head = (state & QUEUE_MASK) as *const WaitNode<E>;
104            if head.is_null() && spin < MAX_SPIN_DOUBLING {
105                spin += 1;
106                (0..(1 << spin)).for_each(|_| spin_loop_hint());
107                state = self.state.load(Ordering::Relaxed);
108                continue;
109            }
110
111            // try to enqueue our node to the wait queue
112            let head = wait_node.enqueue(head);
113            if let Err(s) = self.state.compare_exchange_weak(
114                state,
115                (head as usize) | (state & !QUEUE_MASK),
116                Ordering::Release,
117                Ordering::Relaxed,
118            ) {
119                state = s;
120                continue;
121            }
122
123            // wait to be signaled by an unlocking thread
124            if wait_node.wait() {
125                return;
126            } else {
127                spin = 0;
128                wait_node.reset();
129                state = self.state.load(Ordering::Relaxed);
130            }
131        }
132    }
133
134    #[cold]
135    fn unlock_slow(&self) {
136        // acquire the queue lock in order to dequeue a node
137        let mut state = self.state.load(Ordering::Relaxed);
138        loop {
139            // give up if theres no nodes to dequeue or the queue is already locked.
140            if (state & QUEUE_MASK == 0) || (state & QUEUE_LOCK != 0) {
141                return;
142            }
143
144            // Try to lock the queue using an Acquire barrier on success
145            // in order to have WaitNode write visibility. See below.
146            match self.state.compare_exchange_weak(
147                state,
148                state | QUEUE_LOCK,
149                Ordering::Acquire,
150                Ordering::Relaxed,
151            ) {
152                Ok(_) => break,
153                Err(s) => state = s,
154            }
155        }
156
157        // A Acquire barrier is required when looping back with a new state
158        // since it will be dereferenced and read from as the head of the queue
159        // and updates to its fields need to be visible from the Release store in
160        // `lock_slow()`.
161        'outer: loop {
162            // If the mutex is locked, let the under dequeue the node.
163            // Safe to use Relaxed on success since not making any memory writes visible.
164            if state & MUTEX_LOCK != 0 {
165                match self.state.compare_exchange_weak(
166                    state,
167                    state & !QUEUE_LOCK,
168                    Ordering::Relaxed,
169                    Ordering::Relaxed,
170                ) {
171                    Ok(_) => return,
172                    Err(s) => state = s,
173                }
174                fence(Ordering::Acquire);
175                continue;
176            }
177
178            // The head is safe to deref since its confirmed to be non-null with the queue
179            // locking above.
180            let head = unsafe { &*((state & QUEUE_MASK) as *const WaitNode<E>) };
181            let (new_tail, tail) = head.dequeue();
182            if new_tail.is_null() {
183                loop {
184                    // unlock the queue while zeroing the head since tail is last node
185                    match self.state.compare_exchange_weak(
186                        state,
187                        state & MUTEX_LOCK,
188                        Ordering::Release,
189                        Ordering::Relaxed,
190                    ) {
191                        Ok(_) => break,
192                        Err(s) => state = s,
193                    }
194
195                    // re-process the queue if a new node comes in
196                    if state & QUEUE_MASK != 0 {
197                        fence(Ordering::Acquire);
198                        continue 'outer;
199                    }
200                }
201            } else {
202                self.state.fetch_and(!QUEUE_LOCK, Ordering::Release);
203            }
204
205            // wake up the dequeued tail
206            tail.notify(false);
207            return;
208        }
209    }
210}
211
212unsafe impl<E: ThreadEvent> lock_api::RawMutexFair for WordLock<E> {
213    fn unlock_fair(&self) {
214        let mut state = self.state.load(Ordering::Relaxed);
215        loop {
216            // there aren't any nodes to dequeue or the queue is locked.
217            // try to unlock the mutex normally without dequeued a node.
218            if (state & QUEUE_MASK == 0) || (state & QUEUE_LOCK != 0) {
219                match self.state.compare_exchange_weak(
220                    state,
221                    state & QUEUE_LOCK,
222                    Ordering::Relaxed,
223                    Ordering::Relaxed,
224                ) {
225                    Ok(_) => return,
226                    Err(s) => state = s,
227                }
228            // The queue is unlocked and theres a node to remove.
229            // try to lock the queue in order to dequeue & wake the node.
230            } else {
231                match self.state.compare_exchange_weak(
232                    state,
233                    state | QUEUE_LOCK,
234                    Ordering::Acquire,
235                    Ordering::Relaxed,
236                ) {
237                    Ok(_) => break,
238                    Err(s) => state = s,
239                }
240            }
241        }
242
243        'outer: loop {
244            // The head is safe to deref since its confirmed non-null with the queue locking
245            // above.
246            let head = unsafe { &*((state & QUEUE_MASK) as *const WaitNode<E>) };
247            let (new_tail, tail) = head.dequeue();
248
249            // update the state to dequeue with a Release ordering which
250            // publishes the writes done by `.dequeue()` to other threads.
251            if new_tail.is_null() {
252                loop {
253                    // unlock the queue while zeroing the head since tail is last node.
254                    match self.state.compare_exchange_weak(
255                        state,
256                        MUTEX_LOCK,
257                        Ordering::Release,
258                        Ordering::Relaxed,
259                    ) {
260                        Ok(_) => break,
261                        Err(s) => state = s,
262                    }
263
264                    // Re-process the queue if a new node comes in.
265                    // See `unlock_slow()` for the reasoning on the Acquire fence.
266                    if state & QUEUE_MASK != 0 {
267                        fence(Ordering::Acquire);
268                        continue 'outer;
269                    }
270                }
271            } else {
272                self.state.fetch_and(!QUEUE_LOCK, Ordering::Release);
273            }
274
275            // wake up the node with the mutex still locked (direct handoff)
276            tail.notify(true);
277            return;
278        }
279    }
280
281    // TODO: bump()
282}
283
284#[cfg(test)]
285#[test]
286fn test_mutex() {
287    use std::{
288        sync::{atomic::AtomicBool, Arc, Barrier, Mutex},
289        thread,
290    };
291    const NUM_THREADS: usize = 10;
292    const NUM_ITERS: usize = 10_000;
293
294    #[derive(Debug)]
295    struct Context {
296        /// Used to check if the critical section is really accessed by one
297        /// thread
298        is_exclusive: AtomicBool,
299        /// Counter which is verified after running.
300        /// u128 since most cpus cannot operate on it with one instruction.
301        count: u128,
302    }
303
304    let start_barrier = Arc::new(Barrier::new(NUM_THREADS + 1));
305    let context = Arc::new(Mutex::new(Context {
306        is_exclusive: AtomicBool::new(false),
307        count: 0,
308    }));
309
310    let workers = (0..NUM_THREADS)
311        .map(|_| {
312            let context = context.clone();
313            let start_barrier = start_barrier.clone();
314            thread::spawn(move || {
315                start_barrier.wait();
316                for _ in 0..NUM_ITERS {
317                    let mut ctx = context.lock().unwrap();
318                    assert_eq!(ctx.is_exclusive.swap(true, Ordering::SeqCst), false);
319                    ctx.count += 1;
320                    ctx.is_exclusive.store(false, Ordering::SeqCst);
321                }
322            })
323        })
324        .collect::<Vec<_>>();
325    start_barrier.wait();
326    workers.into_iter().for_each(|t| t.join().unwrap());
327    assert_eq!(
328        context.lock().unwrap().count,
329        (NUM_ITERS * NUM_THREADS) as u128
330    );
331}