rs_adk/agents/
parallel.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use crate::agent::Agent;
8use crate::context::InvocationContext;
9use crate::error::AgentError;
10
11pub struct ParallelAgent {
20 name: String,
21 sub_agents: Vec<Arc<dyn Agent>>,
22}
23
24impl ParallelAgent {
25 pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
27 Self {
28 name: name.into(),
29 sub_agents,
30 }
31 }
32}
33
34#[async_trait]
35impl Agent for ParallelAgent {
36 fn name(&self) -> &str {
37 &self.name
38 }
39
40 async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
41 let mut handles = Vec::new();
42
43 for sub in &self.sub_agents {
44 let sub = sub.clone();
45 let agent_session = ctx.agent_session.clone();
46 let event_tx = ctx.event_tx.clone();
47 let middleware = ctx.middleware.clone();
48
49 handles.push(tokio::spawn(async move {
50 let mut branch_ctx = InvocationContext {
51 agent_session,
52 event_tx,
53 middleware,
54 run_config: crate::run_config::RunConfig::default(),
55 session_id: None,
56 artifact_service: None,
57 memory_service: None,
58 session_service: None,
59 };
60 sub.run_live(&mut branch_ctx).await
61 }));
62 }
63
64 for handle in handles {
65 handle
66 .await
67 .map_err(|e| AgentError::Other(format!("Join error: {}", e)))??;
68 }
69
70 Ok(())
71 }
72
73 fn sub_agents(&self) -> Vec<Arc<dyn Agent>> {
74 self.sub_agents.clone()
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use crate::agent_session::{AgentSession, NoOpSessionWriter};
82 use crate::context::InvocationContext;
83 use crate::error::AgentError;
84 use std::sync::Arc;
85 use tokio::sync::broadcast;
86
87 fn test_ctx() -> InvocationContext {
89 let (event_tx, _) = broadcast::channel(16);
90 let writer: Arc<dyn rs_genai::session::SessionWriter> = Arc::new(NoOpSessionWriter);
91 let agent_session = AgentSession::from_writer(writer, event_tx);
92 InvocationContext::new(agent_session)
93 }
94
95 struct StateSetAgent {
97 agent_name: String,
98 key: String,
99 value: String,
100 }
101
102 #[async_trait]
103 impl Agent for StateSetAgent {
104 fn name(&self) -> &str {
105 &self.agent_name
106 }
107
108 async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
109 ctx.state().set(&self.key, &self.value);
110 Ok(())
111 }
112 }
113
114 struct FailAgent {
116 agent_name: String,
117 }
118
119 #[async_trait]
120 impl Agent for FailAgent {
121 fn name(&self) -> &str {
122 &self.agent_name
123 }
124
125 async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
126 Err(AgentError::Other("parallel fail".to_string()))
127 }
128 }
129
130 #[tokio::test]
131 async fn parallel_runs_all() {
132 let agents: Vec<Arc<dyn Agent>> = vec![
133 Arc::new(StateSetAgent {
134 agent_name: "a".into(),
135 key: "key_a".into(),
136 value: "val_a".into(),
137 }),
138 Arc::new(StateSetAgent {
139 agent_name: "b".into(),
140 key: "key_b".into(),
141 value: "val_b".into(),
142 }),
143 Arc::new(StateSetAgent {
144 agent_name: "c".into(),
145 key: "key_c".into(),
146 value: "val_c".into(),
147 }),
148 ];
149
150 let par = ParallelAgent::new("par", agents);
151 let mut ctx = test_ctx();
152 par.run_live(&mut ctx).await.unwrap();
153
154 assert_eq!(
156 ctx.state().get::<String>("key_a"),
157 Some("val_a".to_string())
158 );
159 assert_eq!(
160 ctx.state().get::<String>("key_b"),
161 Some("val_b".to_string())
162 );
163 assert_eq!(
164 ctx.state().get::<String>("key_c"),
165 Some("val_c".to_string())
166 );
167 }
168
169 #[tokio::test]
170 async fn parallel_fails_if_any_fails() {
171 let agents: Vec<Arc<dyn Agent>> = vec![
172 Arc::new(StateSetAgent {
173 agent_name: "a".into(),
174 key: "key_a".into(),
175 value: "val_a".into(),
176 }),
177 Arc::new(FailAgent {
178 agent_name: "b".into(),
179 }),
180 Arc::new(StateSetAgent {
181 agent_name: "c".into(),
182 key: "key_c".into(),
183 value: "val_c".into(),
184 }),
185 ];
186
187 let par = ParallelAgent::new("par", agents);
188 let mut ctx = test_ctx();
189 let result = par.run_live(&mut ctx).await;
190
191 assert!(result.is_err());
192 }
193
194 #[tokio::test]
195 async fn parallel_empty_succeeds() {
196 let par = ParallelAgent::new("empty", vec![]);
197 let mut ctx = test_ctx();
198 par.run_live(&mut ctx).await.unwrap();
199 }
200
201 #[test]
202 fn parallel_sub_agents_returns_children() {
203 let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(StateSetAgent {
204 agent_name: "child".into(),
205 key: "k".into(),
206 value: "v".into(),
207 })];
208
209 let par = ParallelAgent::new("par", agents);
210 assert_eq!(par.sub_agents().len(), 1);
211 assert_eq!(par.sub_agents()[0].name(), "child");
212 }
213}