use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
use futures::future::join_all;
use tokio::sync::Semaphore;
pub const DEFAULT_MAX_TOOL_CONCURRENCY: usize = 32;
pub fn configured_max_tool_concurrency() -> usize {
std::env::var("EVERRUNS_ACT_MAX_TOOL_CONCURRENCY")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.map(|v| v.max(1))
.unwrap_or(DEFAULT_MAX_TOOL_CONCURRENCY)
}
#[derive(Debug, Clone, Copy)]
pub struct ScheduleConfig {
pub max_concurrency: usize,
pub serialize_all: bool,
}
impl Default for ScheduleConfig {
fn default() -> Self {
Self {
max_concurrency: configured_max_tool_concurrency(),
serialize_all: false,
}
}
}
pub async fn schedule<R, MkFut, Fut>(
n: usize,
classes: &[Option<String>],
config: ScheduleConfig,
run: MkFut,
) -> Vec<R>
where
MkFut: Fn(usize) -> Fut,
Fut: Future<Output = R>,
{
if n == 0 {
return Vec::new();
}
if config.serialize_all || n == 1 {
let mut results = Vec::with_capacity(n);
for i in 0..n {
results.push(run(i).await);
}
return results;
}
let mut groups: Vec<Vec<usize>> = Vec::new();
let mut class_index: HashMap<&str, usize> = HashMap::new();
for i in 0..n {
match classes
.get(i)
.and_then(|class| class.as_deref())
.filter(|class| !class.is_empty())
{
Some(key) => {
if let Some(&g) = class_index.get(key) {
groups[g].push(i);
} else {
class_index.insert(key, groups.len());
groups.push(vec![i]);
}
}
None => groups.push(vec![i]),
}
}
let semaphore = Arc::new(Semaphore::new(config.max_concurrency.max(1)));
let run = &run;
let group_futures = groups.into_iter().map(|group| {
let semaphore = semaphore.clone();
async move {
let mut out: Vec<(usize, R)> = Vec::with_capacity(group.len());
for idx in group {
let permit = semaphore
.acquire()
.await
.expect("tool scheduler semaphore is never closed");
let result = run(idx).await;
drop(permit);
out.push((idx, result));
}
out
}
});
let mut indexed: Vec<(usize, R)> = join_all(group_futures)
.await
.into_iter()
.flatten()
.collect();
indexed.sort_by_key(|(idx, _)| *idx);
indexed.into_iter().map(|(_, r)| r).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Default)]
struct Tracker {
in_flight: AtomicUsize,
max_in_flight: AtomicUsize,
events: Mutex<Vec<String>>,
}
impl Tracker {
fn enter(&self, label: &str) {
let now = self.in_flight.fetch_add(1, Ordering::SeqCst) + 1;
self.max_in_flight.fetch_max(now, Ordering::SeqCst);
self.events.lock().unwrap().push(format!("start:{label}"));
}
fn exit(&self, label: &str) {
self.in_flight.fetch_sub(1, Ordering::SeqCst);
self.events.lock().unwrap().push(format!("end:{label}"));
}
}
#[tokio::test]
async fn preserves_input_order() {
let classes = vec![None, None, None];
let results = schedule(
3,
&classes,
ScheduleConfig::default(),
|i| async move { i * 10 },
)
.await;
assert_eq!(results, vec![0, 10, 20]);
}
#[tokio::test]
async fn distinct_classes_run_concurrently() {
let tracker = Arc::new(Tracker::default());
let classes = vec![None, None, None];
let t = tracker.clone();
schedule(3, &classes, ScheduleConfig::default(), |i| {
let t = t.clone();
async move {
t.enter(&i.to_string());
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
t.exit(&i.to_string());
}
})
.await;
assert_eq!(tracker.max_in_flight.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn same_class_serializes_in_arrival_order() {
let tracker = Arc::new(Tracker::default());
let classes = vec![
Some("fs".to_string()),
Some("fs".to_string()),
Some("fs".to_string()),
];
let t = tracker.clone();
schedule(3, &classes, ScheduleConfig::default(), |i| {
let t = t.clone();
async move {
t.enter(&i.to_string());
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
t.exit(&i.to_string());
}
})
.await;
assert_eq!(tracker.max_in_flight.load(Ordering::SeqCst), 1);
let events = tracker.events.lock().unwrap().clone();
assert_eq!(
events,
vec!["start:0", "end:0", "start:1", "end:1", "start:2", "end:2"]
);
}
#[tokio::test]
async fn mixed_classes_serialize_within_parallelize_across() {
let tracker = Arc::new(Tracker::default());
let classes = vec![Some("fs".to_string()), None, Some("fs".to_string()), None];
let t = tracker.clone();
schedule(4, &classes, ScheduleConfig::default(), |i| {
let t = t.clone();
async move {
t.enter(&i.to_string());
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
t.exit(&i.to_string());
}
})
.await;
assert_eq!(tracker.max_in_flight.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn global_cap_bounds_concurrency() {
let tracker = Arc::new(Tracker::default());
let classes = vec![None; 10];
let cfg = ScheduleConfig {
max_concurrency: 2,
serialize_all: false,
};
let t = tracker.clone();
schedule(10, &classes, cfg, |i| {
let t = t.clone();
async move {
t.enter(&i.to_string());
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
t.exit(&i.to_string());
}
})
.await;
assert!(tracker.max_in_flight.load(Ordering::SeqCst) <= 2);
}
#[tokio::test]
async fn shorter_classes_slice_schedules_all_calls() {
let classes = vec![Some("ws".to_string())]; let results = schedule(3, &classes, ScheduleConfig::default(), |i| async move { i }).await;
assert_eq!(results, vec![0, 1, 2]);
}
#[tokio::test]
async fn serialize_all_ignores_classes() {
let tracker = Arc::new(Tracker::default());
let classes = vec![None, None, None];
let cfg = ScheduleConfig {
max_concurrency: 8,
serialize_all: true,
};
let t = tracker.clone();
let results = schedule(3, &classes, cfg, |i| {
let t = t.clone();
async move {
t.enter(&i.to_string());
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
t.exit(&i.to_string());
i
}
})
.await;
assert_eq!(results, vec![0, 1, 2]);
assert_eq!(tracker.max_in_flight.load(Ordering::SeqCst), 1);
}
}