async_semaphore/
lib.rs

1//! An async semaphore.
2//!
3//! A semaphore is a synchronization primitive that limits the number of concurrent operations.
4
5#![forbid(unsafe_code)]
6#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
7
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10
11use event_listener::Event;
12
13/// A counter for limiting the number of concurrent operations.
14#[derive(Debug)]
15pub struct Semaphore {
16    count: AtomicUsize,
17    event: Event,
18}
19
20impl Semaphore {
21    /// Creates a new semaphore with a limit of `n` concurrent operations.
22    ///
23    /// # Examples
24    ///
25    /// ```
26    /// use async_semaphore::Semaphore;
27    ///
28    /// let s = Semaphore::new(5);
29    /// ```
30    pub const fn new(n: usize) -> Semaphore {
31        Semaphore {
32            count: AtomicUsize::new(n),
33            event: Event::new(),
34        }
35    }
36
37    /// Attempts to get a permit for a concurrent operation.
38    ///
39    /// If the permit could not be acquired at this time, then [`None`] is returned. Otherwise, a
40    /// guard is returned that releases the mutex when dropped.
41    ///
42    /// # Examples
43    ///
44    /// ```
45    /// use async_semaphore::Semaphore;
46    ///
47    /// let s = Semaphore::new(2);
48    ///
49    /// let g1 = s.try_acquire().unwrap();
50    /// let g2 = s.try_acquire().unwrap();
51    ///
52    /// assert!(s.try_acquire().is_none());
53    /// drop(g2);
54    /// assert!(s.try_acquire().is_some());
55    /// ```
56    pub fn try_acquire(&self) -> Option<SemaphoreGuard<'_>> {
57        let mut count = self.count.load(Ordering::Acquire);
58        loop {
59            if count == 0 {
60                return None;
61            }
62
63            match self.count.compare_exchange_weak(
64                count,
65                count - 1,
66                Ordering::AcqRel,
67                Ordering::Acquire,
68            ) {
69                Ok(_) => return Some(SemaphoreGuard(self)),
70                Err(c) => count = c,
71            }
72        }
73    }
74
75    /// Waits for a permit for a concurrent operation.
76    ///
77    /// Returns a guard that releases the permit when dropped.
78    ///
79    /// # Examples
80    ///
81    /// ```
82    /// # futures_lite::future::block_on(async {
83    /// use async_semaphore::Semaphore;
84    ///
85    /// let s = Semaphore::new(2);
86    /// let guard = s.acquire().await;
87    /// # });
88    /// ```
89    pub async fn acquire(&self) -> SemaphoreGuard<'_> {
90        let mut listener = None;
91
92        loop {
93            if let Some(guard) = self.try_acquire() {
94                return guard;
95            }
96
97            match listener.take() {
98                None => listener = Some(self.event.listen()),
99                Some(l) => l.await,
100            }
101        }
102    }
103}
104
105impl Semaphore {
106    /// Attempts to get an owned permit for a concurrent operation.
107    ///
108    /// If the permit could not be acquired at this time, then [`None`] is returned. Otherwise, an
109    /// owned guard is returned that releases the mutex when dropped.
110    ///
111    /// # Examples
112    ///
113    /// ```
114    /// use async_semaphore::Semaphore;
115    /// use std::sync::Arc;
116    ///
117    /// let s = Arc::new(Semaphore::new(2));
118    ///
119    /// let g1 = s.try_acquire_arc().unwrap();
120    /// let g2 = s.try_acquire_arc().unwrap();
121    ///
122    /// assert!(s.try_acquire_arc().is_none());
123    /// drop(g2);
124    /// assert!(s.try_acquire_arc().is_some());
125    /// ```
126    pub fn try_acquire_arc(self: &Arc<Self>) -> Option<SemaphoreGuardArc> {
127        let mut count = self.count.load(Ordering::Acquire);
128        loop {
129            if count == 0 {
130                return None;
131            }
132
133            match self.count.compare_exchange_weak(
134                count,
135                count - 1,
136                Ordering::AcqRel,
137                Ordering::Acquire,
138            ) {
139                Ok(_) => return Some(SemaphoreGuardArc(self.clone())),
140                Err(c) => count = c,
141            }
142        }
143    }
144
145    /// Waits for an owned permit for a concurrent operation.
146    ///
147    /// Returns a guard that releases the permit when dropped.
148    ///
149    /// # Examples
150    ///
151    /// ```
152    /// # futures_lite::future::block_on(async {
153    /// use async_semaphore::Semaphore;
154    /// use std::sync::Arc;
155    ///
156    /// let s = Arc::new(Semaphore::new(2));
157    /// let guard = s.acquire_arc().await;
158    /// # });
159    /// ```
160    pub async fn acquire_arc(self: &Arc<Self>) -> SemaphoreGuardArc {
161        let mut listener = None;
162
163        loop {
164            if let Some(guard) = self.try_acquire_arc() {
165                return guard;
166            }
167
168            match listener.take() {
169                None => listener = Some(self.event.listen()),
170                Some(l) => l.await,
171            }
172        }
173    }
174}
175
176/// A guard that releases the acquired permit.
177#[derive(Debug)]
178pub struct SemaphoreGuard<'a>(&'a Semaphore);
179
180impl Drop for SemaphoreGuard<'_> {
181    fn drop(&mut self) {
182        self.0.count.fetch_add(1, Ordering::AcqRel);
183        self.0.event.notify(1);
184    }
185}
186
187/// An owned guard that releases the acquired permit.
188#[derive(Debug)]
189pub struct SemaphoreGuardArc(Arc<Semaphore>);
190
191impl Drop for SemaphoreGuardArc {
192    fn drop(&mut self) {
193        self.0.count.fetch_add(1, Ordering::AcqRel);
194        self.0.event.notify(1);
195    }
196}