unsync/
semaphore.rs

1//! [`Semaphore`] provides an unsychronized asynchronous semaphore for permit acquisition.
2
3use std::cell::Cell;
4use std::mem::ManuallyDrop;
5
6use crate::wait_list::WaitList;
7
8/// An asynchronous semaphore for permit acquisition.
9///
10/// A semaphore allows limiting access to a shared resource to a certain number
11/// of callers at a time. It is created with a certain number of _permits_ which
12/// can be shared among tasks, and tasks can wait for permits to become
13/// available. Semaphores are commonly used for rate limiting.
14///
15/// This semaphore supports both fair and unfair operations. There are two
16/// aspects of fairness to consider:
17///
18/// 1. Whether a task that wants fewer permits can obtain those permits while a
19///     task that wants more permits waits. The default [`acquire`] method does
20///     not allow this, but [`acquire_unfair`] does.
21/// 2. Whether a task can steal the permits from another task if that second
22///    task has not been scheduled yet (so the permits have been released but
23///    have yet to be acquired). This kind of unfairness improves throughput for
24///    tasks that rapidly release and acquire permits without yielding, so by
25///    default it is used. However if permits are released with
26///    [`release_fair`], they will be directly and fairly handed off to the
27///    first waiter in line, disallowing any [`acquire`] call from
28///    opportunistically taking them.
29///
30/// In comparison to [Tokio's semaphore], this semaphore:
31/// - Is `!Sync` (obviously).
32/// - Does not support closing.
33/// - Tracks the total number of permits as well as the current number of
34///   available permits.
35/// - Does not place a limit on the total number of permits — you can go up to
36///   `usize::MAX`.
37/// - Consistently uses `usize` to count permit numbers instead of using `u32`
38///   sometimes.
39/// - Gives more control over the fairness algorithms used.
40///
41/// [`acquire`]: Self::acquire
42/// [`acquire_unfair`]: Self::acquire_unfair
43/// [`release_fair`]: Permit::release_fair
44/// [Tokio's semaphore]: https://docs.rs/tokio/1/tokio/sync/struct.Semaphore.html
45#[derive(Debug)]
46pub struct Semaphore {
47    /// List of waiters.
48    ///
49    /// Each waiter contains a `usize` that stores the number of permits desired.
50    waiters: WaitList<usize, WakeUp>,
51
52    /// The number of available permits in the semaphore.
53    permits: Cell<usize>,
54
55    /// The total number of permits in the semaphore.
56    total_permits: Cell<usize>,
57}
58
59/// Ways in which a waiter can be woken.
60#[derive(Debug, Clone, Copy)]
61enum WakeUp {
62    /// The waiter was fairly given the number of permits it requested by `add_permits_fair`.
63    Fair,
64    /// The waiter was notified by `add_permits` that the requested number of permits may be
65    /// available, but wasn't given them directly.
66    Unfair,
67}
68
69impl Semaphore {
70    /// Create a new semaphore with the given number of permits.
71    #[must_use]
72    pub const fn new(permits: usize) -> Self {
73        Self {
74            waiters: WaitList::new(),
75            permits: Cell::new(permits),
76            total_permits: Cell::new(permits),
77        }
78    }
79
80    /// Retrieve the number of currently available permits.
81    ///
82    /// # Examples
83    ///
84    /// ```
85    /// use unsync::Semaphore;
86    ///
87    /// let semaphore = Semaphore::new(100);
88    /// assert_eq!(semaphore.available_permits(), 100);
89    ///
90    /// semaphore.add_permits(10);
91    /// assert_eq!(semaphore.available_permits(), 110);
92    ///
93    /// let guard = semaphore.try_acquire(20).unwrap();
94    /// assert_eq!(semaphore.available_permits(), 90);
95    /// assert_eq!(guard.permits(), 20);
96    ///
97    /// drop(guard);
98    /// assert_eq!(semaphore.available_permits(), 110);
99    ///
100    /// semaphore.try_acquire(20).unwrap().leak();
101    /// assert_eq!(semaphore.available_permits(), 90);
102    /// ```
103    #[must_use]
104    pub fn available_permits(&self) -> usize {
105        self.permits.get()
106    }
107
108    /// Retrieve the total number of permits, including those currently handed
109    /// out.
110    ///
111    /// # Examples
112    ///
113    /// ```
114    /// use unsync::Semaphore;
115    ///
116    /// let semaphore = Semaphore::new(100);
117    /// assert_eq!(semaphore.total_permits(), 100);
118    ///
119    /// semaphore.add_permits(10);
120    /// assert_eq!(semaphore.total_permits(), 110);
121    ///
122    /// let guard = semaphore.try_acquire(20).unwrap();
123    /// assert_eq!(semaphore.total_permits(), 110);
124    /// assert_eq!(guard.permits(), 20);
125    ///
126    /// drop(guard);
127    /// assert_eq!(semaphore.total_permits(), 110);
128    ///
129    /// semaphore.try_acquire(20).unwrap().leak();
130    /// assert_eq!(semaphore.total_permits(), 90);
131    /// ```
132    #[must_use]
133    pub fn total_permits(&self) -> usize {
134        self.total_permits.get()
135    }
136
137    /// Add additional permits to the semaphore.
138    ///
139    /// # Panics
140    ///
141    /// This function will panic if it would result in more than `usize::MAX` total permits.
142    pub fn add_permits(&self, new_permits: usize) {
143        self.total_permits.set(
144            self.total_permits
145                .get()
146                .checked_add(new_permits)
147                .expect("number of permits overflowed"),
148        );
149
150        self.release_permits(new_permits, WakeUp::Unfair);
151    }
152
153    /// Add new permits to the semaphore, using a fair wakeup algorithm to
154    /// ensure that the new permits won't be taken by any waiter other than the
155    /// one at the front of the queue.
156    ///
157    /// # Examples
158    ///
159    /// ```
160    /// use unsync::Semaphore;
161    ///
162    /// let s = Semaphore::new(3);
163    /// let permit = s.try_acquire(2);
164    /// assert!(s.try_acquire(2).is_none());
165    /// s.add_permits_fair(2);
166    /// assert!(s.try_acquire(2).is_some());
167    /// ```
168    ///
169    /// # Panics
170    ///
171    /// This function will panic if it would result in more than [`usize::MAX`]
172    /// total permits.
173    pub fn add_permits_fair(&self, new_permits: usize) {
174        self.total_permits
175            .set(self.total_permits.get().checked_add(new_permits).unwrap());
176        self.release_permits(new_permits, WakeUp::Fair);
177    }
178
179    /// Attempt to acquire permits from the semaphore immediately.
180    ///
181    /// [`None`] is returned if there are not enough permits available **or** a
182    /// task is currently waiting for a permit.
183    ///
184    /// # Examples
185    ///
186    /// ```
187    /// use unsync::Semaphore;
188    ///
189    /// let s = Semaphore::new(3);
190    /// let permit = s.try_acquire(2);
191    /// assert!(s.try_acquire(2).is_none());
192    /// drop(permit);
193    /// assert!(s.try_acquire(2).is_some());
194    /// ```
195    ///
196    pub fn try_acquire(&self, to_acquire: usize) -> Option<Permit<'_>> {
197        // If a task is already waiting for some permits, we mustn't steal it.
198        if !self.waiters.borrow().is_empty() {
199            return None;
200        }
201
202        self.try_acquire_unfair(to_acquire)
203    }
204
205    /// Attempt to acquire permits from the semaphore immediately, potentially
206    /// unfairly stealing permits from a task that is waiting for permits.
207    ///
208    /// [`None`] is returned if there are not enough permits available, but
209    /// **not** if a task is currently waiting for a permit.
210    ///
211    /// # Examples
212    ///
213    /// ```
214    /// use std::future::Future;
215    /// use std::task::Poll;
216    ///
217    /// use unsync::Semaphore;
218    /// # let cx = &mut unsync::utils::noop_cx();
219    ///
220    /// let semaphore = Semaphore::new(1);
221    ///
222    /// let mut future = Box::pin(semaphore.acquire(2));
223    /// assert!(future.as_mut().poll(cx).is_pending());
224    ///
225    /// assert!(semaphore.try_acquire(1).is_none());
226    /// assert!(semaphore.try_acquire_unfair(1).is_some());
227    ///
228    /// drop(future);
229    ///
230    /// assert!(semaphore.try_acquire(1).is_some());
231    /// ```
232    pub fn try_acquire_unfair(&self, to_acquire: usize) -> Option<Permit<'_>> {
233        let new_permits = self.permits.get().checked_sub(to_acquire)?;
234        self.permits.set(new_permits);
235
236        Some(Permit {
237            semaphore: self,
238            permits: to_acquire,
239        })
240    }
241
242    /// Acquire permits from the semaphore.
243    ///
244    /// # Examples
245    ///
246    /// The following showcases the default use of a permit.
247    ///
248    /// ```
249    /// use unsync::Semaphore;
250    ///
251    /// # #[tokio::main(flavor = "current_thread")] async fn main() {
252    /// let semaphore = Semaphore::new(1);
253    /// let mut permit = semaphore.acquire(1).await;
254    /// assert!(semaphore.try_acquire(1).is_none());
255    /// # }
256    /// ```
257    pub async fn acquire(&self, to_acquire: usize) -> Permit<'_> {
258        loop {
259            if let Some(guard) = self.try_acquire(to_acquire) {
260                break guard;
261            }
262
263            match self.waiters.wait(to_acquire).await {
264                WakeUp::Unfair => continue,
265                WakeUp::Fair => {
266                    return Permit {
267                        semaphore: self,
268                        permits: to_acquire,
269                    };
270                }
271            }
272        }
273    }
274
275    /// Acquire permits from the semaphore, potentially unfairly stealing
276    /// permits from a task that is waiting for permits.
277    ///
278    /// # Examples
279    ///
280    /// ```
281    /// use std::future::Future;
282    ///
283    /// use unsync::Semaphore;
284    ///
285    /// # #[tokio::main(flavor = "current_thread")] async fn main() {
286    /// # let cx = &mut unsync::utils::noop_cx();
287    /// let semaphore = Semaphore::new(1);
288    ///
289    /// {
290    ///     let future = semaphore.acquire(2);
291    ///     tokio::pin!(future);
292    ///     assert!(future.as_mut().poll(cx).is_pending());
293    ///
294    ///     assert!(semaphore.try_acquire(1).is_none());
295    ///     // steal one permit from `future` which is in the process of acquiring permits.
296    ///     let permit = semaphore.acquire_unfair(1).await;
297    ///     drop(permit);
298    /// }
299    ///
300    /// // Since `future` is dropped here we can now acquire more permits fairly.
301    /// assert!(semaphore.try_acquire(1).is_some());
302    /// # }
303    /// ```
304    pub async fn acquire_unfair(&self, to_acquire: usize) -> Permit<'_> {
305        loop {
306            if let Some(guard) = self.try_acquire_unfair(to_acquire) {
307                break guard;
308            }
309
310            match self.waiters.wait(to_acquire).await {
311                WakeUp::Unfair => continue,
312                WakeUp::Fair => {
313                    return Permit {
314                        semaphore: self,
315                        permits: to_acquire,
316                    };
317                }
318            }
319        }
320    }
321
322    fn release_permits(&self, permits: usize, fairness: WakeUp) {
323        let mut permits = self.permits.get() + permits;
324        self.permits.set(permits);
325
326        let mut waiters = self.waiters.borrow();
327
328        while let Some(&wanted_permits) = waiters.head_input() {
329            permits = match permits.checked_sub(wanted_permits) {
330                Some(new_permits) => new_permits,
331                None => break,
332            };
333
334            if let WakeUp::Fair = fairness {
335                self.permits.set(permits);
336            }
337
338            if waiters.wake_one(fairness).is_err() {
339                // Hint that the `None` branch can be optimized away. We know
340                // this is unreachable since we've cheat that an head input
341                // exists above.
342                unreachable!();
343            }
344        }
345    }
346}
347
348/// A RAII guard holding a number of permits obtained from a [`Semaphore`].
349///
350/// # Examples
351///
352/// The following showcases the default use of a permit.
353///
354/// ```
355/// use std::future::Future;
356/// use std::task::Poll;
357///
358/// use unsync::Semaphore;
359/// # let cx = &mut unsync::utils::noop_cx();
360///
361/// let semaphore = Semaphore::new(1);
362///
363/// let initial = semaphore.try_acquire(1).unwrap();
364///
365/// let mut f1 = Box::pin(semaphore.acquire(1));
366/// assert!(f1.as_mut().poll(cx).is_pending());
367///
368/// drop(initial);
369///
370/// let mut f2 = Box::pin(semaphore.acquire(1));
371/// assert!(f2.as_mut().poll(cx).is_ready());
372/// assert!(f1.as_mut().poll(cx).is_ready());
373/// ```
374#[derive(Debug)]
375pub struct Permit<'semaphore> {
376    semaphore: &'semaphore Semaphore,
377    permits: usize,
378}
379
380impl<'semaphore> Permit<'semaphore> {
381    /// Retrieve a shared reference to the semaphore this guard is for.
382    #[must_use]
383    pub fn semaphore(&self) -> &'semaphore Semaphore {
384        self.semaphore
385    }
386
387    /// Get the number of permits this guard holds.
388    ///
389    /// ```
390    /// use unsync::Semaphore;
391    ///
392    /// let semaphore = Semaphore::new(100);
393    ///
394    /// let guard = semaphore.try_acquire(20).unwrap();
395    /// assert_eq!(semaphore.available_permits(), 80);
396    /// assert_eq!(semaphore.total_permits(), 100);
397    /// assert_eq!(guard.permits(), 20);
398    /// drop(guard);
399    ///
400    /// assert_eq!(semaphore.available_permits(), 100);
401    /// assert_eq!(semaphore.total_permits(), 100);
402    /// ```
403    #[must_use]
404    pub fn permits(&self) -> usize {
405        self.permits
406    }
407
408    /// Leak the permits without releasing it to the semaphore.
409    ///
410    /// This reduces the total number of permits in the semaphore.
411    ///
412    /// # Examples
413    ///
414    /// ```
415    /// use unsync::Semaphore;
416    ///
417    /// let semaphore = Semaphore::new(100);
418    /// assert_eq!(semaphore.available_permits(), 100);
419    /// assert_eq!(semaphore.total_permits(), 100);
420    ///
421    /// semaphore.try_acquire(20).unwrap().leak();
422    /// assert_eq!(semaphore.available_permits(), 80);
423    /// assert_eq!(semaphore.total_permits(), 80);
424    /// ```
425    pub fn leak(self) {
426        let this = ManuallyDrop::new(self);
427        let reduced_permits = this.semaphore.total_permits.get() - this.permits;
428        this.semaphore.total_permits.set(reduced_permits);
429    }
430
431    /// Release the permits using a fair wakeup algorithm to ensure that the new
432    /// permits won't be taken by any waiter other than the one at the front of
433    /// the queue.
434    ///
435    /// To release the permits with an unfair wakeup algorithm, simply call the
436    /// [`drop`] on this value or have it fall out of scope.
437    ///
438    /// # Examples
439    ///
440    /// ```
441    /// use std::future::Future;
442    /// use std::task::Poll;
443    ///
444    /// use unsync::Semaphore;
445    /// # let cx = &mut unsync::utils::noop_cx();
446    ///
447    /// let semaphore = Semaphore::new(1);
448    ///
449    /// let initial = semaphore.try_acquire(1).unwrap();
450    ///
451    /// let mut f1 = Box::pin(semaphore.acquire(1));
452    /// assert!(f1.as_mut().poll(cx).is_pending());
453    ///
454    /// initial.release_fair();
455    ///
456    /// let mut f2 = Box::pin(semaphore.acquire(1));
457    /// assert!(f2.as_mut().poll(cx).is_pending());
458    /// assert!(f1.as_mut().poll(cx).is_ready());
459    /// assert!(f2.as_mut().poll(cx).is_ready());
460    /// ```
461    pub fn release_fair(self) {
462        let this = ManuallyDrop::new(self);
463        this.semaphore.release_permits(this.permits, WakeUp::Fair);
464    }
465}
466
467impl Drop for Permit<'_> {
468    fn drop(&mut self) {
469        self.semaphore()
470            .release_permits(self.permits(), WakeUp::Unfair);
471    }
472}