Skip to main content

adk_agent/workflow/
parallel_agent.rs

1use adk_core::{
2    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
3    InvocationContext, Result,
4};
5use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index};
6use async_stream::stream;
7use async_trait::async_trait;
8use std::sync::Arc;
9
10/// Parallel agent executes sub-agents concurrently
11pub struct ParallelAgent {
12    name: String,
13    description: String,
14    sub_agents: Vec<Arc<dyn Agent>>,
15    skills_index: Option<Arc<SkillIndex>>,
16    skill_policy: SelectionPolicy,
17    max_skill_chars: usize,
18    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
19    after_callbacks: Arc<Vec<AfterAgentCallback>>,
20}
21
22impl ParallelAgent {
23    pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
24        Self {
25            name: name.into(),
26            description: String::new(),
27            sub_agents,
28            skills_index: None,
29            skill_policy: SelectionPolicy::default(),
30            max_skill_chars: 2000,
31            before_callbacks: Arc::new(Vec::new()),
32            after_callbacks: Arc::new(Vec::new()),
33        }
34    }
35
36    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
37        self.description = desc.into();
38        self
39    }
40
41    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
42        Arc::get_mut(&mut self.before_callbacks)
43            .expect("before_callbacks not yet shared")
44            .push(callback);
45        self
46    }
47
48    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
49        Arc::get_mut(&mut self.after_callbacks)
50            .expect("after_callbacks not yet shared")
51            .push(callback);
52        self
53    }
54
55    pub fn with_skills(mut self, index: SkillIndex) -> Self {
56        self.skills_index = Some(Arc::new(index));
57        self
58    }
59
60    pub fn with_auto_skills(self) -> Result<Self> {
61        self.with_skills_from_root(".")
62    }
63
64    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
65        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::Agent(e.to_string()))?;
66        self.skills_index = Some(Arc::new(index));
67        Ok(self)
68    }
69
70    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
71        self.skill_policy = policy;
72        self
73    }
74
75    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
76        self.max_skill_chars = max_chars;
77        self
78    }
79}
80
81#[async_trait]
82impl Agent for ParallelAgent {
83    fn name(&self) -> &str {
84        &self.name
85    }
86
87    fn description(&self) -> &str {
88        &self.description
89    }
90
91    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
92        &self.sub_agents
93    }
94
95    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
96        let sub_agents = self.sub_agents.clone();
97        let run_ctx = super::skill_context::with_skill_injected_context(
98            ctx,
99            self.skills_index.as_ref(),
100            &self.skill_policy,
101            self.max_skill_chars,
102        );
103        let before_callbacks = self.before_callbacks.clone();
104        let after_callbacks = self.after_callbacks.clone();
105        let agent_name = self.name.clone();
106        let invocation_id = run_ctx.invocation_id().to_string();
107
108        let s = stream! {
109            use futures::stream::{FuturesUnordered, StreamExt};
110
111            for callback in before_callbacks.as_ref() {
112                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
113                    Ok(Some(content)) => {
114                        let mut early_event = Event::new(&invocation_id);
115                        early_event.author = agent_name.clone();
116                        early_event.llm_response.content = Some(content);
117                        yield Ok(early_event);
118
119                        for after_callback in after_callbacks.as_ref() {
120                            match after_callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
121                                Ok(Some(after_content)) => {
122                                    let mut after_event = Event::new(&invocation_id);
123                                    after_event.author = agent_name.clone();
124                                    after_event.llm_response.content = Some(after_content);
125                                    yield Ok(after_event);
126                                    return;
127                                }
128                                Ok(None) => continue,
129                                Err(e) => {
130                                    yield Err(e);
131                                    return;
132                                }
133                            }
134                        }
135                        return;
136                    }
137                    Ok(None) => continue,
138                    Err(e) => {
139                        yield Err(e);
140                        return;
141                    }
142                }
143            }
144
145            let mut futures = FuturesUnordered::new();
146
147            for agent in sub_agents {
148                let ctx = run_ctx.clone();
149                futures.push(async move {
150                    agent.run(ctx).await
151                });
152            }
153
154            let mut first_error: Option<adk_core::AdkError> = None;
155
156            while let Some(result) = futures.next().await {
157                match result {
158                    Ok(mut stream) => {
159                        while let Some(event_result) = stream.next().await {
160                            match event_result {
161                                Ok(event) => yield Ok(event),
162                                Err(e) => {
163                                    if first_error.is_none() {
164                                        first_error = Some(e);
165                                    }
166                                    // Continue draining other agents instead of returning
167                                    break;
168                                }
169                            }
170                        }
171                    }
172                    Err(e) => {
173                        if first_error.is_none() {
174                            first_error = Some(e);
175                        }
176                        // Continue draining remaining futures to avoid resource leaks
177                    }
178                }
179            }
180
181            // After all agents complete, propagate the first error if any
182            if let Some(e) = first_error {
183                yield Err(e);
184                return;
185            }
186
187            for callback in after_callbacks.as_ref() {
188                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
189                    Ok(Some(content)) => {
190                        let mut after_event = Event::new(&invocation_id);
191                        after_event.author = agent_name.clone();
192                        after_event.llm_response.content = Some(content);
193                        yield Ok(after_event);
194                        break;
195                    }
196                    Ok(None) => continue,
197                    Err(e) => {
198                        yield Err(e);
199                        return;
200                    }
201                }
202            }
203        };
204
205        Ok(Box::pin(s))
206    }
207}