use std::sync::Arc;
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
#[derive(Clone)]
pub struct RecoveryPermitter {
sem: Arc<Semaphore>,
capacity: usize,
}
impl RecoveryPermitter {
pub fn new(max_concurrent: usize) -> Self {
assert!(max_concurrent >= 1, "max_concurrent must be ≥ 1");
Self { sem: Arc::new(Semaphore::new(max_concurrent)), capacity: max_concurrent }
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn available(&self) -> usize {
self.sem.available_permits()
}
pub fn in_flight(&self) -> usize {
self.capacity - self.available()
}
pub async fn acquire(&self) -> Option<OwnedSemaphorePermit> {
self.sem.clone().acquire_owned().await.ok()
}
pub fn try_acquire(&self) -> Result<OwnedSemaphorePermit, TryAcquireError> {
self.sem.clone().try_acquire_owned()
}
pub fn close(&self) {
self.sem.close();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn capacity_bounds_concurrent_acquires() {
let p = RecoveryPermitter::new(2);
assert_eq!(p.capacity(), 2);
assert_eq!(p.available(), 2);
let permit_a = p.acquire().await.unwrap();
let permit_b = p.acquire().await.unwrap();
assert_eq!(p.available(), 0);
assert_eq!(p.in_flight(), 2);
let p2 = p.clone();
let h = tokio::spawn(async move { p2.acquire().await });
tokio::time::sleep(Duration::from_millis(20)).await;
assert!(!h.is_finished()); drop(permit_a);
let permit_c = h.await.unwrap().unwrap();
assert_eq!(p.in_flight(), 2);
drop(permit_b);
drop(permit_c);
assert_eq!(p.in_flight(), 0);
}
#[tokio::test]
async fn try_acquire_returns_immediately() {
let p = RecoveryPermitter::new(1);
let _held = p.try_acquire().unwrap();
assert!(p.try_acquire().is_err());
}
#[tokio::test]
async fn close_returns_none_for_pending() {
let p = RecoveryPermitter::new(1);
let _held = p.acquire().await.unwrap();
let p2 = p.clone();
let h = tokio::spawn(async move { p2.acquire().await });
tokio::time::sleep(Duration::from_millis(10)).await;
p.close();
let r = h.await.unwrap();
assert!(r.is_none());
}
#[test]
#[should_panic]
fn zero_capacity_panics() {
let _ = RecoveryPermitter::new(0);
}
}