use crate::{
blocking::RawMutex,
loom::{
cell::UnsafeCell,
sync::{
atomic::{AtomicUsize, Ordering::*},
blocking::{Mutex, MutexGuard},
},
},
spin::Spinlock,
util::{fmt, CachePadded, WakeBatch},
WaitResult,
};
use cordyceps::{
list::{self, List},
Linked,
};
use core::{
cmp,
future::Future,
marker::PhantomPinned,
pin::Pin,
ptr::{self, NonNull},
task::{Context, Poll, Waker},
};
use pin_project::{pin_project, pinned_drop};
#[cfg(test)]
mod tests;
#[derive(Debug)]
pub struct Semaphore<Lock: RawMutex = Spinlock> {
permits: CachePadded<AtomicUsize>,
waiters: Mutex<SemQueue, Lock>,
}
#[derive(Debug)]
#[must_use = "dropping a `Permit` releases the acquired permits back to the `Semaphore`"]
pub struct Permit<'sem, Lock: RawMutex = Spinlock> {
permits: usize,
semaphore: &'sem Semaphore<Lock>,
}
#[derive(Debug)]
#[pin_project(PinnedDrop)]
#[must_use = "futures do nothing unless `.await`ed or `poll`ed"]
pub struct Acquire<'sem, Lock: RawMutex = Spinlock> {
semaphore: &'sem Semaphore<Lock>,
queued: bool,
permits: usize,
#[pin]
waiter: Waiter,
}
#[derive(Debug, PartialEq, Eq)]
pub enum TryAcquireError {
Closed,
InsufficientPermits,
}
#[derive(Debug)]
struct SemQueue {
queue: List<Waiter>,
closed: bool,
}
#[derive(Debug)]
#[pin_project]
struct Waiter {
#[pin]
node: UnsafeCell<Node>,
remaining_permits: RemainingPermits,
}
#[derive(Debug)]
struct RemainingPermits(AtomicUsize);
#[derive(Debug)]
struct Node {
links: list::Links<Waiter>,
waker: Option<Waker>,
_pin: PhantomPinned,
}
impl Semaphore {
loom_const_fn! {
#[must_use]
pub fn new(permits: usize) -> Self {
Self::new_with_raw_mutex(permits, Spinlock::new())
}
}
}
pub(crate) const MAX_PERMITS: usize = usize::MAX - 1;
impl<Lock: RawMutex> Semaphore<Lock> {
pub const MAX_PERMITS: usize = MAX_PERMITS;
const CLOSED: usize = usize::MAX;
loom_const_fn! {
pub fn new_with_raw_mutex(permits: usize, lock: Lock) -> Self {
assert!(
permits <= Self::MAX_PERMITS,
"a semaphore may not have more than Semaphore::MAX_PERMITS permits",
);
Self {
permits: CachePadded::new(AtomicUsize::new(permits)),
waiters: Mutex::new_with_raw_mutex(SemQueue::new(), lock)
}
}
}
pub fn available_permits(&self) -> usize {
let permits = self.permits.load(Acquire);
if permits == Self::CLOSED {
return 0;
}
permits
}
pub fn acquire(&self, permits: usize) -> Acquire<'_, Lock> {
Acquire {
semaphore: self,
queued: false,
permits,
waiter: Waiter::new(permits),
}
}
#[inline(always)]
pub fn add_permits(&self, permits: usize) {
if permits == 0 {
return;
}
self.add_permits_locked(permits, self.waiters.lock());
}
pub fn try_acquire(&self, permits: usize) -> Result<Permit<'_, Lock>, TryAcquireError> {
trace!(permits, "Semaphore::try_acquire");
self.try_acquire_inner(permits).map(|_| Permit {
permits,
semaphore: self,
})
}
pub fn close(&self) {
let mut waiters = self.waiters.lock();
self.permits.store(Self::CLOSED, Release);
waiters.closed = true;
while let Some(waiter) = waiters.queue.pop_back() {
if let Some(waker) = Waiter::take_waker(waiter, &mut waiters.queue) {
waker.wake();
}
}
}
fn poll_acquire(
&self,
mut node: Pin<&mut Waiter>,
permits: usize,
queued: bool,
cx: &mut Context<'_>,
) -> Poll<WaitResult<()>> {
trace!(
waiter = ?fmt::ptr(node.as_mut()),
permits,
queued,
"Semaphore::poll_acquire"
);
let mut acquired_permits = 0;
let waiter = node.as_mut().project();
let needed_permits = if queued {
waiter.remaining_permits.remaining()
} else {
permits
};
let mut sem_curr = self.permits.load(Relaxed);
let mut lock = None;
let mut waiters = loop {
if sem_curr == Self::CLOSED {
return crate::closed();
}
let available_permits = sem_curr + acquired_permits;
let mut remaining = 0;
let mut sem_next = sem_curr;
let can_acquire = if available_permits >= needed_permits {
sem_next -= needed_permits - acquired_permits;
needed_permits
} else {
sem_next = 0;
remaining = (needed_permits - acquired_permits) - sem_curr;
sem_curr
};
if remaining > 0 && lock.is_none() {
lock = Some(self.waiters.lock());
}
if let Err(actual) = test_dbg!(self.permits.compare_exchange(
test_dbg!(sem_curr),
test_dbg!(sem_next),
AcqRel,
Acquire
)) {
sem_curr = actual;
continue;
}
acquired_permits += can_acquire;
if test_dbg!(remaining) == 0 {
if !queued {
trace!(
waiter = ?fmt::ptr(node.as_mut()),
permits,
queued,
"Semaphore::poll_acquire -> all permits acquired; done"
);
return Poll::Ready(Ok(()));
} else {
break lock.unwrap_or_else(|| self.waiters.lock());
}
}
break lock.expect("we should have acquired the lock before trying to wait");
};
if waiters.closed {
trace!(
waiter = ?fmt::ptr(node.as_mut()),
permits,
queued,
"Semaphore::poll_acquire -> semaphore closed"
);
return crate::closed();
}
if waiter.remaining_permits.add(&mut acquired_permits) {
trace!(
waiter = ?fmt::ptr(node.as_mut()),
permits,
queued,
"Semaphore::poll_acquire -> remaining permits acquired; done"
);
self.add_permits_locked(acquired_permits, waiters);
return Poll::Ready(Ok(()));
}
debug_assert_eq!(
acquired_permits, 0,
"if we are enqueueing a waiter, we must have used all the acquired permits"
);
let node_ptr = unsafe { NonNull::from(Pin::into_inner_unchecked(node)) };
Waiter::with_node(node_ptr, &mut waiters.queue, |node| {
let will_wake = node
.waker
.as_ref()
.map_or(false, |waker| waker.will_wake(cx.waker()));
if !will_wake {
node.waker = Some(cx.waker().clone())
}
});
if !queued {
waiters.queue.push_front(node_ptr);
trace!(
waiter = ?node_ptr,
permits,
queued,
"Semaphore::poll_acquire -> enqueued"
);
}
Poll::Pending
}
#[inline(never)]
fn add_permits_locked<'sem>(
&'sem self,
mut permits: usize,
mut waiters: MutexGuard<'sem, SemQueue, Lock>,
) {
trace!(permits, "Semaphore::add_permits");
if waiters.closed {
trace!(
permits,
"Semaphore::add_permits -> already closed; doing nothing"
);
return;
}
let mut drained_queue = false;
while permits > 0 && !drained_queue {
let mut batch = WakeBatch::new();
while batch.can_add_waker() {
match waiters.queue.back() {
Some(waiter) => {
if !waiter.project_ref().remaining_permits.add(&mut permits) {
debug_assert_eq!(permits, 0);
break;
}
}
None => {
drained_queue = true;
break;
}
};
let waiter = waiters
.queue
.pop_back()
.expect("if `back()` returned `Some`, `pop_back()` will also return `Some`");
let waker = Waiter::take_waker(waiter, &mut waiters.queue);
trace!(?waiter, ?waker, permits, "Semaphore::add_permits -> waking");
if let Some(waker) = waker {
batch.add_waker(waker);
}
}
if permits > 0 && drained_queue {
trace!(
permits,
"Semaphore::add_permits -> queue drained, assigning remaining permits to semaphore"
);
let prev = self.permits.fetch_add(permits, Release);
assert!(
prev + permits <= Self::MAX_PERMITS,
"semaphore overflow adding {permits} permits to {prev}; max permits: {}",
Self::MAX_PERMITS
);
}
drop(waiters);
batch.wake_all();
waiters = self.waiters.lock();
}
}
fn drop_acquire(&self, waiter: Pin<&mut Waiter>, permits: usize, queued: bool) {
if !queued {
return;
}
let mut waiters = self.waiters.lock();
let acquired_permits = permits - waiter.remaining_permits.remaining();
unsafe {
let node = NonNull::from(Pin::into_inner_unchecked(waiter));
waiters.queue.remove(node)
};
if acquired_permits > 0 {
self.add_permits_locked(acquired_permits, waiters);
}
}
fn try_acquire_inner(&self, permits: usize) -> Result<(), TryAcquireError> {
let mut available = self.permits.load(Relaxed);
loop {
match available {
Self::CLOSED => {
trace!(permits, "Semaphore::try_acquire -> closed");
return Err(TryAcquireError::Closed);
}
available if available < permits => {
trace!(
permits,
available,
"Semaphore::try_acquire -> insufficient permits"
);
return Err(TryAcquireError::InsufficientPermits);
}
_ => {}
}
let remaining = available - permits;
match self
.permits
.compare_exchange_weak(available, remaining, AcqRel, Acquire)
{
Ok(_) => {
trace!(permits, remaining, "Semaphore::try_acquire -> acquired");
return Ok(());
}
Err(actual) => available = actual,
}
}
}
}
impl SemQueue {
#[must_use]
const fn new() -> Self {
Self {
queue: List::new(),
closed: false,
}
}
}
impl<'sem, Lock: RawMutex> Future for Acquire<'sem, Lock> {
type Output = WaitResult<Permit<'sem, Lock>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let poll = this
.semaphore
.poll_acquire(this.waiter, *this.permits, *this.queued, cx)
.map_ok(|_| Permit {
permits: *this.permits,
semaphore: this.semaphore,
});
*this.queued = poll.is_pending();
poll
}
}
#[pinned_drop]
impl<Lock: RawMutex> PinnedDrop for Acquire<'_, Lock> {
fn drop(self: Pin<&mut Self>) {
let this = self.project();
trace!(?this.queued, "Acquire::drop");
this.semaphore
.drop_acquire(this.waiter, *this.permits, *this.queued)
}
}
unsafe impl<Lock: RawMutex> Sync for Acquire<'_, Lock> {}
impl<Lock: RawMutex> Permit<'_, Lock> {
pub fn forget(mut self) {
self.permits = 0;
}
#[inline]
#[must_use]
pub fn permits(&self) -> usize {
self.permits
}
}
impl<Lock: RawMutex> Drop for Permit<'_, Lock> {
fn drop(&mut self) {
trace!(?self.permits, "Permit::drop");
self.semaphore.add_permits(self.permits);
}
}
impl fmt::Display for TryAcquireError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Closed => f.pad("semaphore closed"),
Self::InsufficientPermits => f.pad("semaphore has insufficient permits"),
}
}
}
feature! {
#![feature = "core-error"]
impl core::error::Error for TryAcquireError {}
}
feature! {
#![feature = "alloc"]
use alloc::sync::Arc;
#[derive(Debug)]
#[pin_project(PinnedDrop)]
#[must_use = "futures do nothing unless `.await`ed or `poll`ed"]
pub struct AcquireOwned<Lock: RawMutex = Spinlock> {
semaphore: Arc<Semaphore<Lock>>,
queued: bool,
permits: usize,
#[pin]
waiter: Waiter,
}
#[derive(Debug)]
#[must_use = "dropping an `OwnedPermit` releases the acquired permits back to the `Semaphore`"]
pub struct OwnedPermit<Lock: RawMutex = Spinlock> {
permits: usize,
semaphore: Arc<Semaphore<Lock>>,
}
impl<Lock: RawMutex> Semaphore<Lock> {
pub fn acquire_owned(self: &Arc<Self>, permits: usize) -> AcquireOwned<Lock> {
AcquireOwned {
semaphore: self.clone(),
queued: false,
permits,
waiter: Waiter::new(permits),
}
}
pub fn try_acquire_owned(self: &Arc<Self>, permits: usize) -> Result<OwnedPermit<Lock>, TryAcquireError> {
trace!(permits, "Semaphore::try_acquire_owned");
self.try_acquire_inner(permits).map(|_| OwnedPermit {
permits,
semaphore: self.clone(),
})
}
}
impl<Lock: RawMutex> Future for AcquireOwned<Lock> {
type Output = WaitResult<OwnedPermit<Lock>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let poll = this
.semaphore
.poll_acquire(this.waiter, *this.permits, *this.queued, cx)
.map_ok(|_| OwnedPermit {
permits: *this.permits,
semaphore: this.semaphore.clone(),
});
*this.queued = poll.is_pending();
poll
}
}
#[pinned_drop]
impl<Lock: RawMutex> PinnedDrop for AcquireOwned<Lock> {
fn drop(mut self: Pin<&mut Self>) {
let this = self.project();
trace!(?this.queued, "AcquireOwned::drop");
this.semaphore
.drop_acquire(this.waiter, *this.permits, *this.queued)
}
}
unsafe impl<Lock: RawMutex> Sync for AcquireOwned<Lock> {}
impl<Lock: RawMutex> OwnedPermit<Lock> {
pub fn forget(mut self) {
self.permits = 0;
}
#[inline]
#[must_use]
pub fn permits(&self) -> usize {
self.permits
}
}
impl<Lock: RawMutex> Drop for OwnedPermit<Lock> {
fn drop(&mut self) {
trace!(?self.permits, "OwnedPermit::drop");
self.semaphore.add_permits(self.permits);
}
}
}
impl Waiter {
fn new(permits: usize) -> Self {
Self {
node: UnsafeCell::new(Node {
links: list::Links::new(),
waker: None,
_pin: PhantomPinned,
}),
remaining_permits: RemainingPermits(AtomicUsize::new(permits)),
}
}
#[inline(always)]
#[cfg_attr(loom, track_caller)]
fn take_waker(this: NonNull<Self>, list: &mut List<Self>) -> Option<Waker> {
Self::with_node(this, list, |node| node.waker.take())
}
#[inline(always)]
#[cfg_attr(loom, track_caller)]
fn with_node<T>(
mut this: NonNull<Self>,
_list: &mut List<Self>,
f: impl FnOnce(&mut Node) -> T,
) -> T {
unsafe {
this.as_mut().node.with_mut(|node| f(&mut *node))
}
}
}
unsafe impl Linked<list::Links<Waiter>> for Waiter {
type Handle = NonNull<Waiter>;
fn into_ptr(r: Self::Handle) -> NonNull<Self> {
r
}
unsafe fn from_ptr(ptr: NonNull<Self>) -> Self::Handle {
ptr
}
unsafe fn links(target: NonNull<Self>) -> NonNull<list::Links<Waiter>> {
let node = ptr::addr_of!((*target.as_ptr()).node);
(*node).with_mut(|node| {
let links = ptr::addr_of_mut!((*node).links);
NonNull::new_unchecked(links)
})
}
}
impl RemainingPermits {
#[inline]
#[cfg_attr(loom, track_caller)]
fn add(&self, permits: &mut usize) -> bool {
let mut curr = self.0.load(Relaxed);
loop {
let taken = cmp::min(curr, *permits);
let remaining = curr - taken;
match self
.0
.compare_exchange_weak(curr, remaining, AcqRel, Acquire)
{
Ok(_) => {
*permits -= taken;
return remaining == 0;
}
Err(actual) => curr = actual,
}
}
}
#[inline]
fn remaining(&self) -> usize {
self.0.load(Acquire)
}
}