shared_lock/lock.rs
1#[cfg(doc)]
2use crate::locked::Locked;
3use {
4 crate::execution_unit::execution_unit_id,
5 opera::{PhantomNotSend, PhantomNotSync},
6 parking_lot::{
7 RawMutex,
8 lock_api::{RawMutex as RawMutexTrait, RawMutexFair, RawMutexTimed},
9 },
10 run_on_drop::on_drop,
11 static_assertions::assert_not_impl_any,
12 std::{
13 cell::Cell,
14 fmt::{Debug, Formatter},
15 mem::{self, ManuallyDrop},
16 ptr,
17 sync::{
18 Arc,
19 atomic::{AtomicUsize, Ordering::Relaxed},
20 },
21 time::{Duration, Instant},
22 },
23};
24
25#[cfg(test)]
26mod tests;
27
28/// A re-entrant lock that can be used to protect multiple objects.
29///
30/// Locking this lock automatically gives access to all objects protected by it.
31///
32/// # Example
33///
34/// ```
35/// use shared_lock::Lock;
36///
37/// let lock = Lock::default();
38/// let locked1 = lock.wrap(1);
39/// let locked2 = lock.wrap(2);
40/// let guard = &lock.lock();
41/// assert_eq!(*locked1.get(guard), 1);
42/// assert_eq!(*locked2.get(guard), 2);
43/// ```
44#[derive(Clone, Default)]
45pub struct Lock {
46 shared: Arc<Shared>,
47}
48
49struct Shared {
50 // We enforce the following invariants:
51 // 1. if tickets > 0, then raw_mutex is locked
52 // 2. if execution_unit_id != 0, then the mutex is locked and the execution unit with
53 // the id execution_unit_id locked it
54 // We say that the current execution unit owns a ticket if tickets > 0 and
55 // execution_unit_id is the execution_unit_id of the current execution unit.
56 raw_mutex: RawMutex,
57 // Mutations of this field are protected by the raw_mutex.
58 execution_unit_id: AtomicUsize,
59 // This field is protected by the raw_mutex.
60 tickets: Cell<u64>,
61}
62
63/// An acquired lock guard.
64///
65/// This object is created by calling [`Lock::lock`] or one of the other locking
66/// functions.
67///
68/// A [`Guard`] can be used to access [`Locked`] data by calling [`Locked::get`].
69///
70/// Each [`Guard`] represents a ticket of the [`Lock`] it was created from. A thread can
71/// have any number of tickets and while a thread holds at least 1 ticket of a [`Lock`],
72/// no other thread can acquire a ticket from that [`Lock`].
73///
74/// A [`Guard`] can temporarily give up its ticket by calling [`Guard::unlocked`] or
75/// [`Guard::unlocked_fair`]. The ticket will be restored before those functions return.
76/// The [`Guard`] is inaccessible while the function is running.
77///
78/// Dropping the [`Guard`] or calling [`Guard::unlock_fair`] consumes the ticket.
79///
80/// The [`Guard`] can be passed to [`mem::forget`] to leak the ticket without leaking any
81/// memory.
82///
83/// # Example
84///
85/// ```
86/// use shared_lock::Lock;
87///
88/// let lock = Lock::default();
89/// let _guard = lock.lock();
90/// ```
91pub struct Guard<'a> {
92 lock: &'a Lock,
93 _phantom_not_send: PhantomNotSend,
94 _phantom_not_sync: PhantomNotSync,
95}
96
97unsafe impl Send for Lock {}
98
99unsafe impl Sync for Lock {}
100
101assert_not_impl_any!(Guard<'_>: Sync, Send);
102
103impl Default for Shared {
104 fn default() -> Self {
105 Self {
106 raw_mutex: RawMutex::INIT,
107 execution_unit_id: AtomicUsize::new(0),
108 tickets: Cell::new(0),
109 }
110 }
111}
112
113macro_rules! maybe_lock_fast {
114 ($slf:expr, $guard:ident, $ret:expr) => {
115 let shared = &*$slf.shared;
116 if shared.execution_unit_id.load(Relaxed) == execution_unit_id() {
117 // SAFETY: - We have just checked that execution_unit_id contains the ID of
118 // the current execution unit.
119 // - By the invariants, this means that the current execution unit is
120 // holding the mutex.
121 // - Therefore no other execution unit is allowed to modify the
122 // field.
123 let $guard = unsafe { $slf.add_ticket() };
124 return $ret;
125 }
126 };
127}
128
129impl Lock {
130 /// Forcibly consumes a ticket.
131 ///
132 /// This can be used to consume the ticket of a [`Guard`] that was passed to
133 /// [`mem::forget`].
134 ///
135 /// # Safety
136 ///
137 /// - The current thread must own a ticket.
138 /// - The invariant that each [`Guard`] owns a ticket must be upheld whenever a
139 /// [`Guard`] is used or dropped.
140 ///
141 /// # Example
142 ///
143 /// ```
144 /// use std::mem;
145 /// use shared_lock::Lock;
146 ///
147 /// let lock = Lock::default();
148 /// let guard = lock.lock();
149 /// assert!(lock.is_locked());
150 /// mem::forget(guard);
151 /// // SAFETY: This consumes the ticket from the guard.
152 /// unsafe {
153 /// lock.force_unlock();
154 /// }
155 /// assert!(!lock.is_locked());
156 /// ```
157 #[inline]
158 pub unsafe fn force_unlock(&self) {
159 // SAFETY: - The calling thread owning a ticket means that execution_unit_id is
160 // the ID of the current execution unit and tickets > 0.
161 unsafe {
162 self.force_unlock_::<false>();
163 }
164 }
165
166 /// Forcibly consumes a ticket.
167 ///
168 /// If this causes the number tickets to go to 0, the underlying mutex will be
169 /// unlocked fairly.
170 ///
171 /// This can be used to consume the ticket of a [`Guard`] that was passed to
172 /// [`mem::forget`].
173 ///
174 /// # Safety
175 ///
176 /// - The current thread must own a ticket.
177 /// - The invariant that each [`Guard`] owns a ticket must be upheld whenever a
178 /// [`Guard`] is used or dropped.
179 ///
180 /// # Example
181 ///
182 /// ```
183 /// use std::mem;
184 /// use shared_lock::Lock;
185 ///
186 /// let lock = Lock::default();
187 /// let guard = lock.lock();
188 /// assert!(lock.is_locked());
189 /// mem::forget(guard);
190 /// // SAFETY: This consumes the ticket from the guard.
191 /// unsafe {
192 /// lock.force_unlock_fair();
193 /// }
194 /// assert!(!lock.is_locked());
195 /// ```
196 #[inline]
197 pub unsafe fn force_unlock_fair(&self) {
198 // SAFETY: - The calling thread owning a ticket means that execution_unit_id is
199 // the ID of the current execution unit and tickets > 0.
200 unsafe {
201 self.force_unlock_::<true>();
202 }
203 }
204
205 /// # Safety
206 ///
207 /// - execution_unit_id must be the ID of the calling execution unit.
208 /// - tickets must be > 0.
209 #[inline]
210 unsafe fn force_unlock_<const FAIR: bool>(&self) {
211 let shared = &*self.shared;
212 // SAFETY: - By the safety requirements of this function, execution_unit_id is the
213 // ID of the current execution unit.
214 // - By the invariants, the current execution unit is holding the lock.
215 // - Therefore we are allowed to access this field.
216 let guards = shared.tickets.get();
217 debug_assert!(guards > 0);
218 // SAFETY: - Dito.
219 shared.tickets.set(guards - 1);
220 if guards == 1 {
221 // SAFETY: - We've just set tickets to 0.
222 // - The execution_unit_id requirements is forwarded to the caller of
223 // this function.
224 unsafe {
225 self.force_unlock_slow::<FAIR>();
226 }
227 }
228 }
229
230 /// # Safety
231 ///
232 /// - tickets must be 0
233 /// - execution_unit_id must be the execution unit ID of the current execution unit
234 #[cold]
235 #[inline]
236 unsafe fn force_unlock_slow<const FAIR: bool>(&self) {
237 debug_assert_eq!(
238 self.shared.execution_unit_id.load(Relaxed),
239 execution_unit_id(),
240 );
241 // SAFETY: - By the safety requirements of this function, execution_unit_id is
242 // the execution unit of the calling execution unit.
243 // - By the invariants, the current execution unit is holding the mutex.
244 // - Therefore we are allowed to access this field.
245 debug_assert_eq!(self.shared.tickets.get(), 0);
246 self.shared.execution_unit_id.store(0, Relaxed);
247 // SAFETY: - By the safety requirements of this function, the number of tickets is
248 // 0.
249 // - As discussed above, the current execution unit is holding the mutex.
250 unsafe {
251 if FAIR {
252 self.shared.raw_mutex.unlock_fair();
253 } else {
254 self.shared.raw_mutex.unlock();
255 }
256 }
257 }
258
259 /// Returns whether this lock is locked.
260 ///
261 /// # Example
262 ///
263 /// ```
264 /// use shared_lock::Lock;
265 ///
266 /// let lock = Lock::default();
267 /// assert!(!lock.is_locked());
268 /// let _guard = lock.lock();
269 /// assert!(lock.is_locked());
270 /// ```
271 #[inline]
272 pub fn is_locked(&self) -> bool {
273 self.shared.raw_mutex.is_locked()
274 }
275
276 /// Returns whether this lock is locked by the guard.
277 ///
278 /// # Example
279 ///
280 /// ```
281 /// use shared_lock::Lock;
282 ///
283 /// let lock1 = Lock::default();
284 /// let lock2 = Lock::default();
285 ///
286 /// let guard1 = &lock1.lock();
287 /// let guard2 = &lock2.lock();
288 ///
289 /// assert!(lock1.is_locked_by(&guard1));
290 /// assert!(!lock1.is_locked_by(&guard2));
291 /// ```
292 #[inline]
293 pub fn is_locked_by(&self, guard: &Guard<'_>) -> bool {
294 self == guard.lock
295 }
296
297 /// Returns whether the current thread is holding the lock.
298 ///
299 /// # Example
300 ///
301 /// ```
302 /// use std::thread;
303 /// use shared_lock::Lock;
304 ///
305 /// let lock = Lock::default();
306 /// let _guard = lock.lock();
307 /// assert!(lock.is_locked_by_current_thread());
308 ///
309 /// thread::scope(|scope| {
310 /// let handle = scope.spawn(|| {
311 /// assert!(!lock.is_locked_by_current_thread());
312 /// });
313 /// handle.join().unwrap();
314 /// });
315 /// ```
316 #[inline]
317 pub fn is_locked_by_current_thread(&self) -> bool {
318 let shared = &*self.shared;
319 shared.execution_unit_id.load(Relaxed) == execution_unit_id()
320 }
321
322 /// Acquires this lock.
323 ///
324 /// If the lock is held by another thread, then this function will block until it is
325 /// able to acquire the lock. If the current thread has already acquired the lock, the
326 /// function returns immediately.
327 ///
328 /// # Example
329 ///
330 /// ```
331 /// use shared_lock::Lock;
332 ///
333 /// let lock = Lock::default();
334 /// let _guard = lock.lock();
335 /// ```
336 #[inline]
337 pub fn lock(&self) -> Guard<'_> {
338 maybe_lock_fast!(self, guard, guard);
339 self.lock_slow()
340 }
341
342 #[cold]
343 #[inline(always)]
344 fn lock_slow(&self) -> Guard<'_> {
345 self.shared.raw_mutex.lock();
346 // SAFETY: - We've just locked the mutex.
347 unsafe { self.add_ticket_after_lock() }
348 }
349
350 /// # Safety
351 ///
352 /// - The current execution unit must just have succeeded in locking the mutex.
353 #[inline]
354 unsafe fn add_ticket_after_lock(&self) -> Guard<'_> {
355 let shared = &*self.shared;
356 // SAFETY: - By the requirements of this function, we've just locked the mutex.
357 // - Therefore
358 // - we are allowed to mutate this field.
359 // - setting the execution_unit_id to the ID of the current execution
360 // unit upholds the invariant.
361 shared.execution_unit_id.store(execution_unit_id(), Relaxed);
362 // SAFETY: - We have just set execution_unit_id to the ID of the current execution
363 // unit.
364 // - Since we're holding the lock, no other thread is allowed to modify
365 // the field.
366 unsafe { self.add_ticket() }
367 }
368
369 /// # Safety
370 ///
371 /// - execution_unit_id must be the ID of the current execution unit.
372 #[inline]
373 unsafe fn add_ticket(&self) -> Guard<'_> {
374 let shared = &*self.shared;
375 // SAFETY: - By the requirements of this function, execution_unit_id is the ID of
376 // the current execution unit.
377 // - By the invariants, this means that the current execution unit it
378 // holding the mutex.
379 let guards = shared.tickets.get();
380 if guards == u64::MAX {
381 #[cold]
382 fn never() -> ! {
383 #[allow(clippy::empty_loop)]
384 loop {}
385 }
386 never();
387 }
388 // SAFETY: - Dito regarding the ability to access tickets.
389 // - Since the current execution unit is holding the mutex, setting
390 // tickets to guards + 1 > 0 upholds the invariant.
391 shared.tickets.set(guards + 1);
392 // SAFETY: - We've just added a ticket and we're assigning ownership of that
393 // ticket to the new Guard.
394 // - Therefore the requirements that the invariant is upheld is unaffected
395 // by this call.
396 unsafe { self.make_guard_unchecked_() }
397 }
398
399 /// Creates a new [`Guard`] without checking if the lock is held.
400 ///
401 /// # Safety
402 ///
403 /// - The invariant that each [`Guard`] owns a ticket must be upheld whenever a
404 /// [`Guard`] is used or dropped.
405 ///
406 /// # Example
407 ///
408 /// ```
409 /// use std::mem;
410 /// use shared_lock::Lock;
411 ///
412 /// let lock = Lock::default();
413 /// mem::forget(lock.lock());
414 /// // SAFETY: This recovers the guard we just forgot.
415 /// let _guard = unsafe {
416 /// lock.make_guard_unchecked()
417 /// };
418 /// ```
419 #[inline]
420 pub unsafe fn make_guard_unchecked(&self) -> Guard<'_> {
421 // SAFETY: The requirement is forwarded to the caller.
422 unsafe { self.make_guard_unchecked_() }
423 }
424
425 /// # Safety
426 ///
427 /// - The invariant that each [`Guard`] owns a ticket must be upheld whenever a
428 /// [`Guard`] is used or dropped.
429 #[inline]
430 unsafe fn make_guard_unchecked_(&self) -> Guard<'_> {
431 Guard {
432 lock: self,
433 _phantom_not_send: Default::default(),
434 _phantom_not_sync: Default::default(),
435 }
436 }
437
438 /// Attempts to acquire this lock.
439 ///
440 /// If the lock cannot be acquired at this time, `None` is returned. Otherwise a guard
441 /// is returned and the lock will be unlocked when the guard is dropped.
442 ///
443 /// This function does not block.
444 ///
445 /// # Example
446 ///
447 /// ```
448 /// use std::thread;
449 /// use std::time::{Duration, Instant};
450 /// use shared_lock::Lock;
451 ///
452 /// let timeout = Duration::from_millis(200);
453 /// let lock = Lock::default();
454 /// let _guard = lock.lock();
455 /// // The same thread can lock the lock again.
456 /// assert!(lock.try_lock().is_some());
457 ///
458 /// thread::scope(|scope| {
459 /// let join_handle = scope.spawn(|| {
460 /// // Another thread cannot lock the lock.
461 /// assert!(lock.try_lock().is_none());
462 /// });
463 /// join_handle.join().unwrap();
464 /// });
465 /// ```
466 #[inline]
467 pub fn try_lock(&self) -> Option<Guard<'_>> {
468 maybe_lock_fast!(self, guard, Some(guard));
469 self.try_lock_slow()
470 }
471
472 #[cold]
473 #[inline]
474 fn try_lock_slow(&self) -> Option<Guard<'_>> {
475 self.shared.raw_mutex.try_lock().then(|| {
476 // SAFETY: - We've just locked the mutex.
477 unsafe { self.add_ticket_after_lock() }
478 })
479 }
480
481 /// Attempts to acquire this lock until a timeout has expired.
482 ///
483 /// If the lock cannot be acquired before the timeout expires, `None` is returned.
484 /// Otherwise a guard is returned and the lock will be unlocked when the guard is
485 /// dropped.
486 ///
487 /// # Example
488 ///
489 /// ```
490 /// use std::thread;
491 /// use std::time::{Duration, Instant};
492 /// use shared_lock::Lock;
493 ///
494 /// let timeout = Duration::from_millis(200);
495 /// let lock = Lock::default();
496 /// let _guard = lock.lock();
497 ///
498 /// thread::scope(|scope| {
499 /// let join_handle = scope.spawn(|| {
500 /// let guard = lock.try_lock_for(timeout);
501 /// assert!(guard.is_none());
502 /// });
503 /// join_handle.join().unwrap();
504 /// });
505 /// ```
506 #[inline]
507 pub fn try_lock_for(&self, duration: Duration) -> Option<Guard<'_>> {
508 maybe_lock_fast!(self, guard, Some(guard));
509 self.try_lock_for_slow(duration)
510 }
511
512 #[cold]
513 #[inline]
514 fn try_lock_for_slow(&self, duration: Duration) -> Option<Guard<'_>> {
515 self.shared.raw_mutex.try_lock_for(duration).then(|| {
516 // SAFETY: - We've just locked the mutex.
517 unsafe { self.add_ticket_after_lock() }
518 })
519 }
520
521 /// Attempts to acquire this lock until a timeout is reached.
522 ///
523 /// If the lock cannot be acquired before the timeout expires, `None` is returned.
524 /// Otherwise a guard is returned and the lock will be unlocked when the guard is
525 /// dropped.
526 ///
527 /// # Example
528 ///
529 /// ```
530 /// use std::thread;
531 /// use std::time::{Duration, Instant};
532 /// use shared_lock::Lock;
533 ///
534 /// let timeout = Instant::now() + Duration::from_millis(200);
535 /// let lock = Lock::default();
536 /// let _guard = lock.lock();
537 ///
538 /// thread::scope(|scope| {
539 /// let join_handle = scope.spawn(|| {
540 /// let guard = lock.try_lock_until(timeout);
541 /// assert!(guard.is_none());
542 /// });
543 /// join_handle.join().unwrap();
544 /// });
545 /// ```
546 #[inline]
547 pub fn try_lock_until(&self, instant: Instant) -> Option<Guard<'_>> {
548 maybe_lock_fast!(self, guard, Some(guard));
549 self.try_lock_until_slow(instant)
550 }
551
552 #[cold]
553 #[inline]
554 fn try_lock_until_slow(&self, instant: Instant) -> Option<Guard<'_>> {
555 self.shared.raw_mutex.try_lock_until(instant).then(|| {
556 // SAFETY: - We've just locked the mutex.
557 unsafe { self.add_ticket_after_lock() }
558 })
559 }
560
561 #[inline]
562 pub(crate) fn addr(&self) -> *const u8 {
563 let addr: *const Shared = &*self.shared;
564 addr.cast()
565 }
566}
567
568impl Debug for Lock {
569 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
570 f.debug_struct("Lock")
571 .field("id", &self.addr())
572 .finish_non_exhaustive()
573 }
574}
575
576impl PartialEq for Lock {
577 #[inline]
578 fn eq(&self, other: &Self) -> bool {
579 ptr::eq::<Shared>(&*self.shared, &*other.shared)
580 }
581}
582
583impl Eq for Lock {}
584
585impl Guard<'_> {
586 /// Unlocks this guard.
587 ///
588 /// If this causes the number of tickets to drop to 0, the underlying mutex is
589 /// unlocked using a fair protocol.
590 ///
591 /// # Example
592 ///
593 /// ```
594 /// use shared_lock::Lock;
595 ///
596 /// let lock = Lock::default();
597 /// let guard = lock.lock();
598 /// guard.unlock_fair();
599 /// ```
600 #[inline]
601 pub fn unlock_fair(self) {
602 let slf = ManuallyDrop::new(self);
603 // SAFETY: - By the invariants, this guard owns a ticket which means that
604 // the execution_unit_id is the ID of the current execution unit.
605 // - Since we've wrapped self in ManuallyDrop, we know that it won't be
606 // used after this.
607 unsafe {
608 slf.lock.force_unlock_fair();
609 }
610 }
611
612 /// Unlocks this guard, runs a function, and then re-acquires the guard.
613 ///
614 /// If another guard exists, then other threads will not be able to acquire the lock
615 /// even while the function is running.
616 ///
617 /// # Example
618 ///
619 /// ```
620 /// use shared_lock::Lock;
621 ///
622 /// let lock = Lock::default();
623 /// let locked = lock.wrap(1);
624 /// let mut guard = lock.lock();
625 /// assert_eq!(*locked.get(&guard), 1);
626 /// guard.unlocked(|| {
627 /// assert!(!lock.is_locked());
628 /// });
629 /// assert_eq!(*locked.get(&guard), 1);
630 /// ```
631 #[inline]
632 pub fn unlocked<T>(&mut self, f: impl FnOnce() -> T) -> T {
633 self.unlocked_::<_, false>(f)
634 }
635
636 /// Unlocks this guard fairly, runs a function, and then re-acquires the guard.
637 ///
638 /// If another guard exists, then other threads will not be able to acquire the lock
639 /// even while the function is running.
640 ///
641 /// # Example
642 ///
643 /// ```
644 /// use shared_lock::Lock;
645 ///
646 /// let lock = Lock::default();
647 /// let locked = lock.wrap(1);
648 /// let mut guard = lock.lock();
649 /// assert_eq!(*locked.get(&guard), 1);
650 /// guard.unlocked_fair(|| {
651 /// assert!(!lock.is_locked());
652 /// });
653 /// assert_eq!(*locked.get(&guard), 1);
654 /// ```
655 #[inline]
656 pub fn unlocked_fair<T>(&mut self, f: impl FnOnce() -> T) -> T {
657 self.unlocked_::<_, true>(f)
658 }
659
660 #[inline]
661 fn unlocked_<T, const FAIR: bool>(&mut self, f: impl FnOnce() -> T) -> T {
662 // SAFETY: - Since we have have a mutable reference, this guard cannot current be
663 // used to access any locked data (since that borrows the guard).
664 // - Any two guards of the same lock are interchangeable.
665 // - This unlock operation morally consumes the guard. On drop, we restore
666 // it by acquiring a new guard and then forgetting it.
667 unsafe {
668 self.lock.force_unlock_::<FAIR>();
669 }
670 let _lock = on_drop(|| {
671 let guard = self.lock.lock();
672 mem::forget(guard);
673 });
674 f()
675 }
676}
677
678impl Drop for Guard<'_> {
679 #[inline]
680 fn drop(&mut self) {
681 // SAFETY: - By the invariants, this guard owns a ticket which means that
682 // the execution_unit_id is the ID of the current execution unit.
683 unsafe {
684 self.lock.force_unlock_::<false>();
685 }
686 }
687}
688
689impl Debug for Guard<'_> {
690 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
691 f.debug_struct("Guard")
692 .field("lock_id", &self.lock.addr())
693 .finish_non_exhaustive()
694 }
695}