use std::sync::Arc;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio_util::sync::CancellationToken;
#[derive(Debug)]
pub struct Scheduler {
permits: Arc<Semaphore>,
max: usize,
}
impl Scheduler {
#[must_use]
pub fn new(parallel: usize) -> Self {
let max = if parallel == 0 {
tracing::warn!(
target: "grex::scheduler",
max_permits = Semaphore::MAX_PERMITS,
"scheduler: parallel=0 → unbounded; set --parallel N to cap concurrency"
);
Semaphore::MAX_PERMITS
} else {
parallel.min(Semaphore::MAX_PERMITS)
};
Self { permits: Arc::new(Semaphore::new(max)), max }
}
#[must_use]
pub fn permits(&self) -> Arc<Semaphore> {
Arc::clone(&self.permits)
}
pub async fn acquire(&self) -> OwnedSemaphorePermit {
Arc::clone(&self.permits).acquire_owned().await.expect("scheduler semaphore never closes")
}
#[must_use]
pub fn max_parallelism(&self) -> usize {
self.max
}
pub async fn acquire_cancellable(
&self,
cancel: &CancellationToken,
) -> Result<OwnedSemaphorePermit, Cancelled> {
tokio::select! {
biased;
() = cancel.cancelled() => Err(Cancelled),
permit = Arc::clone(&self.permits).acquire_owned() => {
permit.map_err(|_| Cancelled)
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Cancelled;
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
#[tokio::test]
async fn acquire_cancellable_returns_permit_when_not_cancelled() {
let s = Scheduler::new(2);
let token = CancellationToken::new();
let result = s.acquire_cancellable(&token).await;
assert!(result.is_ok(), "expected Ok(permit) on uncontended scheduler");
}
#[tokio::test]
async fn acquire_cancellable_returns_cancelled_if_token_fires_before_permit() {
let s = Scheduler::new(1);
let _hold = s.acquire().await;
let token = CancellationToken::new();
let cancel_handle = token.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
cancel_handle.cancel();
});
let result = tokio::time::timeout(Duration::from_millis(30), s.acquire_cancellable(&token))
.await
.expect("acquire_cancellable must resolve within 30 ms after cancel");
assert!(matches!(result, Err(Cancelled)), "expected Err(Cancelled)");
}
#[tokio::test]
async fn acquire_cancellable_dropped_future_does_not_leak_permit() {
let s = Arc::new(Scheduler::new(4));
let _h1 = s.acquire().await;
let _h2 = s.acquire().await;
let _h3 = s.acquire().await;
let _h4 = s.acquire().await;
let mut tokens = Vec::with_capacity(100);
let mut handles = Vec::with_capacity(100);
for _ in 0..100 {
let token = CancellationToken::new();
let token_for_task = token.clone();
let sched = Arc::clone(&s);
handles.push(tokio::spawn(async move {
let _ = sched.acquire_cancellable(&token_for_task).await;
}));
tokens.push(token);
}
tokio::time::sleep(Duration::from_millis(5)).await;
for token in &tokens {
token.cancel();
}
for h in handles {
h.await.expect("waiter task panicked");
}
drop(_h1);
drop(_h2);
drop(_h3);
drop(_h4);
assert_eq!(s.permits.available_permits(), 4, "cancelled waiters must not leak permits");
}
#[tokio::test]
async fn acquire_cancellable_cancel_after_success_is_no_op() {
let s = Scheduler::new(1);
let token = CancellationToken::new();
let permit = s.acquire_cancellable(&token).await.expect("permit");
token.cancel();
assert_eq!(s.permits.available_permits(), 0, "permit still held");
drop(permit);
assert_eq!(s.permits.available_permits(), 1, "permit released on drop");
}
}