#![allow(unused)]
use core::future::Future;
use std::{
cell::{RefCell, UnsafeCell},
cmp, fmt,
marker::PhantomPinned,
pin::Pin,
ptr::NonNull,
task::{Context, Poll, Waker},
};
use crate::{
linked_list::{self, LinkedList},
wake_list::WakeList,
};
pub(crate) struct Inner {
waiters: RefCell<Waitlist>,
permits: RefCell<usize>,
}
struct Waitlist {
queue: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>,
closed: bool,
}
#[derive(Debug, PartialEq)]
pub enum TryAcquireError {
Closed,
NoPermits,
}
#[derive(Debug)]
pub struct AcquireError(());
pub(crate) struct Acquire<'a> {
node: Waiter,
semaphore: &'a Inner,
num_permits: u32,
queued: bool,
}
struct Waiter {
state: RefCell<usize>,
waker: UnsafeCell<Option<Waker>>,
pointers: linked_list::Pointers<Waiter>,
_p: PhantomPinned,
}
impl Waiter {
fn new(num_permits: u32) -> Self {
Waiter {
waker: UnsafeCell::new(None),
state: RefCell::new(num_permits as usize),
pointers: linked_list::Pointers::new(),
_p: PhantomPinned,
}
}
fn assign_permits(&self, n: &mut usize) -> bool {
let mut curr = self.state.borrow_mut();
let assign = cmp::min(*curr, *n);
*curr -= assign;
*n -= assign;
*curr == 0
}
}
unsafe impl linked_list::Link for Waiter {
type Handle = NonNull<Waiter>;
type Target = Waiter;
fn as_raw(handle: &Self::Handle) -> NonNull<Waiter> {
*handle
}
unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> {
ptr
}
unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> {
NonNull::from(&mut target.as_mut().pointers)
}
}
impl Inner {
pub(crate) const MAX_PERMITS: usize = std::usize::MAX >> 3;
const CLOSED: usize = 1;
const PERMIT_SHIFT: usize = 1;
pub(crate) const fn new(mut permits: usize) -> Self {
permits &= Self::MAX_PERMITS;
Self {
permits: RefCell::new(permits << Self::PERMIT_SHIFT),
waiters: RefCell::new(Waitlist {
queue: LinkedList::new(),
closed: false,
}),
}
}
pub(crate) fn available_permits(&self) -> usize {
*self.permits.borrow() >> Self::PERMIT_SHIFT
}
pub(crate) fn release(&self, added: usize) {
if added == 0 {
return;
}
self.add_permits(added);
}
pub(crate) fn close(&self) {
*self.permits.borrow_mut() |= Self::CLOSED;
(*self.waiters.borrow_mut()).closed = true;
let mut waiters = self.waiters.borrow_mut();
while let Some(mut waiter) = waiters.queue.pop_back() {
let waker = unsafe { (*waiter.as_mut().waker.get()).take() };
if let Some(waker) = waker {
waker.wake();
}
}
}
pub(crate) fn is_closed(&self) -> bool {
*self.permits.borrow() & Self::CLOSED != 0
}
pub(crate) fn try_acquire(&self, num_permits: u32) -> Result<(), TryAcquireError> {
assert!(
num_permits as usize <= Self::MAX_PERMITS,
"a semaphore may not have more than MAX_PERMITS permits ({})",
Self::MAX_PERMITS
);
let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT;
let mut curr = self.permits.borrow_mut();
if *curr & Self::CLOSED == Self::CLOSED {
return Err(TryAcquireError::Closed);
}
if *curr < num_permits {
return Err(TryAcquireError::NoPermits);
}
*curr -= num_permits;
Ok(())
}
pub(crate) fn acquire(&self, num_permits: u32) -> Acquire<'_> {
Acquire::new(self, num_permits)
}
fn add_permits(&self, mut rem: usize) {
let mut waiters = self.waiters.borrow_mut();
let mut wakers = WakeList::new();
let mut is_empty = false;
while rem > 0 {
'inner: while wakers.can_push() {
match waiters.queue.last() {
Some(waiter) => {
if !waiter.assign_permits(&mut rem) {
break 'inner;
}
}
None => {
is_empty = true;
break 'inner;
}
};
let mut waiter = waiters.queue.pop_back().unwrap();
if let Some(waker) = unsafe { (*waiter.as_mut().waker.get()).take() } {
wakers.push(waker);
}
}
if rem > 0 && is_empty {
let permits = rem;
assert!(
permits <= Self::MAX_PERMITS,
"cannot add more than MAX_PERMITS permits ({})",
Self::MAX_PERMITS
);
*self.permits.borrow_mut() += rem << Self::PERMIT_SHIFT;
rem = 0;
}
wakers.wake_all();
}
assert_eq!(rem, 0);
}
fn poll_acquire(
&self,
cx: &mut Context<'_>,
num_permits: u32,
node: Pin<&mut Waiter>,
queued: bool,
) -> Poll<Result<(), AcquireError>> {
let needed = if queued {
*node.state.borrow() << Self::PERMIT_SHIFT
} else {
(num_permits as usize) << Self::PERMIT_SHIFT
};
let mut curr = self.permits.borrow_mut();
if *curr & Self::CLOSED > 0 {
return Poll::Ready(Err(AcquireError::closed()));
}
if *curr >= needed && !queued {
*curr -= needed;
return Poll::Ready(Ok(()));
}
let mut permits = *curr >> Self::PERMIT_SHIFT;
*curr = 0;
drop(curr);
if node.assign_permits(&mut permits) {
self.add_permits(permits);
return Poll::Ready(Ok(()));
}
let waker = unsafe { &mut *node.waker.get() };
if waker
.as_ref()
.map(|waker| !waker.will_wake(cx.waker()))
.unwrap_or(true)
{
*waker = Some(cx.waker().clone());
}
if !queued {
let node = unsafe {
let node = Pin::into_inner_unchecked(node) as *mut _;
NonNull::new_unchecked(node)
};
self.waiters.borrow_mut().queue.push_front(node);
}
Poll::Pending
}
}
impl fmt::Debug for Inner {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Semaphore")
.field("permits", &self.available_permits())
.finish()
}
}
impl Future for Acquire<'_> {
type Output = Result<(), AcquireError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let (node, semaphore, needed, queued) = self.project();
match semaphore.poll_acquire(cx, needed, node, *queued) {
Poll::Pending => {
*queued = true;
Poll::Pending
}
Poll::Ready(r) => {
r?;
*queued = false;
Poll::Ready(Ok(()))
}
}
}
}
impl<'a> Acquire<'a> {
fn new(semaphore: &'a Inner, num_permits: u32) -> Self {
Self {
node: Waiter::new(num_permits),
semaphore,
num_permits,
queued: false,
}
}
fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Inner, u32, &mut bool) {
fn is_unpin<T: Unpin>() {}
unsafe {
is_unpin::<&Inner>();
is_unpin::<&mut bool>();
is_unpin::<u32>();
let this = self.get_unchecked_mut();
(
Pin::new_unchecked(&mut this.node),
this.semaphore,
this.num_permits,
&mut this.queued,
)
}
}
}
impl Drop for Acquire<'_> {
fn drop(&mut self) {
if !self.queued {
return;
}
let mut waiters = self.semaphore.waiters.borrow_mut();
let node = NonNull::from(&mut self.node);
unsafe { waiters.queue.remove(node) };
let acquired_permits = self.num_permits as usize - *self.node.state.borrow();
if acquired_permits > 0 {
self.semaphore.add_permits(acquired_permits);
}
}
}
impl AcquireError {
fn closed() -> AcquireError {
AcquireError(())
}
}
impl fmt::Display for AcquireError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "semaphore closed")
}
}
impl std::error::Error for AcquireError {}
impl TryAcquireError {
#[allow(dead_code)] pub(crate) fn is_closed(&self) -> bool {
matches!(self, TryAcquireError::Closed)
}
#[allow(dead_code)] pub(crate) fn is_no_permits(&self) -> bool {
matches!(self, TryAcquireError::NoPermits)
}
}
impl fmt::Display for TryAcquireError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TryAcquireError::Closed => write!(fmt, "semaphore closed"),
TryAcquireError::NoPermits => write!(fmt, "no permits available"),
}
}
}
impl std::error::Error for TryAcquireError {}
#[derive(Debug)]
pub struct Semaphore(Inner);
#[must_use]
#[derive(Debug)]
pub struct SemaphorePermit<'a> {
sem: &'a Semaphore,
permits: u32,
}
#[must_use]
#[derive(Debug)]
pub struct OwnedSemaphorePermit {
sem: std::rc::Rc<Semaphore>,
permits: u32,
}
pub struct AcquireResult<'a>(Acquire<'a>, &'a Semaphore, u32);
impl<'a> Future for AcquireResult<'a> {
type Output = Result<SemaphorePermit<'a>, AcquireError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let sem = self.1;
let permits = self.2;
let inner = unsafe { self.map_unchecked_mut(|me| &mut me.0) };
futures_util::ready!(inner.poll(cx))?;
Poll::Ready(Ok(SemaphorePermit { sem, permits }))
}
}
impl Semaphore {
pub const fn new(permits: usize) -> Self {
Self(Inner::new(permits))
}
pub fn available_permits(&self) -> usize {
self.0.available_permits()
}
pub fn add_permits(&self, n: usize) {
self.0.release(n);
}
pub fn acquire(&self) -> AcquireResult<'_> {
let acq = self.0.acquire(1);
AcquireResult(acq, self, 1)
}
pub fn acquire_many(&self, n: u32) -> AcquireResult<'_> {
let acq = self.0.acquire(n);
AcquireResult(acq, self, n)
}
pub fn try_acquire(&self) -> Result<SemaphorePermit<'_>, TryAcquireError> {
match self.0.try_acquire(1) {
Ok(_) => Ok(SemaphorePermit {
sem: self,
permits: 1,
}),
Err(e) => Err(e),
}
}
pub fn try_acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, TryAcquireError> {
match self.0.try_acquire(n) {
Ok(_) => Ok(SemaphorePermit {
sem: self,
permits: n,
}),
Err(e) => Err(e),
}
}
pub async fn acquire_owned(
self: std::rc::Rc<Self>,
) -> Result<OwnedSemaphorePermit, AcquireError> {
self.0.acquire(1).await?;
Ok(OwnedSemaphorePermit {
sem: self,
permits: 1,
})
}
pub async fn acquire_many_owned(
self: std::rc::Rc<Self>,
n: u32,
) -> Result<OwnedSemaphorePermit, AcquireError> {
self.0.acquire(n).await?;
Ok(OwnedSemaphorePermit {
sem: self,
permits: n,
})
}
pub fn try_acquire_owned(
self: std::rc::Rc<Self>,
) -> Result<OwnedSemaphorePermit, TryAcquireError> {
match self.0.try_acquire(1) {
Ok(_) => Ok(OwnedSemaphorePermit {
sem: self,
permits: 1,
}),
Err(e) => Err(e),
}
}
pub fn try_acquire_many_owned(
self: std::rc::Rc<Self>,
n: u32,
) -> Result<OwnedSemaphorePermit, TryAcquireError> {
match self.0.try_acquire(n) {
Ok(_) => Ok(OwnedSemaphorePermit {
sem: self,
permits: n,
}),
Err(e) => Err(e),
}
}
pub fn close(&self) {
self.0.close();
}
pub fn is_closed(&self) -> bool {
self.0.is_closed()
}
}
impl<'a> SemaphorePermit<'a> {
pub fn forget(mut self) {
self.permits = 0;
}
}
impl OwnedSemaphorePermit {
pub fn forget(mut self) {
self.permits = 0;
}
}
impl<'a> Drop for SemaphorePermit<'_> {
fn drop(&mut self) {
self.sem.add_permits(self.permits as usize);
}
}
impl Drop for OwnedSemaphorePermit {
fn drop(&mut self) {
self.sem.add_permits(self.permits as usize);
}
}
#[cfg(test)]
mod tests {
use super::{Inner, Semaphore};
#[monoio::test]
async fn inner_works() {
let s = Inner::new(10);
for _ in 0..10 {
s.acquire(1).await.unwrap();
}
}
#[monoio::test]
async fn inner_release_after_acquire() {
let s = std::rc::Rc::new(Inner::new(0));
let s_move = s.clone();
let join = monoio::spawn(async move {
let _ = s_move.acquire(1).await.unwrap();
let _ = s_move.acquire(1).await.unwrap();
});
s.release(2);
join.await;
}
#[monoio::test]
async fn it_works() {
let s = Semaphore::new(0);
s.add_permits(1);
let _ = s.acquire().await.unwrap();
}
}