adk_agent/workflow/
parallel_agent.rs1#[cfg(feature = "skills")]
2use crate::skill_shim::load_skill_index;
3use crate::skill_shim::{SelectionPolicy, SkillIndex};
4use adk_core::{
5 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
6 InvocationContext, Result, SharedState,
7};
8use async_stream::stream;
9use async_trait::async_trait;
10use std::sync::Arc;
11
12use super::shared_state_context::SharedStateContext;
13
14pub struct ParallelAgent {
16 name: String,
17 description: String,
18 sub_agents: Vec<Arc<dyn Agent>>,
19 skills_index: Option<Arc<SkillIndex>>,
20 skill_policy: SelectionPolicy,
21 max_skill_chars: usize,
22 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
23 after_callbacks: Arc<Vec<AfterAgentCallback>>,
24 shared_state_enabled: bool,
25}
26
27impl ParallelAgent {
28 pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
30 Self {
31 name: name.into(),
32 description: String::new(),
33 sub_agents,
34 skills_index: None,
35 skill_policy: SelectionPolicy::default(),
36 max_skill_chars: 2000,
37 before_callbacks: Arc::new(Vec::new()),
38 after_callbacks: Arc::new(Vec::new()),
39 shared_state_enabled: false,
40 }
41 }
42
43 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
45 self.description = desc.into();
46 self
47 }
48
49 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
51 if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
52 callbacks.push(callback);
53 }
54 self
55 }
56
57 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
59 if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
60 callbacks.push(callback);
61 }
62 self
63 }
64
65 #[cfg(feature = "skills")]
67 pub fn with_skills(mut self, index: SkillIndex) -> Self {
68 self.skills_index = Some(Arc::new(index));
69 self
70 }
71
72 #[cfg(feature = "skills")]
74 pub fn with_auto_skills(self) -> Result<Self> {
75 self.with_skills_from_root(".")
76 }
77
78 #[cfg(feature = "skills")]
80 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
81 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
82 self.skills_index = Some(Arc::new(index));
83 Ok(self)
84 }
85
86 #[cfg(feature = "skills")]
88 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
89 self.skill_policy = policy;
90 self
91 }
92
93 #[cfg(feature = "skills")]
95 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
96 self.max_skill_chars = max_chars;
97 self
98 }
99
100 pub fn with_shared_state(mut self) -> Self {
106 self.shared_state_enabled = true;
107 self
108 }
109}
110
111#[async_trait]
112impl Agent for ParallelAgent {
113 fn name(&self) -> &str {
114 &self.name
115 }
116
117 fn description(&self) -> &str {
118 &self.description
119 }
120
121 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
122 &self.sub_agents
123 }
124
125 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
126 let sub_agents = self.sub_agents.clone();
127 let run_ctx = super::skill_context::with_skill_injected_context(
128 ctx,
129 self.skills_index.as_ref(),
130 &self.skill_policy,
131 self.max_skill_chars,
132 );
133 let before_callbacks = self.before_callbacks.clone();
134 let after_callbacks = self.after_callbacks.clone();
135 let agent_name = self.name.clone();
136 let invocation_id = run_ctx.invocation_id().to_string();
137 let shared_state_enabled = self.shared_state_enabled;
138
139 let s = stream! {
140 use futures::stream::{FuturesUnordered, StreamExt};
141
142 for callback in before_callbacks.as_ref() {
143 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
144 Ok(Some(content)) => {
145 let mut early_event = Event::new(&invocation_id);
146 early_event.author = agent_name.clone();
147 early_event.llm_response.content = Some(content);
148 yield Ok(early_event);
149
150 for after_callback in after_callbacks.as_ref() {
151 match after_callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
152 Ok(Some(after_content)) => {
153 let mut after_event = Event::new(&invocation_id);
154 after_event.author = agent_name.clone();
155 after_event.llm_response.content = Some(after_content);
156 yield Ok(after_event);
157 return;
158 }
159 Ok(None) => continue,
160 Err(e) => {
161 yield Err(e);
162 return;
163 }
164 }
165 }
166 return;
167 }
168 Ok(None) => continue,
169 Err(e) => {
170 yield Err(e);
171 return;
172 }
173 }
174 }
175
176 let mut futures = FuturesUnordered::new();
177
178 let shared = if shared_state_enabled {
180 Some(Arc::new(SharedState::new()))
181 } else {
182 None
183 };
184
185 for agent in sub_agents {
186 let ctx: Arc<dyn InvocationContext> = if let Some(ref shared) = shared {
187 Arc::new(SharedStateContext::new(run_ctx.clone(), shared.clone()))
188 } else {
189 run_ctx.clone()
190 };
191 futures.push(async move {
192 agent.run(ctx).await
193 });
194 }
195
196 let mut first_error: Option<adk_core::AdkError> = None;
197
198 while let Some(result) = futures.next().await {
199 match result {
200 Ok(mut stream) => {
201 while let Some(event_result) = stream.next().await {
202 match event_result {
203 Ok(event) => yield Ok(event),
204 Err(e) => {
205 if first_error.is_none() {
206 first_error = Some(e);
207 }
208 break;
210 }
211 }
212 }
213 }
214 Err(e) => {
215 if first_error.is_none() {
216 first_error = Some(e);
217 }
218 }
220 }
221 }
222
223 if let Some(e) = first_error {
225 yield Err(e);
226 return;
227 }
228
229 for callback in after_callbacks.as_ref() {
230 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
231 Ok(Some(content)) => {
232 let mut after_event = Event::new(&invocation_id);
233 after_event.author = agent_name.clone();
234 after_event.llm_response.content = Some(content);
235 yield Ok(after_event);
236 break;
237 }
238 Ok(None) => continue,
239 Err(e) => {
240 yield Err(e);
241 return;
242 }
243 }
244 }
245 };
246
247 Ok(Box::pin(s))
248 }
249}