Skip to main content

adk_agent/workflow/
parallel_agent.rs

1use adk_core::{
2    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
3    InvocationContext, Result, SharedState,
4};
5use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index};
6use async_stream::stream;
7use async_trait::async_trait;
8use std::sync::Arc;
9
10use super::shared_state_context::SharedStateContext;
11
12/// Parallel agent executes sub-agents concurrently
13pub struct ParallelAgent {
14    name: String,
15    description: String,
16    sub_agents: Vec<Arc<dyn Agent>>,
17    skills_index: Option<Arc<SkillIndex>>,
18    skill_policy: SelectionPolicy,
19    max_skill_chars: usize,
20    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
21    after_callbacks: Arc<Vec<AfterAgentCallback>>,
22    shared_state_enabled: bool,
23}
24
25impl ParallelAgent {
26    pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
27        Self {
28            name: name.into(),
29            description: String::new(),
30            sub_agents,
31            skills_index: None,
32            skill_policy: SelectionPolicy::default(),
33            max_skill_chars: 2000,
34            before_callbacks: Arc::new(Vec::new()),
35            after_callbacks: Arc::new(Vec::new()),
36            shared_state_enabled: false,
37        }
38    }
39
40    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
41        self.description = desc.into();
42        self
43    }
44
45    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
46        if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
47            callbacks.push(callback);
48        }
49        self
50    }
51
52    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
53        if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
54            callbacks.push(callback);
55        }
56        self
57    }
58
59    pub fn with_skills(mut self, index: SkillIndex) -> Self {
60        self.skills_index = Some(Arc::new(index));
61        self
62    }
63
64    pub fn with_auto_skills(self) -> Result<Self> {
65        self.with_skills_from_root(".")
66    }
67
68    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
69        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
70        self.skills_index = Some(Arc::new(index));
71        Ok(self)
72    }
73
74    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
75        self.skill_policy = policy;
76        self
77    }
78
79    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
80        self.max_skill_chars = max_chars;
81        self
82    }
83
84    /// Enables shared state coordination for sub-agents.
85    ///
86    /// When enabled, a fresh `SharedState` instance is created for each
87    /// `run()` invocation and injected into each sub-agent's context.
88    /// Sub-agents can then use `ctx.shared_state()` to access the store.
89    pub fn with_shared_state(mut self) -> Self {
90        self.shared_state_enabled = true;
91        self
92    }
93}
94
95#[async_trait]
96impl Agent for ParallelAgent {
97    fn name(&self) -> &str {
98        &self.name
99    }
100
101    fn description(&self) -> &str {
102        &self.description
103    }
104
105    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
106        &self.sub_agents
107    }
108
109    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
110        let sub_agents = self.sub_agents.clone();
111        let run_ctx = super::skill_context::with_skill_injected_context(
112            ctx,
113            self.skills_index.as_ref(),
114            &self.skill_policy,
115            self.max_skill_chars,
116        );
117        let before_callbacks = self.before_callbacks.clone();
118        let after_callbacks = self.after_callbacks.clone();
119        let agent_name = self.name.clone();
120        let invocation_id = run_ctx.invocation_id().to_string();
121        let shared_state_enabled = self.shared_state_enabled;
122
123        let s = stream! {
124            use futures::stream::{FuturesUnordered, StreamExt};
125
126            for callback in before_callbacks.as_ref() {
127                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
128                    Ok(Some(content)) => {
129                        let mut early_event = Event::new(&invocation_id);
130                        early_event.author = agent_name.clone();
131                        early_event.llm_response.content = Some(content);
132                        yield Ok(early_event);
133
134                        for after_callback in after_callbacks.as_ref() {
135                            match after_callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
136                                Ok(Some(after_content)) => {
137                                    let mut after_event = Event::new(&invocation_id);
138                                    after_event.author = agent_name.clone();
139                                    after_event.llm_response.content = Some(after_content);
140                                    yield Ok(after_event);
141                                    return;
142                                }
143                                Ok(None) => continue,
144                                Err(e) => {
145                                    yield Err(e);
146                                    return;
147                                }
148                            }
149                        }
150                        return;
151                    }
152                    Ok(None) => continue,
153                    Err(e) => {
154                        yield Err(e);
155                        return;
156                    }
157                }
158            }
159
160            let mut futures = FuturesUnordered::new();
161
162            // Create shared state if enabled (fresh per run)
163            let shared = if shared_state_enabled {
164                Some(Arc::new(SharedState::new()))
165            } else {
166                None
167            };
168
169            for agent in sub_agents {
170                let ctx: Arc<dyn InvocationContext> = if let Some(ref shared) = shared {
171                    Arc::new(SharedStateContext::new(run_ctx.clone(), shared.clone()))
172                } else {
173                    run_ctx.clone()
174                };
175                futures.push(async move {
176                    agent.run(ctx).await
177                });
178            }
179
180            let mut first_error: Option<adk_core::AdkError> = None;
181
182            while let Some(result) = futures.next().await {
183                match result {
184                    Ok(mut stream) => {
185                        while let Some(event_result) = stream.next().await {
186                            match event_result {
187                                Ok(event) => yield Ok(event),
188                                Err(e) => {
189                                    if first_error.is_none() {
190                                        first_error = Some(e);
191                                    }
192                                    // Continue draining other agents instead of returning
193                                    break;
194                                }
195                            }
196                        }
197                    }
198                    Err(e) => {
199                        if first_error.is_none() {
200                            first_error = Some(e);
201                        }
202                        // Continue draining remaining futures to avoid resource leaks
203                    }
204                }
205            }
206
207            // After all agents complete, propagate the first error if any
208            if let Some(e) = first_error {
209                yield Err(e);
210                return;
211            }
212
213            for callback in after_callbacks.as_ref() {
214                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
215                    Ok(Some(content)) => {
216                        let mut after_event = Event::new(&invocation_id);
217                        after_event.author = agent_name.clone();
218                        after_event.llm_response.content = Some(content);
219                        yield Ok(after_event);
220                        break;
221                    }
222                    Ok(None) => continue,
223                    Err(e) => {
224                        yield Err(e);
225                        return;
226                    }
227                }
228            }
229        };
230
231        Ok(Box::pin(s))
232    }
233}