1use std::future::Future;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::Arc;
7use std::task::{Context, Poll, Waker};
8use std::collections::VecDeque;
9
10pub struct MkAsyncBarrier {
12 inner: Arc<BarrierInner>,
13}
14
15struct BarrierInner {
16 count: AtomicUsize,
17 target: usize,
18}
19
20impl MkAsyncBarrier {
21 pub fn new(n: usize) -> Self {
23 Self {
24 inner: Arc::new(BarrierInner {
25 count: AtomicUsize::new(0),
26 target: n,
27 }),
28 }
29 }
30
31 pub async fn wait(&self) {
33 let prev = self.inner.count.fetch_add(1, Ordering::SeqCst);
35
36 if prev + 1 >= self.inner.target {
37 self.inner.count.store(0, Ordering::SeqCst);
39 return;
40 }
41
42 WaitBarrier::new(&self.inner.count, self.inner.target).await
44 }
45
46 pub fn waiting(&self) -> usize {
48 self.inner.count.load(Ordering::Relaxed)
49 }
50}
51
52impl Clone for MkAsyncBarrier {
53 fn clone(&self) -> Self {
54 Self {
55 inner: Arc::clone(&self.inner),
56 }
57 }
58}
59
60struct WaitBarrier<'a> {
61 count: &'a AtomicUsize,
62 target: usize,
63}
64
65impl<'a> WaitBarrier<'a> {
66 fn new(count: &'a AtomicUsize, target: usize) -> Self {
67 Self { count, target }
68 }
69}
70
71impl<'a> Future for WaitBarrier<'a> {
72 type Output = ();
73
74 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75 let current = self.count.load(Ordering::Acquire);
76 if current >= self.target || current == 0 {
77 Poll::Ready(())
78 } else {
79 cx.waker().wake_by_ref();
80 Poll::Pending
81 }
82 }
83}
84
85pub struct MkAsyncSemaphore {
87 permits: AtomicUsize,
88 max_permits: usize,
89}
90
91impl MkAsyncSemaphore {
92 pub fn new(permits: usize) -> Self {
94 Self {
95 permits: AtomicUsize::new(permits),
96 max_permits: permits,
97 }
98 }
99
100 pub async fn acquire(&self) -> SemaphorePermit<'_> {
102 loop {
103 let current = self.permits.load(Ordering::Acquire);
104 if current > 0 {
105 match self.permits.compare_exchange_weak(
106 current,
107 current - 1,
108 Ordering::AcqRel,
109 Ordering::Relaxed,
110 ) {
111 Ok(_) => return SemaphorePermit { semaphore: self },
112 Err(_) => continue,
113 }
114 }
115
116 YieldOnce::new().await;
118 }
119 }
120
121 pub fn try_acquire(&self) -> Option<SemaphorePermit<'_>> {
123 loop {
124 let current = self.permits.load(Ordering::Acquire);
125 if current == 0 {
126 return None;
127 }
128 match self.permits.compare_exchange_weak(
129 current,
130 current - 1,
131 Ordering::AcqRel,
132 Ordering::Relaxed,
133 ) {
134 Ok(_) => return Some(SemaphorePermit { semaphore: self }),
135 Err(_) => continue,
136 }
137 }
138 }
139
140 pub fn available(&self) -> usize {
142 self.permits.load(Ordering::Relaxed)
143 }
144}
145
146pub struct SemaphorePermit<'a> {
148 semaphore: &'a MkAsyncSemaphore,
149}
150
151impl<'a> Drop for SemaphorePermit<'a> {
152 fn drop(&mut self) {
153 self.semaphore.permits.fetch_add(1, Ordering::Release);
154 }
155}
156
157struct YieldOnce(bool);
159
160impl YieldOnce {
161 fn new() -> Self {
162 Self(false)
163 }
164}
165
166impl Future for YieldOnce {
167 type Output = ();
168
169 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
170 if self.0 {
171 Poll::Ready(())
172 } else {
173 self.0 = true;
174 cx.waker().wake_by_ref();
175 Poll::Pending
176 }
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn test_semaphore_sync() {
186 let sem = MkAsyncSemaphore::new(3);
187 assert_eq!(sem.available(), 3);
188
189 let _p1 = sem.try_acquire().unwrap();
190 assert_eq!(sem.available(), 2);
191
192 let _p2 = sem.try_acquire().unwrap();
193 let _p3 = sem.try_acquire().unwrap();
194 assert_eq!(sem.available(), 0);
195
196 assert!(sem.try_acquire().is_none());
197 }
198}