1use std::sync::Arc;
8
9use tokio::sync::Semaphore;
10use tokio::task::JoinSet;
11
12use crate::error::Error;
13use crate::llm::LlmProvider;
14use crate::llm::types::TokenUsage;
15
16use super::AgentOutput;
17use super::AgentRunner;
18
19#[derive(Debug)]
21pub struct BatchResult {
22 pub index: usize,
24 pub input: String,
26 pub result: Result<AgentOutput, Error>,
28}
29
30#[derive(Debug, Clone)]
32pub struct BatchConfig {
33 pub max_concurrency: usize,
35}
36
37impl Default for BatchConfig {
38 fn default() -> Self {
39 Self {
40 max_concurrency: std::thread::available_parallelism()
41 .map(|n| n.get())
42 .unwrap_or(4),
43 }
44 }
45}
46
47pub struct BatchExecutor<P: LlmProvider + 'static> {
52 agent: Arc<AgentRunner<P>>,
53 config: BatchConfig,
54}
55
56impl<P: LlmProvider + 'static> std::fmt::Debug for BatchExecutor<P> {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("BatchExecutor")
59 .field("max_concurrency", &self.config.max_concurrency)
60 .finish()
61 }
62}
63
64pub struct BatchExecutorBuilder<P: LlmProvider + 'static> {
66 agent: AgentRunner<P>,
67 max_concurrency: Option<usize>,
68}
69
70impl<P: LlmProvider + 'static> BatchExecutor<P> {
71 pub fn builder(agent: AgentRunner<P>) -> BatchExecutorBuilder<P> {
73 BatchExecutorBuilder {
74 agent,
75 max_concurrency: None,
76 }
77 }
78
79 pub async fn execute(&self, tasks: Vec<String>) -> Vec<BatchResult> {
83 if tasks.is_empty() {
84 return Vec::new();
85 }
86
87 let semaphore = Arc::new(Semaphore::new(self.config.max_concurrency));
88 let mut set = JoinSet::new();
89
90 for (index, input) in tasks.into_iter().enumerate() {
91 let agent = Arc::clone(&self.agent);
92 let sem = Arc::clone(&semaphore);
93 set.spawn(async move {
94 let _permit = sem.acquire().await.expect("semaphore closed unexpectedly");
95 let result = agent.execute(&input).await;
96 BatchResult {
97 index,
98 input,
99 result,
100 }
101 });
102 }
103
104 let mut results = Vec::with_capacity(set.len());
105 while let Some(join_result) = set.join_next().await {
106 match join_result {
107 Ok(batch_result) => results.push(batch_result),
108 Err(e) => {
109 tracing::error!("batch task panicked: {e}");
113 }
114 }
115 }
116
117 results.sort_by_key(|r| r.index);
118 results
119 }
120
121 pub async fn execute_ref(&self, tasks: &[&str]) -> Vec<BatchResult> {
123 let owned: Vec<String> = tasks.iter().map(|s| (*s).to_string()).collect();
124 self.execute(owned).await
125 }
126
127 pub fn aggregate_usage(results: &[BatchResult]) -> TokenUsage {
129 let mut total = TokenUsage::default();
130 for r in results {
131 if let Ok(output) = &r.result {
132 total += output.tokens_used;
133 }
134 }
135 total
136 }
137}
138
139impl<P: LlmProvider + 'static> BatchExecutorBuilder<P> {
140 pub fn max_concurrency(mut self, n: usize) -> Self {
142 self.max_concurrency = Some(n);
143 self
144 }
145
146 pub fn build(self) -> Result<BatchExecutor<P>, Error> {
148 let config = match self.max_concurrency {
149 Some(n) => {
150 if n == 0 {
151 return Err(Error::Config(
152 "BatchExecutor max_concurrency must be at least 1".into(),
153 ));
154 }
155 BatchConfig { max_concurrency: n }
156 }
157 None => BatchConfig::default(),
158 };
159 Ok(BatchExecutor {
160 agent: Arc::new(self.agent),
161 config,
162 })
163 }
164}
165
166#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::agent::test_helpers::{MockProvider, make_agent};
174 use crate::llm::types::{CompletionRequest, CompletionResponse, ContentBlock, StopReason};
175 use std::sync::atomic::{AtomicUsize, Ordering};
176
177 struct ConcurrencyTrackingProvider {
179 current: Arc<AtomicUsize>,
181 peak: Arc<AtomicUsize>,
183 response_text: String,
184 }
185
186 impl ConcurrencyTrackingProvider {
187 fn new(current: Arc<AtomicUsize>, peak: Arc<AtomicUsize>, response_text: &str) -> Self {
188 Self {
189 current,
190 peak,
191 response_text: response_text.to_string(),
192 }
193 }
194 }
195
196 impl LlmProvider for ConcurrencyTrackingProvider {
197 async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, Error> {
198 let prev = self.current.fetch_add(1, Ordering::SeqCst);
199 let concurrent = prev + 1;
200 self.peak.fetch_max(concurrent, Ordering::SeqCst);
202 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
204 self.current.fetch_sub(1, Ordering::SeqCst);
205
206 Ok(CompletionResponse {
207 content: vec![ContentBlock::Text {
208 text: self.response_text.clone(),
209 }],
210 stop_reason: StopReason::EndTurn,
211 usage: TokenUsage {
212 input_tokens: 10,
213 output_tokens: 5,
214 ..Default::default()
215 },
216 model: None,
217 })
218 }
219
220 fn model_name(&self) -> Option<&str> {
221 Some("concurrency-mock")
222 }
223 }
224
225 #[test]
230 fn builder_uses_default_concurrency() {
231 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
232 "ok", 10, 5,
233 )]));
234 let agent = make_agent(provider, "test");
235 let executor = BatchExecutor::builder(agent).build().unwrap();
236 assert!(executor.config.max_concurrency >= 1);
237 }
238
239 #[test]
240 fn builder_accepts_custom_concurrency() {
241 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
242 "ok", 10, 5,
243 )]));
244 let agent = make_agent(provider, "test");
245 let executor = BatchExecutor::builder(agent)
246 .max_concurrency(8)
247 .build()
248 .unwrap();
249 assert_eq!(executor.config.max_concurrency, 8);
250 }
251
252 #[test]
253 fn builder_rejects_zero_concurrency() {
254 let provider = Arc::new(MockProvider::new(vec![]));
255 let agent = make_agent(provider, "test");
256 let result = BatchExecutor::builder(agent).max_concurrency(0).build();
257 assert!(result.is_err());
258 assert!(result.unwrap_err().to_string().contains("at least 1"));
259 }
260
261 #[test]
262 fn debug_impl() {
263 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
264 "ok", 10, 5,
265 )]));
266 let agent = make_agent(provider, "test");
267 let executor = BatchExecutor::builder(agent)
268 .max_concurrency(3)
269 .build()
270 .unwrap();
271 let debug = format!("{executor:?}");
272 assert!(debug.contains("BatchExecutor"));
273 assert!(debug.contains("3"));
274 }
275
276 #[tokio::test]
281 async fn empty_batch_returns_empty_vec() {
282 let provider = Arc::new(MockProvider::new(vec![]));
283 let agent = make_agent(provider, "test");
284 let executor = BatchExecutor::builder(agent)
285 .max_concurrency(2)
286 .build()
287 .unwrap();
288
289 let results = executor.execute(vec![]).await;
290 assert!(results.is_empty());
291 }
292
293 #[tokio::test]
294 async fn single_task_succeeds() {
295 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
296 "hello", 100, 50,
297 )]));
298 let agent = make_agent(provider, "test");
299 let executor = BatchExecutor::builder(agent)
300 .max_concurrency(2)
301 .build()
302 .unwrap();
303
304 let results = executor.execute(vec!["task1".to_string()]).await;
305 assert_eq!(results.len(), 1);
306 assert_eq!(results[0].index, 0);
307 assert_eq!(results[0].input, "task1");
308 let output = results[0].result.as_ref().unwrap();
309 assert_eq!(output.result, "hello");
310 assert_eq!(output.tokens_used.input_tokens, 100);
311 assert_eq!(output.tokens_used.output_tokens, 50);
312 }
313
314 #[tokio::test]
315 async fn multiple_tasks_all_succeed() {
316 let provider = Arc::new(MockProvider::new(vec![
317 MockProvider::text_response("r1", 10, 5),
318 MockProvider::text_response("r2", 20, 10),
319 MockProvider::text_response("r3", 30, 15),
320 MockProvider::text_response("r4", 40, 20),
321 MockProvider::text_response("r5", 50, 25),
322 ]));
323 let agent = make_agent(provider, "test");
324 let executor = BatchExecutor::builder(agent)
325 .max_concurrency(5)
326 .build()
327 .unwrap();
328
329 let tasks: Vec<String> = (1..=5).map(|i| format!("task{i}")).collect();
330 let results = executor.execute(tasks).await;
331
332 assert_eq!(results.len(), 5);
333 for r in &results {
334 assert!(r.result.is_ok(), "task {} failed: {:?}", r.index, r.result);
335 }
336 }
337
338 #[tokio::test]
339 async fn results_ordered_by_index() {
340 let provider = Arc::new(MockProvider::new(vec![
341 MockProvider::text_response("a", 10, 5),
342 MockProvider::text_response("b", 10, 5),
343 MockProvider::text_response("c", 10, 5),
344 ]));
345 let agent = make_agent(provider, "test");
346 let executor = BatchExecutor::builder(agent)
347 .max_concurrency(3)
348 .build()
349 .unwrap();
350
351 let tasks = vec!["t0".to_string(), "t1".to_string(), "t2".to_string()];
352 let results = executor.execute(tasks).await;
353
354 assert_eq!(results.len(), 3);
355 for (i, r) in results.iter().enumerate() {
356 assert_eq!(r.index, i);
357 }
358 }
359
360 #[tokio::test]
361 async fn partial_failure_returns_all_results() {
362 let provider = Arc::new(MockProvider::new(vec![
364 MockProvider::text_response("ok1", 10, 5),
365 MockProvider::text_response("ok2", 20, 10),
366 ]));
367 let agent = make_agent(provider, "test");
368 let executor = BatchExecutor::builder(agent)
370 .max_concurrency(1)
371 .build()
372 .unwrap();
373
374 let tasks = vec![
375 "task0".to_string(),
376 "task1".to_string(),
377 "task2".to_string(),
378 ];
379 let results = executor.execute(tasks).await;
380
381 assert_eq!(results.len(), 3);
382 assert!(results[0].result.is_ok());
384 assert!(results[1].result.is_ok());
385 assert!(results[2].result.is_err());
386 }
387
388 #[tokio::test]
389 async fn concurrency_limit_respected() {
390 let current = Arc::new(AtomicUsize::new(0));
391 let peak = Arc::new(AtomicUsize::new(0));
392
393 let provider = Arc::new(ConcurrencyTrackingProvider::new(
394 Arc::clone(¤t),
395 Arc::clone(&peak),
396 "done",
397 ));
398 let agent = AgentRunner::builder(provider)
399 .name("conc-test")
400 .system_prompt("test")
401 .max_turns(1)
402 .build()
403 .expect("build agent");
404
405 let executor = BatchExecutor::builder(agent)
406 .max_concurrency(2)
407 .build()
408 .unwrap();
409
410 let tasks: Vec<String> = (0..10).map(|i| format!("task{i}")).collect();
411 let results = executor.execute(tasks).await;
412
413 assert_eq!(results.len(), 10);
414 let observed_peak = peak.load(Ordering::SeqCst);
416 assert!(
417 observed_peak <= 2,
418 "peak concurrency was {observed_peak}, expected <= 2"
419 );
420 }
421
422 #[tokio::test]
423 async fn aggregate_usage_sums_successes() {
424 let provider = Arc::new(MockProvider::new(vec![
425 MockProvider::text_response("a", 100, 50),
426 MockProvider::text_response("b", 200, 80),
427 ]));
428 let agent = make_agent(provider, "test");
429 let executor = BatchExecutor::builder(agent)
430 .max_concurrency(1)
431 .build()
432 .unwrap();
433
434 let results = executor
435 .execute(vec!["t1".to_string(), "t2".to_string()])
436 .await;
437
438 let usage = BatchExecutor::<MockProvider>::aggregate_usage(&results);
439 assert_eq!(usage.input_tokens, 300);
440 assert_eq!(usage.output_tokens, 130);
441 }
442
443 #[tokio::test]
444 async fn aggregate_usage_ignores_failures() {
445 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
446 "ok", 100, 50,
447 )]));
448 let agent = make_agent(provider, "test");
449 let executor = BatchExecutor::builder(agent)
450 .max_concurrency(1)
451 .build()
452 .unwrap();
453
454 let results = executor
456 .execute(vec!["t1".to_string(), "t2".to_string()])
457 .await;
458
459 let usage = BatchExecutor::<MockProvider>::aggregate_usage(&results);
460 assert_eq!(usage.input_tokens, 100);
462 assert_eq!(usage.output_tokens, 50);
463 }
464
465 #[tokio::test]
466 async fn execute_ref_convenience() {
467 let provider = Arc::new(MockProvider::new(vec![
468 MockProvider::text_response("a", 10, 5),
469 MockProvider::text_response("b", 10, 5),
470 ]));
471 let agent = make_agent(provider, "test");
472 let executor = BatchExecutor::builder(agent)
473 .max_concurrency(2)
474 .build()
475 .unwrap();
476
477 let results = executor.execute_ref(&["hello", "world"]).await;
478 assert_eq!(results.len(), 2);
479 assert_eq!(results[0].input, "hello");
480 assert_eq!(results[1].input, "world");
481 }
482
483 #[test]
484 fn aggregate_usage_empty_results() {
485 let usage = BatchExecutor::<MockProvider>::aggregate_usage(&[]);
486 assert_eq!(usage, TokenUsage::default());
487 }
488}