adk_agent/workflow/
parallel_agent.rs

1use adk_core::{
2    AfterAgentCallback, Agent, BeforeAgentCallback, EventStream, InvocationContext, Result,
3};
4use async_stream::stream;
5use async_trait::async_trait;
6use std::sync::Arc;
7
8/// Parallel agent executes sub-agents concurrently
9pub struct ParallelAgent {
10    name: String,
11    description: String,
12    sub_agents: Vec<Arc<dyn Agent>>,
13    before_callbacks: Vec<BeforeAgentCallback>,
14    after_callbacks: Vec<AfterAgentCallback>,
15}
16
17impl ParallelAgent {
18    pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
19        Self {
20            name: name.into(),
21            description: String::new(),
22            sub_agents,
23            before_callbacks: Vec::new(),
24            after_callbacks: Vec::new(),
25        }
26    }
27
28    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
29        self.description = desc.into();
30        self
31    }
32
33    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
34        self.before_callbacks.push(callback);
35        self
36    }
37
38    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
39        self.after_callbacks.push(callback);
40        self
41    }
42}
43
44#[async_trait]
45impl Agent for ParallelAgent {
46    fn name(&self) -> &str {
47        &self.name
48    }
49
50    fn description(&self) -> &str {
51        &self.description
52    }
53
54    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
55        &self.sub_agents
56    }
57
58    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
59        let sub_agents = self.sub_agents.clone();
60
61        let s = stream! {
62            use futures::stream::{FuturesUnordered, StreamExt};
63
64            let mut futures = FuturesUnordered::new();
65
66            for agent in sub_agents {
67                let ctx = ctx.clone();
68                futures.push(async move {
69                    agent.run(ctx).await
70                });
71            }
72
73            while let Some(result) = futures.next().await {
74                match result {
75                    Ok(mut stream) => {
76                        while let Some(event_result) = stream.next().await {
77                            yield event_result;
78                        }
79                    }
80                    Err(e) => {
81                        yield Err(e);
82                        return;
83                    }
84                }
85            }
86        };
87
88        Ok(Box::pin(s))
89    }
90}