use async_trait::async_trait;
use parking_lot::{Condvar, Mutex};
use std::sync::Arc;
#[async_trait]
pub trait BytePermits: Send + Sync {
async fn acquire(&self, n_bytes: usize) -> Permit;
}
struct SemInner {
available: usize,
max_bytes: usize,
}
pub struct Permit {
inner: Option<PermitInner>,
}
enum PermitInner {
ByteSem(Arc<(Mutex<SemInner>, Condvar)>, usize),
NoOp,
}
impl Drop for Permit {
fn drop(&mut self) {
match self.inner.take() {
Some(PermitInner::ByteSem(inner, n_bytes)) => {
let (mutex, condvar) = &*inner;
let mut guard = mutex.lock();
guard.available += n_bytes;
condvar.notify_all();
}
Some(PermitInner::NoOp) | None => {}
}
}
}
impl Permit {
pub(crate) const fn noop() -> Self {
Self {
inner: Some(PermitInner::NoOp),
}
}
fn byte_sem(inner: Arc<(Mutex<SemInner>, Condvar)>, n_bytes: usize) -> Self {
Self {
inner: Some(PermitInner::ByteSem(inner, n_bytes)),
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NoOpPermits;
#[async_trait]
impl BytePermits for NoOpPermits {
async fn acquire(&self, _n_bytes: usize) -> Permit {
Permit::noop()
}
}
#[derive(Clone)]
pub struct SemaphorePermits {
inner: Arc<(Mutex<SemInner>, Condvar)>,
}
impl SemaphorePermits {
#[must_use]
pub fn new(max_bytes: usize) -> Self {
Self {
inner: Arc::new((
Mutex::new(SemInner {
available: max_bytes,
max_bytes,
}),
Condvar::new(),
)),
}
}
}
#[async_trait]
impl BytePermits for SemaphorePermits {
async fn acquire(&self, n_bytes: usize) -> Permit {
if n_bytes == 0 {
return Permit::noop();
}
let inner = self.inner.clone();
let actual = compio::runtime::spawn_blocking(move || {
let (mutex, condvar) = &*inner;
let mut guard = mutex.lock();
let claim = n_bytes.min(guard.max_bytes);
while guard.available < claim {
condvar.wait(&mut guard);
}
guard.available -= claim;
claim
})
.await;
Permit::byte_sem(self.inner.clone(), actual)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn noop_permits_always_succeed() {
let permits = NoOpPermits;
let rt = compio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let _p1 = permits.acquire(1024).await;
let _p2 = permits.acquire(1_000_000).await;
});
}
#[test]
fn semaphore_permits_enforce_limit() {
let permits = SemaphorePermits::new(1024);
let rt = compio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let p1 = permits.acquire(1024).await;
drop(p1);
let _p2 = permits.acquire(512).await;
let _p3 = permits.acquire(512).await;
});
}
#[test]
fn semaphore_permits_release_on_drop() {
let permits = SemaphorePermits::new(1000);
let rt = compio::runtime::Runtime::new().unwrap();
rt.block_on(async {
{
let _p1 = permits.acquire(500).await;
let _p2 = permits.acquire(500).await;
}
let _p3 = permits.acquire(1000).await;
});
}
#[test]
fn semaphore_permits_oversized_acquire_does_not_deadlock() {
let permits = SemaphorePermits::new(1024);
let rt = compio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let permit = permits.acquire(2048).await; drop(permit);
let _p = permits.acquire(1024).await;
});
}
#[test]
fn semaphore_permits_single_atomic_acquire() {
let permits = SemaphorePermits::new(1024 * 1024); let rt = compio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let permit = permits.acquire(512 * 1024).await; drop(permit);
});
}
}