use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use std::thread::{self, Thread};
use parking_lot::Mutex;
#[derive(Debug)]
enum Waiter {
Sync(Thread),
Async(Waker),
}
impl Waiter {
fn wake(self) {
match self {
Waiter::Sync(thread) => thread.unpark(),
Waiter::Async(waker) => waker.wake(),
}
}
fn will_wake(&self, waker: &Waker) -> bool {
match self {
Waiter::Async(self_waker) => self_waker.will_wake(waker),
Waiter::Sync(_) => false,
}
}
}
#[derive(Debug)]
struct GateInternal {
permits: usize,
waiters: VecDeque<Waiter>,
}
pub struct CapacityGate {
capacity: usize,
internal: Arc<Mutex<GateInternal>>,
}
impl fmt::Debug for CapacityGate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let internal = self.internal.lock();
f.debug_struct("CapacityGate")
.field("capacity", &self.capacity)
.field("permits", &internal.permits)
.field("waiters", &internal.waiters.len())
.finish()
}
}
impl CapacityGate {
pub fn new(capacity: usize) -> Self {
Self {
capacity,
internal: Arc::new(Mutex::new(GateInternal {
permits: capacity,
waiters: VecDeque::new(),
})),
}
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn acquire_sync(&self) {
if self.try_acquire() {
return;
}
let mut internal = self.internal.lock();
loop {
if internal.permits > 0 {
internal.permits -= 1;
return;
}
internal.waiters.push_back(Waiter::Sync(thread::current()));
drop(internal);
thread::park();
internal = self.internal.lock();
}
}
pub fn acquire_async(&self) -> AcquireFuture<'_> {
AcquireFuture { gate: self }
}
pub fn try_acquire(&self) -> bool {
let mut internal = self.internal.lock();
if internal.waiters.is_empty() && internal.permits > 0 {
internal.permits -= 1;
true
} else {
false
}
}
pub fn release(&self) {
let mut internal = self.internal.lock();
internal.permits += 1;
if let Some(waiter) = internal.waiters.pop_front() {
waiter.wake();
} else {
internal.permits = internal.permits.min(self.capacity);
}
}
}
#[must_use = "futures do nothing unless you .await or poll them"]
pub struct AcquireFuture<'a> {
gate: &'a CapacityGate,
}
impl Future for AcquireFuture<'_> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut internal = self.gate.internal.lock();
if internal.waiters.is_empty() && internal.permits > 0 {
internal.permits -= 1;
return Poll::Ready(());
}
if internal.permits > 0 {
internal.permits -= 1;
return Poll::Ready(());
}
if !internal.waiters.iter().any(|w| w.will_wake(cx.waker())) {
internal
.waiters
.push_back(Waiter::Async(cx.waker().clone()));
}
Poll::Pending
}
}
impl Clone for CapacityGate {
fn clone(&self) -> Self {
Self {
capacity: self.capacity,
internal: self.internal.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::time::timeout;
#[test]
fn gate_new_and_capacity() {
let gate = CapacityGate::new(5);
assert_eq!(gate.capacity(), 5);
}
#[test]
fn acquire_sync_release() {
let gate = CapacityGate::new(1);
gate.acquire_sync();
gate.release();
}
#[test]
fn acquire_sync_blocks_and_unblocks() {
let gate = Arc::new(CapacityGate::new(1));
gate.acquire_sync();
let gate_clone = gate.clone();
let handle = thread::spawn(move || {
gate_clone.acquire_sync();
});
thread::sleep(Duration::from_millis(100));
assert!(!handle.is_finished(), "Thread should have blocked");
gate.release();
handle.join().expect("Thread panicked");
}
#[tokio::test]
async fn acquire_async_waits_and_completes() {
let gate = Arc::new(CapacityGate::new(1));
gate.acquire_sync();
let acquire_fut = gate.acquire_async();
let gate_for_spawn = gate.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
gate_for_spawn.release();
});
timeout(Duration::from_millis(500), acquire_fut)
.await
.expect("Future did not complete after release");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn mixed_waiters_contention() {
let gate = Arc::new(CapacityGate::new(2));
let mut thread_handles = Vec::new();
let mut task_handles = Vec::new();
let completion_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
for _ in 0..3 {
let gate = gate.clone();
let count = completion_count.clone();
thread_handles.push(thread::spawn(move || {
gate.acquire_sync();
thread::sleep(Duration::from_millis(50));
gate.release();
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}));
}
for _ in 0..3 {
let gate = gate.clone();
let count = completion_count.clone();
task_handles.push(tokio::spawn(async move {
gate.acquire_async().await;
tokio::time::sleep(Duration::from_millis(50)).await;
gate.release();
count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}));
}
for handle in task_handles {
handle.await.unwrap();
}
for handle in thread_handles {
handle.join().unwrap();
}
assert_eq!(
completion_count.load(std::sync::atomic::Ordering::Relaxed),
6
);
}
}