use std::collections::HashMap;
use std::hash::Hash;
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use tokio::task::JoinHandle;
#[derive(Debug, Default)]
pub struct InflightSeqs {
xs: Vec<i64>,
}
impl InflightSeqs {
pub fn new() -> Self {
Self { xs: Vec::new() }
}
pub fn add(&mut self, seq: i64) {
if seq <= 0 {
return;
}
let idx = self.xs.partition_point(|&x| x < seq);
self.xs.insert(idx, seq);
}
pub fn remove(&mut self, seq: i64) {
if seq <= 0 {
return;
}
if let Ok(idx) = self.xs.binary_search(&seq) {
self.xs.remove(idx);
}
}
pub fn min(&self) -> i64 {
self.xs.first().copied().unwrap_or(0)
}
pub fn len(&self) -> usize {
self.xs.len()
}
pub fn is_empty(&self) -> bool {
self.xs.is_empty()
}
}
pub struct Scheduler<K, W>
where
K: Eq + Hash + Clone + Send + 'static,
W: Send + 'static,
{
inner: Arc<SchedulerInner<K, W>>,
feeder_tx: mpsc::UnboundedSender<(K, W)>,
handles: Vec<JoinHandle<()>>,
}
struct SchedulerInner<K, W>
where
K: Eq + Hash + Clone + Send + 'static,
W: Send + 'static,
{
key_queue_cap: usize,
active: Mutex<HashMap<K, Vec<W>>>,
}
#[derive(Debug, PartialEq, Eq)]
pub enum AddOutcome<W> {
Queued,
Dropped(W),
ShuttingDown(W),
}
impl<K, W> Scheduler<K, W>
where
K: Eq + Hash + Clone + Send + 'static,
W: Send + 'static,
{
pub fn new<F, Fut, R>(
workers: usize,
key_queue_cap: usize,
result_capacity: usize,
run: F,
) -> (Self, mpsc::Receiver<R>)
where
F: Fn(W) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = R> + Send + 'static,
R: Send + 'static,
{
let workers = workers.max(1);
let (feeder_tx, feeder_rx) = mpsc::unbounded_channel::<(K, W)>();
let (result_tx, result_rx) = mpsc::channel(result_capacity.max(1));
let inner = Arc::new(SchedulerInner {
key_queue_cap,
active: Mutex::new(HashMap::new()),
});
let run = Arc::new(run);
let shared_rx = Arc::new(Mutex::new(feeder_rx));
let mut handles = Vec::with_capacity(workers);
for _ in 0..workers {
let inner = Arc::clone(&inner);
let run = Arc::clone(&run);
let result_tx = result_tx.clone();
let shared_rx = Arc::clone(&shared_rx);
handles.push(tokio::spawn(async move {
worker_loop(inner, run, result_tx, shared_rx).await;
}));
}
(
Scheduler {
inner,
feeder_tx,
handles,
},
result_rx,
)
}
pub async fn add_work(&self, key: K, work: W) -> AddOutcome<W> {
if self.feeder_tx.is_closed() {
return AddOutcome::ShuttingDown(work);
}
let mut active = self.inner.active.lock().await;
if let Some(queue) = active.get_mut(&key) {
if self.inner.key_queue_cap > 0 && queue.len() >= self.inner.key_queue_cap {
let dropped = queue.remove(0);
queue.push(work);
return AddOutcome::Dropped(dropped);
}
queue.push(work);
return AddOutcome::Queued;
}
active.insert(key.clone(), Vec::new());
match self.feeder_tx.send((key.clone(), work)) {
Ok(()) => {
drop(active);
AddOutcome::Queued
}
Err(mpsc::error::SendError((_, work))) => {
active.remove(&key);
AddOutcome::ShuttingDown(work)
}
}
}
pub async fn shutdown(mut self) {
let (dead_tx, _) = mpsc::unbounded_channel();
self.feeder_tx = dead_tx;
for handle in std::mem::take(&mut self.handles) {
let _ = handle.await;
}
}
pub fn abort(mut self) {
for handle in std::mem::take(&mut self.handles) {
handle.abort();
}
}
}
impl<K, W> Drop for Scheduler<K, W>
where
K: Eq + Hash + Clone + Send + 'static,
W: Send + 'static,
{
fn drop(&mut self) {
for handle in &self.handles {
handle.abort();
}
}
}
async fn worker_loop<K, W, F, Fut, R>(
inner: Arc<SchedulerInner<K, W>>,
run: Arc<F>,
result_tx: mpsc::Sender<R>,
shared_rx: Arc<Mutex<mpsc::UnboundedReceiver<(K, W)>>>,
) where
K: Eq + Hash + Clone + Send + 'static,
W: Send + 'static,
F: Fn(W) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = R> + Send + 'static,
R: Send + 'static,
{
loop {
let next = {
let mut rx = shared_rx.lock().await;
rx.recv().await
};
let Some((key, work)) = next else {
return;
};
run_chain(&inner, &run, &result_tx, key, work).await;
}
}
async fn run_chain<K, W, F, Fut, R>(
inner: &Arc<SchedulerInner<K, W>>,
run: &Arc<F>,
result_tx: &mpsc::Sender<R>,
key: K,
mut work: W,
) where
K: Eq + Hash + Clone + Send + 'static,
W: Send + 'static,
F: Fn(W) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = R> + Send + 'static,
R: Send + 'static,
{
loop {
let result = run(work).await;
if result_tx.send(result).await.is_err() {
inner.active.lock().await.remove(&key);
return;
}
let mut active = inner.active.lock().await;
match active.get_mut(&key) {
Some(queue) if !queue.is_empty() => {
work = queue.remove(0);
}
_ => {
active.remove(&key);
return;
}
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
#[tokio::test]
async fn preserves_same_key_order() {
let (sched, mut results) = Scheduler::<u8, usize>::new(4, 0, 256, |w| async move {
tokio::time::sleep(Duration::from_millis((w % 3) as u64)).await;
w
});
for i in 0..100usize {
assert!(matches!(sched.add_work(7, i).await, AddOutcome::Queued));
}
let mut got = Vec::new();
for _ in 0..100 {
got.push(results.recv().await.unwrap());
}
let expected: Vec<usize> = (0..100).collect();
assert_eq!(
got, expected,
"same-key work must complete in arrival order"
);
sched.abort();
}
#[tokio::test]
async fn runs_different_keys_concurrently() {
let barrier = Arc::new(tokio::sync::Barrier::new(8));
let entered = Arc::new(AtomicUsize::new(0));
let entered2 = Arc::clone(&entered);
let barrier2 = Arc::clone(&barrier);
let (sched, mut results) = Scheduler::<u8, u8>::new(8, 0, 64, move |w| {
let barrier = Arc::clone(&barrier2);
let entered = Arc::clone(&entered2);
async move {
entered.fetch_add(1, Ordering::SeqCst);
barrier.wait().await;
w
}
});
for key in 0..8u8 {
assert!(matches!(sched.add_work(key, key).await, AddOutcome::Queued));
}
for _ in 0..8 {
tokio::time::timeout(Duration::from_secs(2), results.recv())
.await
.expect("all keys should run concurrently")
.unwrap();
}
assert_eq!(entered.load(Ordering::SeqCst), 8);
sched.abort();
}
#[tokio::test]
async fn drops_oldest_when_per_key_queue_full() {
let release = Arc::new(tokio::sync::Notify::new());
let release2 = Arc::clone(&release);
let started = Arc::new(tokio::sync::Notify::new());
let started2 = Arc::clone(&started);
let (sched, mut results) = Scheduler::<u8, usize>::new(1, 1, 64, move |w| {
let release = Arc::clone(&release2);
let started = Arc::clone(&started2);
async move {
if w == 0 {
started.notify_one();
release.notified().await;
}
w
}
});
assert!(matches!(sched.add_work(1, 0).await, AddOutcome::Queued));
started.notified().await;
assert!(matches!(sched.add_work(1, 1).await, AddOutcome::Queued));
assert_eq!(sched.add_work(1, 2).await, AddOutcome::Dropped(1));
release.notify_one();
let mut got = Vec::new();
for _ in 0..2 {
got.push(results.recv().await.unwrap());
}
assert_eq!(got, vec![0, 2], "dropped unit 1 must not run; 0 then 2 do");
sched.abort();
}
#[test]
fn inflight_tracks_min_and_ignores_nonpositive() {
let mut inflight = InflightSeqs::new();
inflight.add(30);
inflight.add(10);
inflight.add(20);
inflight.add(0); inflight.add(-5); assert_eq!(inflight.len(), 3);
assert_eq!(inflight.min(), 10);
inflight.remove(10);
assert_eq!(inflight.min(), 20);
inflight.remove(20);
inflight.remove(30);
assert!(inflight.is_empty());
assert_eq!(inflight.min(), 0);
}
#[tokio::test]
async fn shutdown_does_not_hang() {
let (sched, results) = Scheduler::<u8, u8>::new(2, 0, 8, |w| async move { w });
sched.add_work(1, 1).await;
sched.add_work(2, 2).await;
drop(results);
tokio::time::timeout(Duration::from_secs(2), sched.shutdown())
.await
.expect("shutdown must not hang");
}
}