Skip to main content

heartbit_core/agent/
batch.rs

1//! Batch executor for running the same agent on multiple tasks with controlled concurrency.
2//!
3//! Unlike [`ParallelAgent`](super::workflow::ParallelAgent) which runs different agents on the
4//! same task, `BatchExecutor` runs the **same agent** on different tasks with a concurrency limit
5//! via [`tokio::sync::Semaphore`].
6
7use std::sync::Arc;
8
9use tokio::sync::Semaphore;
10use tokio::task::JoinSet;
11
12use crate::error::Error;
13use crate::llm::LlmProvider;
14use crate::llm::types::TokenUsage;
15
16use super::AgentOutput;
17use super::AgentRunner;
18
19/// Result of a single batch item execution.
20#[derive(Debug)]
21pub struct BatchResult {
22    /// Index of this item in the original input batch.
23    pub index: usize,
24    /// The input task that was executed.
25    pub input: String,
26    /// The execution result (Ok with output, or Err).
27    pub result: Result<AgentOutput, Error>,
28}
29
30/// Configuration for batch execution.
31#[derive(Debug, Clone)]
32pub struct BatchConfig {
33    /// Maximum number of concurrent agent executions.
34    pub max_concurrency: usize,
35}
36
37impl Default for BatchConfig {
38    fn default() -> Self {
39        Self {
40            max_concurrency: std::thread::available_parallelism()
41                .map(|n| n.get())
42                .unwrap_or(4),
43        }
44    }
45}
46
47/// Executes multiple tasks through an agent with controlled concurrency.
48///
49/// Unlike `ParallelAgent` which runs different agents on the same task,
50/// `BatchExecutor` runs the SAME agent on different tasks with a concurrency limit.
51pub struct BatchExecutor<P: LlmProvider + 'static> {
52    agent: Arc<AgentRunner<P>>,
53    config: BatchConfig,
54}
55
56impl<P: LlmProvider + 'static> std::fmt::Debug for BatchExecutor<P> {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.debug_struct("BatchExecutor")
59            .field("max_concurrency", &self.config.max_concurrency)
60            .finish()
61    }
62}
63
64/// Builder for [`BatchExecutor`].
65pub struct BatchExecutorBuilder<P: LlmProvider + 'static> {
66    agent: AgentRunner<P>,
67    max_concurrency: Option<usize>,
68}
69
70impl<P: LlmProvider + 'static> BatchExecutor<P> {
71    /// Create a new builder for `BatchExecutor`.
72    pub fn builder(agent: AgentRunner<P>) -> BatchExecutorBuilder<P> {
73        BatchExecutorBuilder {
74            agent,
75            max_concurrency: None,
76        }
77    }
78
79    /// Execute all tasks with controlled concurrency.
80    /// Returns results for ALL tasks (successes and failures).
81    /// Results are sorted by input index.
82    pub async fn execute(&self, tasks: Vec<String>) -> Vec<BatchResult> {
83        if tasks.is_empty() {
84            return Vec::new();
85        }
86
87        let semaphore = Arc::new(Semaphore::new(self.config.max_concurrency));
88        let mut set = JoinSet::new();
89
90        for (index, input) in tasks.into_iter().enumerate() {
91            let agent = Arc::clone(&self.agent);
92            let sem = Arc::clone(&semaphore);
93            set.spawn(async move {
94                let _permit = sem.acquire().await.expect("semaphore closed unexpectedly");
95                let result = agent.execute(&input).await;
96                BatchResult {
97                    index,
98                    input,
99                    result,
100                }
101            });
102        }
103
104        let mut results = Vec::with_capacity(set.len());
105        while let Some(join_result) = set.join_next().await {
106            match join_result {
107                Ok(batch_result) => results.push(batch_result),
108                Err(e) => {
109                    // JoinSet task panicked — should not happen in normal operation.
110                    // We can't recover the index/input, so we skip it.
111                    // In practice, agent.execute() should not panic.
112                    tracing::error!("batch task panicked: {e}");
113                }
114            }
115        }
116
117        results.sort_by_key(|r| r.index);
118        results
119    }
120
121    /// Convenience: execute with string slice references.
122    pub async fn execute_ref(&self, tasks: &[&str]) -> Vec<BatchResult> {
123        let owned: Vec<String> = tasks.iter().map(|s| (*s).to_string()).collect();
124        self.execute(owned).await
125    }
126
127    /// Returns aggregate token usage across all successful executions.
128    pub fn aggregate_usage(results: &[BatchResult]) -> TokenUsage {
129        let mut total = TokenUsage::default();
130        for r in results {
131            if let Ok(output) = &r.result {
132                total += output.tokens_used;
133            }
134        }
135        total
136    }
137}
138
139impl<P: LlmProvider + 'static> BatchExecutorBuilder<P> {
140    /// Set the maximum number of concurrent agent executions.
141    pub fn max_concurrency(mut self, n: usize) -> Self {
142        self.max_concurrency = Some(n);
143        self
144    }
145
146    /// Build the [`BatchExecutor`].
147    pub fn build(self) -> Result<BatchExecutor<P>, Error> {
148        let config = match self.max_concurrency {
149            Some(n) => {
150                if n == 0 {
151                    return Err(Error::Config(
152                        "BatchExecutor max_concurrency must be at least 1".into(),
153                    ));
154                }
155                BatchConfig { max_concurrency: n }
156            }
157            None => BatchConfig::default(),
158        };
159        Ok(BatchExecutor {
160            agent: Arc::new(self.agent),
161            config,
162        })
163    }
164}
165
166// ===========================================================================
167// Tests
168// ===========================================================================
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::agent::test_helpers::{MockProvider, make_agent};
174    use crate::llm::types::{CompletionRequest, CompletionResponse, ContentBlock, StopReason};
175    use std::sync::atomic::{AtomicUsize, Ordering};
176
177    /// A mock provider that tracks concurrency via an atomic counter.
178    struct ConcurrencyTrackingProvider {
179        /// Current number of concurrent executions.
180        current: Arc<AtomicUsize>,
181        /// Peak concurrency observed.
182        peak: Arc<AtomicUsize>,
183        response_text: String,
184    }
185
186    impl ConcurrencyTrackingProvider {
187        fn new(current: Arc<AtomicUsize>, peak: Arc<AtomicUsize>, response_text: &str) -> Self {
188            Self {
189                current,
190                peak,
191                response_text: response_text.to_string(),
192            }
193        }
194    }
195
196    impl LlmProvider for ConcurrencyTrackingProvider {
197        async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, Error> {
198            let prev = self.current.fetch_add(1, Ordering::SeqCst);
199            let concurrent = prev + 1;
200            // Update peak
201            self.peak.fetch_max(concurrent, Ordering::SeqCst);
202            // Simulate some work
203            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
204            self.current.fetch_sub(1, Ordering::SeqCst);
205
206            Ok(CompletionResponse {
207                content: vec![ContentBlock::Text {
208                    text: self.response_text.clone(),
209                }],
210                stop_reason: StopReason::EndTurn,
211                usage: TokenUsage {
212                    input_tokens: 10,
213                    output_tokens: 5,
214                    ..Default::default()
215                },
216                model: None,
217            })
218        }
219
220        fn model_name(&self) -> Option<&str> {
221            Some("concurrency-mock")
222        }
223    }
224
225    // -----------------------------------------------------------------------
226    // Builder tests
227    // -----------------------------------------------------------------------
228
229    #[test]
230    fn builder_uses_default_concurrency() {
231        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
232            "ok", 10, 5,
233        )]));
234        let agent = make_agent(provider, "test");
235        let executor = BatchExecutor::builder(agent).build().unwrap();
236        assert!(executor.config.max_concurrency >= 1);
237    }
238
239    #[test]
240    fn builder_accepts_custom_concurrency() {
241        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
242            "ok", 10, 5,
243        )]));
244        let agent = make_agent(provider, "test");
245        let executor = BatchExecutor::builder(agent)
246            .max_concurrency(8)
247            .build()
248            .unwrap();
249        assert_eq!(executor.config.max_concurrency, 8);
250    }
251
252    #[test]
253    fn builder_rejects_zero_concurrency() {
254        let provider = Arc::new(MockProvider::new(vec![]));
255        let agent = make_agent(provider, "test");
256        let result = BatchExecutor::builder(agent).max_concurrency(0).build();
257        assert!(result.is_err());
258        assert!(result.unwrap_err().to_string().contains("at least 1"));
259    }
260
261    #[test]
262    fn debug_impl() {
263        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
264            "ok", 10, 5,
265        )]));
266        let agent = make_agent(provider, "test");
267        let executor = BatchExecutor::builder(agent)
268            .max_concurrency(3)
269            .build()
270            .unwrap();
271        let debug = format!("{executor:?}");
272        assert!(debug.contains("BatchExecutor"));
273        assert!(debug.contains("3"));
274    }
275
276    // -----------------------------------------------------------------------
277    // Execution tests
278    // -----------------------------------------------------------------------
279
280    #[tokio::test]
281    async fn empty_batch_returns_empty_vec() {
282        let provider = Arc::new(MockProvider::new(vec![]));
283        let agent = make_agent(provider, "test");
284        let executor = BatchExecutor::builder(agent)
285            .max_concurrency(2)
286            .build()
287            .unwrap();
288
289        let results = executor.execute(vec![]).await;
290        assert!(results.is_empty());
291    }
292
293    #[tokio::test]
294    async fn single_task_succeeds() {
295        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
296            "hello", 100, 50,
297        )]));
298        let agent = make_agent(provider, "test");
299        let executor = BatchExecutor::builder(agent)
300            .max_concurrency(2)
301            .build()
302            .unwrap();
303
304        let results = executor.execute(vec!["task1".to_string()]).await;
305        assert_eq!(results.len(), 1);
306        assert_eq!(results[0].index, 0);
307        assert_eq!(results[0].input, "task1");
308        let output = results[0].result.as_ref().unwrap();
309        assert_eq!(output.result, "hello");
310        assert_eq!(output.tokens_used.input_tokens, 100);
311        assert_eq!(output.tokens_used.output_tokens, 50);
312    }
313
314    #[tokio::test]
315    async fn multiple_tasks_all_succeed() {
316        let provider = Arc::new(MockProvider::new(vec![
317            MockProvider::text_response("r1", 10, 5),
318            MockProvider::text_response("r2", 20, 10),
319            MockProvider::text_response("r3", 30, 15),
320            MockProvider::text_response("r4", 40, 20),
321            MockProvider::text_response("r5", 50, 25),
322        ]));
323        let agent = make_agent(provider, "test");
324        let executor = BatchExecutor::builder(agent)
325            .max_concurrency(5)
326            .build()
327            .unwrap();
328
329        let tasks: Vec<String> = (1..=5).map(|i| format!("task{i}")).collect();
330        let results = executor.execute(tasks).await;
331
332        assert_eq!(results.len(), 5);
333        for r in &results {
334            assert!(r.result.is_ok(), "task {} failed: {:?}", r.index, r.result);
335        }
336    }
337
338    #[tokio::test]
339    async fn results_ordered_by_index() {
340        let provider = Arc::new(MockProvider::new(vec![
341            MockProvider::text_response("a", 10, 5),
342            MockProvider::text_response("b", 10, 5),
343            MockProvider::text_response("c", 10, 5),
344        ]));
345        let agent = make_agent(provider, "test");
346        let executor = BatchExecutor::builder(agent)
347            .max_concurrency(3)
348            .build()
349            .unwrap();
350
351        let tasks = vec!["t0".to_string(), "t1".to_string(), "t2".to_string()];
352        let results = executor.execute(tasks).await;
353
354        assert_eq!(results.len(), 3);
355        for (i, r) in results.iter().enumerate() {
356            assert_eq!(r.index, i);
357        }
358    }
359
360    #[tokio::test]
361    async fn partial_failure_returns_all_results() {
362        // Provide only 2 responses for 3 tasks — third will fail
363        let provider = Arc::new(MockProvider::new(vec![
364            MockProvider::text_response("ok1", 10, 5),
365            MockProvider::text_response("ok2", 20, 10),
366        ]));
367        let agent = make_agent(provider, "test");
368        // max_concurrency=1 to get deterministic ordering of mock responses
369        let executor = BatchExecutor::builder(agent)
370            .max_concurrency(1)
371            .build()
372            .unwrap();
373
374        let tasks = vec![
375            "task0".to_string(),
376            "task1".to_string(),
377            "task2".to_string(),
378        ];
379        let results = executor.execute(tasks).await;
380
381        assert_eq!(results.len(), 3);
382        // First two succeed, third fails
383        assert!(results[0].result.is_ok());
384        assert!(results[1].result.is_ok());
385        assert!(results[2].result.is_err());
386    }
387
388    #[tokio::test]
389    async fn concurrency_limit_respected() {
390        let current = Arc::new(AtomicUsize::new(0));
391        let peak = Arc::new(AtomicUsize::new(0));
392
393        let provider = Arc::new(ConcurrencyTrackingProvider::new(
394            Arc::clone(&current),
395            Arc::clone(&peak),
396            "done",
397        ));
398        let agent = AgentRunner::builder(provider)
399            .name("conc-test")
400            .system_prompt("test")
401            .max_turns(1)
402            .build()
403            .expect("build agent");
404
405        let executor = BatchExecutor::builder(agent)
406            .max_concurrency(2)
407            .build()
408            .unwrap();
409
410        let tasks: Vec<String> = (0..10).map(|i| format!("task{i}")).collect();
411        let results = executor.execute(tasks).await;
412
413        assert_eq!(results.len(), 10);
414        // Peak concurrency should not exceed 2
415        let observed_peak = peak.load(Ordering::SeqCst);
416        assert!(
417            observed_peak <= 2,
418            "peak concurrency was {observed_peak}, expected <= 2"
419        );
420    }
421
422    #[tokio::test]
423    async fn aggregate_usage_sums_successes() {
424        let provider = Arc::new(MockProvider::new(vec![
425            MockProvider::text_response("a", 100, 50),
426            MockProvider::text_response("b", 200, 80),
427        ]));
428        let agent = make_agent(provider, "test");
429        let executor = BatchExecutor::builder(agent)
430            .max_concurrency(1)
431            .build()
432            .unwrap();
433
434        let results = executor
435            .execute(vec!["t1".to_string(), "t2".to_string()])
436            .await;
437
438        let usage = BatchExecutor::<MockProvider>::aggregate_usage(&results);
439        assert_eq!(usage.input_tokens, 300);
440        assert_eq!(usage.output_tokens, 130);
441    }
442
443    #[tokio::test]
444    async fn aggregate_usage_ignores_failures() {
445        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
446            "ok", 100, 50,
447        )]));
448        let agent = make_agent(provider, "test");
449        let executor = BatchExecutor::builder(agent)
450            .max_concurrency(1)
451            .build()
452            .unwrap();
453
454        // 2 tasks but only 1 response — second fails
455        let results = executor
456            .execute(vec!["t1".to_string(), "t2".to_string()])
457            .await;
458
459        let usage = BatchExecutor::<MockProvider>::aggregate_usage(&results);
460        // Only the first task's usage
461        assert_eq!(usage.input_tokens, 100);
462        assert_eq!(usage.output_tokens, 50);
463    }
464
465    #[tokio::test]
466    async fn execute_ref_convenience() {
467        let provider = Arc::new(MockProvider::new(vec![
468            MockProvider::text_response("a", 10, 5),
469            MockProvider::text_response("b", 10, 5),
470        ]));
471        let agent = make_agent(provider, "test");
472        let executor = BatchExecutor::builder(agent)
473            .max_concurrency(2)
474            .build()
475            .unwrap();
476
477        let results = executor.execute_ref(&["hello", "world"]).await;
478        assert_eq!(results.len(), 2);
479        assert_eq!(results[0].input, "hello");
480        assert_eq!(results[1].input, "world");
481    }
482
483    #[test]
484    fn aggregate_usage_empty_results() {
485        let usage = BatchExecutor::<MockProvider>::aggregate_usage(&[]);
486        assert_eq!(usage, TokenUsage::default());
487    }
488}