priority_async_mutex/
lib.rs

1use std::ops::{Deref, DerefMut};
2
3use event_listener::PriorityEventListener;
4use simple_mutex::{Mutex, MutexGuard};
5
6mod event_listener;
7mod pv;
8
9/// An async mutex where the lock operation takes a priority.
10pub struct PriorityMutex<T> {
11    inner: Mutex<T>,
12    listen: PriorityEventListener,
13}
14
15impl<T> PriorityMutex<T> {
16    /// Creates a new priority mutex.
17    pub fn new(t: T) -> Self {
18        Self {
19            inner: Mutex::new(t),
20            listen: PriorityEventListener::new(),
21        }
22    }
23
24    /// Locks the mutex. When the mutex becomes available, lower priorities are woken up first.
25    pub async fn lock(&self, priority: u32) -> PriorityMutexGuard<'_, T> {
26        let guard = loop {
27            if let Some(val) = self.inner.try_lock() {
28                break val;
29            } else {
30                let listener = self.listen.listen(priority);
31                if let Some(val) = self.inner.try_lock() {
32                    break val;
33                }
34                listener.wait().await;
35            }
36        };
37        PriorityMutexGuard {
38            inner: guard,
39            parent: self,
40        }
41    }
42}
43
44pub struct PriorityMutexGuard<'a, T> {
45    inner: MutexGuard<'a, T>,
46    parent: &'a PriorityMutex<T>,
47}
48
49impl<'a, T> Drop for PriorityMutexGuard<'a, T> {
50    fn drop(&mut self) {
51        self.parent.listen.notify_one();
52    }
53}
54
55impl<'a, T> Deref for PriorityMutexGuard<'a, T> {
56    type Target = T;
57
58    fn deref(&self) -> &Self::Target {
59        self.inner.deref()
60    }
61}
62
63impl<'a, T> DerefMut for PriorityMutexGuard<'a, T> {
64    fn deref_mut(&mut self) -> &mut Self::Target {
65        self.inner.deref_mut()
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use std::{sync::Arc, time::Duration};
72
73    use crate::PriorityMutex;
74
75    #[test]
76    fn simple() {
77        let item = Arc::new(PriorityMutex::new(0));
78        for i in 0..1000 {
79            let priority = fastrand::u32(0..1000);
80            let item = item.clone();
81            smol::spawn(async move {
82                let mut g = item.lock(priority).await;
83                *g += 1;
84                smol::Timer::after(Duration::from_millis(1)).await;
85                eprintln!("incrementing to {} with {priority}", *g);
86            })
87            .detach();
88        }
89        std::thread::sleep(Duration::from_secs(1))
90    }
91}