adk_agent/workflow/
loop_agent.rs1#[cfg(feature = "skills")]
2use crate::skill_shim::load_skill_index;
3use crate::skill_shim::{SelectionPolicy, SkillIndex};
4use adk_core::{
5 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Content, Event, EventStream,
6 InvocationContext, ReadonlyContext, Result, Session, State,
7};
8use async_stream::stream;
9use async_trait::async_trait;
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12
13pub const DEFAULT_LOOP_MAX_ITERATIONS: u32 = 1000;
16
17pub struct LoopAgent {
19 name: String,
20 description: String,
21 sub_agents: Vec<Arc<dyn Agent>>,
22 max_iterations: u32,
23 skills_index: Option<Arc<SkillIndex>>,
24 skill_policy: SelectionPolicy,
25 max_skill_chars: usize,
26 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
27 after_callbacks: Arc<Vec<AfterAgentCallback>>,
28}
29
30impl LoopAgent {
31 pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
32 Self {
33 name: name.into(),
34 description: String::new(),
35 sub_agents,
36 max_iterations: DEFAULT_LOOP_MAX_ITERATIONS,
37 skills_index: None,
38 skill_policy: SelectionPolicy::default(),
39 max_skill_chars: 2000,
40 before_callbacks: Arc::new(Vec::new()),
41 after_callbacks: Arc::new(Vec::new()),
42 }
43 }
44
45 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
46 self.description = desc.into();
47 self
48 }
49
50 pub fn with_max_iterations(mut self, max: u32) -> Self {
51 self.max_iterations = max;
52 self
53 }
54
55 #[cfg(feature = "skills")]
56 pub fn with_skills(mut self, index: SkillIndex) -> Self {
57 self.skills_index = Some(Arc::new(index));
58 self
59 }
60
61 #[cfg(feature = "skills")]
62 pub fn with_auto_skills(self) -> Result<Self> {
63 self.with_skills_from_root(".")
64 }
65
66 #[cfg(feature = "skills")]
67 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
68 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
69 self.skills_index = Some(Arc::new(index));
70 Ok(self)
71 }
72
73 #[cfg(feature = "skills")]
74 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
75 self.skill_policy = policy;
76 self
77 }
78
79 #[cfg(feature = "skills")]
80 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
81 self.max_skill_chars = max_chars;
82 self
83 }
84
85 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
86 if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
87 callbacks.push(callback);
88 }
89 self
90 }
91
92 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
93 if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
94 callbacks.push(callback);
95 }
96 self
97 }
98}
99
100struct HistoryTrackingSession {
101 parent_ctx: Arc<dyn InvocationContext>,
102 history: Arc<RwLock<Vec<Content>>>,
103 state: StateTrackingState,
104}
105
106struct StateTrackingState {
107 values: RwLock<HashMap<String, serde_json::Value>>,
108}
109
110impl StateTrackingState {
111 fn new(parent_ctx: &Arc<dyn InvocationContext>) -> Self {
112 Self { values: RwLock::new(parent_ctx.session().state().all()) }
113 }
114
115 fn apply_delta(&self, delta: &HashMap<String, serde_json::Value>) {
116 if delta.is_empty() {
117 return;
118 }
119
120 let mut values = self.values.write().unwrap_or_else(|e| e.into_inner());
121 for (key, value) in delta {
122 values.insert(key.clone(), value.clone());
123 }
124 }
125}
126
127impl State for StateTrackingState {
128 fn get(&self, key: &str) -> Option<serde_json::Value> {
129 self.values.read().unwrap_or_else(|e| e.into_inner()).get(key).cloned()
130 }
131
132 fn set(&mut self, key: String, value: serde_json::Value) {
133 if let Err(msg) = adk_core::validate_state_key(&key) {
134 tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
135 return;
136 }
137 self.values.write().unwrap_or_else(|e| e.into_inner()).insert(key, value);
138 }
139
140 fn all(&self) -> HashMap<String, serde_json::Value> {
141 self.values.read().unwrap_or_else(|e| e.into_inner()).clone()
142 }
143}
144
145impl HistoryTrackingSession {
146 fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
147 Self {
148 history: Arc::new(RwLock::new(parent_ctx.session().conversation_history())),
149 state: StateTrackingState::new(&parent_ctx),
150 parent_ctx,
151 }
152 }
153
154 fn apply_event(&self, event: &Event) {
155 if let Some(content) = &event.llm_response.content {
156 let mut history = self.history.write().unwrap_or_else(|e| e.into_inner());
161
162 if event.llm_response.partial {
163 if let Some(last) = history.last_mut() {
165 if last.role == content.role {
166 for part in &content.parts {
167 if let adk_core::Part::Text { text } = part {
168 if let Some(adk_core::Part::Text { text: existing }) =
170 last.parts.last_mut()
171 {
172 existing.push_str(text);
173 } else {
174 last.parts.push(part.clone());
175 }
176 } else {
177 last.parts.push(part.clone());
178 }
179 }
180 return;
181 }
182 }
183 history.push(content.clone());
185 } else {
186 if let Some(last) = history.last_mut() {
193 if last.role == content.role && !content.parts.is_empty() {
194 for part in &content.parts {
196 if let adk_core::Part::Text { text } = part {
197 if let Some(adk_core::Part::Text { text: existing }) =
198 last.parts.last_mut()
199 {
200 existing.push_str(text);
201 } else {
202 last.parts.push(part.clone());
203 }
204 } else {
205 last.parts.push(part.clone());
206 }
207 }
208 } else if !content.parts.is_empty() {
209 history.push(content.clone());
210 }
211 } else {
212 history.push(content.clone());
213 }
214 }
215 }
216 self.state.apply_delta(&event.actions.state_delta);
217 }
218}
219
220impl Session for HistoryTrackingSession {
221 fn id(&self) -> &str {
222 self.parent_ctx.session().id()
223 }
224
225 fn app_name(&self) -> &str {
226 self.parent_ctx.session().app_name()
227 }
228
229 fn user_id(&self) -> &str {
230 self.parent_ctx.session().user_id()
231 }
232
233 fn state(&self) -> &dyn State {
234 &self.state
235 }
236
237 fn conversation_history(&self) -> Vec<Content> {
238 self.history.read().unwrap_or_else(|e| e.into_inner()).clone()
239 }
240
241 fn conversation_history_for_agent(&self, _agent_name: &str) -> Vec<Content> {
242 self.conversation_history()
243 }
244
245 fn append_to_history(&self, content: Content) {
246 self.history.write().unwrap_or_else(|e| e.into_inner()).push(content);
247 }
248}
249
250struct HistoryTrackingContext {
251 parent_ctx: Arc<dyn InvocationContext>,
252 session: HistoryTrackingSession,
253}
254
255impl HistoryTrackingContext {
256 fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
257 let session = HistoryTrackingSession::new(parent_ctx.clone());
258 Self { parent_ctx, session }
259 }
260
261 fn apply_event(&self, event: &Event) {
262 self.session.apply_event(event);
263 }
264}
265
266#[async_trait]
267impl adk_core::ReadonlyContext for HistoryTrackingContext {
268 fn invocation_id(&self) -> &str {
269 self.parent_ctx.invocation_id()
270 }
271
272 fn agent_name(&self) -> &str {
273 self.parent_ctx.agent_name()
274 }
275
276 fn user_id(&self) -> &str {
277 self.parent_ctx.user_id()
278 }
279
280 fn app_name(&self) -> &str {
281 self.parent_ctx.app_name()
282 }
283
284 fn session_id(&self) -> &str {
285 self.parent_ctx.session_id()
286 }
287
288 fn branch(&self) -> &str {
289 self.parent_ctx.branch()
290 }
291
292 fn user_content(&self) -> &Content {
293 self.parent_ctx.user_content()
294 }
295}
296
297#[async_trait]
298impl CallbackContext for HistoryTrackingContext {
299 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
300 self.parent_ctx.artifacts()
301 }
302}
303
304#[async_trait]
305impl InvocationContext for HistoryTrackingContext {
306 fn agent(&self) -> Arc<dyn Agent> {
307 self.parent_ctx.agent()
308 }
309
310 fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
311 self.parent_ctx.memory()
312 }
313
314 fn session(&self) -> &dyn Session {
315 &self.session
316 }
317
318 fn run_config(&self) -> &adk_core::RunConfig {
319 self.parent_ctx.run_config()
320 }
321
322 fn end_invocation(&self) {
323 self.parent_ctx.end_invocation();
324 }
325
326 fn ended(&self) -> bool {
327 self.parent_ctx.ended()
328 }
329
330 fn user_scopes(&self) -> Vec<String> {
331 self.parent_ctx.user_scopes()
332 }
333
334 fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
335 self.parent_ctx.request_metadata()
336 }
337}
338
339#[async_trait]
340impl Agent for LoopAgent {
341 fn name(&self) -> &str {
342 &self.name
343 }
344
345 fn description(&self) -> &str {
346 &self.description
347 }
348
349 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
350 &self.sub_agents
351 }
352
353 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
354 let sub_agents = self.sub_agents.clone();
355 let max_iterations = self.max_iterations;
356 let before_callbacks = self.before_callbacks.clone();
357 let after_callbacks = self.after_callbacks.clone();
358 let agent_name = self.name.clone();
359 let run_ctx = super::skill_context::with_skill_injected_context(
360 ctx,
361 self.skills_index.as_ref(),
362 &self.skill_policy,
363 self.max_skill_chars,
364 );
365 let run_ctx = Arc::new(HistoryTrackingContext::new(run_ctx));
366
367 let s = stream! {
368 use futures::StreamExt;
369
370 for callback in before_callbacks.as_ref() {
372 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
373 Ok(Some(content)) => {
374 let mut early_event = Event::new(run_ctx.invocation_id());
375 early_event.author = agent_name.clone();
376 early_event.llm_response.content = Some(content);
377 yield Ok(early_event);
378
379 for after_cb in after_callbacks.as_ref() {
380 match after_cb(run_ctx.clone() as Arc<dyn CallbackContext>).await {
381 Ok(Some(after_content)) => {
382 let mut after_event = Event::new(run_ctx.invocation_id());
383 after_event.author = agent_name.clone();
384 after_event.llm_response.content = Some(after_content);
385 yield Ok(after_event);
386 return;
387 }
388 Ok(None) => continue,
389 Err(e) => { yield Err(e); return; }
390 }
391 }
392 return;
393 }
394 Ok(None) => continue,
395 Err(e) => { yield Err(e); return; }
396 }
397 }
398
399 let mut remaining = max_iterations;
400
401 loop {
402 let mut should_exit = false;
403
404 for agent in &sub_agents {
405 let mut stream = agent.run(run_ctx.clone() as Arc<dyn InvocationContext>).await?;
406
407 while let Some(result) = stream.next().await {
408 match result {
409 Ok(event) => {
410 run_ctx.apply_event(&event);
411 if event.actions.escalate {
412 should_exit = true;
413 }
414 yield Ok(event);
415 }
416 Err(e) => {
417 yield Err(e);
418 return;
419 }
420 }
421 }
422
423 if should_exit {
424 break;
425 }
426 }
427
428 if should_exit {
429 break;
430 }
431
432 remaining -= 1;
433 if remaining == 0 {
434 break;
435 }
436 }
437
438 for callback in after_callbacks.as_ref() {
440 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
441 Ok(Some(content)) => {
442 let mut after_event = Event::new(run_ctx.invocation_id());
443 after_event.author = agent_name.clone();
444 after_event.llm_response.content = Some(content);
445 yield Ok(after_event);
446 break;
447 }
448 Ok(None) => continue,
449 Err(e) => { yield Err(e); return; }
450 }
451 }
452 };
453
454 Ok(Box::pin(s))
455 }
456}