use std::future::Future;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::{Semaphore, watch};
use tokio::task::JoinSet;
#[derive(Debug, Clone)]
pub struct JobDispatcherConfig {
pub max_concurrent_jobs: usize,
pub max_result_bytes: u64,
}
#[derive(Debug, PartialEq, Eq)]
pub enum DispatchOutcome {
Spawned,
OverBudget,
}
pub struct JobDispatcher {
config: JobDispatcherConfig,
semaphore: Arc<Semaphore>,
shutdown_tx: watch::Sender<bool>,
shutdown_rx: watch::Receiver<bool>,
in_flight: Arc<AtomicUsize>,
tasks: Mutex<JoinSet<()>>,
}
impl JobDispatcher {
pub fn new(config: JobDispatcherConfig) -> Self {
let (shutdown_tx, shutdown_rx) = watch::channel(false);
let semaphore = Arc::new(Semaphore::new(config.max_concurrent_jobs));
Self {
config,
semaphore,
shutdown_tx,
shutdown_rx,
in_flight: Arc::new(AtomicUsize::new(0)),
tasks: Mutex::new(JoinSet::new()),
}
}
pub fn try_spawn<F, Fut>(&self, job: F) -> DispatchOutcome
where
F: FnOnce(watch::Receiver<bool>) -> Fut + Send + 'static,
Fut: Future<Output = crate::Result<()>> + Send + 'static,
{
let Ok(permit) = Arc::clone(&self.semaphore).try_acquire_owned() else {
return DispatchOutcome::OverBudget;
};
let shutdown_rx = self.shutdown_rx.clone();
let in_flight = Arc::clone(&self.in_flight);
in_flight.fetch_add(1, Ordering::Relaxed);
let mut tasks = self.tasks.lock().unwrap_or_else(|p| p.into_inner());
tasks.spawn(async move {
let _ = job(shutdown_rx).await;
drop(permit);
in_flight.fetch_sub(1, Ordering::Relaxed);
});
DispatchOutcome::Spawned
}
pub fn try_spawn_with_budget<F, Fut>(
&self,
declared_result_bytes: u64,
job: F,
) -> DispatchOutcome
where
F: FnOnce(watch::Receiver<bool>) -> Fut + Send + 'static,
Fut: Future<Output = crate::Result<()>> + Send + 'static,
{
if declared_result_bytes > self.config.max_result_bytes {
return DispatchOutcome::OverBudget;
}
self.try_spawn(job)
}
pub fn in_flight(&self) -> usize {
self.in_flight.load(Ordering::Relaxed)
}
pub async fn shutdown_and_drain(&self) {
let _ = self.shutdown_tx.send(true);
let mut tasks = {
let mut guard = self.tasks.lock().unwrap_or_else(|p| p.into_inner());
std::mem::take(&mut *guard)
};
tasks.abort_all();
while let Some(_res) = tasks.join_next().await {}
self.in_flight.store(0, Ordering::Relaxed);
}
pub fn shutdown_receiver(&self) -> watch::Receiver<bool> {
self.shutdown_rx.clone()
}
}