async_sema/lib.rs
1use event_listener::Event;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4
5#[derive(Debug)]
6pub(crate) struct SemaphoreInner {
7 count: AtomicUsize,
8 event: Event,
9}
10
11impl SemaphoreInner {
12 pub const fn new(n: usize) -> Self {
13 Self {
14 count: AtomicUsize::new(n),
15 event: Event::new(),
16 }
17 }
18
19 pub fn try_acquire(&self, count: usize) -> usize {
20 let mut balance = self.count.load(Ordering::Acquire);
21 loop {
22 if balance == 0 {
23 return 0;
24 }
25 let dest = match balance >= count {
26 true => balance - count,
27 false => 0,
28 };
29
30 match self.count.compare_exchange_weak(
31 balance,
32 dest,
33 Ordering::AcqRel,
34 Ordering::Acquire,
35 ) {
36 Ok(_) => return balance - dest,
37 Err(c) => balance = c,
38 }
39 }
40 }
41
42 pub async fn acquire(&self, count: usize) {
43 let mut listener = None;
44 let mut acquired = 0;
45
46 loop {
47 acquired += self.try_acquire(count - acquired);
48 if count == acquired {
49 return;
50 }
51
52 match listener.take() {
53 None => listener = Some(self.event.listen()),
54 Some(l) => l.await,
55 }
56 }
57 }
58
59 pub fn add_permits(&self, n: usize) {
60 self.count.fetch_add(n, Ordering::AcqRel);
61 self.event.notify(n);
62 }
63}
64
65/// A counter for limiting the number of concurrent operations.
66#[derive(Debug, Clone)]
67pub struct Semaphore {
68 inner: Arc<SemaphoreInner>,
69}
70
71unsafe impl Send for Semaphore {}
72
73impl Semaphore {
74 /// Creates a new semaphore with a limit of `n` concurrent operations.
75 ///
76 /// # Examples
77 ///
78 /// ```
79 /// use async_sema::Semaphore;
80 ///
81 /// let s = Semaphore::new(5);
82 /// ```
83 pub fn new(n: usize) -> Semaphore {
84 Semaphore {
85 inner: Arc::new(SemaphoreInner::new(n)),
86 }
87 }
88
89 /// Attempts to get a permit for a concurrent operation.
90 ///
91 /// Return whether permit has been acquired
92 ///
93 /// # Examples
94 ///
95 /// ```
96 /// use async_sema::Semaphore;
97 ///
98 /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
99 /// let s = Semaphore::new(2);
100 ///
101 /// s.acquire().await;
102 /// s.acquire().await;
103 ///
104 /// assert!(!s.try_acquire());
105 /// s.add_permits(1);
106 /// assert!(s.try_acquire());
107 /// # });
108 /// ```
109 pub fn try_acquire(&self) -> bool {
110 self.inner.try_acquire(1) > 0
111 }
112
113 /// Waits for a permit for a concurrent operation.
114 ///
115 /// # Examples
116 ///
117 /// ```
118 /// use async_sema::Semaphore;
119 ///
120 /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
121 /// let s = Semaphore::new(2);
122 ///
123 /// s.acquire().await;
124 /// # });
125 /// ```
126 pub async fn acquire(&self) {
127 self.inner.acquire(1).await
128 }
129
130 /// Waits for multiple permit for a concurrent operation.
131 ///
132 /// # Examples
133 ///
134 /// ```
135 /// use async_sema::Semaphore;
136 ///
137 /// # tokio::runtime::Runtime::new().unwrap().block_on(async {
138 /// let s = Semaphore::new(2);
139 ///
140 /// s.batch_acquire(1).await;
141 /// # });
142 /// ```
143 pub async fn batch_acquire(&self, count: usize) {
144 self.inner.acquire(count).await
145 }
146
147 /// Add permit for a concurrent operations
148 ///
149 /// # Examples
150 ///
151 /// ```
152 /// use async_sema::Semaphore;
153 ///
154 /// let s = Semaphore::new(0);
155 ///
156 /// assert!(!s.try_acquire());
157 /// s.add_permits(1);
158 /// assert!(s.try_acquire());
159 /// ```
160 pub fn add_permits(&self, n: usize) {
161 self.inner.add_permits(n)
162 }
163}