adk_agent/workflow/
parallel_agent.rs1#[cfg(feature = "skills")]
2use crate::skill_shim::load_skill_index;
3use crate::skill_shim::{SelectionPolicy, SkillIndex};
4use adk_core::{
5 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
6 InvocationContext, Result, SharedState,
7};
8use async_stream::stream;
9use async_trait::async_trait;
10use std::sync::Arc;
11
12use super::shared_state_context::SharedStateContext;
13
14pub struct ParallelAgent {
16 name: String,
17 description: String,
18 sub_agents: Vec<Arc<dyn Agent>>,
19 skills_index: Option<Arc<SkillIndex>>,
20 skill_policy: SelectionPolicy,
21 max_skill_chars: usize,
22 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
23 after_callbacks: Arc<Vec<AfterAgentCallback>>,
24 shared_state_enabled: bool,
25}
26
27impl ParallelAgent {
28 pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
29 Self {
30 name: name.into(),
31 description: String::new(),
32 sub_agents,
33 skills_index: None,
34 skill_policy: SelectionPolicy::default(),
35 max_skill_chars: 2000,
36 before_callbacks: Arc::new(Vec::new()),
37 after_callbacks: Arc::new(Vec::new()),
38 shared_state_enabled: false,
39 }
40 }
41
42 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
43 self.description = desc.into();
44 self
45 }
46
47 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
48 if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
49 callbacks.push(callback);
50 }
51 self
52 }
53
54 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
55 if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
56 callbacks.push(callback);
57 }
58 self
59 }
60
61 #[cfg(feature = "skills")]
62 pub fn with_skills(mut self, index: SkillIndex) -> Self {
63 self.skills_index = Some(Arc::new(index));
64 self
65 }
66
67 #[cfg(feature = "skills")]
68 pub fn with_auto_skills(self) -> Result<Self> {
69 self.with_skills_from_root(".")
70 }
71
72 #[cfg(feature = "skills")]
73 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
74 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
75 self.skills_index = Some(Arc::new(index));
76 Ok(self)
77 }
78
79 #[cfg(feature = "skills")]
80 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
81 self.skill_policy = policy;
82 self
83 }
84
85 #[cfg(feature = "skills")]
86 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
87 self.max_skill_chars = max_chars;
88 self
89 }
90
91 pub fn with_shared_state(mut self) -> Self {
97 self.shared_state_enabled = true;
98 self
99 }
100}
101
102#[async_trait]
103impl Agent for ParallelAgent {
104 fn name(&self) -> &str {
105 &self.name
106 }
107
108 fn description(&self) -> &str {
109 &self.description
110 }
111
112 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
113 &self.sub_agents
114 }
115
116 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
117 let sub_agents = self.sub_agents.clone();
118 let run_ctx = super::skill_context::with_skill_injected_context(
119 ctx,
120 self.skills_index.as_ref(),
121 &self.skill_policy,
122 self.max_skill_chars,
123 );
124 let before_callbacks = self.before_callbacks.clone();
125 let after_callbacks = self.after_callbacks.clone();
126 let agent_name = self.name.clone();
127 let invocation_id = run_ctx.invocation_id().to_string();
128 let shared_state_enabled = self.shared_state_enabled;
129
130 let s = stream! {
131 use futures::stream::{FuturesUnordered, StreamExt};
132
133 for callback in before_callbacks.as_ref() {
134 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
135 Ok(Some(content)) => {
136 let mut early_event = Event::new(&invocation_id);
137 early_event.author = agent_name.clone();
138 early_event.llm_response.content = Some(content);
139 yield Ok(early_event);
140
141 for after_callback in after_callbacks.as_ref() {
142 match after_callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
143 Ok(Some(after_content)) => {
144 let mut after_event = Event::new(&invocation_id);
145 after_event.author = agent_name.clone();
146 after_event.llm_response.content = Some(after_content);
147 yield Ok(after_event);
148 return;
149 }
150 Ok(None) => continue,
151 Err(e) => {
152 yield Err(e);
153 return;
154 }
155 }
156 }
157 return;
158 }
159 Ok(None) => continue,
160 Err(e) => {
161 yield Err(e);
162 return;
163 }
164 }
165 }
166
167 let mut futures = FuturesUnordered::new();
168
169 let shared = if shared_state_enabled {
171 Some(Arc::new(SharedState::new()))
172 } else {
173 None
174 };
175
176 for agent in sub_agents {
177 let ctx: Arc<dyn InvocationContext> = if let Some(ref shared) = shared {
178 Arc::new(SharedStateContext::new(run_ctx.clone(), shared.clone()))
179 } else {
180 run_ctx.clone()
181 };
182 futures.push(async move {
183 agent.run(ctx).await
184 });
185 }
186
187 let mut first_error: Option<adk_core::AdkError> = None;
188
189 while let Some(result) = futures.next().await {
190 match result {
191 Ok(mut stream) => {
192 while let Some(event_result) = stream.next().await {
193 match event_result {
194 Ok(event) => yield Ok(event),
195 Err(e) => {
196 if first_error.is_none() {
197 first_error = Some(e);
198 }
199 break;
201 }
202 }
203 }
204 }
205 Err(e) => {
206 if first_error.is_none() {
207 first_error = Some(e);
208 }
209 }
211 }
212 }
213
214 if let Some(e) = first_error {
216 yield Err(e);
217 return;
218 }
219
220 for callback in after_callbacks.as_ref() {
221 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
222 Ok(Some(content)) => {
223 let mut after_event = Event::new(&invocation_id);
224 after_event.author = agent_name.clone();
225 after_event.llm_response.content = Some(content);
226 yield Ok(after_event);
227 break;
228 }
229 Ok(None) => continue,
230 Err(e) => {
231 yield Err(e);
232 return;
233 }
234 }
235 }
236 };
237
238 Ok(Box::pin(s))
239 }
240}