adk_agent/workflow/
parallel_agent.rs1use adk_core::{
2 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
3 InvocationContext, Result, SharedState,
4};
5use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index};
6use async_stream::stream;
7use async_trait::async_trait;
8use std::sync::Arc;
9
10use super::shared_state_context::SharedStateContext;
11
12pub struct ParallelAgent {
14 name: String,
15 description: String,
16 sub_agents: Vec<Arc<dyn Agent>>,
17 skills_index: Option<Arc<SkillIndex>>,
18 skill_policy: SelectionPolicy,
19 max_skill_chars: usize,
20 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
21 after_callbacks: Arc<Vec<AfterAgentCallback>>,
22 shared_state_enabled: bool,
23}
24
25impl ParallelAgent {
26 pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
27 Self {
28 name: name.into(),
29 description: String::new(),
30 sub_agents,
31 skills_index: None,
32 skill_policy: SelectionPolicy::default(),
33 max_skill_chars: 2000,
34 before_callbacks: Arc::new(Vec::new()),
35 after_callbacks: Arc::new(Vec::new()),
36 shared_state_enabled: false,
37 }
38 }
39
40 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
41 self.description = desc.into();
42 self
43 }
44
45 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
46 if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
47 callbacks.push(callback);
48 }
49 self
50 }
51
52 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
53 if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
54 callbacks.push(callback);
55 }
56 self
57 }
58
59 pub fn with_skills(mut self, index: SkillIndex) -> Self {
60 self.skills_index = Some(Arc::new(index));
61 self
62 }
63
64 pub fn with_auto_skills(self) -> Result<Self> {
65 self.with_skills_from_root(".")
66 }
67
68 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
69 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
70 self.skills_index = Some(Arc::new(index));
71 Ok(self)
72 }
73
74 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
75 self.skill_policy = policy;
76 self
77 }
78
79 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
80 self.max_skill_chars = max_chars;
81 self
82 }
83
84 pub fn with_shared_state(mut self) -> Self {
90 self.shared_state_enabled = true;
91 self
92 }
93}
94
95#[async_trait]
96impl Agent for ParallelAgent {
97 fn name(&self) -> &str {
98 &self.name
99 }
100
101 fn description(&self) -> &str {
102 &self.description
103 }
104
105 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
106 &self.sub_agents
107 }
108
109 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
110 let sub_agents = self.sub_agents.clone();
111 let run_ctx = super::skill_context::with_skill_injected_context(
112 ctx,
113 self.skills_index.as_ref(),
114 &self.skill_policy,
115 self.max_skill_chars,
116 );
117 let before_callbacks = self.before_callbacks.clone();
118 let after_callbacks = self.after_callbacks.clone();
119 let agent_name = self.name.clone();
120 let invocation_id = run_ctx.invocation_id().to_string();
121 let shared_state_enabled = self.shared_state_enabled;
122
123 let s = stream! {
124 use futures::stream::{FuturesUnordered, StreamExt};
125
126 for callback in before_callbacks.as_ref() {
127 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
128 Ok(Some(content)) => {
129 let mut early_event = Event::new(&invocation_id);
130 early_event.author = agent_name.clone();
131 early_event.llm_response.content = Some(content);
132 yield Ok(early_event);
133
134 for after_callback in after_callbacks.as_ref() {
135 match after_callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
136 Ok(Some(after_content)) => {
137 let mut after_event = Event::new(&invocation_id);
138 after_event.author = agent_name.clone();
139 after_event.llm_response.content = Some(after_content);
140 yield Ok(after_event);
141 return;
142 }
143 Ok(None) => continue,
144 Err(e) => {
145 yield Err(e);
146 return;
147 }
148 }
149 }
150 return;
151 }
152 Ok(None) => continue,
153 Err(e) => {
154 yield Err(e);
155 return;
156 }
157 }
158 }
159
160 let mut futures = FuturesUnordered::new();
161
162 let shared = if shared_state_enabled {
164 Some(Arc::new(SharedState::new()))
165 } else {
166 None
167 };
168
169 for agent in sub_agents {
170 let ctx: Arc<dyn InvocationContext> = if let Some(ref shared) = shared {
171 Arc::new(SharedStateContext::new(run_ctx.clone(), shared.clone()))
172 } else {
173 run_ctx.clone()
174 };
175 futures.push(async move {
176 agent.run(ctx).await
177 });
178 }
179
180 let mut first_error: Option<adk_core::AdkError> = None;
181
182 while let Some(result) = futures.next().await {
183 match result {
184 Ok(mut stream) => {
185 while let Some(event_result) = stream.next().await {
186 match event_result {
187 Ok(event) => yield Ok(event),
188 Err(e) => {
189 if first_error.is_none() {
190 first_error = Some(e);
191 }
192 break;
194 }
195 }
196 }
197 }
198 Err(e) => {
199 if first_error.is_none() {
200 first_error = Some(e);
201 }
202 }
204 }
205 }
206
207 if let Some(e) = first_error {
209 yield Err(e);
210 return;
211 }
212
213 for callback in after_callbacks.as_ref() {
214 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
215 Ok(Some(content)) => {
216 let mut after_event = Event::new(&invocation_id);
217 after_event.author = agent_name.clone();
218 after_event.llm_response.content = Some(content);
219 yield Ok(after_event);
220 break;
221 }
222 Ok(None) => continue,
223 Err(e) => {
224 yield Err(e);
225 return;
226 }
227 }
228 }
229 };
230
231 Ok(Box::pin(s))
232 }
233}