use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use std::thread::{self, Thread};
use parking_lot::Mutex;
pub(crate) const STATE_WAITING: u8 = 0;
pub(crate) const STATE_SUCCESS: u8 = 1;
pub(crate) const STATE_CANCELLED: u8 = 2;
#[derive(Debug)]
enum Waiter {
Sync(Thread),
Async {
waker: Waker,
state: *const AtomicU8,
},
}
unsafe impl Send for Waiter {}
impl Waiter {
fn wake(self) {
match self {
Waiter::Sync(thread) => thread.unpark(),
Waiter::Async { waker, .. } => waker.wake(),
}
}
fn will_wake(&self, waker: &Waker) -> bool {
match self {
Waiter::Async { waker: self_waker, .. } => self_waker.will_wake(waker),
Waiter::Sync(_) => false,
}
}
}
#[derive(Debug)]
struct GateInternal {
permits: usize,
waiters: VecDeque<Waiter>,
}
pub struct CapacityGate {
capacity: usize,
internal: Arc<Mutex<GateInternal>>,
}
impl fmt::Debug for CapacityGate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let internal = self.internal.lock();
f.debug_struct("CapacityGate")
.field("capacity", &self.capacity)
.field("permits", &internal.permits)
.field("waiters", &internal.waiters.len())
.finish()
}
}
impl CapacityGate {
pub fn new(capacity: usize) -> Self {
Self {
capacity,
internal: Arc::new(Mutex::new(GateInternal {
permits: capacity,
waiters: VecDeque::new(),
})),
}
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn acquire_sync(&self) {
if self.try_acquire() {
return;
}
let mut internal = self.internal.lock();
loop {
if internal.permits > 0 {
internal.permits -= 1;
return;
}
internal.waiters.push_back(Waiter::Sync(thread::current()));
drop(internal);
thread::park();
internal = self.internal.lock();
}
}
pub fn acquire_async(&self) -> AcquireFuture<'_> {
AcquireFuture {
gate: self,
state: AtomicU8::new(STATE_WAITING),
is_registered: false,
_phantom: PhantomPinned,
}
}
pub fn try_acquire(&self) -> bool {
let mut internal = self.internal.lock();
if internal.waiters.is_empty() && internal.permits > 0 {
internal.permits -= 1;
true
} else {
false
}
}
pub fn close(&self) {
let mut internal = self.internal.lock();
while let Some(waiter) = internal.waiters.pop_front() {
waiter.wake();
}
}
pub fn release(&self) {
let mut internal = self.internal.lock();
internal.permits += 1;
while let Some(waiter) = internal.waiters.pop_front() {
match waiter {
Waiter::Sync(thread) => {
thread.unpark();
return;
}
Waiter::Async { waker, state } => {
let state_ref = unsafe { &*state };
if state_ref
.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
waker.wake();
return;
}
}
}
}
internal.permits = internal.permits.min(self.capacity);
}
}
#[must_use = "futures do nothing unless you .await or poll them"]
pub struct AcquireFuture<'a> {
gate: &'a CapacityGate,
state: AtomicU8,
is_registered: bool,
_phantom: PhantomPinned,
}
impl<'a> Future for AcquireFuture<'a> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.as_mut().get_unchecked_mut() };
if this.state.load(Ordering::Acquire) == STATE_SUCCESS {
this.is_registered = false;
return Poll::Ready(());
}
let mut internal = this.gate.internal.lock();
if this.state.load(Ordering::Acquire) == STATE_SUCCESS {
this.is_registered = false;
return Poll::Ready(());
}
if internal.waiters.is_empty() && internal.permits > 0 {
internal.permits -= 1;
this.is_registered = false;
return Poll::Ready(());
}
if internal.permits > 0 {
internal.permits -= 1;
this.is_registered = false;
return Poll::Ready(());
}
let new_waker = cx.waker();
let state_ptr = &this.state as *const AtomicU8;
let mut found = false;
for waiter in internal.waiters.iter_mut() {
if let Waiter::Async { state, waker: ref mut existing_waker } = waiter {
if *state == state_ptr {
*existing_waker = new_waker.clone();
found = true;
break;
}
}
}
if !found {
this.is_registered = true;
this.state.store(STATE_WAITING, Ordering::SeqCst);
internal.waiters.push_back(Waiter::Async {
waker: new_waker.clone(),
state: state_ptr,
});
}
Poll::Pending
}
}
impl<'a> Drop for AcquireFuture<'a> {
fn drop(&mut self) {
if self.is_registered
&& self
.state
.compare_exchange(STATE_WAITING, STATE_CANCELLED, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
let mut internal = self.gate.internal.lock();
let state_ptr = &self.state as *const AtomicU8;
internal.waiters.retain(|w| match w {
Waiter::Async { state, .. } => *state != state_ptr,
_ => true,
});
}
}
}
impl Clone for CapacityGate {
fn clone(&self) -> Self {
Self {
capacity: self.capacity,
internal: self.internal.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn gate_new_and_capacity() {
let gate = CapacityGate::new(5);
assert_eq!(gate.capacity(), 5);
}
#[test]
fn acquire_sync_release() {
let gate = CapacityGate::new(1);
gate.acquire_sync();
gate.release();
}
#[test]
fn acquire_sync_blocks_and_unblocks() {
let gate = Arc::new(CapacityGate::new(1));
gate.acquire_sync();
let gate_clone = gate.clone();
let handle = thread::spawn(move || {
gate_clone.acquire_sync();
});
thread::sleep(Duration::from_millis(100));
assert!(!handle.is_finished(), "Thread should have blocked");
gate.release();
handle.join().expect("Thread panicked");
}
#[cfg(not(miri))]
#[tokio::test]
async fn acquire_async_waits_and_completes() {
use tokio::time::timeout;
let gate = Arc::new(CapacityGate::new(1));
gate.acquire_sync();
let acquire_fut = gate.acquire_async();
let gate_for_spawn = gate.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
gate_for_spawn.release();
});
timeout(Duration::from_millis(500), acquire_fut)
.await
.expect("Future did not complete after release");
}
#[cfg(not(miri))]
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn mixed_waiters_contention() {
let gate = Arc::new(CapacityGate::new(2));
let mut thread_handles = Vec::new();
let mut task_handles = Vec::new();
let completion_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
for _ in 0..3 {
let gate = gate.clone();
let count = completion_count.clone();
thread_handles.push(thread::spawn(move || {
gate.acquire_sync();
thread::sleep(Duration::from_millis(50));
gate.release();
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}));
}
for _ in 0..3 {
let gate = gate.clone();
let count = completion_count.clone();
task_handles.push(tokio::spawn(async move {
gate.acquire_async().await;
tokio::time::sleep(Duration::from_millis(50)).await;
gate.release();
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}));
}
for handle in task_handles {
handle.await.unwrap();
}
for handle in thread_handles {
handle.join().unwrap();
}
assert_eq!(
completion_count.load(std::sync::atomic::Ordering::Relaxed),
6
);
}
#[test]
fn test_acquire_async_drop_leak() {
let gate = CapacityGate::new(1);
assert!(gate.try_acquire());
fn dummy_waker() -> Waker {
use std::task::{RawWaker, RawWakerVTable};
unsafe fn clone(_: *const ()) -> RawWaker {
RawWaker::new(std::ptr::null(), &VTABLE)
}
unsafe fn wake(_: *const ()) {}
unsafe fn wake_by_ref(_: *const ()) {}
unsafe fn drop_raw(_: *const ()) {}
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_raw);
unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
}
let waker = dummy_waker();
let mut cx = Context::from_waker(&waker);
let mut fut = Box::pin(gate.acquire_async());
assert!(fut.as_mut().poll(&mut cx).is_pending());
{
let internal = gate.internal.lock();
assert_eq!(internal.waiters.len(), 1);
}
drop(fut);
let leaked_count = {
let internal = gate.internal.lock();
internal.waiters.len()
};
assert_eq!(
leaked_count, 0,
"Waker leak detected: dropping AcquireFuture left a stale waker in the waiters queue!"
);
}
}