shared_lock/
lock.rs

1#[cfg(doc)]
2use crate::locked::Locked;
3use {
4    crate::execution_unit::execution_unit_id,
5    opera::{PhantomNotSend, PhantomNotSync},
6    parking_lot::{
7        RawMutex,
8        lock_api::{RawMutex as RawMutexTrait, RawMutexFair, RawMutexTimed},
9    },
10    run_on_drop::on_drop,
11    static_assertions::assert_not_impl_any,
12    std::{
13        cell::Cell,
14        fmt::{Debug, Formatter},
15        mem::{self, ManuallyDrop},
16        ptr,
17        sync::{
18            Arc,
19            atomic::{AtomicUsize, Ordering::Relaxed},
20        },
21        time::{Duration, Instant},
22    },
23};
24
25#[cfg(test)]
26mod tests;
27
28/// A re-entrant lock that can be used to protect multiple objects.
29///
30/// Locking this lock automatically gives access to all objects protected by it.
31///
32/// # Example
33///
34/// ```
35/// use shared_lock::Lock;
36///
37/// let lock = Lock::default();
38/// let locked1 = lock.wrap(1);
39/// let locked2 = lock.wrap(2);
40/// let guard = &lock.lock();
41/// assert_eq!(*locked1.get(guard), 1);
42/// assert_eq!(*locked2.get(guard), 2);
43/// ```
44#[derive(Clone, Default)]
45pub struct Lock {
46    shared: Arc<Shared>,
47}
48
49struct Shared {
50    // We enforce the following invariants:
51    // 1. if tickets > 0, then raw_mutex is locked
52    // 2. if execution_unit_id != 0, then the mutex is locked and the execution unit with
53    //    the id execution_unit_id locked it
54    // We say that the current execution unit owns a ticket if tickets > 0 and
55    // execution_unit_id is the execution_unit_id of the current execution unit.
56    raw_mutex: RawMutex,
57    // Mutations of this field are protected by the raw_mutex.
58    execution_unit_id: AtomicUsize,
59    // This field is protected by the raw_mutex.
60    tickets: Cell<u64>,
61}
62
63/// An acquired lock guard.
64///
65/// This object is created by calling [`Lock::lock`] or one of the other locking
66/// functions.
67///
68/// A [`Guard`] can be used to access [`Locked`] data by calling [`Locked::get`].
69///
70/// Each [`Guard`] represents a ticket of the [`Lock`] it was created from. A thread can
71/// have any number of tickets and while a thread holds at least 1 ticket of a [`Lock`],
72/// no other thread can acquire a ticket from that [`Lock`].
73///
74/// A [`Guard`] can temporarily give up its ticket by calling [`Guard::unlocked`] or
75/// [`Guard::unlocked_fair`]. The ticket will be restored before those functions return.
76/// The [`Guard`] is inaccessible while the function is running.
77///
78/// Dropping the [`Guard`] or calling [`Guard::unlock_fair`] consumes the ticket.
79///
80/// The [`Guard`] can be passed to [`mem::forget`] to leak the ticket without leaking any
81/// memory.
82///
83/// # Example
84///
85/// ```
86/// use shared_lock::Lock;
87///
88/// let lock = Lock::default();
89/// let _guard = lock.lock();
90/// ```
91pub struct Guard<'a> {
92    lock: &'a Lock,
93    _phantom_not_send: PhantomNotSend,
94    _phantom_not_sync: PhantomNotSync,
95}
96
97unsafe impl Send for Lock {}
98
99unsafe impl Sync for Lock {}
100
101assert_not_impl_any!(Guard<'_>: Sync, Send);
102
103impl Default for Shared {
104    fn default() -> Self {
105        Self {
106            raw_mutex: RawMutex::INIT,
107            execution_unit_id: AtomicUsize::new(0),
108            tickets: Cell::new(0),
109        }
110    }
111}
112
113macro_rules! maybe_lock_fast {
114    ($slf:expr, $guard:ident, $ret:expr) => {
115        let shared = &*$slf.shared;
116        if shared.execution_unit_id.load(Relaxed) == execution_unit_id() {
117            // SAFETY: - We have just checked that execution_unit_id contains the ID of
118            //           the current execution unit.
119            //         - By the invariants, this means that the current execution unit is
120            //           holding the mutex.
121            //         - Therefore no other execution unit is allowed to modify the
122            //           field.
123            let $guard = unsafe { $slf.add_ticket() };
124            return $ret;
125        }
126    };
127}
128
129impl Lock {
130    /// Forcibly consumes a ticket.
131    ///
132    /// This can be used to consume the ticket of a [`Guard`] that was passed to
133    /// [`mem::forget`].
134    ///
135    /// # Safety
136    ///
137    /// - The current thread must own a ticket.
138    /// - The invariant that each [`Guard`] owns a ticket must be upheld whenever a
139    ///   [`Guard`] is used or dropped.
140    ///
141    /// # Example
142    ///
143    /// ```
144    /// use std::mem;
145    /// use shared_lock::Lock;
146    ///
147    /// let lock = Lock::default();
148    /// let guard = lock.lock();
149    /// assert!(lock.is_locked());
150    /// mem::forget(guard);
151    /// // SAFETY: This consumes the ticket from the guard.
152    /// unsafe {
153    ///     lock.force_unlock();
154    /// }
155    /// assert!(!lock.is_locked());
156    /// ```
157    #[inline]
158    pub unsafe fn force_unlock(&self) {
159        // SAFETY: - The calling thread owning a ticket means that execution_unit_id is
160        //           the ID of the current execution unit and tickets > 0.
161        unsafe {
162            self.force_unlock_::<false>();
163        }
164    }
165
166    /// Forcibly consumes a ticket.
167    ///
168    /// If this causes the number tickets to go to 0, the underlying mutex will be
169    /// unlocked fairly.
170    ///
171    /// This can be used to consume the ticket of a [`Guard`] that was passed to
172    /// [`mem::forget`].
173    ///
174    /// # Safety
175    ///
176    /// - The current thread must own a ticket.
177    /// - The invariant that each [`Guard`] owns a ticket must be upheld whenever a
178    ///   [`Guard`] is used or dropped.
179    ///
180    /// # Example
181    ///
182    /// ```
183    /// use std::mem;
184    /// use shared_lock::Lock;
185    ///
186    /// let lock = Lock::default();
187    /// let guard = lock.lock();
188    /// assert!(lock.is_locked());
189    /// mem::forget(guard);
190    /// // SAFETY: This consumes the ticket from the guard.
191    /// unsafe {
192    ///     lock.force_unlock_fair();
193    /// }
194    /// assert!(!lock.is_locked());
195    /// ```
196    #[inline]
197    pub unsafe fn force_unlock_fair(&self) {
198        // SAFETY: - The calling thread owning a ticket means that execution_unit_id is
199        //           the ID of the current execution unit and tickets > 0.
200        unsafe {
201            self.force_unlock_::<true>();
202        }
203    }
204
205    /// # Safety
206    ///
207    /// - execution_unit_id must be the ID of the calling execution unit.
208    /// - tickets must be > 0.
209    #[inline]
210    unsafe fn force_unlock_<const FAIR: bool>(&self) {
211        let shared = &*self.shared;
212        // SAFETY: - By the safety requirements of this function, execution_unit_id is the
213        //           ID of the current execution unit.
214        //         - By the invariants, the current execution unit is holding the lock.
215        //         - Therefore we are allowed to access this field.
216        let guards = shared.tickets.get();
217        debug_assert!(guards > 0);
218        // SAFETY: - Dito.
219        shared.tickets.set(guards - 1);
220        if guards == 1 {
221            // SAFETY: - We've just set tickets to 0.
222            //         - The execution_unit_id requirements is forwarded to the caller of
223            //           this function.
224            unsafe {
225                self.force_unlock_slow::<FAIR>();
226            }
227        }
228    }
229
230    /// # Safety
231    ///
232    /// - tickets must be 0
233    /// - execution_unit_id must be the execution unit ID of the current execution unit
234    #[cold]
235    #[inline]
236    unsafe fn force_unlock_slow<const FAIR: bool>(&self) {
237        debug_assert_eq!(
238            self.shared.execution_unit_id.load(Relaxed),
239            execution_unit_id(),
240        );
241        // SAFETY: - By the safety requirements of this function, execution_unit_id is
242        //           the execution unit of the calling execution unit.
243        //         - By the invariants, the current execution unit is holding the mutex.
244        //         - Therefore we are allowed to access this field.
245        debug_assert_eq!(self.shared.tickets.get(), 0);
246        self.shared.execution_unit_id.store(0, Relaxed);
247        // SAFETY: - By the safety requirements of this function, the number of tickets is
248        //           0.
249        //         - As discussed above, the current execution unit is holding the mutex.
250        unsafe {
251            if FAIR {
252                self.shared.raw_mutex.unlock_fair();
253            } else {
254                self.shared.raw_mutex.unlock();
255            }
256        }
257    }
258
259    /// Returns whether this lock is locked.
260    ///
261    /// # Example
262    ///
263    /// ```
264    /// use shared_lock::Lock;
265    ///
266    /// let lock = Lock::default();
267    /// assert!(!lock.is_locked());
268    /// let _guard = lock.lock();
269    /// assert!(lock.is_locked());
270    /// ```
271    #[inline]
272    pub fn is_locked(&self) -> bool {
273        self.shared.raw_mutex.is_locked()
274    }
275
276    /// Returns whether this lock is locked by the guard.
277    ///
278    /// # Example
279    ///
280    /// ```
281    /// use shared_lock::Lock;
282    ///
283    /// let lock1 = Lock::default();
284    /// let lock2 = Lock::default();
285    ///
286    /// let guard1 = &lock1.lock();
287    /// let guard2 = &lock2.lock();
288    ///
289    /// assert!(lock1.is_locked_by(&guard1));
290    /// assert!(!lock1.is_locked_by(&guard2));
291    /// ```
292    #[inline]
293    pub fn is_locked_by(&self, guard: &Guard<'_>) -> bool {
294        self == guard.lock
295    }
296
297    /// Returns whether the current thread is holding the lock.
298    ///
299    /// # Example
300    ///
301    /// ```
302    /// use std::thread;
303    /// use shared_lock::Lock;
304    ///
305    /// let lock = Lock::default();
306    /// let _guard = lock.lock();
307    /// assert!(lock.is_locked_by_current_thread());
308    ///
309    /// thread::scope(|scope| {
310    ///     let handle = scope.spawn(|| {
311    ///         assert!(!lock.is_locked_by_current_thread());
312    ///     });
313    ///     handle.join().unwrap();
314    /// });
315    /// ```
316    #[inline]
317    pub fn is_locked_by_current_thread(&self) -> bool {
318        let shared = &*self.shared;
319        shared.execution_unit_id.load(Relaxed) == execution_unit_id()
320    }
321
322    /// Acquires this lock.
323    ///
324    /// If the lock is held by another thread, then this function will block until it is
325    /// able to acquire the lock. If the current thread has already acquired the lock, the
326    /// function returns immediately.
327    ///
328    /// # Example
329    ///
330    /// ```
331    /// use shared_lock::Lock;
332    ///
333    /// let lock = Lock::default();
334    /// let _guard = lock.lock();
335    /// ```
336    #[inline]
337    pub fn lock(&self) -> Guard<'_> {
338        maybe_lock_fast!(self, guard, guard);
339        self.lock_slow()
340    }
341
342    #[cold]
343    #[inline(always)]
344    fn lock_slow(&self) -> Guard<'_> {
345        self.shared.raw_mutex.lock();
346        // SAFETY: - We've just locked the mutex.
347        unsafe { self.add_ticket_after_lock() }
348    }
349
350    /// # Safety
351    ///
352    /// - The current execution unit must just have succeeded in locking the mutex.
353    #[inline]
354    unsafe fn add_ticket_after_lock(&self) -> Guard<'_> {
355        let shared = &*self.shared;
356        // SAFETY: - By the requirements of this function, we've just locked the mutex.
357        //         - Therefore
358        //           - we are allowed to mutate this field.
359        //           - setting the execution_unit_id to the ID of the current execution
360        //             unit upholds the invariant.
361        shared.execution_unit_id.store(execution_unit_id(), Relaxed);
362        // SAFETY: - We have just set execution_unit_id to the ID of the current execution
363        //           unit.
364        //         - Since we're holding the lock, no other thread is allowed to modify
365        //           the field.
366        unsafe { self.add_ticket() }
367    }
368
369    /// # Safety
370    ///
371    /// - execution_unit_id must be the ID of the current execution unit.
372    #[inline]
373    unsafe fn add_ticket(&self) -> Guard<'_> {
374        let shared = &*self.shared;
375        // SAFETY: - By the requirements of this function, execution_unit_id is the ID of
376        //           the current execution unit.
377        //         - By the invariants, this means that the current execution unit it
378        //           holding the mutex.
379        let guards = shared.tickets.get();
380        if guards == u64::MAX {
381            #[cold]
382            fn never() -> ! {
383                #[allow(clippy::empty_loop)]
384                loop {}
385            }
386            never();
387        }
388        // SAFETY: - Dito regarding the ability to access tickets.
389        //         - Since the current execution unit is holding the mutex, setting
390        //           tickets to guards + 1 > 0 upholds the invariant.
391        shared.tickets.set(guards + 1);
392        // SAFETY: - We've just added a ticket and we're assigning ownership of that
393        //           ticket to the new Guard.
394        //         - Therefore the requirements that the invariant is upheld is unaffected
395        //           by this call.
396        unsafe { self.make_guard_unchecked_() }
397    }
398
399    /// Creates a new [`Guard`] without checking if the lock is held.
400    ///
401    /// # Safety
402    ///
403    /// - The invariant that each [`Guard`] owns a ticket must be upheld whenever a
404    ///   [`Guard`] is used or dropped.
405    ///
406    /// # Example
407    ///
408    /// ```
409    /// use std::mem;
410    /// use shared_lock::Lock;
411    ///
412    /// let lock = Lock::default();
413    /// mem::forget(lock.lock());
414    /// // SAFETY: This recovers the guard we just forgot.
415    /// let _guard = unsafe {
416    ///     lock.make_guard_unchecked()
417    /// };
418    /// ```
419    #[inline]
420    pub unsafe fn make_guard_unchecked(&self) -> Guard<'_> {
421        // SAFETY: The requirement is forwarded to the caller.
422        unsafe { self.make_guard_unchecked_() }
423    }
424
425    /// # Safety
426    ///
427    /// - The invariant that each [`Guard`] owns a ticket must be upheld whenever a
428    ///   [`Guard`] is used or dropped.
429    #[inline]
430    unsafe fn make_guard_unchecked_(&self) -> Guard<'_> {
431        Guard {
432            lock: self,
433            _phantom_not_send: Default::default(),
434            _phantom_not_sync: Default::default(),
435        }
436    }
437
438    /// Attempts to acquire this lock.
439    ///
440    /// If the lock cannot be acquired at this time, `None` is returned. Otherwise a guard
441    /// is returned and the lock will be unlocked when the guard is dropped.
442    ///
443    /// This function does not block.
444    ///
445    /// # Example
446    ///
447    /// ```
448    /// use std::thread;
449    /// use std::time::{Duration, Instant};
450    /// use shared_lock::Lock;
451    ///
452    /// let timeout = Duration::from_millis(200);
453    /// let lock = Lock::default();
454    /// let _guard = lock.lock();
455    /// // The same thread can lock the lock again.
456    /// assert!(lock.try_lock().is_some());
457    ///
458    /// thread::scope(|scope| {
459    ///     let join_handle = scope.spawn(|| {
460    ///         // Another thread cannot lock the lock.
461    ///         assert!(lock.try_lock().is_none());
462    ///     });
463    ///     join_handle.join().unwrap();
464    /// });
465    /// ```
466    #[inline]
467    pub fn try_lock(&self) -> Option<Guard<'_>> {
468        maybe_lock_fast!(self, guard, Some(guard));
469        self.try_lock_slow()
470    }
471
472    #[cold]
473    #[inline]
474    fn try_lock_slow(&self) -> Option<Guard<'_>> {
475        self.shared.raw_mutex.try_lock().then(|| {
476            // SAFETY: - We've just locked the mutex.
477            unsafe { self.add_ticket_after_lock() }
478        })
479    }
480
481    /// Attempts to acquire this lock until a timeout has expired.
482    ///
483    /// If the lock cannot be acquired before the timeout expires, `None` is returned.
484    /// Otherwise a guard is returned and the lock will be unlocked when the guard is
485    /// dropped.
486    ///
487    /// # Example
488    ///
489    /// ```
490    /// use std::thread;
491    /// use std::time::{Duration, Instant};
492    /// use shared_lock::Lock;
493    ///
494    /// let timeout = Duration::from_millis(200);
495    /// let lock = Lock::default();
496    /// let _guard = lock.lock();
497    ///
498    /// thread::scope(|scope| {
499    ///     let join_handle = scope.spawn(|| {
500    ///         let guard = lock.try_lock_for(timeout);
501    ///         assert!(guard.is_none());
502    ///     });
503    ///     join_handle.join().unwrap();
504    /// });
505    /// ```
506    #[inline]
507    pub fn try_lock_for(&self, duration: Duration) -> Option<Guard<'_>> {
508        maybe_lock_fast!(self, guard, Some(guard));
509        self.try_lock_for_slow(duration)
510    }
511
512    #[cold]
513    #[inline]
514    fn try_lock_for_slow(&self, duration: Duration) -> Option<Guard<'_>> {
515        self.shared.raw_mutex.try_lock_for(duration).then(|| {
516            // SAFETY: - We've just locked the mutex.
517            unsafe { self.add_ticket_after_lock() }
518        })
519    }
520
521    /// Attempts to acquire this lock until a timeout is reached.
522    ///
523    /// If the lock cannot be acquired before the timeout expires, `None` is returned.
524    /// Otherwise a guard is returned and the lock will be unlocked when the guard is
525    /// dropped.
526    ///
527    /// # Example
528    ///
529    /// ```
530    /// use std::thread;
531    /// use std::time::{Duration, Instant};
532    /// use shared_lock::Lock;
533    ///
534    /// let timeout = Instant::now() + Duration::from_millis(200);
535    /// let lock = Lock::default();
536    /// let _guard = lock.lock();
537    ///
538    /// thread::scope(|scope| {
539    ///     let join_handle = scope.spawn(|| {
540    ///         let guard = lock.try_lock_until(timeout);
541    ///         assert!(guard.is_none());
542    ///     });
543    ///     join_handle.join().unwrap();
544    /// });
545    /// ```
546    #[inline]
547    pub fn try_lock_until(&self, instant: Instant) -> Option<Guard<'_>> {
548        maybe_lock_fast!(self, guard, Some(guard));
549        self.try_lock_until_slow(instant)
550    }
551
552    #[cold]
553    #[inline]
554    fn try_lock_until_slow(&self, instant: Instant) -> Option<Guard<'_>> {
555        self.shared.raw_mutex.try_lock_until(instant).then(|| {
556            // SAFETY: - We've just locked the mutex.
557            unsafe { self.add_ticket_after_lock() }
558        })
559    }
560
561    #[inline]
562    pub(crate) fn addr(&self) -> *const u8 {
563        let addr: *const Shared = &*self.shared;
564        addr.cast()
565    }
566}
567
568impl Debug for Lock {
569    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
570        f.debug_struct("Lock")
571            .field("id", &self.addr())
572            .finish_non_exhaustive()
573    }
574}
575
576impl PartialEq for Lock {
577    #[inline]
578    fn eq(&self, other: &Self) -> bool {
579        ptr::eq::<Shared>(&*self.shared, &*other.shared)
580    }
581}
582
583impl Eq for Lock {}
584
585impl Guard<'_> {
586    /// Unlocks this guard.
587    ///
588    /// If this causes the number of tickets to drop to 0, the underlying mutex is
589    /// unlocked using a fair protocol.
590    ///
591    /// # Example
592    ///
593    /// ```
594    /// use shared_lock::Lock;
595    ///
596    /// let lock = Lock::default();
597    /// let guard = lock.lock();
598    /// guard.unlock_fair();
599    /// ```
600    #[inline]
601    pub fn unlock_fair(self) {
602        let slf = ManuallyDrop::new(self);
603        // SAFETY: - By the invariants, this guard owns a ticket which means that
604        //           the execution_unit_id is the ID of the current execution unit.
605        //         - Since we've wrapped self in ManuallyDrop, we know that it won't be
606        //           used after this.
607        unsafe {
608            slf.lock.force_unlock_fair();
609        }
610    }
611
612    /// Unlocks this guard, runs a function, and then re-acquires the guard.
613    ///
614    /// If another guard exists, then other threads will not be able to acquire the lock
615    /// even while the function is running.
616    ///
617    /// # Example
618    ///
619    /// ```
620    /// use shared_lock::Lock;
621    ///
622    /// let lock = Lock::default();
623    /// let locked = lock.wrap(1);
624    /// let mut guard = lock.lock();
625    /// assert_eq!(*locked.get(&guard), 1);
626    /// guard.unlocked(|| {
627    ///     assert!(!lock.is_locked());
628    /// });
629    /// assert_eq!(*locked.get(&guard), 1);
630    /// ```
631    #[inline]
632    pub fn unlocked<T>(&mut self, f: impl FnOnce() -> T) -> T {
633        self.unlocked_::<_, false>(f)
634    }
635
636    /// Unlocks this guard fairly, runs a function, and then re-acquires the guard.
637    ///
638    /// If another guard exists, then other threads will not be able to acquire the lock
639    /// even while the function is running.
640    ///
641    /// # Example
642    ///
643    /// ```
644    /// use shared_lock::Lock;
645    ///
646    /// let lock = Lock::default();
647    /// let locked = lock.wrap(1);
648    /// let mut guard = lock.lock();
649    /// assert_eq!(*locked.get(&guard), 1);
650    /// guard.unlocked_fair(|| {
651    ///     assert!(!lock.is_locked());
652    /// });
653    /// assert_eq!(*locked.get(&guard), 1);
654    /// ```
655    #[inline]
656    pub fn unlocked_fair<T>(&mut self, f: impl FnOnce() -> T) -> T {
657        self.unlocked_::<_, true>(f)
658    }
659
660    #[inline]
661    fn unlocked_<T, const FAIR: bool>(&mut self, f: impl FnOnce() -> T) -> T {
662        // SAFETY: - Since we have have a mutable reference, this guard cannot current be
663        //           used to access any locked data (since that borrows the guard).
664        //         - Any two guards of the same lock are interchangeable.
665        //         - This unlock operation morally consumes the guard. On drop, we restore
666        //           it by acquiring a new guard and then forgetting it.
667        unsafe {
668            self.lock.force_unlock_::<FAIR>();
669        }
670        let _lock = on_drop(|| {
671            let guard = self.lock.lock();
672            mem::forget(guard);
673        });
674        f()
675    }
676}
677
678impl Drop for Guard<'_> {
679    #[inline]
680    fn drop(&mut self) {
681        // SAFETY: - By the invariants, this guard owns a ticket which means that
682        //           the execution_unit_id is the ID of the current execution unit.
683        unsafe {
684            self.lock.force_unlock_::<false>();
685        }
686    }
687}
688
689impl Debug for Guard<'_> {
690    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
691        f.debug_struct("Guard")
692            .field("lock_id", &self.lock.addr())
693            .finish_non_exhaustive()
694    }
695}