Skip to main content

adk_agent/workflow/
parallel_agent.rs

1#[cfg(feature = "skills")]
2use crate::skill_shim::load_skill_index;
3use crate::skill_shim::{SelectionPolicy, SkillIndex};
4use adk_core::{
5    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
6    InvocationContext, Result, SharedState,
7};
8use async_stream::stream;
9use async_trait::async_trait;
10use std::sync::Arc;
11
12use super::shared_state_context::SharedStateContext;
13
14/// Parallel agent executes sub-agents concurrently
15pub struct ParallelAgent {
16    name: String,
17    description: String,
18    sub_agents: Vec<Arc<dyn Agent>>,
19    skills_index: Option<Arc<SkillIndex>>,
20    skill_policy: SelectionPolicy,
21    max_skill_chars: usize,
22    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
23    after_callbacks: Arc<Vec<AfterAgentCallback>>,
24    shared_state_enabled: bool,
25}
26
27impl ParallelAgent {
28    pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
29        Self {
30            name: name.into(),
31            description: String::new(),
32            sub_agents,
33            skills_index: None,
34            skill_policy: SelectionPolicy::default(),
35            max_skill_chars: 2000,
36            before_callbacks: Arc::new(Vec::new()),
37            after_callbacks: Arc::new(Vec::new()),
38            shared_state_enabled: false,
39        }
40    }
41
42    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
43        self.description = desc.into();
44        self
45    }
46
47    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
48        if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
49            callbacks.push(callback);
50        }
51        self
52    }
53
54    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
55        if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
56            callbacks.push(callback);
57        }
58        self
59    }
60
61    #[cfg(feature = "skills")]
62    pub fn with_skills(mut self, index: SkillIndex) -> Self {
63        self.skills_index = Some(Arc::new(index));
64        self
65    }
66
67    #[cfg(feature = "skills")]
68    pub fn with_auto_skills(self) -> Result<Self> {
69        self.with_skills_from_root(".")
70    }
71
72    #[cfg(feature = "skills")]
73    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
74        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
75        self.skills_index = Some(Arc::new(index));
76        Ok(self)
77    }
78
79    #[cfg(feature = "skills")]
80    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
81        self.skill_policy = policy;
82        self
83    }
84
85    #[cfg(feature = "skills")]
86    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
87        self.max_skill_chars = max_chars;
88        self
89    }
90
91    /// Enables shared state coordination for sub-agents.
92    ///
93    /// When enabled, a fresh `SharedState` instance is created for each
94    /// `run()` invocation and injected into each sub-agent's context.
95    /// Sub-agents can then use `ctx.shared_state()` to access the store.
96    pub fn with_shared_state(mut self) -> Self {
97        self.shared_state_enabled = true;
98        self
99    }
100}
101
102#[async_trait]
103impl Agent for ParallelAgent {
104    fn name(&self) -> &str {
105        &self.name
106    }
107
108    fn description(&self) -> &str {
109        &self.description
110    }
111
112    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
113        &self.sub_agents
114    }
115
116    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
117        let sub_agents = self.sub_agents.clone();
118        let run_ctx = super::skill_context::with_skill_injected_context(
119            ctx,
120            self.skills_index.as_ref(),
121            &self.skill_policy,
122            self.max_skill_chars,
123        );
124        let before_callbacks = self.before_callbacks.clone();
125        let after_callbacks = self.after_callbacks.clone();
126        let agent_name = self.name.clone();
127        let invocation_id = run_ctx.invocation_id().to_string();
128        let shared_state_enabled = self.shared_state_enabled;
129
130        let s = stream! {
131            use futures::stream::{FuturesUnordered, StreamExt};
132
133            for callback in before_callbacks.as_ref() {
134                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
135                    Ok(Some(content)) => {
136                        let mut early_event = Event::new(&invocation_id);
137                        early_event.author = agent_name.clone();
138                        early_event.llm_response.content = Some(content);
139                        yield Ok(early_event);
140
141                        for after_callback in after_callbacks.as_ref() {
142                            match after_callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
143                                Ok(Some(after_content)) => {
144                                    let mut after_event = Event::new(&invocation_id);
145                                    after_event.author = agent_name.clone();
146                                    after_event.llm_response.content = Some(after_content);
147                                    yield Ok(after_event);
148                                    return;
149                                }
150                                Ok(None) => continue,
151                                Err(e) => {
152                                    yield Err(e);
153                                    return;
154                                }
155                            }
156                        }
157                        return;
158                    }
159                    Ok(None) => continue,
160                    Err(e) => {
161                        yield Err(e);
162                        return;
163                    }
164                }
165            }
166
167            let mut futures = FuturesUnordered::new();
168
169            // Create shared state if enabled (fresh per run)
170            let shared = if shared_state_enabled {
171                Some(Arc::new(SharedState::new()))
172            } else {
173                None
174            };
175
176            for agent in sub_agents {
177                let ctx: Arc<dyn InvocationContext> = if let Some(ref shared) = shared {
178                    Arc::new(SharedStateContext::new(run_ctx.clone(), shared.clone()))
179                } else {
180                    run_ctx.clone()
181                };
182                futures.push(async move {
183                    agent.run(ctx).await
184                });
185            }
186
187            let mut first_error: Option<adk_core::AdkError> = None;
188
189            while let Some(result) = futures.next().await {
190                match result {
191                    Ok(mut stream) => {
192                        while let Some(event_result) = stream.next().await {
193                            match event_result {
194                                Ok(event) => yield Ok(event),
195                                Err(e) => {
196                                    if first_error.is_none() {
197                                        first_error = Some(e);
198                                    }
199                                    // Continue draining other agents instead of returning
200                                    break;
201                                }
202                            }
203                        }
204                    }
205                    Err(e) => {
206                        if first_error.is_none() {
207                            first_error = Some(e);
208                        }
209                        // Continue draining remaining futures to avoid resource leaks
210                    }
211                }
212            }
213
214            // After all agents complete, propagate the first error if any
215            if let Some(e) = first_error {
216                yield Err(e);
217                return;
218            }
219
220            for callback in after_callbacks.as_ref() {
221                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
222                    Ok(Some(content)) => {
223                        let mut after_event = Event::new(&invocation_id);
224                        after_event.author = agent_name.clone();
225                        after_event.llm_response.content = Some(content);
226                        yield Ok(after_event);
227                        break;
228                    }
229                    Ok(None) => continue,
230                    Err(e) => {
231                        yield Err(e);
232                        return;
233                    }
234                }
235            }
236        };
237
238        Ok(Box::pin(s))
239    }
240}