adk_agent/workflow/
parallel_agent.rs1use adk_core::{
2 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Event, EventStream,
3 InvocationContext, Result,
4};
5use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index};
6use async_stream::stream;
7use async_trait::async_trait;
8use std::sync::Arc;
9
10pub struct ParallelAgent {
12 name: String,
13 description: String,
14 sub_agents: Vec<Arc<dyn Agent>>,
15 skills_index: Option<Arc<SkillIndex>>,
16 skill_policy: SelectionPolicy,
17 max_skill_chars: usize,
18 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
19 after_callbacks: Arc<Vec<AfterAgentCallback>>,
20}
21
22impl ParallelAgent {
23 pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
24 Self {
25 name: name.into(),
26 description: String::new(),
27 sub_agents,
28 skills_index: None,
29 skill_policy: SelectionPolicy::default(),
30 max_skill_chars: 2000,
31 before_callbacks: Arc::new(Vec::new()),
32 after_callbacks: Arc::new(Vec::new()),
33 }
34 }
35
36 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
37 self.description = desc.into();
38 self
39 }
40
41 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
42 Arc::get_mut(&mut self.before_callbacks)
43 .expect("before_callbacks not yet shared")
44 .push(callback);
45 self
46 }
47
48 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
49 Arc::get_mut(&mut self.after_callbacks)
50 .expect("after_callbacks not yet shared")
51 .push(callback);
52 self
53 }
54
55 pub fn with_skills(mut self, index: SkillIndex) -> Self {
56 self.skills_index = Some(Arc::new(index));
57 self
58 }
59
60 pub fn with_auto_skills(self) -> Result<Self> {
61 self.with_skills_from_root(".")
62 }
63
64 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
65 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::Agent(e.to_string()))?;
66 self.skills_index = Some(Arc::new(index));
67 Ok(self)
68 }
69
70 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
71 self.skill_policy = policy;
72 self
73 }
74
75 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
76 self.max_skill_chars = max_chars;
77 self
78 }
79}
80
81#[async_trait]
82impl Agent for ParallelAgent {
83 fn name(&self) -> &str {
84 &self.name
85 }
86
87 fn description(&self) -> &str {
88 &self.description
89 }
90
91 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
92 &self.sub_agents
93 }
94
95 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
96 let sub_agents = self.sub_agents.clone();
97 let run_ctx = super::skill_context::with_skill_injected_context(
98 ctx,
99 self.skills_index.as_ref(),
100 &self.skill_policy,
101 self.max_skill_chars,
102 );
103 let before_callbacks = self.before_callbacks.clone();
104 let after_callbacks = self.after_callbacks.clone();
105 let agent_name = self.name.clone();
106 let invocation_id = run_ctx.invocation_id().to_string();
107
108 let s = stream! {
109 use futures::stream::{FuturesUnordered, StreamExt};
110
111 for callback in before_callbacks.as_ref() {
112 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
113 Ok(Some(content)) => {
114 let mut early_event = Event::new(&invocation_id);
115 early_event.author = agent_name.clone();
116 early_event.llm_response.content = Some(content);
117 yield Ok(early_event);
118
119 for after_callback in after_callbacks.as_ref() {
120 match after_callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
121 Ok(Some(after_content)) => {
122 let mut after_event = Event::new(&invocation_id);
123 after_event.author = agent_name.clone();
124 after_event.llm_response.content = Some(after_content);
125 yield Ok(after_event);
126 return;
127 }
128 Ok(None) => continue,
129 Err(e) => {
130 yield Err(e);
131 return;
132 }
133 }
134 }
135 return;
136 }
137 Ok(None) => continue,
138 Err(e) => {
139 yield Err(e);
140 return;
141 }
142 }
143 }
144
145 let mut futures = FuturesUnordered::new();
146
147 for agent in sub_agents {
148 let ctx = run_ctx.clone();
149 futures.push(async move {
150 agent.run(ctx).await
151 });
152 }
153
154 let mut first_error: Option<adk_core::AdkError> = None;
155
156 while let Some(result) = futures.next().await {
157 match result {
158 Ok(mut stream) => {
159 while let Some(event_result) = stream.next().await {
160 match event_result {
161 Ok(event) => yield Ok(event),
162 Err(e) => {
163 if first_error.is_none() {
164 first_error = Some(e);
165 }
166 break;
168 }
169 }
170 }
171 }
172 Err(e) => {
173 if first_error.is_none() {
174 first_error = Some(e);
175 }
176 }
178 }
179 }
180
181 if let Some(e) = first_error {
183 yield Err(e);
184 return;
185 }
186
187 for callback in after_callbacks.as_ref() {
188 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
189 Ok(Some(content)) => {
190 let mut after_event = Event::new(&invocation_id);
191 after_event.author = agent_name.clone();
192 after_event.llm_response.content = Some(content);
193 yield Ok(after_event);
194 break;
195 }
196 Ok(None) => continue,
197 Err(e) => {
198 yield Err(e);
199 return;
200 }
201 }
202 }
203 };
204
205 Ok(Box::pin(s))
206 }
207}