Skip to main content

adk_agent/workflow/
parallel_agent.rs

1#[cfg(feature = "skills")]
2use crate::skill_shim::load_skill_index;
3use crate::skill_shim::{SelectionPolicy, SkillIndex};
4use adk_core::{
5    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
6    InvocationContext, Result, SharedState,
7};
8use async_stream::stream;
9use async_trait::async_trait;
10use std::sync::Arc;
11
12use super::shared_state_context::SharedStateContext;
13
14/// Parallel agent executes sub-agents concurrently
15pub struct ParallelAgent {
16    name: String,
17    description: String,
18    sub_agents: Vec<Arc<dyn Agent>>,
19    skills_index: Option<Arc<SkillIndex>>,
20    skill_policy: SelectionPolicy,
21    max_skill_chars: usize,
22    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
23    after_callbacks: Arc<Vec<AfterAgentCallback>>,
24    shared_state_enabled: bool,
25}
26
27impl ParallelAgent {
28    /// Create a new parallel agent with the given name and sub-agents.
29    pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
30        Self {
31            name: name.into(),
32            description: String::new(),
33            sub_agents,
34            skills_index: None,
35            skill_policy: SelectionPolicy::default(),
36            max_skill_chars: 2000,
37            before_callbacks: Arc::new(Vec::new()),
38            after_callbacks: Arc::new(Vec::new()),
39            shared_state_enabled: false,
40        }
41    }
42
43    /// Set the agent description.
44    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
45        self.description = desc.into();
46        self
47    }
48
49    /// Add a before-agent callback.
50    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
51        if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
52            callbacks.push(callback);
53        }
54        self
55    }
56
57    /// Add an after-agent callback.
58    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
59        if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
60            callbacks.push(callback);
61        }
62        self
63    }
64
65    /// Set a preloaded skills index for this agent.
66    #[cfg(feature = "skills")]
67    pub fn with_skills(mut self, index: SkillIndex) -> Self {
68        self.skills_index = Some(Arc::new(index));
69        self
70    }
71
72    /// Auto-load skills from `.skills/` in the current working directory.
73    #[cfg(feature = "skills")]
74    pub fn with_auto_skills(self) -> Result<Self> {
75        self.with_skills_from_root(".")
76    }
77
78    /// Auto-load skills from `.skills/` under a custom root directory.
79    #[cfg(feature = "skills")]
80    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
81        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
82        self.skills_index = Some(Arc::new(index));
83        Ok(self)
84    }
85
86    /// Customize skill selection behavior.
87    #[cfg(feature = "skills")]
88    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
89        self.skill_policy = policy;
90        self
91    }
92
93    /// Limit injected skill content length.
94    #[cfg(feature = "skills")]
95    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
96        self.max_skill_chars = max_chars;
97        self
98    }
99
100    /// Enables shared state coordination for sub-agents.
101    ///
102    /// When enabled, a fresh `SharedState` instance is created for each
103    /// `run()` invocation and injected into each sub-agent's context.
104    /// Sub-agents can then use `ctx.shared_state()` to access the store.
105    pub fn with_shared_state(mut self) -> Self {
106        self.shared_state_enabled = true;
107        self
108    }
109}
110
111#[async_trait]
112impl Agent for ParallelAgent {
113    fn name(&self) -> &str {
114        &self.name
115    }
116
117    fn description(&self) -> &str {
118        &self.description
119    }
120
121    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
122        &self.sub_agents
123    }
124
125    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
126        let sub_agents = self.sub_agents.clone();
127        let run_ctx = super::skill_context::with_skill_injected_context(
128            ctx,
129            self.skills_index.as_ref(),
130            &self.skill_policy,
131            self.max_skill_chars,
132        );
133        let before_callbacks = self.before_callbacks.clone();
134        let after_callbacks = self.after_callbacks.clone();
135        let agent_name = self.name.clone();
136        let invocation_id = run_ctx.invocation_id().to_string();
137        let shared_state_enabled = self.shared_state_enabled;
138
139        let s = stream! {
140            use futures::stream::{FuturesUnordered, StreamExt};
141
142            for callback in before_callbacks.as_ref() {
143                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
144                    Ok(Some(content)) => {
145                        let mut early_event = Event::new(&invocation_id);
146                        early_event.author = agent_name.clone();
147                        early_event.llm_response.content = Some(content);
148                        yield Ok(early_event);
149
150                        for after_callback in after_callbacks.as_ref() {
151                            match after_callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
152                                Ok(Some(after_content)) => {
153                                    let mut after_event = Event::new(&invocation_id);
154                                    after_event.author = agent_name.clone();
155                                    after_event.llm_response.content = Some(after_content);
156                                    yield Ok(after_event);
157                                    return;
158                                }
159                                Ok(None) => continue,
160                                Err(e) => {
161                                    yield Err(e);
162                                    return;
163                                }
164                            }
165                        }
166                        return;
167                    }
168                    Ok(None) => continue,
169                    Err(e) => {
170                        yield Err(e);
171                        return;
172                    }
173                }
174            }
175
176            let mut futures = FuturesUnordered::new();
177
178            // Create shared state if enabled (fresh per run)
179            let shared = if shared_state_enabled {
180                Some(Arc::new(SharedState::new()))
181            } else {
182                None
183            };
184
185            for agent in sub_agents {
186                let ctx: Arc<dyn InvocationContext> = if let Some(ref shared) = shared {
187                    Arc::new(SharedStateContext::new(run_ctx.clone(), shared.clone()))
188                } else {
189                    run_ctx.clone()
190                };
191                futures.push(async move {
192                    agent.run(ctx).await
193                });
194            }
195
196            let mut first_error: Option<adk_core::AdkError> = None;
197
198            while let Some(result) = futures.next().await {
199                match result {
200                    Ok(mut stream) => {
201                        while let Some(event_result) = stream.next().await {
202                            match event_result {
203                                Ok(event) => yield Ok(event),
204                                Err(e) => {
205                                    if first_error.is_none() {
206                                        first_error = Some(e);
207                                    }
208                                    // Continue draining other agents instead of returning
209                                    break;
210                                }
211                            }
212                        }
213                    }
214                    Err(e) => {
215                        if first_error.is_none() {
216                            first_error = Some(e);
217                        }
218                        // Continue draining remaining futures to avoid resource leaks
219                    }
220                }
221            }
222
223            // After all agents complete, propagate the first error if any
224            if let Some(e) = first_error {
225                yield Err(e);
226                return;
227            }
228
229            for callback in after_callbacks.as_ref() {
230                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
231                    Ok(Some(content)) => {
232                        let mut after_event = Event::new(&invocation_id);
233                        after_event.author = agent_name.clone();
234                        after_event.llm_response.content = Some(content);
235                        yield Ok(after_event);
236                        break;
237                    }
238                    Ok(None) => continue,
239                    Err(e) => {
240                        yield Err(e);
241                        return;
242                    }
243                }
244            }
245        };
246
247        Ok(Box::pin(s))
248    }
249}