use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::{Semaphore, SemaphorePermit};
#[derive(Debug)]
pub struct DynamicSemaphore {
inner: Arc<Semaphore>,
max_capacity: AtomicUsize,
permits_in_use: AtomicUsize,
}
#[derive(Debug)]
pub struct DynamicSemaphorePermit<'a> {
permit: Option<SemaphorePermit<'a>>,
semaphore: &'a DynamicSemaphore,
}
impl DynamicSemaphore {
pub fn new(permits: usize) -> Self {
Self {
inner: Arc::new(Semaphore::new(permits)),
max_capacity: AtomicUsize::new(permits),
permits_in_use: AtomicUsize::new(0),
}
}
pub async fn acquire(&self) -> Result<DynamicSemaphorePermit<'_>, tokio::sync::AcquireError> {
loop {
let current_in_use = self.permits_in_use.load(Ordering::Acquire);
let current_capacity = self.current_capacity();
if current_in_use >= current_capacity {
let _temp_permit = self.inner.acquire().await?;
drop(_temp_permit);
continue;
}
let permit = self.inner.acquire().await?;
let new_in_use = self.permits_in_use.fetch_add(1, Ordering::AcqRel) + 1;
if new_in_use <= current_capacity {
return Ok(DynamicSemaphorePermit {
permit: Some(permit),
semaphore: self,
});
} else {
self.permits_in_use.fetch_sub(1, Ordering::AcqRel);
drop(permit);
}
}
}
pub fn try_acquire(&self) -> Result<DynamicSemaphorePermit<'_>, tokio::sync::TryAcquireError> {
let current_in_use = self.permits_in_use.load(Ordering::Acquire);
let current_capacity = self.current_capacity();
if current_in_use >= current_capacity {
return Err(tokio::sync::TryAcquireError::NoPermits);
}
let permit = self.inner.try_acquire()?;
let new_in_use = self.permits_in_use.fetch_add(1, Ordering::AcqRel) + 1;
if new_in_use <= current_capacity {
Ok(DynamicSemaphorePermit {
permit: Some(permit),
semaphore: self,
})
} else {
self.permits_in_use.fetch_sub(1, Ordering::AcqRel);
drop(permit);
Err(tokio::sync::TryAcquireError::NoPermits)
}
}
pub fn reduce_capacity(&self, new_capacity: usize) -> usize {
let old_capacity = self.max_capacity.swap(new_capacity, Ordering::AcqRel);
if new_capacity < old_capacity {
let available = self.inner.available_permits();
let to_forget = available.saturating_sub(new_capacity);
if to_forget > 0 {
self.inner.forget_permits(to_forget);
}
}
old_capacity
}
pub fn increase_capacity(&self, new_capacity: usize) -> usize {
let old_capacity = self.max_capacity.swap(new_capacity, Ordering::AcqRel);
if new_capacity > old_capacity {
let to_add = new_capacity - old_capacity;
self.inner.add_permits(to_add);
}
old_capacity
}
pub fn current_capacity(&self) -> usize {
self.max_capacity.load(Ordering::Acquire)
}
pub fn available_permits(&self) -> usize {
self.inner.available_permits()
}
pub fn close(&self) {
self.inner.close();
}
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
pub fn permits_in_use(&self) -> usize {
self.permits_in_use.load(Ordering::Acquire)
}
}
impl<'a> Drop for DynamicSemaphorePermit<'a> {
fn drop(&mut self) {
if let Some(permit) = self.permit.take() {
self.semaphore.permits_in_use.fetch_sub(1, Ordering::AcqRel);
let current_capacity = self.semaphore.current_capacity();
let current_available = self.semaphore.available_permits();
let total_after_release = current_available + 1;
if total_after_release > current_capacity {
permit.forget();
} else {
drop(permit);
}
}
}
}
unsafe impl<'a> Send for DynamicSemaphorePermit<'a> {}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
#[tokio::test]
async fn test_basic_acquire_release() {
let semaphore = DynamicSemaphore::new(2);
assert_eq!(semaphore.available_permits(), 2);
assert_eq!(semaphore.current_capacity(), 2);
assert_eq!(semaphore.permits_in_use(), 0);
let permit1 = semaphore.acquire().await.unwrap();
assert_eq!(semaphore.available_permits(), 1);
assert_eq!(semaphore.permits_in_use(), 1);
let permit2 = semaphore.acquire().await.unwrap();
assert_eq!(semaphore.available_permits(), 0);
assert_eq!(semaphore.permits_in_use(), 2);
drop(permit1);
assert_eq!(semaphore.available_permits(), 1);
assert_eq!(semaphore.permits_in_use(), 1);
drop(permit2);
assert_eq!(semaphore.available_permits(), 2);
assert_eq!(semaphore.permits_in_use(), 0);
}
#[tokio::test]
async fn test_capacity_reduction() {
let semaphore = DynamicSemaphore::new(3);
let permit1 = semaphore.acquire().await.unwrap();
let permit2 = semaphore.acquire().await.unwrap();
let permit3 = semaphore.acquire().await.unwrap();
assert_eq!(semaphore.available_permits(), 0);
assert_eq!(semaphore.permits_in_use(), 3);
let old_capacity = semaphore.reduce_capacity(2);
assert_eq!(old_capacity, 3);
assert_eq!(semaphore.current_capacity(), 2);
drop(permit1);
assert_eq!(semaphore.available_permits(), 1);
assert_eq!(semaphore.permits_in_use(), 2);
drop(permit2);
assert_eq!(semaphore.available_permits(), 2);
assert_eq!(semaphore.permits_in_use(), 1);
drop(permit3);
assert_eq!(semaphore.available_permits(), 2); assert_eq!(semaphore.permits_in_use(), 0);
}
#[tokio::test]
async fn test_capacity_increase() {
let semaphore = DynamicSemaphore::new(2);
assert_eq!(semaphore.available_permits(), 2);
let old_capacity = semaphore.increase_capacity(5);
assert_eq!(old_capacity, 2);
assert_eq!(semaphore.current_capacity(), 5);
assert_eq!(semaphore.available_permits(), 5);
}
#[tokio::test]
async fn test_try_acquire() {
let semaphore = DynamicSemaphore::new(1);
let permit1 = semaphore.try_acquire().unwrap();
assert!(semaphore.try_acquire().is_err());
drop(permit1);
assert!(semaphore.try_acquire().is_ok());
}
#[tokio::test]
async fn test_close() {
let semaphore = DynamicSemaphore::new(1);
assert!(!semaphore.is_closed());
semaphore.close();
assert!(semaphore.is_closed());
assert!(semaphore.acquire().await.is_err());
}
#[tokio::test]
async fn test_over_capacity_acquisition_prevention() {
let semaphore = Arc::new(DynamicSemaphore::new(5));
let permit1 = semaphore.acquire().await.unwrap();
let permit2 = semaphore.acquire().await.unwrap();
assert_eq!(semaphore.available_permits(), 3);
assert_eq!(semaphore.permits_in_use(), 2);
semaphore.reduce_capacity(1);
assert_eq!(semaphore.current_capacity(), 1);
assert_eq!(semaphore.available_permits(), 1); assert_eq!(semaphore.permits_in_use(), 2);
assert!(
semaphore.try_acquire().is_err(),
"Should not be able to acquire when over capacity"
);
drop(permit1);
assert_eq!(semaphore.available_permits(), 1);
assert_eq!(semaphore.permits_in_use(), 1);
drop(permit2);
assert_eq!(semaphore.available_permits(), 1);
assert_eq!(semaphore.permits_in_use(), 0);
let permit_new = semaphore.try_acquire().unwrap();
assert_eq!(semaphore.available_permits(), 0);
assert_eq!(semaphore.permits_in_use(), 1);
drop(permit_new);
assert_eq!(semaphore.available_permits(), 1);
assert_eq!(semaphore.permits_in_use(), 0);
}
#[tokio::test]
async fn test_concurrent_capacity_reduction() {
let semaphore = Arc::new(DynamicSemaphore::new(10));
let mut handles = vec![];
for _ in 0..20 {
let sem = semaphore.clone();
handles.push(tokio::spawn(async move {
if let Ok(permit) = sem.try_acquire() {
sleep(Duration::from_millis(50)).await;
drop(permit);
}
}));
}
sleep(Duration::from_millis(10)).await;
semaphore.reduce_capacity(5);
for handle in handles {
handle.await.unwrap();
}
assert!(semaphore.available_permits() <= semaphore.current_capacity());
assert_eq!(semaphore.current_capacity(), 5);
}
#[tokio::test]
async fn test_stress_concurrent_operations() {
let semaphore = Arc::new(DynamicSemaphore::new(50));
let mut handles = vec![];
for _ in 0..100 {
let sem = semaphore.clone();
handles.push(tokio::spawn(async move {
for _ in 0..5 {
if let Ok(permit) = sem.try_acquire() {
tokio::task::yield_now().await;
drop(permit);
}
tokio::task::yield_now().await;
}
}));
}
let sem_reducer = semaphore.clone();
let reducer_handle = tokio::spawn(async move {
for new_capacity in (1..=50).rev() {
sem_reducer.reduce_capacity(new_capacity);
tokio::task::yield_now().await;
}
});
for handle in handles {
handle.await.unwrap();
}
reducer_handle.await.unwrap();
assert!(semaphore.available_permits() <= semaphore.current_capacity());
assert_eq!(semaphore.current_capacity(), 1);
assert_eq!(semaphore.permits_in_use(), 0);
}
#[tokio::test]
async fn test_feroxbuster_integration_scenario() {
let limiter = Arc::new(DynamicSemaphore::new(3));
let permit1 = limiter.acquire().await.unwrap();
let permit2 = limiter.acquire().await.unwrap();
let permit3 = limiter.acquire().await.unwrap();
assert_eq!(limiter.available_permits(), 0);
assert_eq!(limiter.current_capacity(), 3);
limiter.reduce_capacity(1);
assert_eq!(limiter.current_capacity(), 1);
assert!(limiter.try_acquire().is_err());
drop(permit1);
assert_eq!(limiter.available_permits(), 1);
drop(permit2);
assert_eq!(limiter.available_permits(), 1);
drop(permit3);
assert_eq!(limiter.available_permits(), 1);
let _new_permit = limiter.acquire().await.unwrap();
assert_eq!(limiter.available_permits(), 0);
assert!(limiter.try_acquire().is_err());
}
#[tokio::test]
async fn test_edge_cases() {
let semaphore = DynamicSemaphore::new(0);
assert_eq!(semaphore.current_capacity(), 0);
assert_eq!(semaphore.available_permits(), 0);
assert!(semaphore.try_acquire().is_err());
let semaphore = DynamicSemaphore::new(2);
let permit = semaphore.acquire().await.unwrap();
semaphore.reduce_capacity(0);
assert_eq!(semaphore.current_capacity(), 0);
assert!(semaphore.try_acquire().is_err());
drop(permit);
assert_eq!(semaphore.available_permits(), 0);
assert!(semaphore.try_acquire().is_err());
let semaphore = DynamicSemaphore::new(1000);
assert_eq!(semaphore.current_capacity(), 1000);
assert_eq!(semaphore.available_permits(), 1000);
let permit = semaphore.try_acquire().unwrap();
assert_eq!(semaphore.available_permits(), 999);
drop(permit);
assert_eq!(semaphore.available_permits(), 1000);
}
}