use std::{
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::Duration,
};
use tokio::sync::Semaphore;
pub struct AdmissionController {
semaphore: Arc<Semaphore>,
queue_depth: AtomicU64,
max_queue_depth: u64,
}
impl AdmissionController {
pub fn new(max_concurrent: usize, max_queue_depth: u64) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(max_concurrent)),
queue_depth: AtomicU64::new(0),
max_queue_depth,
}
}
pub fn try_acquire(&self) -> Option<AdmissionPermit<'_>> {
let current_depth = self.queue_depth.load(Ordering::Relaxed);
if current_depth >= self.max_queue_depth {
return None;
}
if let Ok(permit) = self.semaphore.clone().try_acquire_owned() {
Some(AdmissionPermit {
_permit: permit,
_phantom: std::marker::PhantomData,
})
} else {
self.queue_depth.fetch_add(1, Ordering::Relaxed);
None
}
}
pub async fn acquire_timeout(&self, timeout: Duration) -> Option<AdmissionPermit<'_>> {
let current_depth = self.queue_depth.load(Ordering::Relaxed);
if current_depth >= self.max_queue_depth {
return None;
}
self.queue_depth.fetch_add(1, Ordering::Relaxed);
let result = tokio::time::timeout(timeout, self.semaphore.clone().acquire_owned()).await;
self.queue_depth.fetch_sub(1, Ordering::Relaxed);
if let Ok(Ok(permit)) = result {
Some(AdmissionPermit {
_permit: permit,
_phantom: std::marker::PhantomData,
})
} else {
None
}
}
pub fn queue_depth(&self) -> u64 {
self.queue_depth.load(Ordering::Relaxed)
}
}
pub struct AdmissionPermit<'a> {
_permit: tokio::sync::OwnedSemaphorePermit,
_phantom: std::marker::PhantomData<&'a AdmissionController>,
}
impl Drop for AdmissionPermit<'_> {
fn drop(&mut self) {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allows_when_below_capacity() {
let ac = AdmissionController::new(10, 100);
let permit = ac.try_acquire();
assert!(permit.is_some(), "must allow when below capacity");
}
#[test]
fn rejects_when_semaphore_exhausted() {
let ac = AdmissionController::new(2, 10);
let _p1 = ac.try_acquire().expect("1st permit");
let _p2 = ac.try_acquire().expect("2nd permit");
assert!(ac.try_acquire().is_none(), "must reject when semaphore exhausted");
}
#[test]
fn releases_on_permit_drop() {
let ac = AdmissionController::new(1, 10);
{
let _p = ac.try_acquire().expect("must succeed");
assert!(ac.try_acquire().is_none(), "at capacity: must reject");
}
assert!(ac.try_acquire().is_some(), "after permit drop, must allow new request");
}
#[test]
fn queue_depth_tracked_on_semaphore_exhaustion() {
let ac = AdmissionController::new(1, 10);
let _p = ac.try_acquire().expect("first permit");
assert_eq!(ac.queue_depth(), 0, "no queueing yet");
assert!(ac.try_acquire().is_none());
assert_eq!(ac.queue_depth(), 1, "queue_depth must be 1 after one failed acquire");
}
#[test]
fn zero_max_queue_depth_rejects_all() {
let ac = AdmissionController::new(10, 0);
assert!(ac.try_acquire().is_none(), "max_queue_depth=0 must reject all requests");
}
#[tokio::test]
async fn acquire_timeout_succeeds_when_available() {
let ac = AdmissionController::new(5, 10);
let permit = ac.acquire_timeout(Duration::from_millis(100)).await;
assert!(permit.is_some(), "must succeed when permits available");
}
#[tokio::test]
async fn acquire_timeout_rejects_when_queue_full() {
let ac = AdmissionController::new(1, 0);
let permit = ac.acquire_timeout(Duration::from_millis(10)).await;
assert!(permit.is_none(), "must reject when max_queue_depth=0");
}
#[tokio::test]
async fn acquire_timeout_returns_none_on_expiry() {
let ac = AdmissionController::new(1, 10);
let _p = ac.try_acquire().expect("first permit");
let permit = ac.acquire_timeout(Duration::from_millis(10)).await;
assert!(permit.is_none(), "must return None when timeout elapses");
assert_eq!(ac.queue_depth(), 0, "queue_depth must be 0 after timeout cleanup");
}
#[tokio::test]
async fn acquire_timeout_succeeds_when_permit_freed_in_time() {
let ac = AdmissionController::new(1, 10);
let p = ac.try_acquire().expect("first permit");
tokio::task::yield_now().await;
drop(p);
let result = ac.acquire_timeout(Duration::from_secs(1)).await;
assert!(result.is_some(), "must succeed when permit freed before timeout");
}
}