adk_agent/workflow/
loop_agent.rs1#[cfg(feature = "skills")]
2use crate::skill_shim::load_skill_index;
3use crate::skill_shim::{SelectionPolicy, SkillIndex};
4use adk_core::{
5 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Content, Event, EventStream,
6 InvocationContext, ReadonlyContext, Result, Session, State,
7};
8use async_stream::stream;
9use async_trait::async_trait;
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12
13pub const DEFAULT_LOOP_MAX_ITERATIONS: u32 = 1000;
16
17pub struct LoopAgent {
19 name: String,
20 description: String,
21 sub_agents: Vec<Arc<dyn Agent>>,
22 max_iterations: u32,
23 skills_index: Option<Arc<SkillIndex>>,
24 skill_policy: SelectionPolicy,
25 max_skill_chars: usize,
26 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
27 after_callbacks: Arc<Vec<AfterAgentCallback>>,
28}
29
30impl LoopAgent {
31 pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
33 Self {
34 name: name.into(),
35 description: String::new(),
36 sub_agents,
37 max_iterations: DEFAULT_LOOP_MAX_ITERATIONS,
38 skills_index: None,
39 skill_policy: SelectionPolicy::default(),
40 max_skill_chars: 2000,
41 before_callbacks: Arc::new(Vec::new()),
42 after_callbacks: Arc::new(Vec::new()),
43 }
44 }
45
46 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
48 self.description = desc.into();
49 self
50 }
51
52 pub fn with_max_iterations(mut self, max: u32) -> Self {
54 self.max_iterations = max;
55 self
56 }
57
58 #[cfg(feature = "skills")]
60 pub fn with_skills(mut self, index: SkillIndex) -> Self {
61 self.skills_index = Some(Arc::new(index));
62 self
63 }
64
65 #[cfg(feature = "skills")]
67 pub fn with_auto_skills(self) -> Result<Self> {
68 self.with_skills_from_root(".")
69 }
70
71 #[cfg(feature = "skills")]
73 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
74 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
75 self.skills_index = Some(Arc::new(index));
76 Ok(self)
77 }
78
79 #[cfg(feature = "skills")]
81 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
82 self.skill_policy = policy;
83 self
84 }
85
86 #[cfg(feature = "skills")]
88 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
89 self.max_skill_chars = max_chars;
90 self
91 }
92
93 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
95 if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
96 callbacks.push(callback);
97 }
98 self
99 }
100
101 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
103 if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
104 callbacks.push(callback);
105 }
106 self
107 }
108}
109
110struct HistoryTrackingSession {
111 parent_ctx: Arc<dyn InvocationContext>,
112 history: Arc<RwLock<Vec<Content>>>,
113 state: StateTrackingState,
114}
115
116struct StateTrackingState {
117 values: RwLock<HashMap<String, serde_json::Value>>,
118}
119
120impl StateTrackingState {
121 fn new(parent_ctx: &Arc<dyn InvocationContext>) -> Self {
122 Self { values: RwLock::new(parent_ctx.session().state().all()) }
123 }
124
125 fn apply_delta(&self, delta: &HashMap<String, serde_json::Value>) {
126 if delta.is_empty() {
127 return;
128 }
129
130 let mut values = self.values.write().unwrap_or_else(|e| e.into_inner());
131 for (key, value) in delta {
132 values.insert(key.clone(), value.clone());
133 }
134 }
135}
136
137impl State for StateTrackingState {
138 fn get(&self, key: &str) -> Option<serde_json::Value> {
139 self.values.read().unwrap_or_else(|e| e.into_inner()).get(key).cloned()
140 }
141
142 fn set(&mut self, key: String, value: serde_json::Value) {
143 if let Err(msg) = adk_core::validate_state_key(&key) {
144 tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
145 return;
146 }
147 self.values.write().unwrap_or_else(|e| e.into_inner()).insert(key, value);
148 }
149
150 fn all(&self) -> HashMap<String, serde_json::Value> {
151 self.values.read().unwrap_or_else(|e| e.into_inner()).clone()
152 }
153}
154
155impl HistoryTrackingSession {
156 fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
157 Self {
158 history: Arc::new(RwLock::new(parent_ctx.session().conversation_history())),
159 state: StateTrackingState::new(&parent_ctx),
160 parent_ctx,
161 }
162 }
163
164 fn apply_event(&self, event: &Event) {
165 if let Some(content) = &event.llm_response.content {
166 let mut history = self.history.write().unwrap_or_else(|e| e.into_inner());
171
172 if event.llm_response.partial {
173 if let Some(last) = history.last_mut() {
175 if last.role == content.role {
176 for part in &content.parts {
177 if let adk_core::Part::Text { text } = part {
178 if let Some(adk_core::Part::Text { text: existing }) =
180 last.parts.last_mut()
181 {
182 existing.push_str(text);
183 } else {
184 last.parts.push(part.clone());
185 }
186 } else {
187 last.parts.push(part.clone());
188 }
189 }
190 return;
191 }
192 }
193 history.push(content.clone());
195 } else {
196 if let Some(last) = history.last_mut() {
203 if last.role == content.role && !content.parts.is_empty() {
204 for part in &content.parts {
206 if let adk_core::Part::Text { text } = part {
207 if let Some(adk_core::Part::Text { text: existing }) =
208 last.parts.last_mut()
209 {
210 existing.push_str(text);
211 } else {
212 last.parts.push(part.clone());
213 }
214 } else {
215 last.parts.push(part.clone());
216 }
217 }
218 } else if !content.parts.is_empty() {
219 history.push(content.clone());
220 }
221 } else {
222 history.push(content.clone());
223 }
224 }
225 }
226 self.state.apply_delta(&event.actions.state_delta);
227 }
228}
229
230impl Session for HistoryTrackingSession {
231 fn id(&self) -> &str {
232 self.parent_ctx.session().id()
233 }
234
235 fn app_name(&self) -> &str {
236 self.parent_ctx.session().app_name()
237 }
238
239 fn user_id(&self) -> &str {
240 self.parent_ctx.session().user_id()
241 }
242
243 fn state(&self) -> &dyn State {
244 &self.state
245 }
246
247 fn conversation_history(&self) -> Vec<Content> {
248 self.history.read().unwrap_or_else(|e| e.into_inner()).clone()
249 }
250
251 fn conversation_history_for_agent(&self, _agent_name: &str) -> Vec<Content> {
252 self.conversation_history()
253 }
254
255 fn append_to_history(&self, content: Content) {
256 self.history.write().unwrap_or_else(|e| e.into_inner()).push(content);
257 }
258}
259
260struct HistoryTrackingContext {
261 parent_ctx: Arc<dyn InvocationContext>,
262 session: HistoryTrackingSession,
263}
264
265impl HistoryTrackingContext {
266 fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
267 let session = HistoryTrackingSession::new(parent_ctx.clone());
268 Self { parent_ctx, session }
269 }
270
271 fn apply_event(&self, event: &Event) {
272 self.session.apply_event(event);
273 }
274}
275
276#[async_trait]
277impl adk_core::ReadonlyContext for HistoryTrackingContext {
278 fn invocation_id(&self) -> &str {
279 self.parent_ctx.invocation_id()
280 }
281
282 fn agent_name(&self) -> &str {
283 self.parent_ctx.agent_name()
284 }
285
286 fn user_id(&self) -> &str {
287 self.parent_ctx.user_id()
288 }
289
290 fn app_name(&self) -> &str {
291 self.parent_ctx.app_name()
292 }
293
294 fn session_id(&self) -> &str {
295 self.parent_ctx.session_id()
296 }
297
298 fn branch(&self) -> &str {
299 self.parent_ctx.branch()
300 }
301
302 fn user_content(&self) -> &Content {
303 self.parent_ctx.user_content()
304 }
305}
306
307#[async_trait]
308impl CallbackContext for HistoryTrackingContext {
309 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
310 self.parent_ctx.artifacts()
311 }
312}
313
314#[async_trait]
315impl InvocationContext for HistoryTrackingContext {
316 fn agent(&self) -> Arc<dyn Agent> {
317 self.parent_ctx.agent()
318 }
319
320 fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
321 self.parent_ctx.memory()
322 }
323
324 fn session(&self) -> &dyn Session {
325 &self.session
326 }
327
328 fn run_config(&self) -> &adk_core::RunConfig {
329 self.parent_ctx.run_config()
330 }
331
332 fn end_invocation(&self) {
333 self.parent_ctx.end_invocation();
334 }
335
336 fn ended(&self) -> bool {
337 self.parent_ctx.ended()
338 }
339
340 fn user_scopes(&self) -> Vec<String> {
341 self.parent_ctx.user_scopes()
342 }
343
344 fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
345 self.parent_ctx.request_metadata()
346 }
347}
348
349#[async_trait]
350impl Agent for LoopAgent {
351 fn name(&self) -> &str {
352 &self.name
353 }
354
355 fn description(&self) -> &str {
356 &self.description
357 }
358
359 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
360 &self.sub_agents
361 }
362
363 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
364 let sub_agents = self.sub_agents.clone();
365 let max_iterations = self.max_iterations;
366 let before_callbacks = self.before_callbacks.clone();
367 let after_callbacks = self.after_callbacks.clone();
368 let agent_name = self.name.clone();
369 let run_ctx = super::skill_context::with_skill_injected_context(
370 ctx,
371 self.skills_index.as_ref(),
372 &self.skill_policy,
373 self.max_skill_chars,
374 );
375 let run_ctx = Arc::new(HistoryTrackingContext::new(run_ctx));
376
377 let s = stream! {
378 use futures::StreamExt;
379
380 for callback in before_callbacks.as_ref() {
382 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
383 Ok(Some(content)) => {
384 let mut early_event = Event::new(run_ctx.invocation_id());
385 early_event.author = agent_name.clone();
386 early_event.llm_response.content = Some(content);
387 yield Ok(early_event);
388
389 for after_cb in after_callbacks.as_ref() {
390 match after_cb(run_ctx.clone() as Arc<dyn CallbackContext>).await {
391 Ok(Some(after_content)) => {
392 let mut after_event = Event::new(run_ctx.invocation_id());
393 after_event.author = agent_name.clone();
394 after_event.llm_response.content = Some(after_content);
395 yield Ok(after_event);
396 return;
397 }
398 Ok(None) => continue,
399 Err(e) => { yield Err(e); return; }
400 }
401 }
402 return;
403 }
404 Ok(None) => continue,
405 Err(e) => { yield Err(e); return; }
406 }
407 }
408
409 let mut remaining = max_iterations;
410
411 loop {
412 let mut should_exit = false;
413
414 for agent in &sub_agents {
415 let mut stream = agent.run(run_ctx.clone() as Arc<dyn InvocationContext>).await?;
416
417 while let Some(result) = stream.next().await {
418 match result {
419 Ok(event) => {
420 run_ctx.apply_event(&event);
421 if event.actions.escalate {
422 should_exit = true;
423 }
424 yield Ok(event);
425 }
426 Err(e) => {
427 yield Err(e);
428 return;
429 }
430 }
431 }
432
433 if should_exit {
434 break;
435 }
436 }
437
438 if should_exit {
439 break;
440 }
441
442 remaining -= 1;
443 if remaining == 0 {
444 break;
445 }
446 }
447
448 for callback in after_callbacks.as_ref() {
450 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
451 Ok(Some(content)) => {
452 let mut after_event = Event::new(run_ctx.invocation_id());
453 after_event.author = agent_name.clone();
454 after_event.llm_response.content = Some(content);
455 yield Ok(after_event);
456 break;
457 }
458 Ok(None) => continue,
459 Err(e) => { yield Err(e); return; }
460 }
461 }
462 };
463
464 Ok(Box::pin(s))
465 }
466}