use crate::error::AcquireError;
use crate::wait_queue::waiter::WaiterHandle;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub(crate) struct Acquire<'a> {
pub(crate) semaphore: &'a super::RankedSemaphore,
pub(crate) permits_needed: usize,
pub(crate) priority: isize,
pub(crate) waiter_handle: Option<WaiterHandle>,
}
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub(crate) struct AcquireOwned {
pub(crate) semaphore: Arc<super::RankedSemaphore>,
pub(crate) permits_needed: usize,
pub(crate) priority: isize,
pub(crate) waiter_handle: Option<WaiterHandle>,
}
impl<'a> Future for Acquire<'a> {
type Output = Result<super::RankedSemaphorePermit<'a>, AcquireError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
if this.waiter_handle.is_none() {
let n_shifted = (this.permits_needed) << super::RankedSemaphore::PERMIT_SHIFT;
let curr = this.semaphore.permits.load(std::sync::atomic::Ordering::Relaxed);
if curr & super::RankedSemaphore::CLOSED != 0 {
return Poll::Ready(Err(AcquireError::closed()));
}
if curr >= n_shifted {
let next = curr - n_shifted;
if this.semaphore.permits.compare_exchange_weak(
curr,
next,
std::sync::atomic::Ordering::AcqRel,
std::sync::atomic::Ordering::Relaxed,
).is_ok() {
return Poll::Ready(Ok(super::RankedSemaphorePermit {
sem: this.semaphore,
permits: this.permits_needed as u32,
}));
}
}
}
if this.waiter_handle.is_none() {
let mut waiters = this.semaphore.waiters.lock().unwrap();
if this.semaphore.permits.load(std::sync::atomic::Ordering::Relaxed) & super::RankedSemaphore::CLOSED != 0 {
return Poll::Ready(Err(AcquireError::closed()));
}
this.waiter_handle = Some(waiters.push_waiter(this.permits_needed, this.priority));
}
let handle = this.waiter_handle.as_ref().unwrap();
if handle.state.is_cancelled() {
return Poll::Ready(Err(AcquireError::closed()));
}
if handle.state.is_notified() {
if this.semaphore.is_closed() {
return Poll::Ready(Err(AcquireError::closed()));
}
return Poll::Ready(Ok(super::RankedSemaphorePermit {
sem: this.semaphore,
permits: this.permits_needed as u32,
}));
}
handle.state.set_waker(cx.waker().clone());
Poll::Pending
}
}
impl<'a> Drop for Acquire<'a> {
fn drop(&mut self) {
if let Some(handle) = self.waiter_handle.take() {
handle.state.cancel();
if handle.state.is_waiting() {
let mut waiters = self.semaphore.waiters.lock().unwrap();
waiters.remove_waiter(&handle.state);
}
}
}
}
impl Future for AcquireOwned {
type Output = Result<super::OwnedRankedSemaphorePermit, AcquireError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
if this.waiter_handle.is_none() {
let n_shifted = (this.permits_needed) << super::RankedSemaphore::PERMIT_SHIFT;
let curr = this.semaphore.permits.load(std::sync::atomic::Ordering::Relaxed);
if curr & super::RankedSemaphore::CLOSED != 0 {
return Poll::Ready(Err(AcquireError::closed()));
}
if curr >= n_shifted {
let next = curr - n_shifted;
if this.semaphore.permits.compare_exchange_weak(
curr,
next,
std::sync::atomic::Ordering::AcqRel,
std::sync::atomic::Ordering::Relaxed,
).is_ok() {
return Poll::Ready(Ok(super::OwnedRankedSemaphorePermit {
sem: Arc::clone(&this.semaphore),
permits: this.permits_needed as u32,
}));
}
}
}
if this.waiter_handle.is_none() {
let mut waiters = this.semaphore.waiters.lock().unwrap();
if this.semaphore.permits.load(std::sync::atomic::Ordering::Relaxed) & super::RankedSemaphore::CLOSED != 0 {
return Poll::Ready(Err(AcquireError::closed()));
}
this.waiter_handle = Some(waiters.push_waiter(this.permits_needed, this.priority));
}
let handle = this.waiter_handle.as_ref().unwrap();
if handle.state.is_cancelled() {
return Poll::Ready(Err(AcquireError::closed()));
}
if handle.state.is_notified() {
if this.semaphore.is_closed() {
return Poll::Ready(Err(AcquireError::closed()));
}
return Poll::Ready(Ok(super::OwnedRankedSemaphorePermit {
sem: Arc::clone(&this.semaphore),
permits: this.permits_needed as u32,
}));
}
handle.state.set_waker(cx.waker().clone());
Poll::Pending
}
}
impl Drop for AcquireOwned {
fn drop(&mut self) {
if let Some(handle) = self.waiter_handle.take() {
handle.state.cancel();
if handle.state.is_waiting() {
let mut waiters = self.semaphore.waiters.lock().unwrap();
waiters.remove_waiter(&handle.state);
}
}
}
}
use std::fmt;
impl<'a> fmt::Debug for Acquire<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Acquire")
.field("permits_needed", &self.permits_needed)
.field("priority", &self.priority)
.field("queued", &self.waiter_handle.is_some())
.finish()
}
}
impl fmt::Debug for AcquireOwned {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AcquireOwned")
.field("permits_needed", &self.permits_needed)
.field("priority", &self.priority)
.field("queued", &self.waiter_handle.is_some())
.finish()
}
}