rart_rs/futures/
mutex.rs

1use heapless::Deque;
2use core::future::Future;
3use core::pin::Pin;
4use core::task::{Context, Poll, Waker};
5use crate::common::arc::Arc;
6use core::sync::atomic::{AtomicBool, Ordering};
7use core::cell::UnsafeCell;
8use core::ops::{Deref, DerefMut};
9use crate::common::ArcMutex;
10use crate::common::blocking_mutex::BlockingMutex;
11use crate::common::result::Expect;
12use crate::common::waker::waker_into_rart_waker;
13
14pub struct Mutex<T, const TN: usize> {
15    data: UnsafeCell<T>,
16    is_unlocked: AtomicBool,
17    wakers: ArcMutex<Deque<Waker, TN>>,
18}
19
20impl<T, const TN: usize> Mutex<T, TN> {
21    pub fn new(data: T) -> Self {
22        Self {
23            data: UnsafeCell::new(data),
24            is_unlocked: AtomicBool::new(true),
25            wakers: Arc::new(BlockingMutex::new(Deque::new())),
26        }
27    }
28
29    pub async fn lock(&'static self) -> MutexGuard<T, TN> {
30        MutexLocker { mutex: &self }.await
31    }
32
33    fn unlock(&self) {
34        self.is_unlocked.store(true, Ordering::Release)
35    }
36}
37
38pub struct MutexLocker<T: 'static, const TN: usize> {
39    mutex: &'static Mutex<T, TN>,
40}
41
42impl<T: 'static, const TN: usize> Future for MutexLocker<T, TN> {
43    type Output = MutexGuard<T, TN>;
44
45    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
46        let mut wakers = self.mutex.wakers.lock()
47            .rart_expect("Cannot lock mutex wakers");
48
49        if wakers.len() > 0 {
50            let next_waker = wakers.pop_front().rart_expect("Cannot pop a waker");
51            let next_waker_id = waker_into_rart_waker(next_waker.clone()).id();
52
53            let task_id = waker_into_rart_waker(cx.waker().clone()).id();
54
55            if next_waker_id == task_id {
56                Poll::Ready(MutexGuard { mutex: self.mutex })
57            } else {
58                wakers.push_front(next_waker.clone()).rart_expect("Cannot restore next_waker");
59                wakers.push_back(cx.waker().clone()).rart_expect("Cannot store mutex waker");
60                Poll::Pending
61            }
62        } else if self.mutex.is_unlocked.compare_exchange(true, false,
63                                                          Ordering::AcqRel,
64                                                          Ordering::Relaxed).is_ok() {
65            Poll::Ready(MutexGuard { mutex: self.mutex })
66        } else {
67            wakers.push_back(cx.waker().clone()).rart_expect("Cannot store mutex waker");
68            Poll::Pending
69        }
70    }
71}
72
73pub struct MutexGuard<T: 'static, const TN: usize> {
74    mutex: &'static Mutex<T, TN>,
75}
76
77impl<T: 'static, const TN: usize> MutexGuard<T, TN> {
78    pub fn unlock(self) {
79        drop(self)
80    }
81}
82
83impl<T: 'static, const TN: usize> Drop for MutexGuard<T, TN> {
84    fn drop(&mut self) {
85        let mutex = &*self.mutex;
86        mutex.unlock();
87
88        let mut wakers = mutex.wakers.lock()
89            .rart_expect("Cannot lock mutex wakers");
90        if let Some(next_waker) = wakers.pop_front() {
91            wakers.push_front(next_waker.clone()).rart_expect("Cannot push front the next_waker");
92            next_waker.wake_by_ref();
93        }
94    }
95}
96
97impl<T: 'static, const TN: usize> Deref for MutexGuard<T, TN> {
98    type Target = T;
99
100    fn deref(&self) -> &Self::Target {
101        // TODO Explain why this is safe
102        unsafe {
103            let mutex = &*self.mutex;
104            let data_ptr = &*mutex.data.get();
105            data_ptr
106        }
107    }
108}
109
110impl<T: 'static, const TN: usize> DerefMut for MutexGuard<T, TN> {
111    fn deref_mut(&mut self) -> &mut Self::Target {
112        // TODO Explain why this is safe
113        unsafe {
114            let mutex = &*self.mutex;
115            let data_ptr = &mut *mutex.data.get();
116            data_ptr
117        }
118    }
119}