qp/
sync.rs

1//! Synchronization primitives for use in asynchronous contexts.
2use crossbeam_queue::SegQueue;
3use crossbeam_utils::Backoff;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7use std::task::{Context, Poll, Waker};
8
9/// Counting semaphore performing asynchronous permit acquisition.
10pub struct Semaphore {
11    permits: AtomicUsize,
12    waiters: SegQueue<Waker>,
13}
14
15impl Semaphore {
16    /// Creates a new semaphore with the initial number of permits.
17    ///
18    /// # Examples
19    ///
20    /// ```
21    /// # use qp::sync::Semaphore;
22    /// let binary_semaphore = Semaphore::new(1);
23    /// ```
24    pub const fn new(permits: usize) -> Self {
25        debug_assert!(permits >= 1);
26        Self {
27            permits: AtomicUsize::new(permits),
28            waiters: SegQueue::new(),
29        }
30    }
31
32    /// Acquires a permit from the semaphore.
33    ///
34    /// # Examples
35    ///
36    /// ```
37    /// # use qp::sync::Semaphore;
38    /// # #[tokio::main]
39    /// # async fn main() {
40    /// let binary_semaphore = Semaphore::new(1);
41    /// assert_eq!(binary_semaphore.available_permits(), 1);
42    /// let permit = binary_semaphore.acquire().await;
43    /// assert_eq!(binary_semaphore.available_permits(), 0);
44    /// drop(permit);
45    /// assert_eq!(binary_semaphore.available_permits(), 1);
46    /// # }
47    /// ```
48    pub async fn acquire(&self) -> SemaphorePermit<'_> {
49        Acquire::new(self).await
50    }
51
52    /// Returns the current number of available permits.
53    ///
54    /// # Examples
55    ///
56    /// ```
57    /// # use qp::sync::Semaphore;
58    /// # #[tokio::main]
59    /// # async fn main() {
60    /// let binary_semaphore = Semaphore::new(1);
61    /// assert_eq!(binary_semaphore.available_permits(), 1);
62    /// let permit = binary_semaphore.acquire().await;
63    /// assert_eq!(binary_semaphore.available_permits(), 0);
64    /// # }
65    /// ```
66    pub fn available_permits(&self) -> usize {
67        self.permits.load(Ordering::Acquire)
68    }
69
70    /// Tries to acquire a permit from the semaphore if there is one available.
71    ///
72    /// Returns `None` immediately if there are no idle resources available in the pool.
73    ///
74    /// # Examples
75    ///
76    /// ```
77    /// # use qp::sync::Semaphore;
78    /// # #[tokio::main]
79    /// # async fn main() {
80    /// let binary_semaphore = Semaphore::new(1);
81    /// let permit1 = binary_semaphore.try_acquire();
82    /// assert!(permit1.is_some());
83    /// let permit2 = binary_semaphore.try_acquire();
84    /// assert!(permit2.is_none());
85    /// drop(permit1);
86    /// let permit3 = binary_semaphore.try_acquire();
87    /// assert!(permit3.is_some());
88    /// # }
89    pub fn try_acquire(&self) -> Option<SemaphorePermit> {
90        let backoff = Backoff::new();
91        let mut permits = self.permits.load(Ordering::Relaxed);
92        loop {
93            if permits == 0 {
94                return None;
95            }
96            match self.permits.compare_exchange_weak(
97                permits,
98                permits - 1,
99                Ordering::Acquire,
100                Ordering::Relaxed,
101            ) {
102                Ok(_) => return Some(SemaphorePermit::new(self)),
103                Err(changed) => permits = changed,
104            }
105            backoff.spin();
106        }
107    }
108}
109
110/// A permit from the semaphore.
111///
112/// This type is created by the [`Semaphore::acquire`] method and related methods.
113pub struct SemaphorePermit<'a> {
114    semaphore: &'a Semaphore,
115}
116
117impl Drop for SemaphorePermit<'_> {
118    fn drop(&mut self) {
119        self.semaphore.permits.fetch_add(1, Ordering::Release);
120        if let Some(waker) = self.semaphore.waiters.pop() {
121            waker.wake();
122        }
123    }
124}
125
126impl<'a> SemaphorePermit<'a> {
127    const fn new(semaphore: &'a Semaphore) -> Self {
128        Self { semaphore }
129    }
130}
131
132struct Acquire<'a> {
133    semaphore: &'a Semaphore,
134    waiting: AtomicBool,
135}
136
137impl<'a> Future for Acquire<'a> {
138    type Output = SemaphorePermit<'a>;
139
140    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141        match self.semaphore.try_acquire() {
142            Some(permit) => Poll::Ready(permit),
143            None => {
144                if self
145                    .waiting
146                    .compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
147                    .is_ok()
148                {
149                    self.semaphore.waiters.push(cx.waker().clone());
150                }
151                Poll::Pending
152            }
153        }
154    }
155}
156
157impl<'a> Acquire<'a> {
158    const fn new(semaphore: &'a Semaphore) -> Self {
159        Self {
160            semaphore,
161            waiting: AtomicBool::new(false),
162        }
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use std::sync::Arc;
170    use std::time::Duration;
171
172    #[tokio::test]
173    async fn test_abort_acquire() {
174        let sem = Arc::new(Semaphore::new(1));
175
176        // Grab the only permit for the semaphore
177        let permit = sem.try_acquire().unwrap();
178
179        // Spawn two tokio tasks waiting for the semaphore to become available.
180        // The first one times out after 1ms and the second after 3ms.
181        let a = {
182            let sem = sem.clone();
183            tokio::spawn(tokio::time::timeout(Duration::from_millis(1), async move {
184                let _ = sem.acquire().await;
185            }))
186        };
187        tokio::time::sleep(Duration::from_millis(1)).await;
188        let b = {
189            let sem = sem.clone();
190            tokio::spawn(tokio::time::timeout(Duration::from_millis(2), async move {
191                let _ = sem.acquire().await;
192            }))
193        };
194        tokio::time::sleep(Duration::from_millis(1)).await;
195
196        // Release the grapped permit.
197        // Task B will grap this permit.
198        drop(permit);
199
200        // The first task should now be timed out.
201        assert!(a.await.unwrap().is_err());
202        assert!(b.await.unwrap().is_ok());
203
204        // Show memory leak.
205        assert!(sem.waiters.is_empty());
206    }
207}