async_sema/
lib.rs

1use event_listener::Event;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4
5#[derive(Debug)]
6pub(crate) struct SemaphoreInner {
7    count: AtomicUsize,
8    event: Event,
9}
10
11impl SemaphoreInner {
12    pub const fn new(n: usize) -> Self {
13        Self {
14            count: AtomicUsize::new(n),
15            event: Event::new(),
16        }
17    }
18
19    pub fn try_acquire(&self, count: usize) -> usize {
20        let mut balance = self.count.load(Ordering::Acquire);
21        loop {
22            if balance == 0 {
23                return 0;
24            }
25            let dest = match balance >= count {
26                true => balance - count,
27                false => 0,
28            };
29
30            match self.count.compare_exchange_weak(
31                balance,
32                dest,
33                Ordering::AcqRel,
34                Ordering::Acquire,
35            ) {
36                Ok(_) => return balance - dest,
37                Err(c) => balance = c,
38            }
39        }
40    }
41
42    pub async fn acquire(&self, count: usize) {
43        let mut listener = None;
44        let mut acquired = 0;
45
46        loop {
47            acquired += self.try_acquire(count - acquired);
48            if count == acquired {
49                return;
50            }
51
52            match listener.take() {
53                None => listener = Some(self.event.listen()),
54                Some(l) => l.await,
55            }
56        }
57    }
58
59    pub fn add_permits(&self, n: usize) {
60        self.count.fetch_add(n, Ordering::AcqRel);
61        self.event.notify(n);
62    }
63}
64
65/// A counter for limiting the number of concurrent operations.
66#[derive(Debug, Clone)]
67pub struct Semaphore {
68    inner: Arc<SemaphoreInner>,
69}
70
71unsafe impl Send for Semaphore {}
72
73impl Semaphore {
74    /// Creates a new semaphore with a limit of `n` concurrent operations.
75    ///
76    /// # Examples
77    ///
78    /// ```
79    /// use async_sema::Semaphore;
80    ///
81    /// let s = Semaphore::new(5);
82    /// ```
83    pub fn new(n: usize) -> Semaphore {
84        Semaphore {
85            inner: Arc::new(SemaphoreInner::new(n)),
86        }
87    }
88
89    /// Attempts to get a permit for a concurrent operation.
90    ///
91    /// Return whether permit has been acquired
92    ///
93    /// # Examples
94    ///
95    /// ```
96    /// use async_sema::Semaphore;
97    ///
98    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
99    /// let s = Semaphore::new(2);
100    ///
101    /// s.acquire().await;
102    /// s.acquire().await;
103    ///
104    /// assert!(!s.try_acquire());
105    /// s.add_permits(1);
106    /// assert!(s.try_acquire());
107    /// # });
108    /// ```
109    pub fn try_acquire(&self) -> bool {
110        self.inner.try_acquire(1) > 0
111    }
112
113    /// Waits for a permit for a concurrent operation.
114    ///
115    /// # Examples
116    ///
117    /// ```
118    /// use async_sema::Semaphore;
119    ///
120    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
121    /// let s = Semaphore::new(2);
122    ///
123    /// s.acquire().await;
124    /// # });
125    /// ```
126    pub async fn acquire(&self) {
127        self.inner.acquire(1).await
128    }
129
130    /// Waits for multiple permit for a concurrent operation.
131    ///
132    /// # Examples
133    ///
134    /// ```
135    /// use async_sema::Semaphore;
136    ///
137    /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
138    /// let s = Semaphore::new(2);
139    ///
140    /// s.batch_acquire(1).await;
141    /// # });
142    /// ```
143    pub async fn batch_acquire(&self, count: usize) {
144        self.inner.acquire(count).await
145    }
146
147    /// Add permit for a concurrent operations
148    ///
149    /// # Examples
150    ///
151    /// ```
152    /// use async_sema::Semaphore;
153    ///
154    /// let s = Semaphore::new(0);
155    ///
156    /// assert!(!s.try_acquire());
157    /// s.add_permits(1);
158    /// assert!(s.try_acquire());
159    /// ```
160    pub fn add_permits(&self, n: usize) {
161        self.inner.add_permits(n)
162    }
163}