use parking_lot::Mutex;
use pin_project::{pin_project, pinned_drop};
use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
pub struct DeficitSemaphore(Mutex<State>);
impl DeficitSemaphore {
pub fn new(permits: usize) -> Arc<DeficitSemaphore> {
Arc::new(DeficitSemaphore(Mutex::new(State {
total_permits: permits,
outstanding_permits: 0,
head: None,
tail: None,
})))
}
pub fn add_permits(&self, permits: usize) {
let mut state = self.0.lock();
state.total_permits = state
.total_permits
.checked_add(permits)
.expect("permit count overflow");
state.maybe_wake();
}
pub fn remove_permits(&self, permits: usize) {
let mut state = self.0.lock();
state.total_permits = state
.total_permits
.checked_sub(permits)
.expect("permit count underflow");
}
pub fn acquire(self: Arc<Self>) -> Acquire {
Acquire {
semaphore: self,
node: None,
_p: PhantomPinned,
}
}
}
struct State {
total_permits: usize,
outstanding_permits: usize,
head: Option<NonNull<Node>>,
tail: Option<NonNull<Node>>,
}
impl State {
fn has_permits(&self) -> bool {
self.total_permits > self.outstanding_permits
}
fn maybe_wake(&mut self) {
if self.has_permits() {
if let Some(head) = &mut self.head {
unsafe {
if let Some(waker) = head.as_mut().waker.take() {
waker.wake();
}
}
}
}
}
}
struct Node {
waker: Option<Waker>,
next: Option<NonNull<Node>>,
prev: Option<NonNull<Node>>,
}
impl Node {
unsafe fn dequeue(&mut self, state: &mut State) {
let self_ptr = NonNull::from(&mut *self);
match &mut self.prev {
Some(prev) => {
debug_assert_eq!(prev.as_mut().next, Some(self_ptr));
prev.as_mut().next = self.next;
}
None => {
debug_assert_eq!(state.head, Some(self_ptr));
state.head = self.next;
}
}
match &mut self.next {
Some(next) => {
debug_assert_eq!(next.as_mut().prev, Some(self_ptr));
next.as_mut().prev = self.prev;
}
None => {
debug_assert_eq!(state.tail, Some(self_ptr));
state.tail = self.prev;
}
}
self.prev = None;
self.next = None;
}
unsafe fn push_back(&mut self, state: &mut State) {
debug_assert_eq!(self.next, None);
debug_assert_eq!(self.prev, None);
match &mut state.tail {
Some(tail) => {
debug_assert_eq!(tail.as_mut().next, None);
tail.as_mut().next = Some(self.into());
self.prev = state.tail;
}
None => state.head = Some(self.into()),
}
state.tail = Some(self.into());
}
}
unsafe impl Sync for State {}
unsafe impl Send for State {}
#[pin_project(PinnedDrop)]
pub struct Acquire {
semaphore: Arc<DeficitSemaphore>,
node: Option<Node>,
#[pin]
_p: PhantomPinned,
}
unsafe impl Sync for Acquire {}
unsafe impl Send for Acquire {}
#[pinned_drop]
impl PinnedDrop for Acquire {
fn drop(self: Pin<&mut Self>) {
let this = self.project();
if let Some(node) = this.node {
let mut state = this.semaphore.0.lock();
let is_head = node.prev.is_none();
unsafe {
node.dequeue(&mut state);
}
if is_head {
state.maybe_wake();
}
}
}
}
impl Future for Acquire {
type Output = Permit;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut state = this.semaphore.0.lock();
if state.head.is_none() && state.has_permits() {
state.outstanding_permits += 1;
return Poll::Ready(Permit {
semaphore: this.semaphore.clone(),
});
}
let node = match this.node {
Some(node) => node,
None => {
*this.node = Some(Node {
waker: None, next: None,
prev: None,
});
let node = this.node.as_mut().unwrap();
unsafe {
node.push_back(&mut state);
}
node
}
};
if node.prev.is_none() && state.has_permits() {
unsafe {
node.dequeue(&mut state);
}
*this.node = None;
state.outstanding_permits += 1;
let permit = Permit {
semaphore: this.semaphore.clone(),
};
state.maybe_wake();
Poll::Ready(permit)
} else {
if !node
.waker
.as_ref()
.map_or(false, |w| w.will_wake(cx.waker()))
{
node.waker = Some(cx.waker().clone());
}
Poll::Pending
}
}
}
pub struct Permit {
semaphore: Arc<DeficitSemaphore>,
}
impl Drop for Permit {
fn drop(&mut self) {
let mut state = self.semaphore.0.lock();
state.outstanding_permits -= 1;
state.maybe_wake();
}
}
#[cfg(test)]
mod test {
use super::*;
use futures::pin_mut;
use futures_test::task;
#[track_caller]
fn assert_ready<T>(poll: Poll<T>) -> T {
match poll {
Poll::Ready(value) => value,
Poll::Pending => panic!("expected ready"),
}
}
#[track_caller]
fn assert_pending<T>(poll: Poll<T>) {
match poll {
Poll::Pending => {}
Poll::Ready(_) => panic!("expected pending"),
}
}
#[test]
fn uncontended_acquire() {
let semaphore = DeficitSemaphore::new(2);
let fut = semaphore.clone().acquire();
pin_mut!(fut);
let _permit = assert_ready(fut.poll(&mut task::panic_context()));
let fut = semaphore.acquire();
pin_mut!(fut);
let _permit = assert_ready(fut.poll(&mut task::panic_context()));
}
#[test]
fn single_queued_acquire() {
let (waker, count) = task::new_count_waker();
let mut count_cx = Context::from_waker(&waker);
let semaphore = DeficitSemaphore::new(1);
let fut = semaphore.clone().acquire();
pin_mut!(fut);
let permit = assert_ready(fut.poll(&mut task::panic_context()));
let fut = semaphore.acquire();
pin_mut!(fut);
assert_pending(fut.as_mut().poll(&mut count_cx));
assert_eq!(count, 0);
drop(permit);
assert_eq!(count, 1);
assert_ready(fut.poll(&mut count_cx));
}
#[test]
fn acquires_are_ordered() {
let (waker1, count1) = task::new_count_waker();
let mut count_cx1 = Context::from_waker(&waker1);
let (waker2, count2) = task::new_count_waker();
let mut count_cx2 = Context::from_waker(&waker2);
let semaphore = DeficitSemaphore::new(1);
let fut = semaphore.clone().acquire();
pin_mut!(fut);
let permit = assert_ready(fut.poll(&mut task::panic_context()));
let fut1 = semaphore.clone().acquire();
pin_mut!(fut1);
assert_pending(fut1.as_mut().poll(&mut count_cx1));
let fut2 = semaphore.acquire();
pin_mut!(fut2);
assert_pending(fut2.as_mut().poll(&mut count_cx2));
assert_eq!(count1, 0);
assert_eq!(count2, 0);
drop(permit);
assert_eq!(count1, 1);
assert_eq!(count2, 0);
assert_pending(fut2.as_mut().poll(&mut count_cx2));
let permit = assert_ready(fut1.poll(&mut count_cx2));
assert_pending(fut2.as_mut().poll(&mut count_cx2));
assert_eq!(count1, 1);
assert_eq!(count2, 0);
drop(permit);
assert_eq!(count1, 1);
assert_eq!(count2, 1);
assert_ready(fut2.poll(&mut count_cx2));
}
#[test]
fn wakes_chain_on_acquire() {
let (waker1, count1) = task::new_count_waker();
let mut count_cx1 = Context::from_waker(&waker1);
let (waker2, count2) = task::new_count_waker();
let mut count_cx2 = Context::from_waker(&waker2);
let semaphore = DeficitSemaphore::new(2);
let fut = semaphore.clone().acquire();
pin_mut!(fut);
let permit1 = assert_ready(fut.poll(&mut task::panic_context()));
let fut = semaphore.clone().acquire();
pin_mut!(fut);
let permit2 = assert_ready(fut.poll(&mut task::panic_context()));
let fut1 = semaphore.clone().acquire();
pin_mut!(fut1);
assert_pending(fut1.as_mut().poll(&mut count_cx1));
let fut2 = semaphore.acquire();
pin_mut!(fut2);
assert_pending(fut2.as_mut().poll(&mut count_cx2));
assert_eq!(count1, 0);
assert_eq!(count2, 0);
drop(permit1);
drop(permit2);
assert_eq!(count1, 1);
assert_eq!(count2, 0);
let _permit = assert_ready(fut1.poll(&mut count_cx1));
assert_eq!(count1, 1);
assert_eq!(count2, 1);
assert_ready(fut2.poll(&mut count_cx2));
}
#[test]
fn early_head_drop() {
let (waker1, count1) = task::new_count_waker();
let mut count_cx1 = Context::from_waker(&waker1);
let (waker2, count2) = task::new_count_waker();
let mut count_cx2 = Context::from_waker(&waker2);
let semaphore = DeficitSemaphore::new(1);
let fut = semaphore.clone().acquire();
pin_mut!(fut);
let permit = assert_ready(fut.poll(&mut task::panic_context()));
let mut fut1 = Box::pin(semaphore.clone().acquire());
assert_pending(fut1.as_mut().poll(&mut count_cx1));
let fut2 = semaphore.acquire();
pin_mut!(fut2);
assert_pending(fut2.as_mut().poll(&mut count_cx2));
assert_eq!(count1, 0);
assert_eq!(count2, 0);
drop(permit);
assert_eq!(count1, 1);
assert_eq!(count2, 0);
drop(fut1);
assert_eq!(count1, 1);
assert_eq!(count2, 1);
assert_ready(fut2.poll(&mut count_cx2));
}
#[test]
fn early_middle_drop() {
let (waker1, count1) = task::new_count_waker();
let mut count_cx1 = Context::from_waker(&waker1);
let (waker2, count2) = task::new_count_waker();
let mut count_cx2 = Context::from_waker(&waker2);
let (waker3, count3) = task::new_count_waker();
let mut count_cx3 = Context::from_waker(&waker3);
let semaphore = DeficitSemaphore::new(1);
let fut = semaphore.clone().acquire();
pin_mut!(fut);
let permit = assert_ready(fut.poll(&mut task::panic_context()));
let fut1 = semaphore.clone().acquire();
pin_mut!(fut1);
assert_pending(fut1.as_mut().poll(&mut count_cx1));
let mut fut2 = Box::pin(semaphore.clone().acquire());
assert_pending(fut2.as_mut().poll(&mut count_cx2));
let fut3 = semaphore.acquire();
pin_mut!(fut3);
assert_pending(fut3.as_mut().poll(&mut count_cx3));
drop(fut2);
assert_eq!(count1, 0);
assert_eq!(count2, 0);
assert_eq!(count3, 0);
drop(permit);
assert_eq!(count1, 1);
assert_eq!(count2, 0);
assert_eq!(count3, 0);
assert_ready(fut1.poll(&mut count_cx1));
assert_eq!(count1, 1);
assert_eq!(count2, 0);
assert_eq!(count3, 1);
assert_ready(fut3.poll(&mut count_cx3));
}
#[test]
fn early_tail_drop() {
let (waker1, count1) = task::new_count_waker();
let mut count_cx1 = Context::from_waker(&waker1);
let (waker2, count2) = task::new_count_waker();
let mut count_cx2 = Context::from_waker(&waker2);
let semaphore = DeficitSemaphore::new(1);
let fut = semaphore.clone().acquire();
pin_mut!(fut);
let permit = assert_ready(fut.poll(&mut task::panic_context()));
let fut1 = semaphore.clone().acquire();
pin_mut!(fut1);
assert_pending(fut1.as_mut().poll(&mut count_cx1));
let mut fut2 = Box::pin(semaphore.acquire());
assert_pending(fut2.as_mut().poll(&mut count_cx2));
drop(fut2);
assert_eq!(count1, 0);
assert_eq!(count2, 0);
drop(permit);
assert_eq!(count1, 1);
assert_eq!(count2, 0);
assert_ready(fut1.poll(&mut count_cx1));
assert_eq!(count1, 1);
assert_eq!(count2, 0);
}
#[test]
fn add_permits() {
let (waker, count) = task::new_count_waker();
let mut count_cx = Context::from_waker(&waker);
let semaphore = DeficitSemaphore::new(0);
let fut = semaphore.clone().acquire();
pin_mut!(fut);
assert_pending(fut.as_mut().poll(&mut count_cx));
assert_eq!(count, 0);
semaphore.add_permits(1);
assert_eq!(count, 1);
assert_ready(fut.poll(&mut count_cx));
}
#[test]
fn remove_permits() {
let semaphore = DeficitSemaphore::new(1);
let fut = semaphore.clone().acquire();
pin_mut!(fut);
let permit = assert_ready(fut.poll(&mut task::panic_context()));
let fut = semaphore.clone().acquire();
pin_mut!(fut);
assert_pending(fut.as_mut().poll(&mut task::panic_context()));
semaphore.remove_permits(1);
drop(permit);
assert_pending(fut.poll(&mut task::panic_context()));
}
#[tokio::test]
async fn stress_test() {
let semaphore = DeficitSemaphore::new(10);
let mut handles = vec![];
for _ in 0..100 {
let handle = tokio::spawn({
let semaphore = semaphore.clone();
async move {
for _ in 0..1000 {
let _permit = semaphore.clone().acquire().await;
tokio::task::yield_now().await;
}
}
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
}
}