Skip to main content

async_priority_lock/
lib.rs

1//! # Async Priority Lock
2//! Primitive for priority-based locking of resources.
3//! Locks are granted in order of priority, with the option for lock holders to allow for eviction
4//! (via `evict` flag).
5//!
6//! Note that this is a "California eviction" - the holder is requested to release the lock, but
7//! it's up to the holder of the guard to listen (via .evicted()) and release the lock.
8//!
9//! For no-std environments, the `no-std` feature can be enabled (although alloc is still needed).
10#![cfg_attr(feature = "no-std", no_std)]
11use core::{
12    cell::UnsafeCell,
13    fmt::{Debug, Display},
14    format_args,
15    mem::MaybeUninit,
16    ops::{Deref, DerefMut},
17    pin::Pin,
18    sync::atomic::{AtomicPtr, Ordering},
19    task::{Context, Poll, Waker},
20};
21
22#[cfg(feature = "no-std")]
23extern crate alloc;
24#[cfg(feature = "no-std")]
25use alloc::boxed::Box;
26#[cfg(not(feature = "no-std"))]
27use std::boxed::Box;
28
29#[cfg(not(feature = "no-std"))]
30#[derive(Default)]
31#[repr(transparent)]
32struct Mutex<T>(std::sync::Mutex<T>);
33
34#[cfg(feature = "no-std")]
35#[derive(Default)]
36#[repr(transparent)]
37struct Mutex<T>(spin::Mutex<T>);
38
39impl<T> Mutex<T> {
40    const fn new(value: T) -> Self {
41        #[cfg(not(feature = "no-std"))]
42        return Self(std::sync::Mutex::new(value));
43
44        #[cfg(feature = "no-std")]
45        return Self(spin::Mutex::new(value));
46    }
47
48    #[cfg(not(feature = "no-std"))]
49    #[inline(always)]
50    fn lock(&self) -> impl Deref<Target = T> + DerefMut {
51        self.0.lock().unwrap()
52    }
53
54    #[cfg(feature = "no-std")]
55    #[inline(always)]
56    fn lock(&self) -> impl Deref<Target = T> + DerefMut {
57        self.0.lock()
58    }
59}
60
61// ensure at least 3 free bits
62#[cfg(feature = "evict")]
63#[repr(align(8))]
64struct AlignedWaker(Waker);
65// ensure at least 2 free bits (only 2 are needed if evict flag isn't set)
66#[cfg(not(feature = "evict"))]
67#[repr(align(4))]
68struct AlignedWaker(Waker);
69
70/// If true, the waiter has exclusive access to the guarded resource
71const WAITER_FLAG_HAS_LOCK: usize = 1;
72/// This bit is a special bit used to prevent accidentally corrupting the memory of the ptr itself
73/// while it's being read.  If set to 1, it is not being read, if set to 0 then it was just read
74/// and may still be.
75/// We clear this every time we look at the pointer, then set it again if it was set previously.
76/// If the data part of this pointer is null we must wait for it to be set back to 1 before
77/// modifying the waker itself.
78/// This flag can never be false when a waker is present, thus we can take and drop without checking
79/// this bit when we have ownership of the waker's storage.
80const WAITER_FLAG_CAN_DROP: usize = 2;
81/// This flag indicates that a higher priority waiter is queued.  Thus, if possible the current
82/// holder should exit early.
83/// It is possible for both HAS_LOCK and WANTS_EVICT to be set before the waiter is ever polled.
84/// In this case, we still return the lock and it is up to the owner of the guard to release the
85/// lock if that is desired.
86#[cfg(feature = "evict")]
87const WAITER_FLAG_WANTS_EVICT: usize = 4;
88#[cfg(not(feature = "evict"))]
89const WAITER_FLAG_WANTS_EVICT: usize = 0;
90const WAITER_FLAG_MASK: usize =
91    WAITER_FLAG_HAS_LOCK | WAITER_FLAG_CAN_DROP | WAITER_FLAG_WANTS_EVICT;
92const WAITER_PTR_MASK: usize = !WAITER_FLAG_MASK;
93
94#[inline(always)]
95fn get_flag(w: *mut AlignedWaker) -> usize {
96    w as usize & WAITER_FLAG_MASK
97}
98
99impl<P: Ord> PriorityMutexWaiter<P> {
100    #[inline]
101    fn notify(&self) {
102        let ptr = self
103            .waker
104            .fetch_and(WAITER_FLAG_MASK ^ WAITER_FLAG_CAN_DROP, Ordering::AcqRel);
105
106        let waker_ptr = ptr.map_addr(|x| x & WAITER_PTR_MASK);
107        // if ptr isn't null, read it first
108        let maybe_waker = (!waker_ptr.is_null()).then(|| unsafe { waker_ptr.read() });
109
110        // now that we've read the ptr (if it was set) we can restore ownership
111        // note that it is impossible to have a state where both the waker ptr and the drop flag
112        // are unset - we always set the drop flag when we set the address part of the ptr.
113        if ptr as usize & WAITER_FLAG_CAN_DROP != 0 {
114            self.waker.fetch_or(WAITER_FLAG_CAN_DROP, Ordering::AcqRel);
115        }
116
117        if let Some(waker) = maybe_waker {
118            waker.0.wake();
119        }
120    }
121
122    #[inline]
123    fn add_flag(&self, flag: usize) {
124        let recv = self
125            .waker
126            .fetch_or(flag, Ordering::AcqRel)
127            .map_addr(|x| x & WAITER_PTR_MASK);
128
129        // no need to notify if waker is not set - if it's not set, either we already took the
130        // waker from another tag change, or it hasn't been polled yet. either way when we set
131        // the waker we'll check for the bits, so no races here
132        if (recv as usize) & WAITER_PTR_MASK != 0 {
133            self.notify();
134        }
135    }
136
137    #[inline]
138    fn start(&self) {
139        self.add_flag(WAITER_FLAG_HAS_LOCK)
140    }
141
142    #[cfg(feature = "evict")]
143    #[inline]
144    fn evict(&self) {
145        self.add_flag(WAITER_FLAG_WANTS_EVICT)
146    }
147
148    #[inline]
149    /// Clear waiter (must be called by owner of waker location)
150    fn clear_waker(&self, storage: &mut MaybeUninit<AlignedWaker>) -> usize {
151        let ptr = self.waker.fetch_and(WAITER_FLAG_MASK, Ordering::AcqRel);
152        let flags = ptr as usize & WAITER_FLAG_MASK;
153
154        // safe to get data before ever looking at the lock, as ptr simply won't be set if
155        let waker_ptr = ptr as usize & WAITER_PTR_MASK;
156
157        // we took the data, which means that we have exclusive access to modify
158        // (we don't need to touch the can drop flag as we have a mut ref to the storage itself, so
159        // the storage can't be dropped)
160        if waker_ptr != 0 {
161            debug_assert!(
162                waker_ptr == storage.as_ptr() as usize,
163                "if a waker exists, it must be ours {:p} {:p}",
164                ptr,
165                storage
166            );
167            unsafe { storage.assume_init_drop() };
168            return flags;
169        }
170
171        // there isn't a waker anymore, but we may still be reading it. if so, we must wait before
172        // we can change the value.
173        if ptr as usize & WAITER_FLAG_CAN_DROP == 0 {
174            while self.waker.load(Ordering::Acquire) as usize & WAITER_FLAG_CAN_DROP == 0 {}
175        }
176
177        flags
178    }
179
180    #[inline]
181    fn wait_for_flag(
182        &self,
183        cx: &mut Context<'_>,
184        waker: &mut MaybeUninit<AlignedWaker>,
185        target: usize,
186    ) -> Poll<()> {
187        if self.clear_waker(waker) & target == target {
188            return Poll::Ready(());
189        }
190
191        waker.write(AlignedWaker(cx.waker().clone()));
192        // Ok, sot the notifier
193        let existing = self.waker.fetch_or(
194            waker.as_ptr() as usize | WAITER_FLAG_CAN_DROP,
195            Ordering::AcqRel,
196        );
197
198        // we set it, but it's possible we updated state in meantime. So we might be ready.
199        if get_flag(existing) & target != target {
200            // if not ready (likely case) return pending and leave waker in place
201            return Poll::Pending;
202        }
203
204        // ok, so looks like we *did* reach the target state
205        // so now we need to remove the waker (if still there - we could have been notified again)
206        // and return
207        self.clear_waker(waker);
208
209        Poll::Ready(())
210    }
211}
212
213struct WaiterFlagFut<'a, P: Ord, const FLAG: usize> {
214    tracker: &'a PriorityMutexWaiter<P>,
215    waker: MaybeUninit<AlignedWaker>,
216}
217
218impl<'a, P: Ord, const FLAG: usize> WaiterFlagFut<'a, P, FLAG> {
219    fn new(tracker: &'a PriorityMutexWaiter<P>) -> Self {
220        Self {
221            tracker,
222            waker: MaybeUninit::uninit(),
223        }
224    }
225}
226
227impl<'a, P: Ord, const FLAG: usize> Drop for WaiterFlagFut<'a, P, FLAG> {
228    #[inline]
229    fn drop(&mut self) {
230        self.tracker.clear_waker(&mut self.waker);
231    }
232}
233
234impl<'a, P: Ord, const FLAG: usize> Future for WaiterFlagFut<'a, P, FLAG> {
235    type Output = ();
236
237    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
238        self.tracker
239            .wait_for_flag(cx, &mut self.as_mut().waker, FLAG)
240    }
241}
242
243struct PriorityMutexWaiter<P: Ord> {
244    priority: P,
245    waker: AtomicPtr<AlignedWaker>,
246    // semantics:
247    // can only be read or written to when lock on list is held
248    next: UnsafeCell<Option<Pin<Box<Self>>>>,
249    _must_pin: core::marker::PhantomPinned,
250}
251
252unsafe impl<P: Ord + Sync> Sync for PriorityMutexWaiter<P> {}
253
254impl<P: Ord> PriorityMutexWaiter<P> {
255    #[inline]
256    fn next(&self) -> &mut Option<Pin<Box<Self>>> {
257        unsafe { &mut *self.next.get() }
258    }
259
260    #[inline]
261    fn new<'a>(holder: P, has_lock: bool) -> (Pin<Box<Self>>, &'a Self) {
262        let pin = Box::pin(Self {
263            priority: holder,
264            waker: AtomicPtr::new(core::ptr::without_provenance_mut(if has_lock {
265                WAITER_FLAG_HAS_LOCK | WAITER_FLAG_CAN_DROP
266            } else {
267                WAITER_FLAG_CAN_DROP
268            })),
269            next: UnsafeCell::default(),
270            _must_pin: core::marker::PhantomPinned,
271        });
272
273        let ptr = &raw const *pin;
274        (pin, unsafe { &*ptr })
275    }
276}
277
278#[derive(Default)]
279/// A mutex that distributes access by priority as opposed to just fifo / whoever gets it first.
280/// If fifo isn't set, the current behavior is lifo - however this is may not always be the case.
281/// Having fifo = false means it doesn't matter the order of items with the same priority (instead,
282/// they will be queued in whichever order is fastest - currently, lifo)
283///
284/// If this is non-desirable, the type alias FIFOPriorityMutex can be used or the const param can
285/// FIFO be set manually to true.
286///
287/// By default, a highor P is higher priorrity, but this can be reversed via setting the
288/// LOWEST_FIRST const arg to true (or by using the LowestFirstPriorityMutex alias type.
289pub struct PriorityMutex<P: Ord, T, const FIFO: bool = false, const LOWEST_FIRST: bool = false> {
290    // PERF: Could later optimize this via using a pre-allocated block to avoid
291    // having to hit the allocator every time we enqueue a waiter (probably index by
292    // offset instead of pointer). though it'd likely be a bit less memory efficient
293    // current thought process on how I'd do this is to have a pre-allocaated (and expandable)
294    // slice/array of (Pin<Box<Waiter>>, next: usize) where next being usize::MAX indicates an empty entry.
295    // This w/ an associated idx for head.  This would probably not reduce the scale of its
296    // allocation.
297    // For now not doing this as the complexity isn't needed (in real world, it's probably unlikely
298    // to have many waiters for a single lock)
299    head: Mutex<Option<Pin<Box<PriorityMutexWaiter<P>>>>>,
300    data: UnsafeCell<T>,
301}
302
303#[cfg(feature = "serde")]
304impl<'de, P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> serde::Deserialize<'de>
305    for PriorityMutex<P, T, FIFO, LOWEST_FIRST>
306where
307    T: serde::Deserialize<'de>,
308{
309    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
310    where
311        D: serde::Deserializer<'de>,
312    {
313        Ok(Self::new(T::deserialize(deserializer)?))
314    }
315}
316
317pub type FIFOPriorityMutex<P, T, const LOWEST_FIRST: bool = false> =
318    PriorityMutex<P, T, true, LOWEST_FIRST>;
319pub type LowestFirstPriorityMutex<P, T, const FIFO: bool = false> = PriorityMutex<P, T, FIFO, true>;
320
321unsafe impl<P: Ord + Sync, T: Sync, const FIFO: bool, const LOWEST_FIRST: bool> Sync
322    for PriorityMutex<P, T, FIFO, LOWEST_FIRST>
323{
324}
325
326pub struct PriorityMutexGuard<'a, P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> {
327    mutex: &'a PriorityMutex<P, T, FIFO, LOWEST_FIRST>,
328    node: &'a PriorityMutexWaiter<P>,
329}
330
331impl<'a, P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> Display
332    for PriorityMutexGuard<'a, P, T, FIFO, LOWEST_FIRST>
333where
334    T: Display,
335{
336    #[inline]
337    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
338        self.deref().fmt(f)
339    }
340}
341
342impl<'a, P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> Debug
343    for PriorityMutexGuard<'a, P, T, FIFO, LOWEST_FIRST>
344where
345    T: Debug,
346{
347    #[inline]
348    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
349        self.deref().fmt(f)
350    }
351}
352
353impl<'a, P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> Deref
354    for PriorityMutexGuard<'a, P, T, FIFO, LOWEST_FIRST>
355{
356    type Target = T;
357
358    #[inline]
359    fn deref(&self) -> &Self::Target {
360        unsafe { &*self.mutex.data.get() }
361    }
362}
363
364impl<'a, P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> DerefMut
365    for PriorityMutexGuard<'a, P, T, FIFO, LOWEST_FIRST>
366{
367    #[inline]
368    fn deref_mut(&mut self) -> &mut Self::Target {
369        unsafe { &mut *self.mutex.data.get() }
370    }
371}
372
373#[cfg(feature = "evict")]
374impl<'a, P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool>
375    PriorityMutexGuard<'a, P, T, FIFO, LOWEST_FIRST>
376{
377    /// Returns a future which resolves when/if another, higher priority holder attempts to acquire
378    /// the lock.
379    ///
380    /// Note: this is an associated method to avoid colision with `T`.  Invoke via
381    /// `PriorityMutexGuard::evicted(&mut self)`.
382    ///
383    /// Cancel safety: this function is cancel safe
384    #[inline]
385    pub fn evicted(this: &mut Self) -> impl Future<Output = ()> {
386        WaiterFlagFut::<'_, P, WAITER_FLAG_WANTS_EVICT>::new(&this.node)
387    }
388}
389
390impl<'a, P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> Drop
391    for PriorityMutexGuard<'a, P, T, FIFO, LOWEST_FIRST>
392{
393    #[inline]
394    fn drop(&mut self) {
395        self.mutex.dequeue(self.node);
396    }
397}
398
399/// Opaque marker type for try_lock result
400#[derive(Debug)]
401pub struct TryLockError;
402
403impl Display for TryLockError {
404    #[inline]
405    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
406        write!(f, "lock is already held")
407    }
408}
409
410impl core::error::Error for TryLockError {}
411
412impl<P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> Debug
413    for PriorityMutex<P, T, FIFO, LOWEST_FIRST>
414where
415    T: Debug,
416    P: Default,
417{
418    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
419        let mut d = f.debug_tuple("PriorityMutex");
420        match self.try_lock(P::default()) {
421            Ok(data) => d.field(&data.deref()),
422            Err(_) => d.field(&format_args!("<locked>")),
423        };
424
425        d.finish()
426    }
427}
428
429impl<P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool>
430    PriorityMutex<P, T, FIFO, LOWEST_FIRST>
431{
432    fn dequeue(&self, item: *const PriorityMutexWaiter<P>) {
433        let mut head = self.head.lock();
434
435        if let Some(mut node) = head.as_ref() {
436            if &raw const **node == item {
437                return {
438                    // by using a return block here we get rust to be bit smarter and realize we
439                    // don't actually have a ref to node by the time we modify head...
440                    *head = node.next().take();
441
442                    if let Some(new_head) = &*head {
443                        new_head.start();
444                    }
445                };
446            }
447
448            while let Some(next) = node.next() {
449                if &raw const **next == item {
450                    *node.next() = next.next().take();
451                    return;
452                }
453
454                node = &*next;
455            }
456        }
457    }
458
459    #[inline(always)]
460    /// returns true if lhs is higher priority than rhs
461    fn is_higher_priority(lhs: &P, rhs: &P) -> bool {
462        match lhs.cmp(rhs) {
463            core::cmp::Ordering::Less => LOWEST_FIRST,
464            core::cmp::Ordering::Equal => !FIFO,
465            core::cmp::Ordering::Greater => !LOWEST_FIRST,
466        }
467    }
468
469    /// Create a new mutex
470    pub const fn new(data: T) -> Self {
471        Self {
472            head: Mutex::new(None),
473            data: UnsafeCell::new(data),
474        }
475    }
476
477    /// Try to acquire the lock without blocking or requesting eviction of the current holder.
478    /// Priority will be stored in guard; higher priority requesters will try to evict the returned
479    /// guard if the `evict` flag is enabled.
480    pub fn try_lock(
481        &self,
482        priority: P,
483    ) -> Result<PriorityMutexGuard<'_, P, T, FIFO, LOWEST_FIRST>, TryLockError> {
484        let mut queue = self.head.lock();
485
486        if queue.is_some() {
487            return Err(TryLockError);
488        }
489
490        let (node, rf) = PriorityMutexWaiter::new(priority, true);
491        *queue = Some(node);
492
493        Ok(PriorityMutexGuard {
494            mutex: self,
495            node: rf,
496        })
497    }
498
499    /// Acquire exclusive access to the locked resource, waiting until after higher priority
500    /// requesters acquire and release the lock.
501    ///
502    /// If the `evict` feature is enabled, this will also notify the current holder to request it
503    /// to release the lock if the current holder is lower priority.
504    ///
505    /// Cancel safety: this function is cancel safe.
506    pub async fn lock(&self, priority: P) -> PriorityMutexGuard<'_, P, T, FIFO, LOWEST_FIRST> {
507        // rust makes us write this backwards... for some reason the compiler refuses to believe
508        // that head is not held over awaits even when we explicitly drop it...
509        let guard = {
510            let mut head = self.head.lock();
511
512            let mut node = match head.as_ref() {
513                Some(x) => x,
514                None => {
515                    // there's no head node (ie no holder) so we return without doing any waiting
516                    let (new_node, new_ref) = PriorityMutexWaiter::new(priority, true);
517
518                    *head = Some(new_node);
519                    return PriorityMutexGuard {
520                        mutex: self,
521                        node: new_ref,
522                    };
523                }
524            };
525
526            #[cfg(feature = "evict")]
527            if Self::is_higher_priority(&priority, &node.priority) {
528                // if requesting priority is higher than head, request stop
529                node.evict();
530            }
531
532            let (new_node, new_ref) = PriorityMutexWaiter::new(priority, false);
533
534            // we still need to iterate through children - as the head isn't always the highest
535            // priority
536            while let Some(next) = node.next() {
537                // order of exec if holder is the same is fifo
538                if Self::is_higher_priority(&new_ref.priority, &next.priority) {
539                    *new_node.next() = node.next().take();
540                    break;
541                }
542
543                node = &*next;
544            }
545
546            *node.next() = Some(new_node);
547
548            // rust gets mad if we inline await here even if we explicitly drop head...
549            // so instead we have this contrived return sub-block to make it even more clear that
550            // head isn't in scope when we await
551            //
552            // note: it is important that the guard is created before the subsequent await, as this
553            // fn could be cancelled during the await
554            PriorityMutexGuard {
555                mutex: self,
556                node: new_ref,
557            }
558        };
559
560        WaiterFlagFut::<P, WAITER_FLAG_HAS_LOCK>::new(&guard.node).await;
561        return guard;
562    }
563}
564
565impl<P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> From<T>
566    for PriorityMutex<P, T, FIFO, LOWEST_FIRST>
567{
568    #[inline]
569    fn from(value: T) -> Self {
570        Self::new(value)
571    }
572}