adk_agent/workflow/
loop_agent.rs1use adk_core::{
2 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Content, Event, EventStream,
3 InvocationContext, ReadonlyContext, Result, Session, State,
4};
5use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index};
6use async_stream::stream;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11pub const DEFAULT_LOOP_MAX_ITERATIONS: u32 = 1000;
14
15pub struct LoopAgent {
17 name: String,
18 description: String,
19 sub_agents: Vec<Arc<dyn Agent>>,
20 max_iterations: u32,
21 skills_index: Option<Arc<SkillIndex>>,
22 skill_policy: SelectionPolicy,
23 max_skill_chars: usize,
24 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
25 after_callbacks: Arc<Vec<AfterAgentCallback>>,
26}
27
28impl LoopAgent {
29 pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
30 Self {
31 name: name.into(),
32 description: String::new(),
33 sub_agents,
34 max_iterations: DEFAULT_LOOP_MAX_ITERATIONS,
35 skills_index: None,
36 skill_policy: SelectionPolicy::default(),
37 max_skill_chars: 2000,
38 before_callbacks: Arc::new(Vec::new()),
39 after_callbacks: Arc::new(Vec::new()),
40 }
41 }
42
43 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
44 self.description = desc.into();
45 self
46 }
47
48 pub fn with_max_iterations(mut self, max: u32) -> Self {
49 self.max_iterations = max;
50 self
51 }
52
53 pub fn with_skills(mut self, index: SkillIndex) -> Self {
54 self.skills_index = Some(Arc::new(index));
55 self
56 }
57
58 pub fn with_auto_skills(self) -> Result<Self> {
59 self.with_skills_from_root(".")
60 }
61
62 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
63 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
64 self.skills_index = Some(Arc::new(index));
65 Ok(self)
66 }
67
68 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
69 self.skill_policy = policy;
70 self
71 }
72
73 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
74 self.max_skill_chars = max_chars;
75 self
76 }
77
78 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
79 if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
80 callbacks.push(callback);
81 }
82 self
83 }
84
85 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
86 if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
87 callbacks.push(callback);
88 }
89 self
90 }
91}
92
93struct HistoryTrackingSession {
94 parent_ctx: Arc<dyn InvocationContext>,
95 history: Arc<RwLock<Vec<Content>>>,
96 state: StateTrackingState,
97}
98
99struct StateTrackingState {
100 values: RwLock<HashMap<String, serde_json::Value>>,
101}
102
103impl StateTrackingState {
104 fn new(parent_ctx: &Arc<dyn InvocationContext>) -> Self {
105 Self { values: RwLock::new(parent_ctx.session().state().all()) }
106 }
107
108 fn apply_delta(&self, delta: &HashMap<String, serde_json::Value>) {
109 if delta.is_empty() {
110 return;
111 }
112
113 let mut values = self.values.write().unwrap_or_else(|e| e.into_inner());
114 for (key, value) in delta {
115 values.insert(key.clone(), value.clone());
116 }
117 }
118}
119
120impl State for StateTrackingState {
121 fn get(&self, key: &str) -> Option<serde_json::Value> {
122 self.values.read().unwrap_or_else(|e| e.into_inner()).get(key).cloned()
123 }
124
125 fn set(&mut self, key: String, value: serde_json::Value) {
126 if let Err(msg) = adk_core::validate_state_key(&key) {
127 tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
128 return;
129 }
130 self.values.write().unwrap_or_else(|e| e.into_inner()).insert(key, value);
131 }
132
133 fn all(&self) -> HashMap<String, serde_json::Value> {
134 self.values.read().unwrap_or_else(|e| e.into_inner()).clone()
135 }
136}
137
138impl HistoryTrackingSession {
139 fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
140 Self {
141 history: Arc::new(RwLock::new(parent_ctx.session().conversation_history())),
142 state: StateTrackingState::new(&parent_ctx),
143 parent_ctx,
144 }
145 }
146
147 fn apply_event(&self, event: &Event) {
148 if let Some(content) = &event.llm_response.content {
149 let mut history = self.history.write().unwrap_or_else(|e| e.into_inner());
154
155 if event.llm_response.partial {
156 if let Some(last) = history.last_mut() {
158 if last.role == content.role {
159 for part in &content.parts {
160 if let adk_core::Part::Text { text } = part {
161 if let Some(adk_core::Part::Text { text: existing }) =
163 last.parts.last_mut()
164 {
165 existing.push_str(text);
166 } else {
167 last.parts.push(part.clone());
168 }
169 } else {
170 last.parts.push(part.clone());
171 }
172 }
173 return;
174 }
175 }
176 history.push(content.clone());
178 } else {
179 if let Some(last) = history.last_mut() {
186 if last.role == content.role && !content.parts.is_empty() {
187 for part in &content.parts {
189 if let adk_core::Part::Text { text } = part {
190 if let Some(adk_core::Part::Text { text: existing }) =
191 last.parts.last_mut()
192 {
193 existing.push_str(text);
194 } else {
195 last.parts.push(part.clone());
196 }
197 } else {
198 last.parts.push(part.clone());
199 }
200 }
201 } else if !content.parts.is_empty() {
202 history.push(content.clone());
203 }
204 } else {
205 history.push(content.clone());
206 }
207 }
208 }
209 self.state.apply_delta(&event.actions.state_delta);
210 }
211}
212
213impl Session for HistoryTrackingSession {
214 fn id(&self) -> &str {
215 self.parent_ctx.session().id()
216 }
217
218 fn app_name(&self) -> &str {
219 self.parent_ctx.session().app_name()
220 }
221
222 fn user_id(&self) -> &str {
223 self.parent_ctx.session().user_id()
224 }
225
226 fn state(&self) -> &dyn State {
227 &self.state
228 }
229
230 fn conversation_history(&self) -> Vec<Content> {
231 self.history.read().unwrap_or_else(|e| e.into_inner()).clone()
232 }
233
234 fn conversation_history_for_agent(&self, _agent_name: &str) -> Vec<Content> {
235 self.conversation_history()
236 }
237
238 fn append_to_history(&self, content: Content) {
239 self.history.write().unwrap_or_else(|e| e.into_inner()).push(content);
240 }
241}
242
243struct HistoryTrackingContext {
244 parent_ctx: Arc<dyn InvocationContext>,
245 session: HistoryTrackingSession,
246}
247
248impl HistoryTrackingContext {
249 fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
250 let session = HistoryTrackingSession::new(parent_ctx.clone());
251 Self { parent_ctx, session }
252 }
253
254 fn apply_event(&self, event: &Event) {
255 self.session.apply_event(event);
256 }
257}
258
259#[async_trait]
260impl adk_core::ReadonlyContext for HistoryTrackingContext {
261 fn invocation_id(&self) -> &str {
262 self.parent_ctx.invocation_id()
263 }
264
265 fn agent_name(&self) -> &str {
266 self.parent_ctx.agent_name()
267 }
268
269 fn user_id(&self) -> &str {
270 self.parent_ctx.user_id()
271 }
272
273 fn app_name(&self) -> &str {
274 self.parent_ctx.app_name()
275 }
276
277 fn session_id(&self) -> &str {
278 self.parent_ctx.session_id()
279 }
280
281 fn branch(&self) -> &str {
282 self.parent_ctx.branch()
283 }
284
285 fn user_content(&self) -> &Content {
286 self.parent_ctx.user_content()
287 }
288}
289
290#[async_trait]
291impl CallbackContext for HistoryTrackingContext {
292 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
293 self.parent_ctx.artifacts()
294 }
295}
296
297#[async_trait]
298impl InvocationContext for HistoryTrackingContext {
299 fn agent(&self) -> Arc<dyn Agent> {
300 self.parent_ctx.agent()
301 }
302
303 fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
304 self.parent_ctx.memory()
305 }
306
307 fn session(&self) -> &dyn Session {
308 &self.session
309 }
310
311 fn run_config(&self) -> &adk_core::RunConfig {
312 self.parent_ctx.run_config()
313 }
314
315 fn end_invocation(&self) {
316 self.parent_ctx.end_invocation();
317 }
318
319 fn ended(&self) -> bool {
320 self.parent_ctx.ended()
321 }
322
323 fn user_scopes(&self) -> Vec<String> {
324 self.parent_ctx.user_scopes()
325 }
326
327 fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
328 self.parent_ctx.request_metadata()
329 }
330}
331
332#[async_trait]
333impl Agent for LoopAgent {
334 fn name(&self) -> &str {
335 &self.name
336 }
337
338 fn description(&self) -> &str {
339 &self.description
340 }
341
342 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
343 &self.sub_agents
344 }
345
346 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
347 let sub_agents = self.sub_agents.clone();
348 let max_iterations = self.max_iterations;
349 let before_callbacks = self.before_callbacks.clone();
350 let after_callbacks = self.after_callbacks.clone();
351 let agent_name = self.name.clone();
352 let run_ctx = super::skill_context::with_skill_injected_context(
353 ctx,
354 self.skills_index.as_ref(),
355 &self.skill_policy,
356 self.max_skill_chars,
357 );
358 let run_ctx = Arc::new(HistoryTrackingContext::new(run_ctx));
359
360 let s = stream! {
361 use futures::StreamExt;
362
363 for callback in before_callbacks.as_ref() {
365 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
366 Ok(Some(content)) => {
367 let mut early_event = Event::new(run_ctx.invocation_id());
368 early_event.author = agent_name.clone();
369 early_event.llm_response.content = Some(content);
370 yield Ok(early_event);
371
372 for after_cb in after_callbacks.as_ref() {
373 match after_cb(run_ctx.clone() as Arc<dyn CallbackContext>).await {
374 Ok(Some(after_content)) => {
375 let mut after_event = Event::new(run_ctx.invocation_id());
376 after_event.author = agent_name.clone();
377 after_event.llm_response.content = Some(after_content);
378 yield Ok(after_event);
379 return;
380 }
381 Ok(None) => continue,
382 Err(e) => { yield Err(e); return; }
383 }
384 }
385 return;
386 }
387 Ok(None) => continue,
388 Err(e) => { yield Err(e); return; }
389 }
390 }
391
392 let mut remaining = max_iterations;
393
394 loop {
395 let mut should_exit = false;
396
397 for agent in &sub_agents {
398 let mut stream = agent.run(run_ctx.clone() as Arc<dyn InvocationContext>).await?;
399
400 while let Some(result) = stream.next().await {
401 match result {
402 Ok(event) => {
403 run_ctx.apply_event(&event);
404 if event.actions.escalate {
405 should_exit = true;
406 }
407 yield Ok(event);
408 }
409 Err(e) => {
410 yield Err(e);
411 return;
412 }
413 }
414 }
415
416 if should_exit {
417 break;
418 }
419 }
420
421 if should_exit {
422 break;
423 }
424
425 remaining -= 1;
426 if remaining == 0 {
427 break;
428 }
429 }
430
431 for callback in after_callbacks.as_ref() {
433 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
434 Ok(Some(content)) => {
435 let mut after_event = Event::new(run_ctx.invocation_id());
436 after_event.author = agent_name.clone();
437 after_event.llm_response.content = Some(content);
438 yield Ok(after_event);
439 break;
440 }
441 Ok(None) => continue,
442 Err(e) => { yield Err(e); return; }
443 }
444 }
445 };
446
447 Ok(Box::pin(s))
448 }
449}