use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use std::thread::{self, Thread};
use parking_lot::Mutex;
pub(crate) const STATE_WAITING: u8 = 0;
pub(crate) const STATE_SUCCESS: u8 = 1;
pub(crate) const STATE_CANCELLED: u8 = 2;
#[derive(Debug)]
enum Waiter {
Sync(Thread),
Async {
waker: Waker,
state: *const AtomicU8,
},
}
unsafe impl Send for Waiter {}
impl Waiter {
fn wake(self) {
match self {
Waiter::Sync(thread) => thread.unpark(),
Waiter::Async { waker, .. } => waker.wake(),
}
}
fn will_wake(&self, waker: &Waker) -> bool {
match self {
Waiter::Async { waker: self_waker, .. } => self_waker.will_wake(waker),
Waiter::Sync(_) => false,
}
}
}
#[derive(Debug)]
struct GateInternal {
permits: usize,
reserved: usize,
waiters: VecDeque<Waiter>,
is_closed: bool,
}
pub struct CapacityGate {
capacity: usize,
internal: Arc<Mutex<GateInternal>>,
}
impl fmt::Debug for CapacityGate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let internal = self.internal.lock();
f.debug_struct("CapacityGate")
.field("capacity", &self.capacity)
.field("permits", &internal.permits)
.field("reserved", &internal.reserved)
.field("waiters", &internal.waiters.len())
.finish()
}
}
impl CapacityGate {
pub fn new(capacity: usize) -> Self {
Self {
capacity,
internal: Arc::new(Mutex::new(GateInternal {
permits: capacity,
reserved: 0,
waiters: VecDeque::new(),
is_closed: false,
})),
}
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn acquire_sync(&self) {
if self.try_acquire() {
return;
}
let mut internal = self.internal.lock();
let mut was_parked = false;
loop {
if internal.is_closed {
return;
}
if was_parked && internal.reserved > 0 {
internal.reserved -= 1;
return;
}
if internal.permits > 0 {
internal.permits -= 1;
return;
}
internal.waiters.push_back(Waiter::Sync(thread::current()));
drop(internal);
thread::park();
was_parked = true;
internal = self.internal.lock();
}
}
pub fn acquire_async(&self) -> AcquireFuture<'_> {
AcquireFuture {
gate: self,
state: AtomicU8::new(STATE_WAITING),
is_registered: false,
_phantom: PhantomPinned,
}
}
pub fn try_acquire(&self) -> bool {
let mut internal = self.internal.lock();
if internal.is_closed {
return true;
}
if internal.waiters.is_empty() && internal.permits > 0 {
internal.permits -= 1;
true
} else {
false
}
}
pub fn close(&self) {
let mut internal = self.internal.lock();
if !internal.is_closed {
internal.is_closed = true;
while let Some(waiter) = internal.waiters.pop_front() {
waiter.wake();
}
}
}
pub fn release(&self) {
let mut internal = self.internal.lock();
internal.permits += 1;
while let Some(waiter) = internal.waiters.pop_front() {
match waiter {
Waiter::Sync(thread) => {
internal.permits -= 1;
internal.reserved += 1;
thread.unpark();
return;
}
Waiter::Async { waker, state } => {
let state_ref = unsafe { &*state };
if state_ref
.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
internal.permits -= 1;
waker.wake();
return;
}
}
}
}
internal.permits = internal.permits.min(self.capacity);
}
pub fn try_acquire_many(&self, n: usize) -> usize {
if n == 0 {
return 0;
}
let mut internal = self.internal.lock();
if internal.is_closed {
return n;
}
if internal.waiters.is_empty() && internal.permits > 0 {
let k = internal.permits.min(n);
internal.permits -= k;
k
} else {
0
}
}
pub fn acquire_many_sync(&self, max: usize) -> usize {
if max == 0 {
return 0;
}
let k = self.try_acquire_many(max);
if k > 0 {
return k;
}
let mut internal = self.internal.lock();
let mut was_parked = false;
loop {
if internal.is_closed {
return max;
}
if was_parked && internal.reserved > 0 {
internal.reserved -= 1;
let extra = internal.permits.min(max - 1);
internal.permits -= extra;
return 1 + extra;
}
if internal.permits > 0 {
let k = internal.permits.min(max);
internal.permits -= k;
return k;
}
internal.waiters.push_back(Waiter::Sync(thread::current()));
drop(internal);
thread::park();
was_parked = true;
internal = self.internal.lock();
}
}
pub fn acquire_many_async(&self, max: usize) -> AcquireManyFuture<'_> {
AcquireManyFuture {
gate: self,
max,
state: AtomicU8::new(STATE_WAITING),
is_registered: false,
_phantom: PhantomPinned,
}
}
pub fn release_many(&self, n: usize) {
if n == 0 {
return;
}
let mut to_wake: Vec<Waiter> = Vec::new();
{
let mut internal = self.internal.lock();
internal.permits += n;
while to_wake.len() < n {
match internal.waiters.pop_front() {
None => break,
Some(Waiter::Sync(thread)) => {
internal.permits -= 1;
internal.reserved += 1;
to_wake.push(Waiter::Sync(thread));
}
Some(Waiter::Async { waker, state }) => {
let state_ref = unsafe { &*state };
if state_ref
.compare_exchange(
STATE_WAITING,
STATE_SUCCESS,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
internal.permits -= 1;
to_wake.push(Waiter::Async { waker, state });
}
}
}
}
internal.permits = internal.permits.min(self.capacity);
}
for waiter in to_wake {
waiter.wake();
}
}
}
#[must_use = "futures do nothing unless you .await or poll them"]
pub struct AcquireFuture<'a> {
gate: &'a CapacityGate,
state: AtomicU8,
is_registered: bool,
_phantom: PhantomPinned,
}
impl<'a> Future for AcquireFuture<'a> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.as_mut().get_unchecked_mut() };
if this.state.load(Ordering::Acquire) == STATE_SUCCESS {
this.is_registered = false;
return Poll::Ready(());
}
let mut internal = this.gate.internal.lock();
let state_ptr = &this.state as *const AtomicU8;
if internal.is_closed {
this.is_registered = false;
return Poll::Ready(());
}
if this.state.load(Ordering::Acquire) == STATE_SUCCESS {
this.is_registered = false;
return Poll::Ready(());
}
if internal.permits > 0 {
internal.permits -= 1;
if this.is_registered {
internal.waiters.retain(|w| match w {
Waiter::Async { state, .. } => *state != state_ptr,
_ => true,
});
this.is_registered = false;
}
return Poll::Ready(());
}
let new_waker = cx.waker();
let mut found = false;
for waiter in internal.waiters.iter_mut() {
if let Waiter::Async { state, waker: ref mut existing_waker } = waiter {
if *state == state_ptr {
*existing_waker = new_waker.clone();
found = true;
break;
}
}
}
if !found {
this.is_registered = true;
this.state.store(STATE_WAITING, Ordering::SeqCst);
internal.waiters.push_back(Waiter::Async {
waker: new_waker.clone(),
state: state_ptr,
});
}
Poll::Pending
}
}
impl<'a> Drop for AcquireFuture<'a> {
fn drop(&mut self) {
if self.is_registered
&& self
.state
.compare_exchange(STATE_WAITING, STATE_CANCELLED, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
let mut internal = self.gate.internal.lock();
let state_ptr = &self.state as *const AtomicU8;
internal.waiters.retain(|w| match w {
Waiter::Async { state, .. } => *state != state_ptr,
_ => true,
});
}
}
}
#[must_use = "futures do nothing unless you .await or poll them"]
pub struct AcquireManyFuture<'a> {
gate: &'a CapacityGate,
max: usize,
state: AtomicU8,
is_registered: bool,
_phantom: PhantomPinned,
}
impl<'a> Future for AcquireManyFuture<'a> {
type Output = usize;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.as_mut().get_unchecked_mut() };
if this.max == 0 {
this.is_registered = false;
return Poll::Ready(0);
}
if this.state.load(Ordering::Acquire) == STATE_SUCCESS {
let mut internal = this.gate.internal.lock();
let extra = internal.permits.min(this.max - 1);
internal.permits -= extra;
this.is_registered = false;
return Poll::Ready(1 + extra);
}
let mut internal = this.gate.internal.lock();
let state_ptr = &this.state as *const AtomicU8;
if internal.is_closed {
this.is_registered = false;
return Poll::Ready(this.max);
}
if this.state.load(Ordering::Acquire) == STATE_SUCCESS {
let extra = internal.permits.min(this.max - 1);
internal.permits -= extra;
this.is_registered = false;
return Poll::Ready(1 + extra);
}
if internal.permits > 0 {
let k = internal.permits.min(this.max);
internal.permits -= k;
if this.is_registered {
internal.waiters.retain(|w| match w {
Waiter::Async { state, .. } => *state != state_ptr,
_ => true,
});
this.is_registered = false;
}
return Poll::Ready(k);
}
let new_waker = cx.waker();
let mut found = false;
for waiter in internal.waiters.iter_mut() {
if let Waiter::Async { state, waker: ref mut existing_waker } = waiter {
if *state == state_ptr {
*existing_waker = new_waker.clone();
found = true;
break;
}
}
}
if !found {
this.is_registered = true;
this.state.store(STATE_WAITING, Ordering::SeqCst);
internal.waiters.push_back(Waiter::Async {
waker: new_waker.clone(),
state: state_ptr,
});
}
Poll::Pending
}
}
impl<'a> Drop for AcquireManyFuture<'a> {
fn drop(&mut self) {
if self.is_registered
&& self
.state
.compare_exchange(STATE_WAITING, STATE_CANCELLED, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
let mut internal = self.gate.internal.lock();
let state_ptr = &self.state as *const AtomicU8;
internal.waiters.retain(|w| match w {
Waiter::Async { state, .. } => *state != state_ptr,
_ => true,
});
}
}
}
impl Clone for CapacityGate {
fn clone(&self) -> Self {
Self {
capacity: self.capacity,
internal: self.internal.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn gate_new_and_capacity() {
let gate = CapacityGate::new(5);
assert_eq!(gate.capacity(), 5);
}
#[test]
fn acquire_sync_release() {
let gate = CapacityGate::new(1);
gate.acquire_sync();
gate.release();
}
#[test]
fn acquire_sync_blocks_and_unblocks() {
let gate = Arc::new(CapacityGate::new(1));
gate.acquire_sync();
let gate_clone = gate.clone();
let handle = thread::spawn(move || {
gate_clone.acquire_sync();
});
thread::sleep(Duration::from_millis(100));
assert!(!handle.is_finished(), "Thread should have blocked");
gate.release();
handle.join().expect("Thread panicked");
}
#[cfg(not(miri))]
#[tokio::test]
async fn acquire_async_waits_and_completes() {
use tokio::time::timeout;
let gate = Arc::new(CapacityGate::new(1));
gate.acquire_sync();
let acquire_fut = gate.acquire_async();
let gate_for_spawn = gate.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
gate_for_spawn.release();
});
timeout(Duration::from_millis(500), acquire_fut)
.await
.expect("Future did not complete after release");
}
#[cfg(not(miri))]
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn mixed_waiters_contention() {
let gate = Arc::new(CapacityGate::new(2));
let mut thread_handles = Vec::new();
let mut task_handles = Vec::new();
let completion_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
for _ in 0..3 {
let gate = gate.clone();
let count = completion_count.clone();
thread_handles.push(thread::spawn(move || {
gate.acquire_sync();
thread::sleep(Duration::from_millis(50));
gate.release();
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}));
}
for _ in 0..3 {
let gate = gate.clone();
let count = completion_count.clone();
task_handles.push(tokio::spawn(async move {
gate.acquire_async().await;
tokio::time::sleep(Duration::from_millis(50)).await;
gate.release();
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}));
}
for handle in task_handles {
handle.await.unwrap();
}
for handle in thread_handles {
handle.join().unwrap();
}
assert_eq!(
completion_count.load(std::sync::atomic::Ordering::Relaxed),
6
);
}
#[test]
fn test_acquire_async_drop_leak() {
let gate = CapacityGate::new(1);
assert!(gate.try_acquire());
fn dummy_waker() -> Waker {
use std::task::{RawWaker, RawWakerVTable};
unsafe fn clone(_: *const ()) -> RawWaker {
RawWaker::new(std::ptr::null(), &VTABLE)
}
unsafe fn wake(_: *const ()) {}
unsafe fn wake_by_ref(_: *const ()) {}
unsafe fn drop_raw(_: *const ()) {}
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_raw);
unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
}
let waker = dummy_waker();
let mut cx = Context::from_waker(&waker);
let mut fut = Box::pin(gate.acquire_async());
assert!(fut.as_mut().poll(&mut cx).is_pending());
{
let internal = gate.internal.lock();
assert_eq!(internal.waiters.len(), 1);
}
drop(fut);
let leaked_count = {
let internal = gate.internal.lock();
internal.waiters.len()
};
assert_eq!(
leaked_count, 0,
"Waker leak detected: dropping AcquireFuture left a stale waker in the waiters queue!"
);
}
#[test]
fn try_acquire_many_basic() {
let gate = CapacityGate::new(5);
assert_eq!(gate.try_acquire_many(0), 0);
assert_eq!(gate.try_acquire_many(3), 3);
assert_eq!(gate.try_acquire_many(5), 2);
assert_eq!(gate.try_acquire_many(1), 0);
}
#[test]
fn release_many_restores_and_clamps() {
let gate = CapacityGate::new(5);
assert_eq!(gate.try_acquire_many(5), 5);
gate.release_many(5);
assert_eq!(gate.try_acquire_many(5), 5);
gate.release_many(10);
assert_eq!(gate.try_acquire_many(10), 5);
}
#[test]
fn many_apis_on_closed_gate() {
let gate = CapacityGate::new(2);
assert_eq!(gate.try_acquire_many(2), 2);
gate.close();
assert_eq!(gate.try_acquire_many(7), 7);
assert_eq!(gate.acquire_many_sync(4), 4);
}
#[test]
fn acquire_many_sync_blocks_and_gets_batch() {
let gate = Arc::new(CapacityGate::new(4));
assert_eq!(gate.try_acquire_many(4), 4);
let gate_clone = gate.clone();
let handle = thread::spawn(move || gate_clone.acquire_many_sync(3));
thread::sleep(Duration::from_millis(100));
assert!(!handle.is_finished(), "Thread should have blocked");
gate.release_many(3);
let acquired = handle.join().expect("Thread panicked");
assert!(
(1..=3).contains(&acquired),
"expected 1..=3 permits, got {acquired}"
);
}
#[test]
fn release_many_wakes_multiple_sync_waiters() {
let gate = Arc::new(CapacityGate::new(3));
assert_eq!(gate.try_acquire_many(3), 3);
let mut handles = Vec::new();
for _ in 0..3 {
let gate = gate.clone();
handles.push(thread::spawn(move || gate.acquire_many_sync(1)));
}
thread::sleep(Duration::from_millis(100));
gate.release_many(3);
let mut total = 0;
for h in handles {
total += h.join().expect("waiter panicked");
}
assert_eq!(total, 3);
}
#[cfg(not(miri))]
#[tokio::test]
async fn acquire_many_async_waits_and_completes() {
use tokio::time::timeout;
let gate = Arc::new(CapacityGate::new(4));
assert_eq!(gate.try_acquire_many(4), 4);
let acquire_fut = gate.acquire_many_async(3);
let gate_for_spawn = gate.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
gate_for_spawn.release_many(3);
});
let acquired = timeout(Duration::from_millis(500), acquire_fut)
.await
.expect("Future did not complete after release_many");
assert!(
(1..=3).contains(&acquired),
"expected 1..=3 permits, got {acquired}"
);
}
#[test]
fn release_does_not_steal_inflight_handoff_permit_cap0() {
let gate = Arc::new(CapacityGate::new(0));
let gate_clone = gate.clone();
let waiter = thread::spawn(move || gate_clone.acquire_sync());
thread::sleep(Duration::from_millis(100));
assert!(!waiter.is_finished(), "waiter should be parked");
gate.release();
gate.release();
thread::sleep(Duration::from_millis(200));
assert!(
waiter.is_finished(),
"REGRESSION: clamp stole the in-flight handoff permit; waiter re-parked forever"
);
waiter.join().unwrap();
}
#[test]
fn release_does_not_steal_inflight_handoff_permit_cap0_many() {
let gate = Arc::new(CapacityGate::new(0));
let gate_clone = gate.clone();
let waiter = thread::spawn(move || gate_clone.acquire_many_sync(4));
thread::sleep(Duration::from_millis(100));
assert!(!waiter.is_finished(), "waiter should be parked");
gate.release();
gate.release();
thread::sleep(Duration::from_millis(200));
assert!(
waiter.is_finished(),
"REGRESSION: clamp stole the in-flight handoff permit from acquire_many_sync"
);
let acquired = waiter.join().unwrap();
assert!(acquired >= 1);
}
#[cfg(not(miri))]
#[tokio::test]
async fn async_wake_consumes_permit_no_pool_inflation() {
let gate = Arc::new(CapacityGate::new(1));
gate.acquire_sync();
let gate_for_task = gate.clone();
let task = tokio::spawn(async move { gate_for_task.acquire_async().await });
tokio::time::sleep(Duration::from_millis(100)).await;
gate.release(); task.await.unwrap();
{
let internal = gate.internal.lock();
assert_eq!(
internal.permits, 0,
"async wake must consume the permit, not inflate the pool"
);
assert_eq!(internal.reserved, 0);
}
assert!(!gate.try_acquire());
}
#[test]
fn test_acquire_many_async_drop_leak() {
let gate = CapacityGate::new(1);
assert!(gate.try_acquire());
fn dummy_waker() -> Waker {
use std::task::{RawWaker, RawWakerVTable};
unsafe fn clone(_: *const ()) -> RawWaker {
RawWaker::new(std::ptr::null(), &VTABLE)
}
unsafe fn wake(_: *const ()) {}
unsafe fn wake_by_ref(_: *const ()) {}
unsafe fn drop_raw(_: *const ()) {}
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_raw);
unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
}
let waker = dummy_waker();
let mut cx = Context::from_waker(&waker);
let mut fut = Box::pin(gate.acquire_many_async(3));
assert!(fut.as_mut().poll(&mut cx).is_pending());
{
let internal = gate.internal.lock();
assert_eq!(internal.waiters.len(), 1);
}
drop(fut);
let leaked_count = {
let internal = gate.internal.lock();
internal.waiters.len()
};
assert_eq!(
leaked_count, 0,
"Waker leak: dropping AcquireManyFuture left a stale waker in the waiters queue!"
);
}
}