use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use futures_util::stream::{FuturesUnordered, Stream, StreamExt};
use tokio::sync::mpsc;
use crate::error::{AgenticError, Result};
use super::agent::Agent;
use super::output::AgentOutput;
const DEFAULT_CONCURRENCY: usize = 1;
pub struct Batch {
concurrency: usize,
agents: Vec<Agent>,
cancel_signal: Option<Arc<AtomicBool>>,
}
impl Default for Batch {
fn default() -> Self {
Self {
concurrency: DEFAULT_CONCURRENCY,
agents: Vec::new(),
cancel_signal: None,
}
}
}
impl Batch {
pub fn new() -> Self {
Self::default()
}
pub fn concurrency(mut self, n: usize) -> Self {
self.concurrency = n.max(1);
self
}
pub fn agent(mut self, agent: Agent) -> Self {
self.agents.push(agent);
self
}
pub fn agents<I>(mut self, agents: I) -> Self
where
I: IntoIterator<Item = Agent>,
{
self.agents.extend(agents);
self
}
pub fn cancel_signal(mut self, signal: Arc<AtomicBool>) -> Self {
self.cancel_signal = Some(signal);
self
}
pub async fn run(self) -> Vec<Result<AgentOutput>> {
let total = self.agents.len();
let (handle, stream) = self.spawn();
handle.drain();
let mut slots: Vec<Option<Result<AgentOutput>>> = (0..total).map(|_| None).collect();
for (index, result) in stream.collect().await {
if index < slots.len() {
slots[index] = Some(result);
}
}
slots
.into_iter()
.enumerate()
.map(|(i, slot)| {
slot.unwrap_or_else(|| {
Err(AgenticError::Other(format!(
"batch: missing result for submission index {i}"
)))
})
})
.collect()
}
pub fn spawn(self) -> (BatchHandle, BatchOutputStream) {
let concurrency = self.concurrency;
let (submit_tx, submit_rx) = mpsc::unbounded_channel::<(usize, Agent)>();
let (output_tx, output_rx) = mpsc::unbounded_channel::<(usize, Result<AgentOutput>)>();
let cancel = self
.cancel_signal
.unwrap_or_else(|| Arc::new(AtomicBool::new(false)));
let counter = Arc::new(AtomicUsize::new(0));
for agent in self.agents {
let index = counter.fetch_add(1, Ordering::Relaxed);
let _ = submit_tx.send((index, agent));
}
let dispatcher_cancel = cancel.clone();
tokio::spawn(async move {
dispatch(submit_rx, output_tx, concurrency, dispatcher_cancel).await;
});
let handle = BatchHandle {
sender: submit_tx,
cancel,
counter,
};
let output = BatchOutputStream { rx: output_rx };
(handle, output)
}
}
async fn dispatch(
mut submit_rx: mpsc::UnboundedReceiver<(usize, Agent)>,
output_tx: mpsc::UnboundedSender<(usize, Result<AgentOutput>)>,
concurrency: usize,
cancel: Arc<AtomicBool>,
) {
let mut in_flight: FuturesUnordered<tokio::task::JoinHandle<(usize, Result<AgentOutput>)>> =
FuturesUnordered::new();
let mut closed = false;
loop {
if cancel.load(Ordering::Relaxed) && !closed {
submit_rx.close();
closed = true;
}
tokio::select! {
biased;
Some(join) = in_flight.next(), if !in_flight.is_empty() => {
if let Ok(pair) = join {
let _ = output_tx.send(pair);
}
}
maybe = submit_rx.recv(), if !closed && in_flight.len() < concurrency => {
let Some((index, agent)) = maybe else {
closed = true;
continue;
};
let agent = agent.cancel_signal(cancel.clone());
in_flight.push(tokio::spawn(async move {
(index, agent.run().await)
}));
}
else => return,
}
}
}
#[derive(Clone)]
pub struct BatchHandle {
sender: mpsc::UnboundedSender<(usize, Agent)>,
cancel: Arc<AtomicBool>,
counter: Arc<AtomicUsize>,
}
impl BatchHandle {
pub fn submit(&self, agent: Agent) -> usize {
let index = self.counter.fetch_add(1, Ordering::Relaxed);
let _ = self.sender.send((index, agent));
index
}
pub fn cancel(&self) {
self.cancel.store(true, Ordering::Relaxed);
}
pub fn is_cancelled(&self) -> bool {
self.cancel.load(Ordering::Relaxed)
}
pub fn drain(self) {}
}
pub struct BatchOutputStream {
rx: mpsc::UnboundedReceiver<(usize, Result<AgentOutput>)>,
}
impl Stream for BatchOutputStream {
type Item = (usize, Result<AgentOutput>);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx.poll_recv(cx)
}
}
impl BatchOutputStream {
pub async fn collect(self) -> Vec<(usize, Result<AgentOutput>)> {
StreamExt::collect(self).await
}
pub async fn next(&mut self) -> Option<(usize, Result<AgentOutput>)> {
StreamExt::next(self).await
}
}
impl Unpin for BatchOutputStream {}
#[cfg(test)]
mod tests {
use super::*;
use crate::testutil::{text_response, tool_response, MockProvider};
use crate::tools::{Tool, ToolResult};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
fn agent_with_response(name: &str, text: &str) -> Agent {
Agent::new()
.name(name)
.model_name("mock")
.identity_prompt("")
.instruction_prompt("go")
.provider(Arc::new(MockProvider::text(text)))
}
fn agent_with_delay(name: &str, delay_ms: u64, text: &str) -> Agent {
let slow_tool = Tool::new("slow", "simulates work")
.schema(serde_json::json!({"type": "object", "properties": {}}))
.handler(move |_, _| {
Box::pin(async move {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
Ok(ToolResult::success("done"))
})
});
let provider = Arc::new(MockProvider::new(vec![
tool_response("slow", "c1", serde_json::json!({})),
text_response(text),
]));
Agent::new()
.name(name)
.model_name("mock")
.identity_prompt("")
.instruction_prompt("go")
.tool(slow_tool)
.provider(provider)
}
#[tokio::test]
async fn empty_run_yields_empty_vec() {
let results = Batch::new().concurrency(4).run().await;
assert!(results.is_empty());
}
#[tokio::test]
async fn run_returns_results_in_submission_order() {
let results = Batch::new()
.concurrency(4)
.agents(["a", "b", "c"].iter().map(|n| agent_with_response(n, "ok")))
.run()
.await;
assert_eq!(results.len(), 3);
let names: Vec<String> = results
.iter()
.map(|r| r.as_ref().unwrap().name.clone())
.collect();
assert_eq!(names, vec!["a", "b", "c"]);
}
#[tokio::test]
async fn run_submission_order_ignores_completion_order() {
let slow = agent_with_delay("slow", 80, "slow");
let fast = agent_with_response("fast", "fast");
let results = Batch::new()
.concurrency(4)
.agent(slow)
.agent(fast)
.run()
.await;
assert_eq!(results[0].as_ref().unwrap().name, "slow");
assert_eq!(results[1].as_ref().unwrap().name, "fast");
}
#[tokio::test]
async fn run_surfaces_failures_without_blocking_others() {
let failing = Agent::new()
.name("fail")
.model_name("mock")
.identity_prompt("")
.instruction_prompt("go")
.provider(Arc::new(MockProvider::new(vec![])));
let results = Batch::new()
.concurrency(2)
.agent(agent_with_response("ok1", "first"))
.agent(failing)
.agent(agent_with_response("ok2", "second"))
.run()
.await;
assert_eq!(results.len(), 3);
assert!(results[0].is_ok());
assert!(results[1].is_err());
assert!(results[2].is_ok());
}
#[tokio::test]
async fn stream_yields_submission_indices() {
let (pool, mut stream) = Batch::new()
.concurrency(4)
.agents(["a", "b", "c"].iter().map(|n| agent_with_response(n, "ok")))
.spawn();
drop(pool);
let mut seen: Vec<(usize, String)> = Vec::new();
while let Some((index, result)) = stream.next().await {
seen.push((index, result.unwrap().name));
}
seen.sort_by_key(|(i, _)| *i);
assert_eq!(
seen,
vec![(0, "a".into()), (1, "b".into()), (2, "c".into()),],
);
}
#[tokio::test]
async fn submit_returns_monotonic_indices_continuing_preloaded() {
let (pool, mut stream) = Batch::new()
.concurrency(4)
.agent(agent_with_response("preloaded", "ok"))
.spawn();
let idx_b = pool.submit(agent_with_response("b", "ok"));
let idx_c = pool.submit(agent_with_response("c", "ok"));
assert_eq!(idx_b, 1);
assert_eq!(idx_c, 2);
drop(pool);
let mut seen = Vec::new();
while let Some((i, _)) = stream.next().await {
seen.push(i);
}
seen.sort();
assert_eq!(seen, vec![0, 1, 2]);
}
#[tokio::test]
async fn concurrency_cap_bounds_parallelism() {
let running = Arc::new(AtomicUsize::new(0));
let max_concurrent = Arc::new(AtomicUsize::new(0));
let make = |i: usize| {
let r = running.clone();
let m = max_concurrent.clone();
let slow_tool = Tool::new("slow", "work")
.schema(serde_json::json!({"type": "object", "properties": {}}))
.handler(move |_, _| {
let r = r.clone();
let m = m.clone();
Box::pin(async move {
let cur = r.fetch_add(1, Ordering::SeqCst) + 1;
m.fetch_max(cur, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(30)).await;
r.fetch_sub(1, Ordering::SeqCst);
Ok(ToolResult::success("done"))
})
});
Agent::new()
.name(&format!("w{i}"))
.model_name("mock")
.identity_prompt("")
.instruction_prompt("go")
.tool(slow_tool)
.provider(Arc::new(MockProvider::new(vec![
tool_response("slow", "c1", serde_json::json!({})),
text_response("finished"),
])))
};
let results = Batch::new()
.concurrency(3)
.agents((0..10).map(make))
.run()
.await;
assert_eq!(results.len(), 10);
assert!(results.iter().all(|r| r.is_ok()));
let peak = max_concurrent.load(Ordering::SeqCst);
assert!(peak <= 3, "peak concurrency {peak} exceeded cap of 3");
assert!(
peak >= 2,
"peak concurrency {peak} never reached meaningful overlap"
);
}
#[tokio::test]
async fn concurrency_scales_throughput() {
let start = tokio::time::Instant::now();
let seq = Batch::new()
.concurrency(1)
.agents((0..10).map(|i| agent_with_delay("w", 30, &format!("r{i}"))))
.run()
.await;
let seq_elapsed = start.elapsed();
let start = tokio::time::Instant::now();
let par = Batch::new()
.concurrency(10)
.agents((0..10).map(|i| agent_with_delay("w", 30, &format!("r{i}"))))
.run()
.await;
let par_elapsed = start.elapsed();
assert_eq!(seq.len(), 10);
assert_eq!(par.len(), 10);
assert!(
seq_elapsed > par_elapsed * 3,
"sequential ({seq_elapsed:?}) should dwarf parallel ({par_elapsed:?})",
);
}
#[tokio::test]
async fn high_throughput_smoke() {
let results = Batch::new()
.concurrency(50)
.agents((0..500).map(|i| agent_with_response("w", &format!("r{i}"))))
.run()
.await;
assert_eq!(results.len(), 500);
assert!(results.iter().all(|r| r.is_ok()));
}
#[tokio::test]
async fn spawn_accepts_dynamic_submissions() {
let (pool, mut stream) = Batch::new().concurrency(2).spawn();
pool.submit(agent_with_response("a", "first"));
pool.submit(agent_with_response("b", "second"));
let r1 = stream.next().await.expect("first result");
let r2 = stream.next().await.expect("second result");
pool.submit(agent_with_response("c", "third"));
drop(pool);
let r3 = stream.next().await.expect("third result");
assert!(stream.next().await.is_none(), "stream must end after drop");
let mut names: Vec<String> = [r1, r2, r3]
.into_iter()
.map(|(_, r)| r.unwrap().name)
.collect();
names.sort();
assert_eq!(names, vec!["a", "b", "c"]);
}
#[tokio::test]
async fn spawn_keeps_stream_open_while_any_handle_lives() {
let (pool, mut stream) = Batch::new().concurrency(4).spawn();
let clone = pool.clone();
pool.submit(agent_with_response("a", "done"));
drop(pool);
assert!(stream.next().await.unwrap().1.is_ok());
clone.submit(agent_with_response("b", "done"));
assert!(stream.next().await.unwrap().1.is_ok());
drop(clone);
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn spawn_drops_handle_drains_backlog_and_ends_stream() {
let (pool, mut stream) = Batch::new().concurrency(2).spawn();
pool.submit(agent_with_response("a", "done"));
pool.submit(agent_with_response("b", "done"));
drop(pool);
let mut seen = 0;
while let Some((_, r)) = stream.next().await {
r.unwrap();
seen += 1;
}
assert_eq!(seen, 2);
}
#[tokio::test]
async fn drain_lets_in_flight_agents_finish_unlike_cancel() {
let (pool, mut stream) = Batch::new().concurrency(2).spawn();
pool.submit(agent_with_delay("a", 30, "done"));
pool.submit(agent_with_delay("b", 30, "done"));
pool.drain();
let mut seen = 0;
while let Some((_, r)) = stream.next().await {
let out = r.unwrap();
assert_eq!(out.status, crate::agent::AgentStatus::Completed);
seen += 1;
}
assert_eq!(seen, 2);
}
#[tokio::test]
async fn spawn_cancel_stops_in_flight_agents() {
let (pool, mut stream) = Batch::new().concurrency(2).spawn();
pool.submit(agent_with_delay("slow", 200, "never"));
tokio::time::sleep(Duration::from_millis(20)).await;
pool.cancel();
let (_, result) = stream.next().await.expect("result after cancel");
let out = result.unwrap();
assert_eq!(out.status, crate::agent::AgentStatus::Cancelled);
assert!(pool.is_cancelled());
drop(pool);
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn preloaded_agents_run_without_explicit_submit() {
let (pool, stream) = Batch::new()
.concurrency(2)
.agents(["a", "b"].iter().map(|n| agent_with_response(n, "ok")))
.spawn();
drop(pool);
let results = stream.collect().await;
assert_eq!(results.len(), 2);
}
}