Skip to main content

rs_adk/agents/
parallel.rs

1//! ParallelAgent — runs sub-agents concurrently.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use crate::agent::Agent;
8use crate::context::InvocationContext;
9use crate::error::AgentError;
10
11/// Runs sub-agents concurrently.
12///
13/// All sub-agents run in parallel via `tokio::spawn`. Each gets a fresh
14/// `InvocationContext` wrapping the same underlying `AgentSession` (shared state
15/// and session). All must complete before `ParallelAgent` returns.
16///
17/// If any sub-agent fails, its error is returned (first error wins when
18/// iterating over join handles in order).
19pub struct ParallelAgent {
20    name: String,
21    sub_agents: Vec<Arc<dyn Agent>>,
22}
23
24impl ParallelAgent {
25    /// Create a new parallel agent with the given name and sub-agents.
26    pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
27        Self {
28            name: name.into(),
29            sub_agents,
30        }
31    }
32}
33
34#[async_trait]
35impl Agent for ParallelAgent {
36    fn name(&self) -> &str {
37        &self.name
38    }
39
40    async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
41        let mut handles = Vec::new();
42
43        for sub in &self.sub_agents {
44            let sub = sub.clone();
45            let agent_session = ctx.agent_session.clone();
46            let event_tx = ctx.event_tx.clone();
47            let middleware = ctx.middleware.clone();
48
49            handles.push(tokio::spawn(async move {
50                let mut branch_ctx = InvocationContext {
51                    agent_session,
52                    event_tx,
53                    middleware,
54                    run_config: crate::run_config::RunConfig::default(),
55                    session_id: None,
56                    artifact_service: None,
57                    memory_service: None,
58                    session_service: None,
59                };
60                sub.run_live(&mut branch_ctx).await
61            }));
62        }
63
64        for handle in handles {
65            handle
66                .await
67                .map_err(|e| AgentError::Other(format!("Join error: {}", e)))??;
68        }
69
70        Ok(())
71    }
72
73    fn sub_agents(&self) -> Vec<Arc<dyn Agent>> {
74        self.sub_agents.clone()
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81    use crate::agent_session::{AgentSession, NoOpSessionWriter};
82    use crate::context::InvocationContext;
83    use crate::error::AgentError;
84    use std::sync::Arc;
85    use tokio::sync::broadcast;
86
87    /// Helper: create a test InvocationContext with a no-op session.
88    fn test_ctx() -> InvocationContext {
89        let (event_tx, _) = broadcast::channel(16);
90        let writer: Arc<dyn rs_genai::session::SessionWriter> = Arc::new(NoOpSessionWriter);
91        let agent_session = AgentSession::from_writer(writer, event_tx);
92        InvocationContext::new(agent_session)
93    }
94
95    /// A test agent that sets a key in shared state.
96    struct StateSetAgent {
97        agent_name: String,
98        key: String,
99        value: String,
100    }
101
102    #[async_trait]
103    impl Agent for StateSetAgent {
104        fn name(&self) -> &str {
105            &self.agent_name
106        }
107
108        async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
109            ctx.state().set(&self.key, &self.value);
110            Ok(())
111        }
112    }
113
114    /// A test agent that always fails.
115    struct FailAgent {
116        agent_name: String,
117    }
118
119    #[async_trait]
120    impl Agent for FailAgent {
121        fn name(&self) -> &str {
122            &self.agent_name
123        }
124
125        async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
126            Err(AgentError::Other("parallel fail".to_string()))
127        }
128    }
129
130    #[tokio::test]
131    async fn parallel_runs_all() {
132        let agents: Vec<Arc<dyn Agent>> = vec![
133            Arc::new(StateSetAgent {
134                agent_name: "a".into(),
135                key: "key_a".into(),
136                value: "val_a".into(),
137            }),
138            Arc::new(StateSetAgent {
139                agent_name: "b".into(),
140                key: "key_b".into(),
141                value: "val_b".into(),
142            }),
143            Arc::new(StateSetAgent {
144                agent_name: "c".into(),
145                key: "key_c".into(),
146                value: "val_c".into(),
147            }),
148        ];
149
150        let par = ParallelAgent::new("par", agents);
151        let mut ctx = test_ctx();
152        par.run_live(&mut ctx).await.unwrap();
153
154        // All three keys should be set via the shared AgentSession state.
155        assert_eq!(
156            ctx.state().get::<String>("key_a"),
157            Some("val_a".to_string())
158        );
159        assert_eq!(
160            ctx.state().get::<String>("key_b"),
161            Some("val_b".to_string())
162        );
163        assert_eq!(
164            ctx.state().get::<String>("key_c"),
165            Some("val_c".to_string())
166        );
167    }
168
169    #[tokio::test]
170    async fn parallel_fails_if_any_fails() {
171        let agents: Vec<Arc<dyn Agent>> = vec![
172            Arc::new(StateSetAgent {
173                agent_name: "a".into(),
174                key: "key_a".into(),
175                value: "val_a".into(),
176            }),
177            Arc::new(FailAgent {
178                agent_name: "b".into(),
179            }),
180            Arc::new(StateSetAgent {
181                agent_name: "c".into(),
182                key: "key_c".into(),
183                value: "val_c".into(),
184            }),
185        ];
186
187        let par = ParallelAgent::new("par", agents);
188        let mut ctx = test_ctx();
189        let result = par.run_live(&mut ctx).await;
190
191        assert!(result.is_err());
192    }
193
194    #[tokio::test]
195    async fn parallel_empty_succeeds() {
196        let par = ParallelAgent::new("empty", vec![]);
197        let mut ctx = test_ctx();
198        par.run_live(&mut ctx).await.unwrap();
199    }
200
201    #[test]
202    fn parallel_sub_agents_returns_children() {
203        let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(StateSetAgent {
204            agent_name: "child".into(),
205            key: "k".into(),
206            value: "v".into(),
207        })];
208
209        let par = ParallelAgent::new("par", agents);
210        assert_eq!(par.sub_agents().len(), 1);
211        assert_eq!(par.sub_agents()[0].name(), "child");
212    }
213}