Skip to main content

adk_agent/workflow/
loop_agent.rs

1use adk_core::{
2    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Content, Event, EventStream,
3    InvocationContext, ReadonlyContext, Result, Session, State,
4};
5use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index};
6use async_stream::stream;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11/// Default maximum iterations for LoopAgent when none is specified.
12/// Prevents infinite loops from consuming unbounded resources.
13pub const DEFAULT_LOOP_MAX_ITERATIONS: u32 = 1000;
14
15/// Loop agent executes sub-agents repeatedly for N iterations or until escalation
16pub struct LoopAgent {
17    name: String,
18    description: String,
19    sub_agents: Vec<Arc<dyn Agent>>,
20    max_iterations: u32,
21    skills_index: Option<Arc<SkillIndex>>,
22    skill_policy: SelectionPolicy,
23    max_skill_chars: usize,
24    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
25    after_callbacks: Arc<Vec<AfterAgentCallback>>,
26}
27
28impl LoopAgent {
29    pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
30        Self {
31            name: name.into(),
32            description: String::new(),
33            sub_agents,
34            max_iterations: DEFAULT_LOOP_MAX_ITERATIONS,
35            skills_index: None,
36            skill_policy: SelectionPolicy::default(),
37            max_skill_chars: 2000,
38            before_callbacks: Arc::new(Vec::new()),
39            after_callbacks: Arc::new(Vec::new()),
40        }
41    }
42
43    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
44        self.description = desc.into();
45        self
46    }
47
48    pub fn with_max_iterations(mut self, max: u32) -> Self {
49        self.max_iterations = max;
50        self
51    }
52
53    pub fn with_skills(mut self, index: SkillIndex) -> Self {
54        self.skills_index = Some(Arc::new(index));
55        self
56    }
57
58    pub fn with_auto_skills(self) -> Result<Self> {
59        self.with_skills_from_root(".")
60    }
61
62    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
63        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
64        self.skills_index = Some(Arc::new(index));
65        Ok(self)
66    }
67
68    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
69        self.skill_policy = policy;
70        self
71    }
72
73    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
74        self.max_skill_chars = max_chars;
75        self
76    }
77
78    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
79        if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
80            callbacks.push(callback);
81        }
82        self
83    }
84
85    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
86        if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
87            callbacks.push(callback);
88        }
89        self
90    }
91}
92
93struct HistoryTrackingSession {
94    parent_ctx: Arc<dyn InvocationContext>,
95    history: Arc<RwLock<Vec<Content>>>,
96    state: StateTrackingState,
97}
98
99struct StateTrackingState {
100    values: RwLock<HashMap<String, serde_json::Value>>,
101}
102
103impl StateTrackingState {
104    fn new(parent_ctx: &Arc<dyn InvocationContext>) -> Self {
105        Self { values: RwLock::new(parent_ctx.session().state().all()) }
106    }
107
108    fn apply_delta(&self, delta: &HashMap<String, serde_json::Value>) {
109        if delta.is_empty() {
110            return;
111        }
112
113        let mut values = self.values.write().unwrap_or_else(|e| e.into_inner());
114        for (key, value) in delta {
115            values.insert(key.clone(), value.clone());
116        }
117    }
118}
119
120impl State for StateTrackingState {
121    fn get(&self, key: &str) -> Option<serde_json::Value> {
122        self.values.read().unwrap_or_else(|e| e.into_inner()).get(key).cloned()
123    }
124
125    fn set(&mut self, key: String, value: serde_json::Value) {
126        if let Err(msg) = adk_core::validate_state_key(&key) {
127            tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
128            return;
129        }
130        self.values.write().unwrap_or_else(|e| e.into_inner()).insert(key, value);
131    }
132
133    fn all(&self) -> HashMap<String, serde_json::Value> {
134        self.values.read().unwrap_or_else(|e| e.into_inner()).clone()
135    }
136}
137
138impl HistoryTrackingSession {
139    fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
140        Self {
141            history: Arc::new(RwLock::new(parent_ctx.session().conversation_history())),
142            state: StateTrackingState::new(&parent_ctx),
143            parent_ctx,
144        }
145    }
146
147    fn apply_event(&self, event: &Event) {
148        if let Some(content) = &event.llm_response.content {
149            // Consolidate streaming chunks: if the last history entry has the
150            // same role, merge text into it instead of creating a new entry.
151            // This prevents N streaming chunks from becoming N separate Content
152            // entries that bloat context for subsequent agents.
153            let mut history = self.history.write().unwrap_or_else(|e| e.into_inner());
154
155            if event.llm_response.partial {
156                // Partial chunk — merge into last entry if same role
157                if let Some(last) = history.last_mut() {
158                    if last.role == content.role {
159                        for part in &content.parts {
160                            if let adk_core::Part::Text { text } = part {
161                                // Append text to the last Text part
162                                if let Some(adk_core::Part::Text { text: existing }) =
163                                    last.parts.last_mut()
164                                {
165                                    existing.push_str(text);
166                                } else {
167                                    last.parts.push(part.clone());
168                                }
169                            } else {
170                                last.parts.push(part.clone());
171                            }
172                        }
173                        return;
174                    }
175                }
176                // No matching last entry — start a new one
177                history.push(content.clone());
178            } else {
179                // Final event (partial=false) — append as-is.
180                // For non-streaming mode this carries the full content.
181                // For streaming mode the accumulated text is already in the
182                // last history entry from partial merges above, so the final
183                // chunk (which may carry the last fragment or be empty) is
184                // merged if same role, or appended if different.
185                if let Some(last) = history.last_mut() {
186                    if last.role == content.role && !content.parts.is_empty() {
187                        // Merge any remaining text from the final chunk
188                        for part in &content.parts {
189                            if let adk_core::Part::Text { text } = part {
190                                if let Some(adk_core::Part::Text { text: existing }) =
191                                    last.parts.last_mut()
192                                {
193                                    existing.push_str(text);
194                                } else {
195                                    last.parts.push(part.clone());
196                                }
197                            } else {
198                                last.parts.push(part.clone());
199                            }
200                        }
201                    } else if !content.parts.is_empty() {
202                        history.push(content.clone());
203                    }
204                } else {
205                    history.push(content.clone());
206                }
207            }
208        }
209        self.state.apply_delta(&event.actions.state_delta);
210    }
211}
212
213impl Session for HistoryTrackingSession {
214    fn id(&self) -> &str {
215        self.parent_ctx.session().id()
216    }
217
218    fn app_name(&self) -> &str {
219        self.parent_ctx.session().app_name()
220    }
221
222    fn user_id(&self) -> &str {
223        self.parent_ctx.session().user_id()
224    }
225
226    fn state(&self) -> &dyn State {
227        &self.state
228    }
229
230    fn conversation_history(&self) -> Vec<Content> {
231        self.history.read().unwrap_or_else(|e| e.into_inner()).clone()
232    }
233
234    fn conversation_history_for_agent(&self, _agent_name: &str) -> Vec<Content> {
235        self.conversation_history()
236    }
237
238    fn append_to_history(&self, content: Content) {
239        self.history.write().unwrap_or_else(|e| e.into_inner()).push(content);
240    }
241}
242
243struct HistoryTrackingContext {
244    parent_ctx: Arc<dyn InvocationContext>,
245    session: HistoryTrackingSession,
246}
247
248impl HistoryTrackingContext {
249    fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
250        let session = HistoryTrackingSession::new(parent_ctx.clone());
251        Self { parent_ctx, session }
252    }
253
254    fn apply_event(&self, event: &Event) {
255        self.session.apply_event(event);
256    }
257}
258
259#[async_trait]
260impl adk_core::ReadonlyContext for HistoryTrackingContext {
261    fn invocation_id(&self) -> &str {
262        self.parent_ctx.invocation_id()
263    }
264
265    fn agent_name(&self) -> &str {
266        self.parent_ctx.agent_name()
267    }
268
269    fn user_id(&self) -> &str {
270        self.parent_ctx.user_id()
271    }
272
273    fn app_name(&self) -> &str {
274        self.parent_ctx.app_name()
275    }
276
277    fn session_id(&self) -> &str {
278        self.parent_ctx.session_id()
279    }
280
281    fn branch(&self) -> &str {
282        self.parent_ctx.branch()
283    }
284
285    fn user_content(&self) -> &Content {
286        self.parent_ctx.user_content()
287    }
288}
289
290#[async_trait]
291impl CallbackContext for HistoryTrackingContext {
292    fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
293        self.parent_ctx.artifacts()
294    }
295}
296
297#[async_trait]
298impl InvocationContext for HistoryTrackingContext {
299    fn agent(&self) -> Arc<dyn Agent> {
300        self.parent_ctx.agent()
301    }
302
303    fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
304        self.parent_ctx.memory()
305    }
306
307    fn session(&self) -> &dyn Session {
308        &self.session
309    }
310
311    fn run_config(&self) -> &adk_core::RunConfig {
312        self.parent_ctx.run_config()
313    }
314
315    fn end_invocation(&self) {
316        self.parent_ctx.end_invocation();
317    }
318
319    fn ended(&self) -> bool {
320        self.parent_ctx.ended()
321    }
322
323    fn user_scopes(&self) -> Vec<String> {
324        self.parent_ctx.user_scopes()
325    }
326
327    fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
328        self.parent_ctx.request_metadata()
329    }
330}
331
332#[async_trait]
333impl Agent for LoopAgent {
334    fn name(&self) -> &str {
335        &self.name
336    }
337
338    fn description(&self) -> &str {
339        &self.description
340    }
341
342    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
343        &self.sub_agents
344    }
345
346    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
347        let sub_agents = self.sub_agents.clone();
348        let max_iterations = self.max_iterations;
349        let before_callbacks = self.before_callbacks.clone();
350        let after_callbacks = self.after_callbacks.clone();
351        let agent_name = self.name.clone();
352        let run_ctx = super::skill_context::with_skill_injected_context(
353            ctx,
354            self.skills_index.as_ref(),
355            &self.skill_policy,
356            self.max_skill_chars,
357        );
358        let run_ctx = Arc::new(HistoryTrackingContext::new(run_ctx));
359
360        let s = stream! {
361            use futures::StreamExt;
362
363            // ===== BEFORE AGENT CALLBACKS =====
364            for callback in before_callbacks.as_ref() {
365                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
366                    Ok(Some(content)) => {
367                        let mut early_event = Event::new(run_ctx.invocation_id());
368                        early_event.author = agent_name.clone();
369                        early_event.llm_response.content = Some(content);
370                        yield Ok(early_event);
371
372                        for after_cb in after_callbacks.as_ref() {
373                            match after_cb(run_ctx.clone() as Arc<dyn CallbackContext>).await {
374                                Ok(Some(after_content)) => {
375                                    let mut after_event = Event::new(run_ctx.invocation_id());
376                                    after_event.author = agent_name.clone();
377                                    after_event.llm_response.content = Some(after_content);
378                                    yield Ok(after_event);
379                                    return;
380                                }
381                                Ok(None) => continue,
382                                Err(e) => { yield Err(e); return; }
383                            }
384                        }
385                        return;
386                    }
387                    Ok(None) => continue,
388                    Err(e) => { yield Err(e); return; }
389                }
390            }
391
392            let mut remaining = max_iterations;
393
394            loop {
395                let mut should_exit = false;
396
397                for agent in &sub_agents {
398                    let mut stream = agent.run(run_ctx.clone() as Arc<dyn InvocationContext>).await?;
399
400                    while let Some(result) = stream.next().await {
401                        match result {
402                            Ok(event) => {
403                                run_ctx.apply_event(&event);
404                                if event.actions.escalate {
405                                    should_exit = true;
406                                }
407                                yield Ok(event);
408                            }
409                            Err(e) => {
410                                yield Err(e);
411                                return;
412                            }
413                        }
414                    }
415
416                    if should_exit {
417                        break;
418                    }
419                }
420
421                if should_exit {
422                    break;
423                }
424
425                remaining -= 1;
426                if remaining == 0 {
427                    break;
428                }
429            }
430
431            // ===== AFTER AGENT CALLBACKS =====
432            for callback in after_callbacks.as_ref() {
433                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
434                    Ok(Some(content)) => {
435                        let mut after_event = Event::new(run_ctx.invocation_id());
436                        after_event.author = agent_name.clone();
437                        after_event.llm_response.content = Some(content);
438                        yield Ok(after_event);
439                        break;
440                    }
441                    Ok(None) => continue,
442                    Err(e) => { yield Err(e); return; }
443                }
444            }
445        };
446
447        Ok(Box::pin(s))
448    }
449}