use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug, thiserror::Error)]
#[error("max_concurrent ({max_concurrent}) reached")]
pub struct CapacityExceededError {
pub max_concurrent: usize,
}
pub struct ConcurrencyGate {
max_concurrent: usize,
active: AtomicUsize,
}
impl ConcurrencyGate {
pub fn new(max_concurrent: usize) -> Self {
assert!(max_concurrent >= 1, "max_concurrent must be >= 1");
Self {
max_concurrent,
active: AtomicUsize::new(0),
}
}
pub fn acquire(&self) -> Result<(), CapacityExceededError> {
let mut cur = self.active.load(Ordering::Acquire);
loop {
if cur >= self.max_concurrent {
return Err(CapacityExceededError {
max_concurrent: self.max_concurrent,
});
}
match self
.active
.compare_exchange(cur, cur + 1, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => return Ok(()),
Err(actual) => cur = actual,
}
}
}
pub fn release(&self) {
let prev = self.active.fetch_sub(1, Ordering::AcqRel);
debug_assert!(prev > 0, "release called more times than acquire");
}
pub async fn run<F, T>(&self, fut: F) -> Result<T, CapacityExceededError>
where
F: std::future::Future<Output = T>,
{
self.acquire()?;
struct ReleaseOnDrop<'a>(&'a ConcurrencyGate);
impl<'a> Drop for ReleaseOnDrop<'a> {
fn drop(&mut self) {
self.0.release();
}
}
let _guard = ReleaseOnDrop(self);
Ok(fut.await)
}
pub fn active_jobs(&self) -> usize {
self.active.load(Ordering::Acquire)
}
pub fn max(&self) -> usize {
self.max_concurrent
}
pub fn load(&self) -> f64 {
if self.max_concurrent == 0 {
return 1.0;
}
self.active_jobs() as f64 / self.max_concurrent as f64
}
}