use futures_intrusive::sync::{GenericSemaphoreReleaser, Semaphore};
use parking_lot::RawMutex;
use std::{future::Future, sync::Arc};
#[derive(Debug)]
pub struct OwnedPermitGuard {
gate: Arc<CapacityGate>,
}
impl Drop for OwnedPermitGuard {
fn drop(&mut self) {
self.gate.release();
}
}
#[derive(Debug)]
pub struct CapacityGate {
capacity: usize,
semaphore: Semaphore,
}
pub(crate) type PermitGuard<'a> = GenericSemaphoreReleaser<'a, RawMutex>;
impl CapacityGate {
pub fn new(capacity: usize) -> Self {
Self {
capacity,
semaphore: Semaphore::new(true, capacity),
}
}
pub fn acquire(&self) -> impl Future<Output = PermitGuard<'_>> {
self.semaphore.acquire(1)
}
pub fn acquire_owned(self: Arc<Self>) -> impl Future<Output = OwnedPermitGuard> {
async move {
let _temporary_guard = self.semaphore.acquire(1).await;
std::mem::forget(_temporary_guard);
OwnedPermitGuard { gate: self }
}
}
pub fn try_acquire(&self) -> Option<PermitGuard<'_>> {
self.semaphore.try_acquire(1)
}
pub fn release(&self) {
self.semaphore.release(1);
}
pub fn get_permits(&self) -> usize {
self.semaphore.permits()
}
pub fn capacity(&self) -> usize {
self.capacity
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{sync::Arc, time::Duration};
#[tokio::test]
async fn new_gate_has_correct_initial_permits() {
let gate = CapacityGate::new(5);
assert_eq!(gate.capacity(), 5);
assert_eq!(gate.get_permits(), 5);
}
#[tokio::test]
async fn acquire_and_release_on_drop() {
let gate = CapacityGate::new(2);
let p1 = gate.acquire().await;
let p2 = gate.acquire().await;
assert_eq!(gate.get_permits(), 0);
drop(p1);
assert_eq!(gate.get_permits(), 1);
drop(p2);
assert_eq!(gate.get_permits(), 2);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_acquire() {
const CAPACITY: usize = 4;
const NUM_TASKS: usize = 64;
let gate = Arc::new(CapacityGate::new(CAPACITY));
let mut handles = Vec::new();
for _i in 0..NUM_TASKS {
let gate_clone = gate.clone();
let task = async move {
let _permit = gate_clone.acquire().await;
tokio::time::sleep(Duration::from_millis(20)).await;
};
handles.push(tokio::spawn(task));
}
for handle in handles {
handle.await.unwrap();
}
assert_eq!(gate.get_permits(), CAPACITY);
}
}