Skip to main content

rs_adk/text/
parallel.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use super::TextAgent;
6use crate::error::AgentError;
7use crate::state::State;
8
9/// Runs text agents concurrently. All branches share state. Results are
10/// collected and joined with newlines.
11pub struct ParallelTextAgent {
12    name: String,
13    branches: Vec<Arc<dyn TextAgent>>,
14}
15
16impl ParallelTextAgent {
17    /// Create a new parallel agent that runs branches concurrently.
18    pub fn new(name: impl Into<String>, branches: Vec<Arc<dyn TextAgent>>) -> Self {
19        Self {
20            name: name.into(),
21            branches,
22        }
23    }
24}
25
26#[async_trait]
27impl TextAgent for ParallelTextAgent {
28    fn name(&self) -> &str {
29        &self.name
30    }
31
32    async fn run(&self, state: &State) -> Result<String, AgentError> {
33        let mut handles = Vec::with_capacity(self.branches.len());
34
35        for branch in &self.branches {
36            let branch = branch.clone();
37            let state = state.clone();
38            handles.push(tokio::spawn(async move { branch.run(&state).await }));
39        }
40
41        let mut results = Vec::with_capacity(handles.len());
42        for handle in handles {
43            let result = handle
44                .await
45                .map_err(|e| AgentError::Other(format!("Join error: {e}")))?;
46            results.push(result?);
47        }
48
49        let combined = results.join("\n");
50        state.set("output", &combined);
51        Ok(combined)
52    }
53}