memkit_async/
sync.rs

1//! Async synchronization primitives.
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::Arc;
7use std::task::{Context, Poll, Waker};
8use std::collections::VecDeque;
9
10/// Async barrier for synchronizing multiple tasks.
11pub struct MkAsyncBarrier {
12    inner: Arc<BarrierInner>,
13}
14
15struct BarrierInner {
16    count: AtomicUsize,
17    target: usize,
18}
19
20impl MkAsyncBarrier {
21    /// Create a new barrier for `n` tasks.
22    pub fn new(n: usize) -> Self {
23        Self {
24            inner: Arc::new(BarrierInner {
25                count: AtomicUsize::new(0),
26                target: n,
27            }),
28        }
29    }
30
31    /// Wait at the barrier.
32    pub async fn wait(&self) {
33        // Increment count
34        let prev = self.inner.count.fetch_add(1, Ordering::SeqCst);
35        
36        if prev + 1 >= self.inner.target {
37            // We're the last one - reset for reuse
38            self.inner.count.store(0, Ordering::SeqCst);
39            return;
40        }
41
42        // Wait for others
43        WaitBarrier::new(&self.inner.count, self.inner.target).await
44    }
45
46    /// Get the number of tasks currently waiting.
47    pub fn waiting(&self) -> usize {
48        self.inner.count.load(Ordering::Relaxed)
49    }
50}
51
52impl Clone for MkAsyncBarrier {
53    fn clone(&self) -> Self {
54        Self {
55            inner: Arc::clone(&self.inner),
56        }
57    }
58}
59
60struct WaitBarrier<'a> {
61    count: &'a AtomicUsize,
62    target: usize,
63}
64
65impl<'a> WaitBarrier<'a> {
66    fn new(count: &'a AtomicUsize, target: usize) -> Self {
67        Self { count, target }
68    }
69}
70
71impl<'a> Future for WaitBarrier<'a> {
72    type Output = ();
73
74    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75        let current = self.count.load(Ordering::Acquire);
76        if current >= self.target || current == 0 {
77            Poll::Ready(())
78        } else {
79            cx.waker().wake_by_ref();
80            Poll::Pending
81        }
82    }
83}
84
85/// Async semaphore for controlling concurrent access.
86pub struct MkAsyncSemaphore {
87    permits: AtomicUsize,
88    max_permits: usize,
89}
90
91impl MkAsyncSemaphore {
92    /// Create a new semaphore with the given number of permits.
93    pub fn new(permits: usize) -> Self {
94        Self {
95            permits: AtomicUsize::new(permits),
96            max_permits: permits,
97        }
98    }
99
100    /// Acquire a permit.
101    pub async fn acquire(&self) -> SemaphorePermit<'_> {
102        loop {
103            let current = self.permits.load(Ordering::Acquire);
104            if current > 0 {
105                match self.permits.compare_exchange_weak(
106                    current,
107                    current - 1,
108                    Ordering::AcqRel,
109                    Ordering::Relaxed,
110                ) {
111                    Ok(_) => return SemaphorePermit { semaphore: self },
112                    Err(_) => continue,
113                }
114            }
115            
116            // Yield and retry
117            YieldOnce::new().await;
118        }
119    }
120
121    /// Try to acquire a permit without waiting.
122    pub fn try_acquire(&self) -> Option<SemaphorePermit<'_>> {
123        loop {
124            let current = self.permits.load(Ordering::Acquire);
125            if current == 0 {
126                return None;
127            }
128            match self.permits.compare_exchange_weak(
129                current,
130                current - 1,
131                Ordering::AcqRel,
132                Ordering::Relaxed,
133            ) {
134                Ok(_) => return Some(SemaphorePermit { semaphore: self }),
135                Err(_) => continue,
136            }
137        }
138    }
139
140    /// Get the number of available permits.
141    pub fn available(&self) -> usize {
142        self.permits.load(Ordering::Relaxed)
143    }
144}
145
146/// A permit from a semaphore.
147pub struct SemaphorePermit<'a> {
148    semaphore: &'a MkAsyncSemaphore,
149}
150
151impl<'a> Drop for SemaphorePermit<'a> {
152    fn drop(&mut self) {
153        self.semaphore.permits.fetch_add(1, Ordering::Release);
154    }
155}
156
157/// Yield once to the runtime.
158struct YieldOnce(bool);
159
160impl YieldOnce {
161    fn new() -> Self {
162        Self(false)
163    }
164}
165
166impl Future for YieldOnce {
167    type Output = ();
168
169    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
170        if self.0 {
171            Poll::Ready(())
172        } else {
173            self.0 = true;
174            cx.waker().wake_by_ref();
175            Poll::Pending
176        }
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_semaphore_sync() {
186        let sem = MkAsyncSemaphore::new(3);
187        assert_eq!(sem.available(), 3);
188        
189        let _p1 = sem.try_acquire().unwrap();
190        assert_eq!(sem.available(), 2);
191        
192        let _p2 = sem.try_acquire().unwrap();
193        let _p3 = sem.try_acquire().unwrap();
194        assert_eq!(sem.available(), 0);
195        
196        assert!(sem.try_acquire().is_none());
197    }
198}