1#![cfg_attr(feature = "no-std", no_std)]
11use core::{
12 cell::UnsafeCell,
13 fmt::{Debug, Display},
14 format_args,
15 mem::MaybeUninit,
16 ops::{Deref, DerefMut},
17 pin::Pin,
18 sync::atomic::{AtomicPtr, Ordering},
19 task::{Context, Poll, Waker},
20};
21
22#[cfg(feature = "no-std")]
23extern crate alloc;
24#[cfg(feature = "no-std")]
25use alloc::boxed::Box;
26#[cfg(not(feature = "no-std"))]
27use std::boxed::Box;
28
29#[cfg(not(feature = "no-std"))]
30#[derive(Default)]
31#[repr(transparent)]
32struct Mutex<T>(std::sync::Mutex<T>);
33
34#[cfg(feature = "no-std")]
35#[derive(Default)]
36#[repr(transparent)]
37struct Mutex<T>(spin::Mutex<T>);
38
39impl<T> Mutex<T> {
40 const fn new(value: T) -> Self {
41 #[cfg(not(feature = "no-std"))]
42 return Self(std::sync::Mutex::new(value));
43
44 #[cfg(feature = "no-std")]
45 return Self(spin::Mutex::new(value));
46 }
47
48 #[cfg(not(feature = "no-std"))]
49 #[inline(always)]
50 fn lock(&self) -> impl Deref<Target = T> + DerefMut {
51 self.0.lock().unwrap()
52 }
53
54 #[cfg(feature = "no-std")]
55 #[inline(always)]
56 fn lock(&self) -> impl Deref<Target = T> + DerefMut {
57 self.0.lock()
58 }
59}
60
61#[cfg(feature = "evict")]
63#[repr(align(8))]
64struct AlignedWaker(Waker);
65#[cfg(not(feature = "evict"))]
67#[repr(align(4))]
68struct AlignedWaker(Waker);
69
70const WAITER_FLAG_HAS_LOCK: usize = 1;
72const WAITER_FLAG_CAN_DROP: usize = 2;
81#[cfg(feature = "evict")]
87const WAITER_FLAG_WANTS_EVICT: usize = 4;
88#[cfg(not(feature = "evict"))]
89const WAITER_FLAG_WANTS_EVICT: usize = 0;
90const WAITER_FLAG_MASK: usize =
91 WAITER_FLAG_HAS_LOCK | WAITER_FLAG_CAN_DROP | WAITER_FLAG_WANTS_EVICT;
92const WAITER_PTR_MASK: usize = !WAITER_FLAG_MASK;
93
94#[inline(always)]
95fn get_flag(w: *mut AlignedWaker) -> usize {
96 w as usize & WAITER_FLAG_MASK
97}
98
99impl<P: Ord> PriorityMutexWaiter<P> {
100 #[inline]
101 fn notify(&self) {
102 let ptr = self
103 .waker
104 .fetch_and(WAITER_FLAG_MASK ^ WAITER_FLAG_CAN_DROP, Ordering::AcqRel);
105
106 let waker_ptr = ptr.map_addr(|x| x & WAITER_PTR_MASK);
107 let maybe_waker = (!waker_ptr.is_null()).then(|| unsafe { waker_ptr.read() });
109
110 if ptr as usize & WAITER_FLAG_CAN_DROP != 0 {
114 self.waker.fetch_or(WAITER_FLAG_CAN_DROP, Ordering::AcqRel);
115 }
116
117 if let Some(waker) = maybe_waker {
118 waker.0.wake();
119 }
120 }
121
122 #[inline]
123 fn add_flag(&self, flag: usize) {
124 let recv = self
125 .waker
126 .fetch_or(flag, Ordering::AcqRel)
127 .map_addr(|x| x & WAITER_PTR_MASK);
128
129 if (recv as usize) & WAITER_PTR_MASK != 0 {
133 self.notify();
134 }
135 }
136
137 #[inline]
138 fn start(&self) {
139 self.add_flag(WAITER_FLAG_HAS_LOCK)
140 }
141
142 #[cfg(feature = "evict")]
143 #[inline]
144 fn evict(&self) {
145 self.add_flag(WAITER_FLAG_WANTS_EVICT)
146 }
147
148 #[inline]
149 fn clear_waker(&self, storage: &mut MaybeUninit<AlignedWaker>) -> usize {
151 let ptr = self.waker.fetch_and(WAITER_FLAG_MASK, Ordering::AcqRel);
152 let flags = ptr as usize & WAITER_FLAG_MASK;
153
154 let waker_ptr = ptr as usize & WAITER_PTR_MASK;
156
157 if waker_ptr != 0 {
161 debug_assert!(
162 waker_ptr == storage.as_ptr() as usize,
163 "if a waker exists, it must be ours {:p} {:p}",
164 ptr,
165 storage
166 );
167 unsafe { storage.assume_init_drop() };
168 return flags;
169 }
170
171 if ptr as usize & WAITER_FLAG_CAN_DROP == 0 {
174 while self.waker.load(Ordering::Acquire) as usize & WAITER_FLAG_CAN_DROP == 0 {}
175 }
176
177 flags
178 }
179
180 #[inline]
181 fn wait_for_flag(
182 &self,
183 cx: &mut Context<'_>,
184 waker: &mut MaybeUninit<AlignedWaker>,
185 target: usize,
186 ) -> Poll<()> {
187 if self.clear_waker(waker) & target == target {
188 return Poll::Ready(());
189 }
190
191 waker.write(AlignedWaker(cx.waker().clone()));
192 let existing = self.waker.fetch_or(
194 waker.as_ptr() as usize | WAITER_FLAG_CAN_DROP,
195 Ordering::AcqRel,
196 );
197
198 if get_flag(existing) & target != target {
200 return Poll::Pending;
202 }
203
204 self.clear_waker(waker);
208
209 Poll::Ready(())
210 }
211}
212
213struct WaiterFlagFut<'a, P: Ord, const FLAG: usize> {
214 tracker: &'a PriorityMutexWaiter<P>,
215 waker: MaybeUninit<AlignedWaker>,
216}
217
218impl<'a, P: Ord, const FLAG: usize> WaiterFlagFut<'a, P, FLAG> {
219 fn new(tracker: &'a PriorityMutexWaiter<P>) -> Self {
220 Self {
221 tracker,
222 waker: MaybeUninit::uninit(),
223 }
224 }
225}
226
227impl<'a, P: Ord, const FLAG: usize> Drop for WaiterFlagFut<'a, P, FLAG> {
228 #[inline]
229 fn drop(&mut self) {
230 self.tracker.clear_waker(&mut self.waker);
231 }
232}
233
234impl<'a, P: Ord, const FLAG: usize> Future for WaiterFlagFut<'a, P, FLAG> {
235 type Output = ();
236
237 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
238 self.tracker
239 .wait_for_flag(cx, &mut self.as_mut().waker, FLAG)
240 }
241}
242
243struct PriorityMutexWaiter<P: Ord> {
244 priority: P,
245 waker: AtomicPtr<AlignedWaker>,
246 next: UnsafeCell<Option<Pin<Box<Self>>>>,
249 _must_pin: core::marker::PhantomPinned,
250}
251
252unsafe impl<P: Ord + Sync> Sync for PriorityMutexWaiter<P> {}
253
254impl<P: Ord> PriorityMutexWaiter<P> {
255 #[inline]
256 fn next(&self) -> &mut Option<Pin<Box<Self>>> {
257 unsafe { &mut *self.next.get() }
258 }
259
260 #[inline]
261 fn new<'a>(holder: P, has_lock: bool) -> (Pin<Box<Self>>, &'a Self) {
262 let pin = Box::pin(Self {
263 priority: holder,
264 waker: AtomicPtr::new(core::ptr::without_provenance_mut(if has_lock {
265 WAITER_FLAG_HAS_LOCK | WAITER_FLAG_CAN_DROP
266 } else {
267 WAITER_FLAG_CAN_DROP
268 })),
269 next: UnsafeCell::default(),
270 _must_pin: core::marker::PhantomPinned,
271 });
272
273 let ptr = &raw const *pin;
274 (pin, unsafe { &*ptr })
275 }
276}
277
278#[derive(Default)]
279pub struct PriorityMutex<P: Ord, T, const FIFO: bool = false, const LOWEST_FIRST: bool = false> {
290 head: Mutex<Option<Pin<Box<PriorityMutexWaiter<P>>>>>,
300 data: UnsafeCell<T>,
301}
302
303#[cfg(feature = "serde")]
304impl<'de, P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> serde::Deserialize<'de>
305 for PriorityMutex<P, T, FIFO, LOWEST_FIRST>
306where
307 T: serde::Deserialize<'de>,
308{
309 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
310 where
311 D: serde::Deserializer<'de>,
312 {
313 Ok(Self::new(T::deserialize(deserializer)?))
314 }
315}
316
317pub type FIFOPriorityMutex<P, T, const LOWEST_FIRST: bool = false> =
318 PriorityMutex<P, T, true, LOWEST_FIRST>;
319pub type LowestFirstPriorityMutex<P, T, const FIFO: bool = false> = PriorityMutex<P, T, FIFO, true>;
320
321unsafe impl<P: Ord + Sync, T: Sync, const FIFO: bool, const LOWEST_FIRST: bool> Sync
322 for PriorityMutex<P, T, FIFO, LOWEST_FIRST>
323{
324}
325
326pub struct PriorityMutexGuard<'a, P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> {
327 mutex: &'a PriorityMutex<P, T, FIFO, LOWEST_FIRST>,
328 node: &'a PriorityMutexWaiter<P>,
329}
330
331impl<'a, P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> Display
332 for PriorityMutexGuard<'a, P, T, FIFO, LOWEST_FIRST>
333where
334 T: Display,
335{
336 #[inline]
337 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
338 self.deref().fmt(f)
339 }
340}
341
342impl<'a, P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> Debug
343 for PriorityMutexGuard<'a, P, T, FIFO, LOWEST_FIRST>
344where
345 T: Debug,
346{
347 #[inline]
348 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
349 self.deref().fmt(f)
350 }
351}
352
353impl<'a, P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> Deref
354 for PriorityMutexGuard<'a, P, T, FIFO, LOWEST_FIRST>
355{
356 type Target = T;
357
358 #[inline]
359 fn deref(&self) -> &Self::Target {
360 unsafe { &*self.mutex.data.get() }
361 }
362}
363
364impl<'a, P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> DerefMut
365 for PriorityMutexGuard<'a, P, T, FIFO, LOWEST_FIRST>
366{
367 #[inline]
368 fn deref_mut(&mut self) -> &mut Self::Target {
369 unsafe { &mut *self.mutex.data.get() }
370 }
371}
372
373#[cfg(feature = "evict")]
374impl<'a, P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool>
375 PriorityMutexGuard<'a, P, T, FIFO, LOWEST_FIRST>
376{
377 #[inline]
385 pub fn evicted(this: &mut Self) -> impl Future<Output = ()> {
386 WaiterFlagFut::<'_, P, WAITER_FLAG_WANTS_EVICT>::new(&this.node)
387 }
388}
389
390impl<'a, P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> Drop
391 for PriorityMutexGuard<'a, P, T, FIFO, LOWEST_FIRST>
392{
393 #[inline]
394 fn drop(&mut self) {
395 self.mutex.dequeue(self.node);
396 }
397}
398
399#[derive(Debug)]
401pub struct TryLockError;
402
403impl Display for TryLockError {
404 #[inline]
405 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
406 write!(f, "lock is already held")
407 }
408}
409
410impl core::error::Error for TryLockError {}
411
412impl<P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> Debug
413 for PriorityMutex<P, T, FIFO, LOWEST_FIRST>
414where
415 T: Debug,
416 P: Default,
417{
418 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
419 let mut d = f.debug_tuple("PriorityMutex");
420 match self.try_lock(P::default()) {
421 Ok(data) => d.field(&data.deref()),
422 Err(_) => d.field(&format_args!("<locked>")),
423 };
424
425 d.finish()
426 }
427}
428
429impl<P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool>
430 PriorityMutex<P, T, FIFO, LOWEST_FIRST>
431{
432 fn dequeue(&self, item: *const PriorityMutexWaiter<P>) {
433 let mut head = self.head.lock();
434
435 if let Some(mut node) = head.as_ref() {
436 if &raw const **node == item {
437 return {
438 *head = node.next().take();
441
442 if let Some(new_head) = &*head {
443 new_head.start();
444 }
445 };
446 }
447
448 while let Some(next) = node.next() {
449 if &raw const **next == item {
450 *node.next() = next.next().take();
451 return;
452 }
453
454 node = &*next;
455 }
456 }
457 }
458
459 #[inline(always)]
460 fn is_higher_priority(lhs: &P, rhs: &P) -> bool {
462 match lhs.cmp(rhs) {
463 core::cmp::Ordering::Less => LOWEST_FIRST,
464 core::cmp::Ordering::Equal => !FIFO,
465 core::cmp::Ordering::Greater => !LOWEST_FIRST,
466 }
467 }
468
469 pub const fn new(data: T) -> Self {
471 Self {
472 head: Mutex::new(None),
473 data: UnsafeCell::new(data),
474 }
475 }
476
477 pub fn try_lock(
481 &self,
482 priority: P,
483 ) -> Result<PriorityMutexGuard<'_, P, T, FIFO, LOWEST_FIRST>, TryLockError> {
484 let mut queue = self.head.lock();
485
486 if queue.is_some() {
487 return Err(TryLockError);
488 }
489
490 let (node, rf) = PriorityMutexWaiter::new(priority, true);
491 *queue = Some(node);
492
493 Ok(PriorityMutexGuard {
494 mutex: self,
495 node: rf,
496 })
497 }
498
499 pub async fn lock(&self, priority: P) -> PriorityMutexGuard<'_, P, T, FIFO, LOWEST_FIRST> {
507 let guard = {
510 let mut head = self.head.lock();
511
512 let mut node = match head.as_ref() {
513 Some(x) => x,
514 None => {
515 let (new_node, new_ref) = PriorityMutexWaiter::new(priority, true);
517
518 *head = Some(new_node);
519 return PriorityMutexGuard {
520 mutex: self,
521 node: new_ref,
522 };
523 }
524 };
525
526 #[cfg(feature = "evict")]
527 if Self::is_higher_priority(&priority, &node.priority) {
528 node.evict();
530 }
531
532 let (new_node, new_ref) = PriorityMutexWaiter::new(priority, false);
533
534 while let Some(next) = node.next() {
537 if Self::is_higher_priority(&new_ref.priority, &next.priority) {
539 *new_node.next() = node.next().take();
540 break;
541 }
542
543 node = &*next;
544 }
545
546 *node.next() = Some(new_node);
547
548 PriorityMutexGuard {
555 mutex: self,
556 node: new_ref,
557 }
558 };
559
560 WaiterFlagFut::<P, WAITER_FLAG_HAS_LOCK>::new(&guard.node).await;
561 return guard;
562 }
563}
564
565impl<P: Ord, T, const FIFO: bool, const LOWEST_FIRST: bool> From<T>
566 for PriorityMutex<P, T, FIFO, LOWEST_FIRST>
567{
568 #[inline]
569 fn from(value: T) -> Self {
570 Self::new(value)
571 }
572}