rart_rs/futures/
semaphore.rs

1use heapless::Deque;
2use core::sync::atomic::AtomicUsize;
3use core::task::Waker;
4use core::future::Future;
5use core::pin::Pin;
6use core::sync::atomic::Ordering;
7use core::task::{Context, Poll};
8use crate::common::arc::Arc;
9use crate::common::ArcMutex;
10use crate::common::blocking_mutex::BlockingMutex;
11use crate::common::result::{Expect, RARTError};
12
13pub struct SemaphoreUnbounded<const TN: usize> {
14    unlocked_count: AtomicUsize,
15    wakers: ArcMutex<Deque<Waker, TN>>,
16}
17
18struct TakerUnbounded<const TN: usize> {
19    semaphore: &'static SemaphoreUnbounded<TN>,
20}
21
22impl<const TN: usize> SemaphoreUnbounded<TN> {
23    pub fn new(initial_count: usize) -> Self {
24        Self {
25            unlocked_count: AtomicUsize::new(initial_count),
26            wakers: Arc::new(BlockingMutex::new(Deque::new())),
27        }
28    }
29
30    pub async fn take(&'static self) {
31        TakerUnbounded { semaphore: &self }.await
32    }
33
34    pub fn give(&self) -> Result<(), RARTError> {
35        self.unlocked_count.fetch_add(1, Ordering::AcqRel);
36        let mut wakers = self.wakers.lock()?;
37        if let Some(waker) = wakers.pop_front() {
38            waker.wake_by_ref();
39        }
40
41        Ok(())
42    }
43}
44
45impl<const TN: usize> Future for TakerUnbounded<TN> {
46    type Output = ();
47
48    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
49        if self.semaphore.unlocked_count.load(Ordering::Acquire) > 0 {
50            self.semaphore.unlocked_count.fetch_sub(1, Ordering::AcqRel);
51            Poll::Ready(())
52        } else {
53            let mut wakers = self.semaphore.wakers.lock().rart_expect("Cannot lock wakers at take poll");
54            wakers.push_back(cx.waker().clone()).rart_expect("Cannot store taker unbounded waker");
55            Poll::Pending
56        }
57    }
58}
59
60pub struct Semaphore<const N: usize, const TN: usize> {
61    unlocked_count: AtomicUsize,
62    max_locks: usize,
63    take_wakers: ArcMutex<Deque<Waker, TN>>,
64    give_wakers: ArcMutex<Deque<Waker, TN>>,
65}
66
67struct Taker<const N: usize, const TN: usize> {
68    semaphore: &'static Semaphore<N, TN>,
69}
70
71struct Giver<const N: usize, const TN: usize> {
72    semaphore: &'static Semaphore<N, TN>,
73}
74
75impl<const N: usize, const TN: usize> Semaphore<N, TN> {
76    pub fn new(initial_count: usize) -> Self {
77        Self {
78            unlocked_count: AtomicUsize::new(initial_count),
79            max_locks: N,
80            take_wakers: Arc::new(BlockingMutex::new(Deque::new())),
81            give_wakers: Arc::new(BlockingMutex::new(Deque::new())),
82        }
83    }
84
85    pub async fn take(&'static self) {
86        Taker { semaphore: &self }.await
87    }
88
89    pub fn give(&self) -> Result<(), RARTError> {
90        if self.unlocked_count.load(Ordering::Acquire) < self.max_locks {
91            self.unlocked_count.fetch_add(1, Ordering::AcqRel);
92            let mut wakers = self.take_wakers.lock()?;
93            if let Some(waker) = wakers.pop_front() {
94                waker.wake_by_ref();
95            }
96        }
97
98        Ok(())
99    }
100
101    pub async fn wait_give(&'static self) {
102        Giver { semaphore: &self }.await
103    }
104}
105
106impl<const N: usize, const TN: usize> Future for Taker<N, TN> {
107    type Output = ();
108
109    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
110        if self.semaphore.unlocked_count.load(Ordering::Acquire) > 0 {
111            if self.semaphore.unlocked_count.load(Ordering::Acquire) == self.semaphore.max_locks {
112                let mut give_wakers = self.semaphore.give_wakers.lock().rart_expect("");
113                if let Some(waker) = give_wakers.pop_front() {
114                    waker.wake_by_ref();
115                }
116            }
117            self.semaphore.unlocked_count.fetch_sub(1, Ordering::AcqRel);
118            Poll::Ready(())
119        } else {
120            let mut take_wakers = self.semaphore.take_wakers.lock().rart_expect("Cannot lock wakers at take poll");
121            take_wakers.push_back(cx.waker().clone()).rart_expect("Cannot store taker waker");
122            Poll::Pending
123        }
124    }
125}
126
127impl<const N: usize, const TN: usize> Future for Giver<N, TN> {
128    type Output = ();
129
130    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
131        if self.semaphore.unlocked_count.load(Ordering::Acquire) < self.semaphore.max_locks {
132            if self.semaphore.unlocked_count.load(Ordering::Acquire) == 0 {
133                let mut take_wakers = self.semaphore.take_wakers.lock().rart_expect("");
134                if let Some(waker) = take_wakers.pop_front() {
135                    waker.wake_by_ref();
136                }
137            }
138            self.semaphore.unlocked_count.fetch_add(1, Ordering::AcqRel);
139            Poll::Ready(())
140        } else {
141            let mut give_wakers = self.semaphore.give_wakers.lock().rart_expect("Cannot lock wakers at give poll");
142            give_wakers.push_back(cx.waker().clone()).rart_expect("Cannot store giver waker");
143            Poll::Pending
144        }
145    }
146}