use std::sync::atomic::Ordering::SeqCst;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, MutexGuard};
use more_asserts::debug_assert_le;
use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore};
const PERMIT_LIMIT: u64 = {
let sem_max = Semaphore::MAX_PERMITS as u64;
let u32_max = u32::MAX as u64;
if sem_max < u32_max { sem_max } else { u32_max }
};
#[derive(Debug)]
pub struct AdjustableSemaphore {
semaphore: Arc<Semaphore>,
total_permits: AtomicU64,
enqueued_permit_decreases: AtomicU64,
min_physical_permits: u64,
max_physical_permits: u64,
basis: u64,
adjustment_lock: Mutex<()>,
}
pub struct AdjustableSemaphorePermit {
permit: Option<OwnedSemaphorePermit>,
num_physical_permits: u32,
parent: Arc<AdjustableSemaphore>,
}
impl AdjustableSemaphorePermit {
pub fn num_permits(&self) -> u64 {
self.num_physical_permits as u64 * self.parent.basis
}
pub fn num_physical_permits(&self) -> u32 {
self.num_physical_permits
}
pub fn split(&mut self, n: u64) -> Option<AdjustableSemaphorePermit> {
let physical_n = n.div_ceil(self.parent.basis);
if physical_n > self.num_physical_permits as u64 {
return None;
}
let physical_n = physical_n as u32;
self.num_physical_permits -= physical_n;
if physical_n > 0 {
let permit = self.permit.as_mut().and_then(|p| p.split(physical_n as usize));
Some(AdjustableSemaphorePermit {
permit,
num_physical_permits: physical_n,
parent: self.parent.clone(),
})
} else {
None
}
}
}
impl Drop for AdjustableSemaphorePermit {
fn drop(&mut self) {
let parent = &self.parent;
let num_permits = self.num_physical_permits as u64;
let decreases_resolved = attempt_sub(&parent.enqueued_permit_decreases, num_permits, 0);
if let Some(mut permit) = self.permit.take() {
if decreases_resolved > 0 {
if let Some(p) = permit.split(decreases_resolved as usize) {
p.forget();
} else {
debug_assert!(false, "Failed to split permit; mismatch in self.num_permits.");
}
}
} else {
debug_assert_le!(decreases_resolved, num_permits);
let to_return = (num_permits - decreases_resolved) as usize;
if to_return > 0 {
parent.semaphore.add_permits(to_return);
}
}
}
}
impl AdjustableSemaphore {
pub fn new(initial_permits: u64, permit_range: (u64, u64)) -> Arc<Self> {
debug_assert!(permit_range.0 <= permit_range.1);
debug_assert!(permit_range.0 <= initial_permits);
debug_assert!(initial_permits <= permit_range.1);
let basis = Self::compute_basis(permit_range.1);
let min_physical = permit_range.0.div_ceil(basis);
let max_physical = permit_range.1.div_ceil(basis);
let initial_physical = initial_permits.div_ceil(basis).clamp(min_physical, max_physical);
Arc::new(Self {
semaphore: Arc::new(Semaphore::new(initial_physical as usize)),
total_permits: AtomicU64::new(initial_physical),
enqueued_permit_decreases: AtomicU64::new(0),
min_physical_permits: min_physical,
max_physical_permits: max_physical,
basis,
adjustment_lock: Mutex::new(()),
})
}
pub fn total_permits(&self) -> u64 {
self.total_permits.load(Ordering::Relaxed) * self.basis
}
pub fn available_permits(&self) -> u64 {
self.semaphore.available_permits() as u64 * self.basis
}
pub fn active_permits(&self) -> u64 {
(self.total_permits.load(Ordering::Relaxed) + self.enqueued_permit_decreases.load(Ordering::Relaxed))
.saturating_sub(self.semaphore.available_permits() as u64)
* self.basis
}
pub fn basis(&self) -> u64 {
self.basis
}
pub async fn acquire(self: &Arc<Self>) -> Result<AdjustableSemaphorePermit, AcquireError> {
self.acquire_many(1).await
}
pub async fn acquire_many(self: &Arc<Self>, n: u64) -> Result<AdjustableSemaphorePermit, AcquireError> {
let physical = self.to_physical_acquire(n);
let permit = self.semaphore.clone().acquire_many_owned(physical).await?;
Ok(AdjustableSemaphorePermit {
permit: Some(permit),
num_physical_permits: physical,
parent: self.clone(),
})
}
pub fn decrement_total_permits(&self, n: u64) -> Option<u64> {
let lock = self.adjustment_lock.lock().unwrap();
self.decrement_total_permits_impl(lock, n)
}
pub fn decrement_permits_to_target(&self, target: u64) -> Option<u64> {
let lock = self.adjustment_lock.lock().unwrap();
let current = self.total_permits();
if target >= current {
return None;
}
let requested_decrease = current - target;
self.decrement_total_permits_impl(lock, requested_decrease)
}
fn decrement_total_permits_impl(&self, _lock: MutexGuard<'_, ()>, n: u64) -> Option<u64> {
let physical_n = n.div_ceil(self.basis);
if physical_n == 0 {
return None;
}
let removed = attempt_sub(&self.total_permits, physical_n, self.min_physical_permits);
if removed == 0 {
return None;
}
if let Ok(permit) = self.semaphore.clone().try_acquire_many_owned(removed as u32) {
permit.forget();
} else {
self.enqueued_permit_decreases.fetch_add(removed, Ordering::Relaxed);
}
Some(removed * self.basis)
}
pub fn increment_total_permits(self: &Arc<Self>, n: u64) -> Option<AdjustableSemaphorePermit> {
let lock = self.adjustment_lock.lock().unwrap();
self.increment_total_permits_impl(lock, n)
}
pub fn increment_permits_to_target(self: &Arc<Self>, target: u64) -> Option<AdjustableSemaphorePermit> {
let lock = self.adjustment_lock.lock().unwrap();
let current = self.total_permits();
if target <= current {
return None;
}
self.increment_total_permits_impl(lock, target - current)
}
fn increment_total_permits_impl(
self: &Arc<Self>,
_lock: MutexGuard<'_, ()>,
n: u64,
) -> Option<AdjustableSemaphorePermit> {
let physical_n = n.div_ceil(self.basis);
if physical_n == 0 {
return None;
}
let added = attempt_add(&self.total_permits, physical_n, self.max_physical_permits);
if added == 0 {
return None;
}
let cancelled = attempt_sub(&self.enqueued_permit_decreases, added, 0);
let to_hold = (added - cancelled) as u32;
Some(AdjustableSemaphorePermit {
permit: None,
num_physical_permits: to_hold,
parent: self.clone(),
})
}
fn compute_basis(max_permits: u64) -> u64 {
let mut basis: u64 = 1;
while max_permits.div_ceil(basis) > PERMIT_LIMIT {
basis *= 2;
}
basis
}
fn to_physical_acquire(&self, n: u64) -> u32 {
let total = self.total_permits.load(Ordering::Relaxed).max(1);
n.div_ceil(self.basis).clamp(1, total) as u32
}
#[cfg(test)]
fn with_forced_basis(initial: u64, min: u64, max: u64, basis: u64) -> Arc<Self> {
assert!(basis > 0, "basis must be greater than zero");
let min_physical_permits = min.div_ceil(basis);
let max_physical_permits = max.div_ceil(basis).min(PERMIT_LIMIT);
let initial_physical = initial.div_ceil(basis).clamp(min_physical_permits, max_physical_permits);
Arc::new(Self {
semaphore: Arc::new(Semaphore::new(initial_physical as usize)),
total_permits: AtomicU64::new(initial_physical),
enqueued_permit_decreases: AtomicU64::new(0),
min_physical_permits,
max_physical_permits,
basis,
adjustment_lock: Mutex::new(()),
})
}
}
#[inline]
fn attempt_add(v: &AtomicU64, n: u64, max_value: u64) -> u64 {
match v.fetch_update(SeqCst, SeqCst, |x| {
if x >= max_value {
None
} else {
Some(x.saturating_add(n).min(max_value))
}
}) {
Ok(old) => old.saturating_add(n).min(max_value) - old,
Err(_) => 0,
}
}
#[inline]
fn attempt_sub(v: &AtomicU64, n: u64, min_value: u64) -> u64 {
match v.fetch_update(SeqCst, SeqCst, |x| {
if x <= min_value {
None
} else {
Some(x.saturating_sub(n).max(min_value))
}
}) {
Ok(old) => old - old.saturating_sub(n).max(min_value),
Err(_) => 0,
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use more_asserts::{assert_ge, assert_le};
use rand::prelude::*;
use tokio::sync::Barrier;
use tokio::task::JoinSet;
use super::*;
#[tokio::test]
async fn test_bounds_and_adjustment() {
for basis in [1u64, 2] {
let sem = AdjustableSemaphore::with_forced_basis(6, 2, 12, basis);
assert_eq!(sem.total_permits(), 6);
assert!(sem.increment_total_permits(4).is_some());
assert_eq!(sem.total_permits(), 10);
assert!(sem.increment_total_permits(100).is_some());
assert_eq!(sem.total_permits(), 12);
assert!(sem.increment_total_permits(2).is_none());
assert_eq!(sem.decrement_total_permits(4), Some(4));
assert_eq!(sem.total_permits(), 8);
assert_eq!(sem.decrement_total_permits(100), Some(6));
assert_eq!(sem.total_permits(), 2);
assert!(sem.decrement_total_permits(2).is_none());
assert!(sem.increment_total_permits(4).is_some());
assert_eq!(sem.total_permits(), 6);
assert!(sem.increment_permits_to_target(10).is_some());
assert_eq!(sem.total_permits(), 10);
assert!(sem.increment_permits_to_target(10).is_none());
assert_eq!(sem.decrement_permits_to_target(6), Some(4));
assert_eq!(sem.total_permits(), 6);
assert!(sem.decrement_permits_to_target(6).is_none());
assert_eq!(sem.decrement_permits_to_target(0), Some(4));
assert_eq!(sem.total_permits(), 2);
assert!(sem.increment_permits_to_target(12).is_some());
assert_eq!(sem.total_permits(), 12);
}
}
#[tokio::test]
async fn test_acquire_and_release() {
for basis in [1u64, 2] {
let sem = AdjustableSemaphore::with_forced_basis(1024, 0, 1024, basis);
assert_eq!(sem.available_permits(), 1024);
let p1 = sem.acquire_many(256).await.unwrap();
assert_eq!(p1.num_permits(), 256);
assert_eq!(sem.available_permits(), 768);
let p2 = sem.acquire_many(512).await.unwrap();
assert_eq!(sem.available_permits(), 256);
drop(p1);
assert_eq!(sem.available_permits(), 512);
drop(p2);
assert_eq!(sem.available_permits(), 1024);
{
let _p = sem.acquire_many(1024).await.unwrap();
assert_eq!(sem.available_permits(), 0);
}
assert_eq!(sem.available_permits(), 1024);
let _p = sem.acquire_many(5000).await.unwrap();
assert_eq!(sem.available_permits(), 0);
}
}
#[tokio::test]
async fn test_enqueued_decrease_resolution() {
for basis in [1u64, 2] {
let sem = AdjustableSemaphore::with_forced_basis(4, 2, 6, basis);
let p1 = sem.acquire_many(2).await.unwrap();
let p2 = sem.acquire_many(2).await.unwrap();
assert_eq!(sem.available_permits(), 0);
assert!(sem.decrement_total_permits(2).is_some());
assert_eq!(sem.total_permits(), 2);
drop(p1); assert_eq!(sem.available_permits(), 0);
drop(p2);
assert_eq!(sem.available_permits(), 2);
let sem = AdjustableSemaphore::with_forced_basis(1024, 0, 1024, basis);
let p = sem.acquire_many(1024).await.unwrap();
assert!(sem.decrement_total_permits(512).is_some());
assert_eq!(sem.total_permits(), 512);
drop(p);
assert_eq!(sem.available_permits(), 512);
}
}
#[tokio::test]
async fn test_increment_cancels_enqueued() {
for basis in [1u64, 2] {
let sem = AdjustableSemaphore::with_forced_basis(4, 0, 10, basis);
let p1 = sem.acquire_many(2).await.unwrap();
let p2 = sem.acquire_many(2).await.unwrap();
assert!(sem.decrement_total_permits(2).is_some());
assert_eq!(sem.total_permits(), 2);
let vp = sem.increment_total_permits(2).unwrap();
assert_eq!(vp.num_permits(), 0);
assert_eq!(sem.total_permits(), 4);
drop(vp);
drop(p1);
assert_eq!(sem.available_permits(), 2);
drop(p2);
assert_eq!(sem.available_permits(), 4);
}
}
#[tokio::test]
async fn test_virtual_permit() {
for basis in [1u64, 2] {
let sem = AdjustableSemaphore::with_forced_basis(4, 0, 20, basis);
let vp = sem.increment_total_permits(6).unwrap();
assert_eq!(sem.total_permits(), 10);
assert_eq!(vp.num_permits(), 6);
assert_eq!(sem.available_permits(), 4);
drop(vp);
assert_eq!(sem.available_permits(), 10);
let sem = AdjustableSemaphore::with_forced_basis(0, 0, 22, basis);
let mut permits = Vec::new();
for i in 0..10u64 {
assert_eq!(sem.available_permits(), 0);
assert_eq!(sem.total_permits(), i * 2);
sem.increment_total_permits(2);
permits.push(sem.acquire_many(2).await.unwrap());
}
for i in 0..10u64 {
assert_eq!(sem.available_permits(), i * 2);
permits.pop();
}
}
}
#[tokio::test]
async fn test_permit_split() {
for basis in [1u64, 2] {
let sem = AdjustableSemaphore::with_forced_basis(10, 0, 10, basis);
let mut p = sem.acquire_many(6).await.unwrap();
let p2 = p.split(2).unwrap();
assert_eq!(p.num_permits(), 4);
assert_eq!(p2.num_permits(), 2);
drop(p2);
assert_eq!(sem.available_permits(), 6);
drop(p);
assert_eq!(sem.available_permits(), 10);
let mut p = sem.acquire_many(6).await.unwrap();
let p2 = p.split(6).unwrap();
assert_eq!(p.num_permits(), 0);
assert_eq!(p2.num_permits(), 6);
drop(p);
assert_eq!(sem.available_permits(), 4);
drop(p2);
assert_eq!(sem.available_permits(), 10);
let mut p = sem.acquire_many(4).await.unwrap();
assert!(p.split(6).is_none());
assert_eq!(p.num_permits(), 4);
drop(p);
}
}
#[tokio::test]
async fn test_virtual_permit_split() {
for basis in [1u64, 2] {
let sem = AdjustableSemaphore::with_forced_basis(4, 0, 20, basis);
let mut vp = sem.increment_total_permits(8).unwrap();
assert_eq!(sem.total_permits(), 12);
assert_eq!(sem.available_permits(), 4);
assert_eq!(vp.num_permits(), 8);
let vp2 = vp.split(2).unwrap();
assert_eq!(vp.num_permits(), 6);
assert_eq!(vp2.num_permits(), 2);
drop(vp2);
assert_eq!(sem.available_permits(), 6);
drop(vp);
assert_eq!(sem.available_permits(), 12);
}
}
#[test]
fn test_basis_computation() {
assert_eq!(AdjustableSemaphore::new(1024, (0, 1024)).basis(), 1);
assert_eq!(AdjustableSemaphore::new(PERMIT_LIMIT, (0, PERMIT_LIMIT)).basis(), 1);
assert_eq!(AdjustableSemaphore::new(PERMIT_LIMIT + 1, (0, PERMIT_LIMIT + 1)).basis(), 2);
}
#[test]
fn test_forced_basis_rounding() {
let sem = AdjustableSemaphore::with_forced_basis(1000, 0, 1000, 300);
assert_eq!(sem.total_permits(), 1200);
let sem = AdjustableSemaphore::with_forced_basis(900, 0, 900, 300);
assert_eq!(sem.total_permits(), 900);
}
#[tokio::test]
async fn test_rounding_and_physical_permits() {
let sem = AdjustableSemaphore::with_forced_basis(1024, 0, 1024, 256);
let p = sem.acquire_many(1).await.unwrap();
assert_eq!(p.num_permits(), 256);
assert_eq!(p.num_physical_permits(), 1);
assert_eq!(sem.available_permits(), 768);
drop(p);
let sem = AdjustableSemaphore::with_forced_basis(1000, 0, 1000, 100);
let p = sem.acquire_many(250).await.unwrap();
assert_eq!(p.num_permits(), 300);
assert_eq!(p.num_physical_permits(), 3);
drop(p);
let sem = AdjustableSemaphore::with_forced_basis(1024, 0, 2048, 256);
let vp = sem.increment_total_permits(512).unwrap();
assert_eq!(vp.num_permits(), 512);
assert_eq!(vp.num_physical_permits(), 2);
drop(vp);
let sem = AdjustableSemaphore::with_forced_basis(500, 0, 500, 100);
let mut p = sem.acquire_many(500).await.unwrap();
let p2 = p.split(1).unwrap();
assert_eq!(p2.num_permits(), 100);
assert_eq!(p2.num_physical_permits(), 1);
assert_eq!(p.num_permits(), 400);
assert_eq!(p.num_physical_permits(), 4);
drop(p2);
drop(p);
let sem = AdjustableSemaphore::with_forced_basis(500, 300, 500, 100);
assert!(sem.decrement_total_permits(300).is_some());
assert_eq!(sem.total_permits(), 300);
assert!(sem.decrement_total_permits(1).is_none());
}
#[test]
fn test_zero_capacity() {
let sem = AdjustableSemaphore::new(0, (0, 0));
assert_eq!(sem.total_permits(), 0);
assert_eq!(sem.available_permits(), 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
#[cfg_attr(feature = "smoke-test", ignore)]
async fn test_concurrent_stress() {
const TASKS: usize = 50;
const OPS_PER_TASK: usize = 1000;
const MIN_PERMITS: u64 = 10;
const MAX_PERMITS: u64 = 50;
let sem = AdjustableSemaphore::new(30, (MIN_PERMITS, MAX_PERMITS));
let mut js = JoinSet::new();
let barrier = Arc::new(Barrier::new(TASKS + 1));
for t in 0..TASKS {
let sem = sem.clone();
let mut rng = SmallRng::seed_from_u64(t as u64);
let barrier = barrier.clone();
js.spawn(async move {
barrier.wait().await;
for _ in 0..OPS_PER_TASK {
if rng.random_bool(0.1) {
sem.increment_total_permits(1);
}
if rng.random_bool(0.1) {
let _ = sem.decrement_total_permits(1);
}
let p = sem.acquire().await;
tokio::time::sleep(Duration::from_micros(100)).await;
drop(p);
assert!(sem.total_permits() >= MIN_PERMITS);
assert!(sem.total_permits() <= MAX_PERMITS);
assert!(sem.available_permits() <= MAX_PERMITS);
}
});
}
barrier.wait().await;
js.join_all().await;
let final_permits = sem.total_permits();
assert_le!(final_permits, MAX_PERMITS);
assert_ge!(final_permits, MIN_PERMITS);
let avail_permits = sem.available_permits();
assert_eq!(avail_permits, final_permits);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
#[cfg_attr(feature = "smoke-test", ignore)]
async fn test_concurrent_stress_acquire_many() {
const TASKS: usize = 30;
const OPS_PER_TASK: usize = 500;
const MIN_PERMITS: u64 = 100;
const MAX_PERMITS: u64 = 500;
let sem = AdjustableSemaphore::new(300, (MIN_PERMITS, MAX_PERMITS));
let mut js = JoinSet::new();
let barrier = Arc::new(Barrier::new(TASKS + 1));
for t in 0..TASKS {
let sem = sem.clone();
let mut rng = SmallRng::seed_from_u64(t as u64);
let barrier = barrier.clone();
js.spawn(async move {
barrier.wait().await;
for _ in 0..OPS_PER_TASK {
if rng.random_bool(0.05) {
sem.increment_total_permits(rng.random_range(1..=10));
}
if rng.random_bool(0.05) {
let _ = sem.decrement_total_permits(rng.random_range(1..=10));
}
let amount = rng.random_range(1..=50);
let p = sem.acquire_many(amount).await;
tokio::time::sleep(Duration::from_micros(50)).await;
drop(p);
assert!(sem.total_permits() >= MIN_PERMITS);
assert!(sem.total_permits() <= MAX_PERMITS);
}
});
}
barrier.wait().await;
js.join_all().await;
let final_permits = sem.total_permits();
assert_le!(final_permits, MAX_PERMITS);
assert_ge!(final_permits, MIN_PERMITS);
}
}