maitake_sync/semaphore.rs
1//! An asynchronous [counting semaphore].
2//!
3//! A semaphore limits the number of tasks which may execute concurrently. See
4//! the [`Semaphore`] type's documentation for details.
5//!
6//! [counting semaphore]: https://en.wikipedia.org/wiki/Semaphore_(programming)
7use crate::{
8 blocking::RawMutex,
9 loom::{
10 cell::UnsafeCell,
11 sync::{
12 atomic::{AtomicUsize, Ordering::*},
13 blocking::{Mutex, MutexGuard},
14 },
15 },
16 spin::Spinlock,
17 util::{fmt, CachePadded, WakeBatch},
18 WaitResult,
19};
20use cordyceps::{
21 list::{self, List},
22 Linked,
23};
24use core::{
25 cmp,
26 future::Future,
27 marker::PhantomPinned,
28 pin::Pin,
29 ptr::{self, NonNull},
30 task::{Context, Poll, Waker},
31};
32use pin_project::{pin_project, pinned_drop};
33
34#[cfg(test)]
35mod tests;
36
37/// An asynchronous [counting semaphore].
38///
39/// A semaphore is a synchronization primitive that limits the number of tasks
40/// that may run concurrently. It consists of a count of _permits_, which tasks
41/// may [`acquire`] in order to execute in some context. When a task acquires a
42/// permit from the semaphore, the count of permits held by the semaphore is
43/// decreased. When no permits remain in the semaphore, any task that wishes to
44/// acquire a permit must (asynchronously) wait until another task has released
45/// a permit.
46///
47/// The [`Permit`] type is a RAII guard representing one or more permits
48/// acquired from a `Semaphore`. When a [`Permit`] is dropped, the permits it
49/// represents are released back to the `Semaphore`, potentially allowing a
50/// waiting task to acquire them.
51///
52/// # Fairness
53///
54/// This semaphore is _fair_: as permits become available, they are assigned to
55/// waiting tasks in the order that those tasks requested permits (first-in,
56/// first-out). This means that all tasks waiting to acquire permits will
57/// eventually be allowed to progress, and a single task cannot starve the
58/// semaphore of permits (provided that permits are eventually released). The
59/// semaphore remains fair even when a call to `acquire` requests more than one
60/// permit at a time.
61///
62/// # Overriding the blocking mutex
63///
64/// This type uses a [blocking `Mutex`](crate::blocking::Mutex) internally to
65/// synchronize access to its wait list. By default, this is a [`Spinlock`]. To
66/// use an alternative [`RawMutex`] implementation, use the
67/// [`new_with_raw_mutex`](Self::new_with_raw_mutex) constructor. See [the documentation
68/// on overriding mutex
69/// implementations](crate::blocking#overriding-mutex-implementations) for more
70/// details.
71///
72/// Note that this type currently requires that the raw mutex implement
73/// [`RawMutex`] rather than [`mutex_traits::ScopedRawMutex`]!
74///
75/// # Examples
76///
77/// Using a semaphore to limit concurrency:
78///
79/// ```
80/// # use tokio::task;
81/// # #[tokio::main(flavor = "current_thread")]
82/// # async fn test() {
83/// # use std as alloc;
84/// use maitake_sync::Semaphore;
85/// use alloc::sync::Arc;
86///
87/// # let mut tasks = Vec::new();
88/// // Allow 4 tasks to run concurrently at a time.
89/// let semaphore = Arc::new(Semaphore::new(4));
90///
91/// for _ in 0..8 {
92/// // Clone the `Arc` around the semaphore.
93/// let semaphore = semaphore.clone();
94/// # let t =
95/// task::spawn(async move {
96/// // Acquire a permit from the semaphore, returning a RAII guard that
97/// // releases the permit back to the semaphore when dropped.
98/// //
99/// // If all 4 permits have been acquired, the calling task will yield,
100/// // and it will be woken when another task releases a permit.
101/// let _permit = semaphore
102/// .acquire(1)
103/// .await
104/// .expect("semaphore will not be closed");
105///
106/// // do some work...
107/// });
108/// # tasks.push(t);
109/// }
110/// # for task in tasks { task.await.unwrap() };
111/// # }
112/// # test();
113/// ```
114///
115/// A semaphore may also be used to cause a task to run once all of a set of
116/// tasks have completed. If we want some task _B_ to run only after a fixed
117/// number _n_ of tasks _A_ have run, we can have task _B_ try to acquire _n_
118/// permits from a semaphore with 0 permits, and have each task _A_ add one
119/// permit to the semaphore when it completes.
120///
121/// For example:
122///
123/// ```
124/// # use tokio::task;
125/// # #[tokio::main(flavor = "current_thread")]
126/// # async fn test() {
127/// # use std as alloc;
128/// use maitake_sync::Semaphore;
129/// use alloc::sync::Arc;
130///
131/// // How many tasks will we be waiting for the completion of?
132/// const TASKS: usize = 4;
133///
134/// // Create the semaphore with 0 permits.
135/// let semaphore = Arc::new(Semaphore::new(0));
136///
137/// // Spawn the "B" task that will wait for the 4 "A" tasks to complete.
138/// # let b_task =
139/// task::spawn({
140/// let semaphore = semaphore.clone();
141/// async move {
142/// println!("Task B starting...");
143///
144/// // Since the semaphore is created with 0 permits, this will
145/// // wait until all 4 "A" tasks have completed.
146/// let _permit = semaphore
147/// .acquire(TASKS)
148/// .await
149/// .expect("semaphore will not be closed");
150///
151/// // ... do some work ...
152///
153/// println!("Task B done!");
154/// }
155/// });
156///
157/// # let mut tasks = Vec::new();
158/// for i in 0..TASKS {
159/// let semaphore = semaphore.clone();
160/// # let t =
161/// task::spawn(async move {
162/// println!("Task A {i} starting...");
163///
164/// // Add a single permit to the semaphore. Once all 4 tasks have
165/// // completed, the semaphore will have the 4 permits required to
166/// // wake the "B" task.
167/// semaphore.add_permits(1);
168///
169/// // ... do some work ...
170///
171/// println!("Task A {i} done");
172/// });
173/// # tasks.push(t);
174/// }
175///
176/// # for t in tasks { t.await.unwrap() };
177/// # b_task.await.unwrap();
178/// # }
179/// # test();
180/// ```
181///
182/// [counting semaphore]: https://en.wikipedia.org/wiki/Semaphore_(programming)
183/// [`acquire`]: Semaphore::acquire
184#[derive(Debug)]
185pub struct Semaphore<Lock: RawMutex = Spinlock> {
186 /// The number of permits in the semaphore (or [`usize::MAX] if the
187 /// semaphore is closed.
188 permits: CachePadded<AtomicUsize>,
189
190 /// The queue of tasks waiting to acquire permits.
191 ///
192 /// A spinlock (from `mycelium_util`) is used here, in order to support
193 /// `no_std` platforms; when running `loom` tests, a `loom` mutex is used
194 /// instead to simulate the spinlock, because loom doesn't play nice with
195 /// real spinlocks.
196 waiters: Mutex<SemQueue, Lock>,
197}
198
199/// A [RAII guard] representing one or more permits acquired from a
200/// [`Semaphore`].
201///
202/// When the `Permit` is dropped, the permits it represents are released back to
203/// the [`Semaphore`], potentially waking another task.
204///
205/// This type is returned by the [`Semaphore::acquire`] and
206/// [`Semaphore::try_acquire`] methods.
207///
208/// [RAII guard]: https://rust-unofficial.github.io/patterns/patterns/behavioural/RAII.html
209#[derive(Debug)]
210#[must_use = "dropping a `Permit` releases the acquired permits back to the `Semaphore`"]
211pub struct Permit<'sem, Lock: RawMutex = Spinlock> {
212 permits: usize,
213 semaphore: &'sem Semaphore<Lock>,
214}
215
216/// The future returned by the [`Semaphore::acquire`] method.
217///
218/// # Notes
219///
220/// This future is `!Unpin`, as it is unsafe to [`core::mem::forget`] an
221/// `Acquire` future once it has been polled. For instance, the following code
222/// must not compile:
223///
224///```compile_fail
225/// use maitake_sync::semaphore::Acquire;
226///
227/// // Calls to this function should only compile if `T` is `Unpin`.
228/// fn assert_unpin<T: Unpin>() {}
229///
230/// assert_unpin::<Acquire<'_>>();
231/// ```
232#[derive(Debug)]
233#[pin_project(PinnedDrop)]
234#[must_use = "futures do nothing unless `.await`ed or `poll`ed"]
235pub struct Acquire<'sem, Lock: RawMutex = Spinlock> {
236 semaphore: &'sem Semaphore<Lock>,
237 queued: bool,
238 permits: usize,
239 #[pin]
240 waiter: Waiter,
241}
242
243/// Errors returned by [`Semaphore::try_acquire`].
244
245#[derive(Debug, PartialEq, Eq)]
246pub enum TryAcquireError {
247 /// The semaphore has been [closed], so additional permits cannot be
248 /// acquired.
249 ///
250 /// [closed]: Semaphore::close
251 Closed,
252 /// The semaphore does not currently have enough permits to satisfy the
253 /// request.
254 InsufficientPermits,
255}
256
257/// The semaphore's queue of waiters. This is the portion of the semaphore's
258/// state stored inside the lock.
259#[derive(Debug)]
260struct SemQueue {
261 /// The linked list of waiters.
262 ///
263 /// # Safety
264 ///
265 /// This is protected by a mutex; the mutex *must* be acquired when
266 /// manipulating the linked list, OR when manipulating waiter nodes that may
267 /// be linked into the list. If a node is known to not be linked, it is safe
268 /// to modify that node (such as by waking the stored [`Waker`]) without
269 /// holding the lock; otherwise, it may be modified through the list, so the
270 /// lock must be held when modifying the
271 /// node.
272 queue: List<Waiter>,
273
274 /// Has the semaphore closed?
275 ///
276 /// This is tracked inside of the locked state to avoid a potential race
277 /// condition where the semaphore closes while trying to lock the wait queue.
278 closed: bool,
279}
280
281#[derive(Debug)]
282#[pin_project]
283struct Waiter {
284 #[pin]
285 node: UnsafeCell<Node>,
286
287 remaining_permits: RemainingPermits,
288}
289
290/// The number of permits needed before this waiter can be woken.
291///
292/// When this value reaches zero, the waiter has acquired all its needed
293/// permits and can be woken. If this value is `usize::max`, then the waiter
294/// has not yet been linked into the semaphore queue.
295#[derive(Debug)]
296struct RemainingPermits(AtomicUsize);
297
298#[derive(Debug)]
299struct Node {
300 links: list::Links<Waiter>,
301 waker: Option<Waker>,
302
303 // This type is !Unpin due to the heuristic from:
304 // <https://github.com/rust-lang/rust/pull/82834>
305 _pin: PhantomPinned,
306}
307
308// === impl Semaphore ===
309
310impl Semaphore {
311 loom_const_fn! {
312 /// Returns a new `Semaphore` with `permits` permits available.
313 ///
314 /// # Panics
315 ///
316 /// If `permits` is less than [`MAX_PERMITS`] ([`usize::MAX`] - 1).
317 ///
318 /// [`MAX_PERMITS`]: Self::MAX_PERMITS
319 #[must_use]
320 pub fn new(permits: usize) -> Self {
321 Self::new_with_raw_mutex(permits, Spinlock::new())
322 }
323 }
324}
325
326// This is factored out as a free constant in this module so that `RwLock` can
327// depend on it without having to specify `Semaphore`'s type parameters. This is
328// a little annoying but whatever.
329pub(crate) const MAX_PERMITS: usize = usize::MAX - 1;
330
331impl<Lock: RawMutex> Semaphore<Lock> {
332 /// The maximum number of permits a `Semaphore` may contain.
333 pub const MAX_PERMITS: usize = MAX_PERMITS;
334
335 const CLOSED: usize = usize::MAX;
336
337 loom_const_fn! {
338 /// Returns a new `Semaphore` with `permits` permits available, using the
339 /// provided [`RawMutex`] implementation.
340 ///
341 /// This constructor allows a [`Semaphore`] to be constructed with any type that
342 /// implements [`RawMutex`] as the underlying raw blocking mutex
343 /// implementation. See [the documentation on overriding mutex
344 /// implementations](crate::blocking#overriding-mutex-implementations)
345 /// for more details.
346 ///
347 /// # Panics
348 ///
349 /// If `permits` is less than [`MAX_PERMITS`] ([`usize::MAX`] - 1).
350 ///
351 /// [`MAX_PERMITS`]: Self::MAX_PERMITS
352 pub fn new_with_raw_mutex(permits: usize, lock: Lock) -> Self {
353 assert!(
354 permits <= Self::MAX_PERMITS,
355 "a semaphore may not have more than Semaphore::MAX_PERMITS permits",
356 );
357 Self {
358 permits: CachePadded::new(AtomicUsize::new(permits)),
359 waiters: Mutex::new_with_raw_mutex(SemQueue::new(), lock)
360 }
361 }
362 }
363
364 /// Returns the number of permits currently available in this semaphore, or
365 /// 0 if the semaphore is [closed].
366 ///
367 /// [closed]: Semaphore::close
368 pub fn available_permits(&self) -> usize {
369 let permits = self.permits.load(Acquire);
370 if permits == Self::CLOSED {
371 return 0;
372 }
373
374 permits
375 }
376
377 /// Acquire `permits` permits from the `Semaphore`, waiting asynchronously
378 /// if there are insufficient permits currently available.
379 ///
380 /// # Returns
381 ///
382 /// - `Ok(`[`Permit`]`)` with the requested number of permits, if the
383 /// permits were acquired.
384 /// - `Err(`[`Closed`]`)` if the semaphore was [closed].
385 ///
386 /// # Cancellation
387 ///
388 /// This method uses a queue to fairly distribute permits in the order they
389 /// were requested. If an [`Acquire`] future is dropped before it completes,
390 /// the task will lose its place in the queue.
391 ///
392 /// [`Closed`]: crate::Closed
393 /// [closed]: Semaphore::close
394 pub fn acquire(&self, permits: usize) -> Acquire<'_, Lock> {
395 Acquire {
396 semaphore: self,
397 queued: false,
398 permits,
399 waiter: Waiter::new(permits),
400 }
401 }
402
403 /// Add `permits` new permits to the semaphore.
404 ///
405 /// This permanently increases the number of permits available in the
406 /// semaphore. The permit count can be permanently *decreased* by calling
407 /// [`acquire`] or [`try_acquire`], and [`forget`]ting the returned [`Permit`].
408 ///
409 /// # Panics
410 ///
411 /// If adding `permits` permits would cause the permit count to overflow
412 /// [`MAX_PERMITS`] ([`usize::MAX`] - 1).
413 ///
414 /// [`acquire`]: Self::acquire
415 /// [`try_acquire`]: Self::try_acquire
416 /// [`forget`]: Permit::forget
417 /// [`MAX_PERMITS`]: Self::MAX_PERMITS
418 #[inline(always)]
419 pub fn add_permits(&self, permits: usize) {
420 if permits == 0 {
421 return;
422 }
423
424 self.add_permits_locked(permits, self.waiters.lock());
425 }
426
427 /// Try to acquire `permits` permits from the `Semaphore`, without waiting
428 /// for additional permits to become available.
429 ///
430 /// # Returns
431 ///
432 /// - `Ok(`[`Permit`]`)` with the requested number of permits, if the
433 /// permits were acquired.
434 /// - `Err(`[`TryAcquireError::Closed`]`)` if the semaphore was [closed].
435 /// - `Err(`[`TryAcquireError::InsufficientPermits`]`)` if the semaphore had
436 /// fewer than `permits` permits available.
437 ///
438 /// [`Closed`]: crate::Closed
439 /// [closed]: Semaphore::close
440 pub fn try_acquire(&self, permits: usize) -> Result<Permit<'_, Lock>, TryAcquireError> {
441 trace!(permits, "Semaphore::try_acquire");
442 self.try_acquire_inner(permits).map(|_| Permit {
443 permits,
444 semaphore: self,
445 })
446 }
447
448 /// Closes the semaphore.
449 ///
450 /// This wakes all tasks currently waiting on the semaphore, and prevents
451 /// new permits from being acquired.
452 pub fn close(&self) {
453 let mut waiters = self.waiters.lock();
454 self.permits.store(Self::CLOSED, Release);
455 waiters.closed = true;
456 while let Some(waiter) = waiters.queue.pop_back() {
457 if let Some(waker) = Waiter::take_waker(waiter, &mut waiters.queue) {
458 waker.wake();
459 }
460 }
461 }
462
463 fn poll_acquire(
464 &self,
465 mut node: Pin<&mut Waiter>,
466 permits: usize,
467 queued: bool,
468 cx: &mut Context<'_>,
469 ) -> Poll<WaitResult<()>> {
470 trace!(
471 waiter = ?fmt::ptr(node.as_mut()),
472 permits,
473 queued,
474 "Semaphore::poll_acquire"
475 );
476 // the total number of permits we've acquired so far.
477 let mut acquired_permits = 0;
478 let waiter = node.as_mut().project();
479
480 // how many permits are currently needed?
481 let needed_permits = if queued {
482 waiter.remaining_permits.remaining()
483 } else {
484 permits
485 };
486
487 // okay, let's try to consume the requested number of permits from the
488 // semaphore.
489 let mut sem_curr = self.permits.load(Relaxed);
490 let mut lock = None;
491 let mut waiters = loop {
492 // semaphore has closed
493 if sem_curr == Self::CLOSED {
494 return crate::closed();
495 }
496
497 // the total number of permits currently available to this waiter
498 // are the number it has acquired so far plus all the permits
499 // in the semaphore.
500 let available_permits = sem_curr + acquired_permits;
501 let mut remaining = 0;
502 let mut sem_next = sem_curr;
503 let can_acquire = if available_permits >= needed_permits {
504 // there are enough permits available to satisfy this request.
505
506 // the semaphore's next state will be the current number of
507 // permits less the amount we have to take from it to satisfy
508 // request.
509 sem_next -= needed_permits - acquired_permits;
510 needed_permits
511 } else {
512 // the number of permits available in the semaphore is less than
513 // number we want to acquire. take all the currently available
514 // permits.
515 sem_next = 0;
516 // how many permits do we still need to acquire?
517 remaining = (needed_permits - acquired_permits) - sem_curr;
518 sem_curr
519 };
520
521 if remaining > 0 && lock.is_none() {
522 // we weren't able to acquire enough permits on this poll, so
523 // the waiter will probably need to be queued, so we must lock
524 // the wait queue.
525 //
526 // this has to happen *before* the CAS that sets the new value
527 // of the semaphore's permits counter. if we subtracted the
528 // permits before acquiring the lock, additional permits might
529 // be added to the semaphore while we were waiting to lock the
530 // wait queue, and we would miss acquiring those permits.
531 // therefore, we lock the queue now.
532 lock = Some(self.waiters.lock());
533 }
534
535 if let Err(actual) = test_dbg!(self.permits.compare_exchange(
536 test_dbg!(sem_curr),
537 test_dbg!(sem_next),
538 AcqRel,
539 Acquire
540 )) {
541 // the semaphore was updated while we were trying to acquire
542 // permits.
543 sem_curr = actual;
544 continue;
545 }
546
547 // okay, we took some permits from the semaphore.
548 acquired_permits += can_acquire;
549 // did we acquire all the permits we needed?
550 if test_dbg!(remaining) == 0 {
551 if !queued {
552 // the wasn't already in the queue, so we won't need to
553 // remove it --- we're done!
554 trace!(
555 waiter = ?fmt::ptr(node.as_mut()),
556 permits,
557 queued,
558 "Semaphore::poll_acquire -> all permits acquired; done"
559 );
560 return Poll::Ready(Ok(()));
561 } else {
562 // we acquired all the permits we needed, but the waiter was
563 // already in the queue, so we need to dequeue it. we may
564 // have already acquired the lock on a previous CAS attempt
565 // that failed, but if not, grab it now.
566 break lock.unwrap_or_else(|| self.waiters.lock());
567 }
568 }
569
570 // we updated the semaphore, and will need to wait to acquire
571 // additional permits.
572 break lock.expect("we should have acquired the lock before trying to wait");
573 };
574
575 if waiters.closed {
576 trace!(
577 waiter = ?fmt::ptr(node.as_mut()),
578 permits,
579 queued,
580 "Semaphore::poll_acquire -> semaphore closed"
581 );
582 return crate::closed();
583 }
584
585 // add permits to the waiter, returning whether we added enough to wake
586 // it.
587 if waiter.remaining_permits.add(&mut acquired_permits) {
588 trace!(
589 waiter = ?fmt::ptr(node.as_mut()),
590 permits,
591 queued,
592 "Semaphore::poll_acquire -> remaining permits acquired; done"
593 );
594 // if there are permits left over after waking the node, give the
595 // remaining permits back to the semaphore, potentially assigning
596 // them to the next waiter in the queue.
597 self.add_permits_locked(acquired_permits, waiters);
598 return Poll::Ready(Ok(()));
599 }
600
601 debug_assert_eq!(
602 acquired_permits, 0,
603 "if we are enqueueing a waiter, we must have used all the acquired permits"
604 );
605
606 // we need to wait --- register the polling task's waker, and enqueue
607 // node.
608 let node_ptr = unsafe { NonNull::from(Pin::into_inner_unchecked(node)) };
609 Waiter::with_node(node_ptr, &mut waiters.queue, |node| {
610 let will_wake = node
611 .waker
612 .as_ref()
613 .map_or(false, |waker| waker.will_wake(cx.waker()));
614 if !will_wake {
615 node.waker = Some(cx.waker().clone())
616 }
617 });
618
619 // if the waiter is not already in the queue, add it now.
620 if !queued {
621 waiters.queue.push_front(node_ptr);
622 trace!(
623 waiter = ?node_ptr,
624 permits,
625 queued,
626 "Semaphore::poll_acquire -> enqueued"
627 );
628 }
629
630 Poll::Pending
631 }
632
633 #[inline(never)]
634 fn add_permits_locked<'sem>(
635 &'sem self,
636 mut permits: usize,
637 mut waiters: MutexGuard<'sem, SemQueue, Lock>,
638 ) {
639 trace!(permits, "Semaphore::add_permits");
640 if waiters.closed {
641 trace!(
642 permits,
643 "Semaphore::add_permits -> already closed; doing nothing"
644 );
645 return;
646 }
647
648 let mut drained_queue = false;
649 while permits > 0 && !drained_queue {
650 let mut batch = WakeBatch::new();
651 while batch.can_add_waker() {
652 // peek the last waiter in the queue to add permits to it; we may not
653 // be popping it from the queue if there are not enough permits to
654 // wake that waiter.
655 match waiters.queue.back() {
656 Some(waiter) => {
657 // try to add enough permits to wake this waiter. if we
658 // can't, break --- we should be out of permits.
659 if !waiter.project_ref().remaining_permits.add(&mut permits) {
660 debug_assert_eq!(permits, 0);
661 break;
662 }
663 }
664 None => {
665 // we've emptied the queue. all done!
666 drained_queue = true;
667 break;
668 }
669 };
670
671 // okay, we added enough permits to wake this waiter.
672 let waiter = waiters
673 .queue
674 .pop_back()
675 .expect("if `back()` returned `Some`, `pop_back()` will also return `Some`");
676 let waker = Waiter::take_waker(waiter, &mut waiters.queue);
677 trace!(?waiter, ?waker, permits, "Semaphore::add_permits -> waking");
678 if let Some(waker) = waker {
679 batch.add_waker(waker);
680 }
681 }
682
683 if permits > 0 && drained_queue {
684 trace!(
685 permits,
686 "Semaphore::add_permits -> queue drained, assigning remaining permits to semaphore"
687 );
688 // we drained the queue, but there are still permits left --- add
689 // them to the semaphore.
690 let prev = self.permits.fetch_add(permits, Release);
691 assert!(
692 prev + permits <= Self::MAX_PERMITS,
693 "semaphore overflow adding {permits} permits to {prev}; max permits: {}",
694 Self::MAX_PERMITS
695 );
696 }
697
698 // wake set is full, drop the lock and wake everyone!
699 drop(waiters);
700 batch.wake_all();
701
702 // reacquire the lock and continue waking
703 waiters = self.waiters.lock();
704 }
705 }
706
707 /// Drop an `Acquire` future.
708 ///
709 /// This is factored out into a method on `Semaphore`, because the same code
710 /// is run when dropping an `Acquire` future or an `AcquireOwned` future.
711 fn drop_acquire(&self, waiter: Pin<&mut Waiter>, permits: usize, queued: bool) {
712 // If the future is completed, there is no node in the wait list, so we
713 // can skip acquiring the lock.
714 if !queued {
715 return;
716 }
717
718 // This is where we ensure safety. The future is being dropped,
719 // which means we must ensure that the waiter entry is no longer stored
720 // in the linked list.
721 let mut waiters = self.waiters.lock();
722
723 let acquired_permits = permits - waiter.remaining_permits.remaining();
724
725 // Safety: we have locked the wait list.
726 unsafe {
727 // remove the entry from the list
728 let node = NonNull::from(Pin::into_inner_unchecked(waiter));
729 waiters.queue.remove(node)
730 };
731
732 if acquired_permits > 0 {
733 self.add_permits_locked(acquired_permits, waiters);
734 }
735 }
736
737 /// Try to acquire permits from the semaphore without waiting.
738 ///
739 /// This method is factored out because it's identical between the
740 /// `try_acquire` and `try_acquire_owned` methods, which behave identically
741 /// but return different permit types.
742 fn try_acquire_inner(&self, permits: usize) -> Result<(), TryAcquireError> {
743 let mut available = self.permits.load(Relaxed);
744 loop {
745 // are there enough permits to satisfy the request?
746 match available {
747 Self::CLOSED => {
748 trace!(permits, "Semaphore::try_acquire -> closed");
749 return Err(TryAcquireError::Closed);
750 }
751 available if available < permits => {
752 trace!(
753 permits,
754 available,
755 "Semaphore::try_acquire -> insufficient permits"
756 );
757 return Err(TryAcquireError::InsufficientPermits);
758 }
759 _ => {}
760 }
761
762 let remaining = available - permits;
763 match self
764 .permits
765 .compare_exchange_weak(available, remaining, AcqRel, Acquire)
766 {
767 Ok(_) => {
768 trace!(permits, remaining, "Semaphore::try_acquire -> acquired");
769 return Ok(());
770 }
771 Err(actual) => available = actual,
772 }
773 }
774 }
775}
776// === impl SemQueue ===
777
778impl SemQueue {
779 #[must_use]
780 const fn new() -> Self {
781 Self {
782 queue: List::new(),
783 closed: false,
784 }
785 }
786}
787
788// === impl Acquire ===
789
790impl<'sem, Lock: RawMutex> Future for Acquire<'sem, Lock> {
791 type Output = WaitResult<Permit<'sem, Lock>>;
792 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
793 let this = self.project();
794 let poll = this
795 .semaphore
796 .poll_acquire(this.waiter, *this.permits, *this.queued, cx)
797 .map_ok(|_| Permit {
798 permits: *this.permits,
799 semaphore: this.semaphore,
800 });
801 *this.queued = poll.is_pending();
802 poll
803 }
804}
805
806#[pinned_drop]
807impl<Lock: RawMutex> PinnedDrop for Acquire<'_, Lock> {
808 fn drop(self: Pin<&mut Self>) {
809 let this = self.project();
810 trace!(?this.queued, "Acquire::drop");
811 this.semaphore
812 .drop_acquire(this.waiter, *this.permits, *this.queued)
813 }
814}
815
816// safety: the `Acquire` future is not automatically `Sync` because the `Waiter`
817// node contains an `UnsafeCell`, which is not `Sync`. this impl is safe because
818// the `Acquire` future will only access this `UnsafeCell` when mutably borrowed
819// (when polling or dropping the future), so the future itself is safe to share
820// immutably between threads.
821unsafe impl<Lock: RawMutex> Sync for Acquire<'_, Lock> {}
822
823// === impl Permit ===
824
825impl<Lock: RawMutex> Permit<'_, Lock> {
826 /// Forget this permit, dropping it *without* returning the number of
827 /// acquired permits to the semaphore.
828 ///
829 /// This permanently decreases the number of permits in the semaphore by
830 /// [`self.permits()`](Self::permits).
831 pub fn forget(mut self) {
832 self.permits = 0;
833 }
834
835 /// Returns the count of semaphore permits owned by this `Permit`.
836 #[inline]
837 #[must_use]
838 pub fn permits(&self) -> usize {
839 self.permits
840 }
841}
842
843impl<Lock: RawMutex> Drop for Permit<'_, Lock> {
844 fn drop(&mut self) {
845 trace!(?self.permits, "Permit::drop");
846 self.semaphore.add_permits(self.permits);
847 }
848}
849
850// === impl TryAcquireError ===
851
852impl fmt::Display for TryAcquireError {
853 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
854 match self {
855 Self::Closed => f.pad("semaphore closed"),
856 Self::InsufficientPermits => f.pad("semaphore has insufficient permits"),
857 }
858 }
859}
860
861feature! {
862 #![feature = "core-error"]
863 impl core::error::Error for TryAcquireError {}
864}
865
866// === Owned variants when `Arc` is available ===
867
868feature! {
869 #![feature = "alloc"]
870
871 use alloc::sync::Arc;
872
873 /// Future returned from [`Semaphore::acquire_owned()`].
874 ///
875 /// This is identical to the [`Acquire`] future, except that it takes an
876 /// [`Arc`] reference to the [`Semaphore`], allowing the returned future to
877 /// live for the `'static` lifetime, and returns an [`OwnedPermit`] (rather
878 /// than a [`Permit`]), which is also valid for the `'static` lifetime.
879 ///
880 /// # Notes
881 ///
882 /// This future is `!Unpin`, as it is unsafe to [`core::mem::forget`] an
883 /// `AcquireOwned` future once it has been polled. For instance, the
884 /// following code must not compile:
885 ///
886 ///```compile_fail
887 /// use maitake_sync::semaphore::AcquireOwned;
888 ///
889 /// // Calls to this function should only compile if `T` is `Unpin`.
890 /// fn assert_unpin<T: Unpin>() {}
891 ///
892 /// assert_unpin::<AcquireOwned<'_>>();
893 /// ```
894 #[derive(Debug)]
895 #[pin_project(PinnedDrop)]
896 #[must_use = "futures do nothing unless `.await`ed or `poll`ed"]
897 pub struct AcquireOwned<Lock: RawMutex = Spinlock> {
898 semaphore: Arc<Semaphore<Lock>>,
899 queued: bool,
900 permits: usize,
901 #[pin]
902 waiter: Waiter,
903 }
904
905 /// An owned [RAII guard] representing one or more permits acquired from a
906 /// [`Semaphore`].
907 ///
908 /// When the `OwnedPermit` is dropped, the permits it represents are
909 /// released back to the [`Semaphore`], potentially waking another task.
910 ///
911 /// This type is identical to the [`Permit`] type, except that it holds an
912 /// [`Arc`] clone of the [`Semaphore`], rather than borrowing it. This
913 /// allows the guard to be valid for the `'static` lifetime.
914 ///
915 /// This type is returned by the [`Semaphore::acquire_owned`] and
916 /// [`Semaphore::try_acquire_owned`] methods.
917 ///
918 /// [RAII guard]: https://rust-unofficial.github.io/patterns/patterns/behavioural/RAII.html
919 #[derive(Debug)]
920 #[must_use = "dropping an `OwnedPermit` releases the acquired permits back to the `Semaphore`"]
921 pub struct OwnedPermit<Lock: RawMutex = Spinlock> {
922 permits: usize,
923 semaphore: Arc<Semaphore<Lock>>,
924 }
925
926 impl<Lock: RawMutex> Semaphore<Lock> {
927 /// Acquire `permits` permits from the `Semaphore`, waiting asynchronously
928 /// if there are insufficient permits currently available, and returning
929 /// an [`OwnedPermit`].
930 ///
931 /// This method behaves identically to [`acquire`], except that it
932 /// requires the `Semaphore` to be wrapped in an [`Arc`], and returns an
933 /// [`OwnedPermit`] which clones the [`Arc`] rather than borrowing the
934 /// semaphore. This allows the returned [`OwnedPermit`] to be valid for
935 /// the `'static` lifetime.
936 ///
937 /// # Returns
938 ///
939 /// - `Ok(`[`OwnedPermit`]`)` with the requested number of permits, if the
940 /// permits were acquired.
941 /// - `Err(`[`Closed`]`)` if the semaphore was [closed].
942 ///
943 /// # Cancellation
944 ///
945 /// This method uses a queue to fairly distribute permits in the order they
946 /// were requested. If an [`AcquireOwned`] future is dropped before it
947 /// completes, the task will lose its place in the queue.
948 ///
949 /// [`acquire`]: Semaphore::acquire
950 /// [`Closed`]: crate::Closed
951 /// [closed]: Semaphore::close
952 pub fn acquire_owned(self: &Arc<Self>, permits: usize) -> AcquireOwned<Lock> {
953 AcquireOwned {
954 semaphore: self.clone(),
955 queued: false,
956 permits,
957 waiter: Waiter::new(permits),
958 }
959 }
960
961 /// Try to acquire `permits` permits from the `Semaphore`, without waiting
962 /// for additional permits to become available, and returning an [`OwnedPermit`].
963 ///
964 /// This method behaves identically to [`try_acquire`], except that it
965 /// requires the `Semaphore` to be wrapped in an [`Arc`], and returns an
966 /// [`OwnedPermit`] which clones the [`Arc`] rather than borrowing the
967 /// semaphore. This allows the returned [`OwnedPermit`] to be valid for
968 /// the `'static` lifetime.
969 ///
970 /// # Returns
971 ///
972 /// - `Ok(`[`OwnedPermit`]`)` with the requested number of permits, if the
973 /// permits were acquired.
974 /// - `Err(`[`TryAcquireError::Closed`]`)` if the semaphore was [closed].
975 /// - `Err(`[`TryAcquireError::InsufficientPermits`]`)` if the semaphore
976 /// had fewer than `permits` permits available.
977 ///
978 ///
979 /// [`try_acquire`]: Semaphore::try_acquire
980 /// [`Closed`]: crate::Closed
981 /// [closed]: Semaphore::close
982 pub fn try_acquire_owned(self: &Arc<Self>, permits: usize) -> Result<OwnedPermit<Lock>, TryAcquireError> {
983 trace!(permits, "Semaphore::try_acquire_owned");
984 self.try_acquire_inner(permits).map(|_| OwnedPermit {
985 permits,
986 semaphore: self.clone(),
987 })
988 }
989 }
990
991 // === impl AcquireOwned ===
992
993 impl<Lock: RawMutex> Future for AcquireOwned<Lock> {
994 type Output = WaitResult<OwnedPermit<Lock>>;
995
996 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
997 let this = self.project();
998 let poll = this
999 .semaphore
1000 .poll_acquire(this.waiter, *this.permits, *this.queued, cx)
1001 .map_ok(|_| OwnedPermit {
1002 permits: *this.permits,
1003 // TODO(eliza): might be nice to not have to bump the
1004 // refcount here...
1005 semaphore: this.semaphore.clone(),
1006 });
1007 *this.queued = poll.is_pending();
1008 poll
1009 }
1010 }
1011
1012 #[pinned_drop]
1013 impl<Lock: RawMutex> PinnedDrop for AcquireOwned<Lock> {
1014 fn drop(mut self: Pin<&mut Self>) {
1015 let this = self.project();
1016 trace!(?this.queued, "AcquireOwned::drop");
1017 this.semaphore
1018 .drop_acquire(this.waiter, *this.permits, *this.queued)
1019 }
1020 }
1021
1022 // safety: this is safe for the same reasons as the `Sync` impl for the
1023 // `Acquire` future.
1024 unsafe impl<Lock: RawMutex> Sync for AcquireOwned<Lock> {}
1025
1026 // === impl OwnedPermit ===
1027
1028 impl<Lock: RawMutex> OwnedPermit<Lock> {
1029 /// Forget this permit, dropping it *without* returning the number of
1030 /// acquired permits to the semaphore.
1031 ///
1032 /// This permanently decreases the number of permits in the semaphore by
1033 /// [`self.permits()`](Self::permits).
1034 pub fn forget(mut self) {
1035 self.permits = 0;
1036 }
1037
1038 /// Returns the count of semaphore permits owned by this `OwnedPermit`.
1039 #[inline]
1040 #[must_use]
1041 pub fn permits(&self) -> usize {
1042 self.permits
1043 }
1044 }
1045
1046 impl<Lock: RawMutex> Drop for OwnedPermit<Lock> {
1047 fn drop(&mut self) {
1048 trace!(?self.permits, "OwnedPermit::drop");
1049 self.semaphore.add_permits(self.permits);
1050 }
1051 }
1052
1053}
1054
1055// === impl Waiter ===
1056
1057impl Waiter {
1058 fn new(permits: usize) -> Self {
1059 Self {
1060 node: UnsafeCell::new(Node {
1061 links: list::Links::new(),
1062 waker: None,
1063 _pin: PhantomPinned,
1064 }),
1065 remaining_permits: RemainingPermits(AtomicUsize::new(permits)),
1066 }
1067 }
1068
1069 #[inline(always)]
1070 #[cfg_attr(loom, track_caller)]
1071 fn take_waker(this: NonNull<Self>, list: &mut List<Self>) -> Option<Waker> {
1072 Self::with_node(this, list, |node| node.waker.take())
1073 }
1074
1075 /// # Safety
1076 ///
1077 /// This is only safe to call while the list is locked. The dummy `_list`
1078 /// parameter ensures this method is only called while holding the lock, so
1079 /// this can be safe.
1080 ///
1081 /// Of course, that must be the *same* list that this waiter is a member of,
1082 /// and currently, there is no way to ensure that...
1083 #[inline(always)]
1084 #[cfg_attr(loom, track_caller)]
1085 fn with_node<T>(
1086 mut this: NonNull<Self>,
1087 _list: &mut List<Self>,
1088 f: impl FnOnce(&mut Node) -> T,
1089 ) -> T {
1090 unsafe {
1091 // safety: this is only called while holding the lock on the queue,
1092 // so it's safe to mutate the waiter.
1093 this.as_mut().node.with_mut(|node| f(&mut *node))
1094 }
1095 }
1096}
1097
1098unsafe impl Linked<list::Links<Waiter>> for Waiter {
1099 type Handle = NonNull<Waiter>;
1100
1101 fn into_ptr(r: Self::Handle) -> NonNull<Self> {
1102 r
1103 }
1104
1105 unsafe fn from_ptr(ptr: NonNull<Self>) -> Self::Handle {
1106 ptr
1107 }
1108
1109 unsafe fn links(target: NonNull<Self>) -> NonNull<list::Links<Waiter>> {
1110 // Safety: using `ptr::addr_of!` avoids creating a temporary
1111 // reference, which stacked borrows dislikes.
1112 let node = ptr::addr_of!((*target.as_ptr()).node);
1113 (*node).with_mut(|node| {
1114 let links = ptr::addr_of_mut!((*node).links);
1115 // Safety: since the `target` pointer is `NonNull`, we can assume
1116 // that pointers to its members are also not null, making this use
1117 // of `new_unchecked` fine.
1118 NonNull::new_unchecked(links)
1119 })
1120 }
1121}
1122
1123// === impl RemainingPermits ===
1124
1125impl RemainingPermits {
1126 /// Add an acquisition of permits to the waiter, returning whether or not
1127 /// the waiter has acquired enough permits to be woken.
1128 #[inline]
1129 #[cfg_attr(loom, track_caller)]
1130 fn add(&self, permits: &mut usize) -> bool {
1131 let mut curr = self.0.load(Relaxed);
1132 loop {
1133 let taken = cmp::min(curr, *permits);
1134 let remaining = curr - taken;
1135 match self
1136 .0
1137 .compare_exchange_weak(curr, remaining, AcqRel, Acquire)
1138 {
1139 // added the permits to the waiter!
1140 Ok(_) => {
1141 *permits -= taken;
1142 return remaining == 0;
1143 }
1144 Err(actual) => curr = actual,
1145 }
1146 }
1147 }
1148
1149 #[inline]
1150 fn remaining(&self) -> usize {
1151 self.0.load(Acquire)
1152 }
1153}