adk_agent/workflow/
parallel_agent.rs1use adk_core::{
2 AfterAgentCallback, Agent, BeforeAgentCallback, EventStream, InvocationContext, Result,
3};
4use async_stream::stream;
5use async_trait::async_trait;
6use std::sync::Arc;
7
8pub struct ParallelAgent {
10 name: String,
11 description: String,
12 sub_agents: Vec<Arc<dyn Agent>>,
13 before_callbacks: Vec<BeforeAgentCallback>,
14 after_callbacks: Vec<AfterAgentCallback>,
15}
16
17impl ParallelAgent {
18 pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
19 Self {
20 name: name.into(),
21 description: String::new(),
22 sub_agents,
23 before_callbacks: Vec::new(),
24 after_callbacks: Vec::new(),
25 }
26 }
27
28 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
29 self.description = desc.into();
30 self
31 }
32
33 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
34 self.before_callbacks.push(callback);
35 self
36 }
37
38 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
39 self.after_callbacks.push(callback);
40 self
41 }
42}
43
44#[async_trait]
45impl Agent for ParallelAgent {
46 fn name(&self) -> &str {
47 &self.name
48 }
49
50 fn description(&self) -> &str {
51 &self.description
52 }
53
54 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
55 &self.sub_agents
56 }
57
58 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
59 let sub_agents = self.sub_agents.clone();
60
61 let s = stream! {
62 use futures::stream::{FuturesUnordered, StreamExt};
63
64 let mut futures = FuturesUnordered::new();
65
66 for agent in sub_agents {
67 let ctx = ctx.clone();
68 futures.push(async move {
69 agent.run(ctx).await
70 });
71 }
72
73 while let Some(result) = futures.next().await {
74 match result {
75 Ok(mut stream) => {
76 while let Some(event_result) = stream.next().await {
77 yield event_result;
78 }
79 }
80 Err(e) => {
81 yield Err(e);
82 return;
83 }
84 }
85 }
86 };
87
88 Ok(Box::pin(s))
89 }
90}