1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use super::TextAgent;
6use crate::error::AgentError;
7use crate::state::State;
8
9pub struct ParallelTextAgent {
12 name: String,
13 branches: Vec<Arc<dyn TextAgent>>,
14}
15
16impl ParallelTextAgent {
17 pub fn new(name: impl Into<String>, branches: Vec<Arc<dyn TextAgent>>) -> Self {
19 Self {
20 name: name.into(),
21 branches,
22 }
23 }
24}
25
26#[async_trait]
27impl TextAgent for ParallelTextAgent {
28 fn name(&self) -> &str {
29 &self.name
30 }
31
32 async fn run(&self, state: &State) -> Result<String, AgentError> {
33 let mut handles = Vec::with_capacity(self.branches.len());
34
35 for branch in &self.branches {
36 let branch = branch.clone();
37 let state = state.clone();
38 handles.push(tokio::spawn(async move { branch.run(&state).await }));
39 }
40
41 let mut results = Vec::with_capacity(handles.len());
42 for handle in handles {
43 let result = handle
44 .await
45 .map_err(|e| AgentError::Other(format!("Join error: {e}")))?;
46 results.push(result?);
47 }
48
49 let combined = results.join("\n");
50 state.set("output", &combined);
51 Ok(combined)
52 }
53}