k_lock/
k_lock_mutex.rs

1use std::{
2    cell::UnsafeCell,
3    hint::spin_loop,
4    marker::PhantomData,
5    ops::{Deref, DerefMut},
6    ptr::addr_of,
7    sync::atomic::{AtomicU32, Ordering},
8};
9
10use atomic_wait::wake_one;
11
12use crate::poison::{self, LockResult, TryLockError, TryLockResult};
13
14const UNLOCKED: u32 = 0;
15const LOCKED: u32 = 1;
16const CONTENDED: u32 = 2;
17const EXTRA_CONTENDED: u32 = 3;
18
19/// A mutual exclusion primitive useful for protecting shared data
20///
21/// This mutex will block threads waiting for the lock to become available. The
22/// mutex can be created via a [`new`] constructor. Each mutex has a type parameter
23/// which represents the data that it is protecting. The data can only be accessed
24/// through the RAII guards returned from [`lock`] and [`try_lock`], which
25/// guarantees that the data is only ever accessed when the mutex is locked.
26///
27/// # Difference from `std::sync::Mutex`
28/// This mutex is optimized for brief critical sections. It has short spin cycles
29/// to prevent overwork, and it aggressively wakes multiple waiters per unlock.
30///
31/// If there is concern about possible panic in a critical section, `std::sync::Mutex`
32/// is the appropriate choice.
33///
34/// If critical sections are more than a few nanoseconds long, `std::sync::Mutex`
35/// may be better. As always, profiling and measuring is important.
36///
37/// Much of this mutex implementation and its documentation is adapted with humble
38/// gratitude from the venerable `std::sync::Mutex`.
39///
40/// # Poisoning
41///
42/// The mutex in this module uses the poisoning strategy from `std::sync::Mutex`.
43///
44/// [`new`]: Self::new
45/// [`lock`]: Self::lock
46/// [`try_lock`]: Self::try_lock
47///
48/// # Examples
49///
50/// ```
51/// use std::sync::Arc;
52/// use std::thread;
53/// use std::sync::mpsc::channel;
54/// use k_lock::Mutex;
55///
56/// const N: usize = 10;
57///
58/// // Spawn a few threads to increment a shared variable (non-atomically), and
59/// // let the main thread know once all increments are done.
60/// //
61/// // Here we're using an Arc to share memory among threads, and the data inside
62/// // the Arc is protected with a mutex.
63/// let data = Arc::new(Mutex::new(0));
64///
65/// let (tx, rx) = channel();
66/// for _ in 0..N {
67///     let (data, tx) = (Arc::clone(&data), tx.clone());
68///     thread::spawn(move || {
69///         // The shared state can only be accessed once the lock is held.
70///         // Our non-atomic increment is safe because we're the only thread
71///         // which can access the shared state when the lock is held.
72///         //
73///         // We unwrap() the return value to assert that we are not expecting
74///         // threads to ever fail while holding the lock.
75///         let mut data = data.lock().unwrap();
76///         *data += 1;
77///         if *data == N {
78///             tx.send(()).unwrap();
79///         }
80///         // the lock is unlocked here when `data` goes out of scope.
81///     });
82/// }
83///
84/// rx.recv().unwrap();
85/// ```
86///
87/// To unlock a mutex guard sooner than the end of the enclosing scope,
88/// either create an inner scope or drop the guard manually.
89///
90/// ```
91/// use std::sync::Arc;
92/// use std::thread;
93/// use k_lock::Mutex;
94///
95/// const N: usize = 3;
96///
97/// let data_mutex = Arc::new(Mutex::new(vec![1, 2, 3, 4]));
98/// let res_mutex = Arc::new(Mutex::new(0));
99///
100/// let mut threads = Vec::with_capacity(N);
101/// (0..N).for_each(|_| {
102///     let data_mutex_clone = Arc::clone(&data_mutex);
103///     let res_mutex_clone = Arc::clone(&res_mutex);
104///
105///     threads.push(thread::spawn(move || {
106///         // Here we use a block to limit the lifetime of the lock guard.
107///         let result = {
108///             let mut data = data_mutex_clone.lock().unwrap();
109///             // This is the result of some important and long-ish work.
110///             let result = data.iter().fold(0, |acc, x| acc + x * 2);
111///             data.push(result);
112///             result
113///             // The mutex guard gets dropped here, together with any other values
114///             // created in the critical section.
115///         };
116///         // The guard created here is a temporary dropped at the end of the statement, i.e.
117///         // the lock would not remain being held even if the thread did some additional work.
118///         *res_mutex_clone.lock().unwrap() += result;
119///     }));
120/// });
121///
122/// let mut data = data_mutex.lock().unwrap();
123/// // This is the result of some important and long-ish work.
124/// let result = data.iter().fold(0, |acc, x| acc + x * 2);
125/// data.push(result);
126/// // We drop the `data` explicitly because it's not necessary anymore and the
127/// // thread still has work to do. This allow other threads to start working on
128/// // the data immediately, without waiting for the rest of the unrelated work
129/// // to be done here.
130/// //
131/// // It's even more important here than in the threads because we `.join` the
132/// // threads after that. If we had not dropped the mutex guard, a thread could
133/// // be waiting forever for it, causing a deadlock.
134/// // As in the threads, a block could have been used instead of calling the
135/// // `drop` function.
136/// drop(data);
137/// // Here the mutex guard is not assigned to a variable and so, even if the
138/// // scope does not end after this line, the mutex is still released: there is
139/// // no deadlock.
140/// *res_mutex.lock().unwrap() += result;
141///
142/// threads.into_iter().for_each(|thread| {
143///     thread
144///         .join()
145///         .expect("The thread creating or execution failed !")
146/// });
147///
148/// assert_eq!(*res_mutex.lock().unwrap(), 800);
149/// ```
150pub struct Mutex<T: ?Sized> {
151    futex: AtomicU32,
152    lock_epoch: AtomicU32,
153    poison: poison::Flag,
154    data: UnsafeCell<T>,
155}
156
157impl<T: ?Sized> Mutex<T> {
158    /// Acquires a mutex, blocking the current thread until it is able to do so.
159    ///
160    /// This function will block the local thread until it is available to acquire
161    /// the mutex. Upon returning, the thread is the only thread with the lock
162    /// held. An RAII guard is returned to allow scoped unlock of the lock. When
163    /// the guard goes out of scope, the mutex will be unlocked.
164    ///
165    /// The exact behavior on locking a mutex in the thread which already holds
166    /// the lock is left unspecified. However, this function will not return on
167    /// the second call (it might panic or deadlock, for example).
168    ///
169    /// # Errors
170    ///
171    /// If another user of this mutex panicked while holding the mutex, then
172    /// this call will return an error once the mutex is acquired.
173    ///
174    /// # Panics
175    ///
176    /// This function might panic when called if the lock is already held by
177    /// the current thread. It also might not. Don't try it!
178    ///
179    /// # Examples
180    ///
181    /// ```
182    /// use std::sync::Arc;
183    /// use std::thread;
184    /// use k_lock::Mutex;
185    ///
186    /// let mutex = Arc::new(Mutex::new(0));
187    /// let c_mutex = Arc::clone(&mutex);
188    ///
189    /// thread::spawn(move || {
190    ///     *c_mutex.lock().unwrap() = 10;
191    /// }).join().expect("thread::spawn failed");
192    /// assert_eq!(*mutex.lock().unwrap(), 10);
193    /// ```
194    #[inline]
195    pub fn lock(&self) -> LockResult<MutexGuard<T>> {
196        if self
197            .futex
198            .compare_exchange(UNLOCKED, LOCKED, Ordering::Acquire, Ordering::Relaxed)
199            .is_ok()
200        {
201            self.lock_epoch.fetch_add(1, Ordering::Relaxed);
202            return MutexGuard::new(self);
203        }
204        self.lock_contended()
205    }
206
207    /// Move this out so it does not bloat asm and reduce the likelihood of lock() being inlined.
208    #[cold]
209    #[allow(clippy::comparison_chain)] // I prefer it this way in this case because of the semantic meaning
210    fn lock_contended(&self) -> LockResult<MutexGuard<T>> {
211        loop {
212            let state = self.spin();
213            // when locking under contention you have to stay contended or you may leak wakes
214            let expect = if state < CONTENDED {
215                if self.futex.swap(CONTENDED, Ordering::Acquire) == UNLOCKED {
216                    self.lock_epoch.fetch_add(1, Ordering::Relaxed);
217                    return MutexGuard::new(self);
218                }
219                CONTENDED
220            } else if state == CONTENDED {
221                if self.futex.swap(EXTRA_CONTENDED, Ordering::Acquire) == UNLOCKED {
222                    self.lock_epoch.fetch_add(1, Ordering::Relaxed);
223                    return MutexGuard::new(self);
224                }
225                EXTRA_CONTENDED
226            } else {
227                // we've already promoted to extra contended. We're... extra contended.
228                EXTRA_CONTENDED
229            };
230            atomic_wait::wait(&self.futex, expect);
231        }
232    }
233
234    /// Move this out so it does not bloat asm.
235    #[cold]
236    fn spin(&self) -> u32 {
237        let mut spin = 400;
238        let mut epoch = 0;
239        loop {
240            let v = self.futex.load(Ordering::Relaxed);
241            if v != LOCKED || spin == 0 {
242                break v;
243            }
244            let now = self.lock_epoch.load(Ordering::Relaxed);
245            if now != epoch {
246                // Refresh the spin because this lock is making timely progress.
247                epoch = now;
248                spin = 400;
249                // This might be too aggressive - it drops latency for small critical sections
250                // by keeping this thread out of futex, but yield_now is not free either.
251                // Adding a yield to the spin refresh dropped contended latency by over 25%,
252                // but there may be other heuristics that outperform this.
253                std::thread::yield_now();
254            }
255            spin_loop();
256            spin -= 1;
257        }
258    }
259
260    /// Attempts to acquire this lock.
261    ///
262    /// If the lock could not be acquired at this time, then [`Err`] is returned.
263    /// Otherwise, an RAII guard is returned. The lock will be unlocked when the
264    /// guard is dropped.
265    ///
266    /// This function does not block.
267    ///
268    /// # Errors
269    ///
270    /// If the mutex could not be acquired because it is already locked, then
271    /// this call will return the [`WouldBlock`] error.
272    ///
273    /// [`WouldBlock`]: TryLockError::WouldBlock
274    ///
275    /// # Examples
276    ///
277    /// ```
278    /// use std::sync::Arc;
279    /// use std::thread;
280    /// use k_lock::Mutex;
281    ///
282    /// let mutex = Arc::new(Mutex::new(0));
283    /// let c_mutex = Arc::clone(&mutex);
284    ///
285    /// thread::spawn(move || {
286    ///     let mut lock = c_mutex.try_lock();
287    ///     if let Ok(ref mut mutex) = lock {
288    ///         **mutex = 10;
289    ///     } else {
290    ///         println!("try_lock failed");
291    ///     }
292    /// }).join().expect("thread::spawn failed");
293    /// assert_eq!(*mutex.lock().unwrap(), 10);
294    /// ```
295    #[inline]
296    pub fn try_lock(&self) -> TryLockResult<MutexGuard<'_, T>> {
297        match self
298            .futex
299            .compare_exchange(UNLOCKED, LOCKED, Ordering::Acquire, Ordering::Relaxed)
300        {
301            Ok(_) => Ok(MutexGuard::new(self)?),
302            Err(_) => Err(TryLockError::WouldBlock),
303        }
304    }
305
306    /// Returns a mutable reference to the underlying data.
307    ///
308    /// Since this call borrows the `Mutex` mutably, no actual locking needs to
309    /// take place -- the mutable borrow statically guarantees no locks exist.
310    ///
311    /// # Errors
312    ///
313    /// PoisonError if a thread has previously paniced while holding this mutex.
314    ///
315    /// # Examples
316    ///
317    /// ```
318    /// use k_lock::Mutex;
319    ///
320    /// let mut mutex = Mutex::new(0);
321    /// *mutex.get_mut().unwrap() = 10;
322    /// assert_eq!(*mutex.lock().unwrap(), 10);
323    /// ```
324    pub fn get_mut(&mut self) -> LockResult<&mut T> {
325        let data = self.data.get_mut();
326        poison::map_result(self.poison.borrow(), |()| data)
327    }
328}
329
330impl<T> Mutex<T> {
331    /// Creates a new mutex in an unlocked state ready for use.
332    ///
333    /// # Examples
334    ///
335    /// ```
336    /// use k_lock::Mutex;
337    ///
338    /// let mutex = Mutex::new(0);
339    /// ```
340    #[inline]
341    pub const fn new(data: T) -> Self {
342        Self {
343            data: UnsafeCell::new(data),
344            lock_epoch: AtomicU32::new(0),
345            poison: poison::Flag::new(),
346            futex: AtomicU32::new(UNLOCKED),
347        }
348    }
349}
350
351impl<T> From<T> for Mutex<T> {
352    /// Creates a new mutex in an unlocked state ready for use.
353    /// This is equivalent to [`Mutex::new`].
354    fn from(t: T) -> Self {
355        Mutex::new(t)
356    }
357}
358
359impl<T: ?Sized + Default> Default for Mutex<T> {
360    /// Creates a `Mutex<T>`, with the `Default` value for T.
361    fn default() -> Mutex<T> {
362        Mutex::new(Default::default())
363    }
364}
365
366impl<T: ?Sized + std::fmt::Debug> std::fmt::Debug for Mutex<T> {
367    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368        let mut d = f.debug_struct("Mutex");
369        match self.try_lock() {
370            Ok(guard) => {
371                d.field("data", &&*guard);
372            }
373            Err(TryLockError::Poisoned(err)) => {
374                d.field("data", &&**err.get_ref());
375            }
376            Err(TryLockError::WouldBlock) => {
377                d.field("data", &format_args!("<locked>"));
378            }
379        }
380        d.field("poisoned", &self.poison.get());
381        d.finish_non_exhaustive()
382    }
383}
384
385// these are the only places where `T: Send` matters; all other
386// functionality works fine on a single thread.
387unsafe impl<T: ?Sized + Send> Send for Mutex<T> {}
388unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {}
389
390// negative impls are not stable yet...
391// impl<T: ?Sized> !Send for MutexGuard<'_, T> {}
392unsafe impl<T: ?Sized + Sync> Sync for MutexGuard<'_, T> {}
393
394/// An RAII implementation of a "scoped lock" of a mutex. When this structure is
395/// dropped (falls out of scope), the lock will be unlocked.
396///
397/// The data protected by the mutex can be accessed through this guard via its
398/// [`Deref`] and [`DerefMut`] implementations.
399///
400/// This structure is created by the [`lock`] and [`try_lock`] methods on
401/// [`Mutex`].
402///
403/// [`lock`]: Mutex::lock
404/// [`try_lock`]: Mutex::try_lock
405#[must_use = "if unused the Mutex will immediately unlock"]
406#[clippy::has_significant_drop]
407pub struct MutexGuard<'a, T: ?Sized + 'a> {
408    lock: &'a Mutex<T>,
409    poison: poison::Guard,
410    _phantom: PhantomUnsend,
411}
412
413impl<'a, T: ?Sized> MutexGuard<'a, T> {
414    fn new(lock: &'a Mutex<T>) -> LockResult<Self> {
415        poison::map_result(lock.poison.guard(), |guard| Self {
416            lock,
417            poison: guard,
418            _phantom: PhantomData,
419        })
420    }
421}
422
423pub type PhantomUnsend = PhantomData<std::sync::MutexGuard<'static, ()>>;
424
425impl<T: ?Sized> Deref for MutexGuard<'_, T> {
426    type Target = T;
427
428    #[inline]
429    fn deref(&self) -> &T {
430        unsafe { &*self.lock.data.get() }
431    }
432}
433
434impl<T: ?Sized> DerefMut for MutexGuard<'_, T> {
435    #[inline]
436    fn deref_mut(&mut self) -> &mut T {
437        unsafe { &mut *self.lock.data.get() }
438    }
439}
440
441impl<T: ?Sized> Drop for MutexGuard<'_, T> {
442    #[inline]
443    fn drop(&mut self) {
444        self.lock.poison.done(&self.poison);
445        let released = self.lock.futex.swap(UNLOCKED, Ordering::Release);
446        if released == CONTENDED {
447            wake_one(addr_of!(self.lock.futex));
448        } else if released == EXTRA_CONTENDED {
449            wake_one(addr_of!(self.lock.futex));
450            wake_one(addr_of!(self.lock.futex));
451        }
452    }
453}
454
455impl<T: ?Sized + std::fmt::Debug> std::fmt::Debug for MutexGuard<'_, T> {
456    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457        std::fmt::Debug::fmt(&**self, f)
458    }
459}
460
461impl<T: ?Sized + std::fmt::Display> std::fmt::Display for MutexGuard<'_, T> {
462    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
463        (**self).fmt(f)
464    }
465}
466
467#[cfg(test)]
468mod test {
469    use std::sync::Arc;
470
471    use crate::Mutex;
472
473    #[test]
474    fn poisoned() {
475        let m = Arc::new(Mutex::new(()));
476        let mt = m.clone();
477        let _ = std::thread::spawn(move || {
478            let _g = mt.lock().expect("lock must succeed");
479            panic!("bail while locked");
480        })
481        .join();
482        match m.lock() {
483            Ok(_) => panic!("must not lock"),
484            Err(_poison) => {
485                // it is poisoned
486            }
487        };
488    }
489}