Skip to main content

adk_agent/workflow/
loop_agent.rs

1#[cfg(feature = "skills")]
2use crate::skill_shim::load_skill_index;
3use crate::skill_shim::{SelectionPolicy, SkillIndex};
4use adk_core::{
5    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Content, Event, EventStream,
6    InvocationContext, ReadonlyContext, Result, Session, State,
7};
8use async_stream::stream;
9use async_trait::async_trait;
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12
13/// Default maximum iterations for LoopAgent when none is specified.
14/// Prevents infinite loops from consuming unbounded resources.
15pub const DEFAULT_LOOP_MAX_ITERATIONS: u32 = 1000;
16
17/// Loop agent executes sub-agents repeatedly for N iterations or until escalation
18pub struct LoopAgent {
19    name: String,
20    description: String,
21    sub_agents: Vec<Arc<dyn Agent>>,
22    max_iterations: u32,
23    skills_index: Option<Arc<SkillIndex>>,
24    skill_policy: SelectionPolicy,
25    max_skill_chars: usize,
26    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
27    after_callbacks: Arc<Vec<AfterAgentCallback>>,
28}
29
30impl LoopAgent {
31    pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
32        Self {
33            name: name.into(),
34            description: String::new(),
35            sub_agents,
36            max_iterations: DEFAULT_LOOP_MAX_ITERATIONS,
37            skills_index: None,
38            skill_policy: SelectionPolicy::default(),
39            max_skill_chars: 2000,
40            before_callbacks: Arc::new(Vec::new()),
41            after_callbacks: Arc::new(Vec::new()),
42        }
43    }
44
45    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
46        self.description = desc.into();
47        self
48    }
49
50    pub fn with_max_iterations(mut self, max: u32) -> Self {
51        self.max_iterations = max;
52        self
53    }
54
55    #[cfg(feature = "skills")]
56    pub fn with_skills(mut self, index: SkillIndex) -> Self {
57        self.skills_index = Some(Arc::new(index));
58        self
59    }
60
61    #[cfg(feature = "skills")]
62    pub fn with_auto_skills(self) -> Result<Self> {
63        self.with_skills_from_root(".")
64    }
65
66    #[cfg(feature = "skills")]
67    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
68        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
69        self.skills_index = Some(Arc::new(index));
70        Ok(self)
71    }
72
73    #[cfg(feature = "skills")]
74    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
75        self.skill_policy = policy;
76        self
77    }
78
79    #[cfg(feature = "skills")]
80    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
81        self.max_skill_chars = max_chars;
82        self
83    }
84
85    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
86        if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
87            callbacks.push(callback);
88        }
89        self
90    }
91
92    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
93        if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
94            callbacks.push(callback);
95        }
96        self
97    }
98}
99
100struct HistoryTrackingSession {
101    parent_ctx: Arc<dyn InvocationContext>,
102    history: Arc<RwLock<Vec<Content>>>,
103    state: StateTrackingState,
104}
105
106struct StateTrackingState {
107    values: RwLock<HashMap<String, serde_json::Value>>,
108}
109
110impl StateTrackingState {
111    fn new(parent_ctx: &Arc<dyn InvocationContext>) -> Self {
112        Self { values: RwLock::new(parent_ctx.session().state().all()) }
113    }
114
115    fn apply_delta(&self, delta: &HashMap<String, serde_json::Value>) {
116        if delta.is_empty() {
117            return;
118        }
119
120        let mut values = self.values.write().unwrap_or_else(|e| e.into_inner());
121        for (key, value) in delta {
122            values.insert(key.clone(), value.clone());
123        }
124    }
125}
126
127impl State for StateTrackingState {
128    fn get(&self, key: &str) -> Option<serde_json::Value> {
129        self.values.read().unwrap_or_else(|e| e.into_inner()).get(key).cloned()
130    }
131
132    fn set(&mut self, key: String, value: serde_json::Value) {
133        if let Err(msg) = adk_core::validate_state_key(&key) {
134            tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
135            return;
136        }
137        self.values.write().unwrap_or_else(|e| e.into_inner()).insert(key, value);
138    }
139
140    fn all(&self) -> HashMap<String, serde_json::Value> {
141        self.values.read().unwrap_or_else(|e| e.into_inner()).clone()
142    }
143}
144
145impl HistoryTrackingSession {
146    fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
147        Self {
148            history: Arc::new(RwLock::new(parent_ctx.session().conversation_history())),
149            state: StateTrackingState::new(&parent_ctx),
150            parent_ctx,
151        }
152    }
153
154    fn apply_event(&self, event: &Event) {
155        if let Some(content) = &event.llm_response.content {
156            // Consolidate streaming chunks: if the last history entry has the
157            // same role, merge text into it instead of creating a new entry.
158            // This prevents N streaming chunks from becoming N separate Content
159            // entries that bloat context for subsequent agents.
160            let mut history = self.history.write().unwrap_or_else(|e| e.into_inner());
161
162            if event.llm_response.partial {
163                // Partial chunk — merge into last entry if same role
164                if let Some(last) = history.last_mut() {
165                    if last.role == content.role {
166                        for part in &content.parts {
167                            if let adk_core::Part::Text { text } = part {
168                                // Append text to the last Text part
169                                if let Some(adk_core::Part::Text { text: existing }) =
170                                    last.parts.last_mut()
171                                {
172                                    existing.push_str(text);
173                                } else {
174                                    last.parts.push(part.clone());
175                                }
176                            } else {
177                                last.parts.push(part.clone());
178                            }
179                        }
180                        return;
181                    }
182                }
183                // No matching last entry — start a new one
184                history.push(content.clone());
185            } else {
186                // Final event (partial=false) — append as-is.
187                // For non-streaming mode this carries the full content.
188                // For streaming mode the accumulated text is already in the
189                // last history entry from partial merges above, so the final
190                // chunk (which may carry the last fragment or be empty) is
191                // merged if same role, or appended if different.
192                if let Some(last) = history.last_mut() {
193                    if last.role == content.role && !content.parts.is_empty() {
194                        // Merge any remaining text from the final chunk
195                        for part in &content.parts {
196                            if let adk_core::Part::Text { text } = part {
197                                if let Some(adk_core::Part::Text { text: existing }) =
198                                    last.parts.last_mut()
199                                {
200                                    existing.push_str(text);
201                                } else {
202                                    last.parts.push(part.clone());
203                                }
204                            } else {
205                                last.parts.push(part.clone());
206                            }
207                        }
208                    } else if !content.parts.is_empty() {
209                        history.push(content.clone());
210                    }
211                } else {
212                    history.push(content.clone());
213                }
214            }
215        }
216        self.state.apply_delta(&event.actions.state_delta);
217    }
218}
219
220impl Session for HistoryTrackingSession {
221    fn id(&self) -> &str {
222        self.parent_ctx.session().id()
223    }
224
225    fn app_name(&self) -> &str {
226        self.parent_ctx.session().app_name()
227    }
228
229    fn user_id(&self) -> &str {
230        self.parent_ctx.session().user_id()
231    }
232
233    fn state(&self) -> &dyn State {
234        &self.state
235    }
236
237    fn conversation_history(&self) -> Vec<Content> {
238        self.history.read().unwrap_or_else(|e| e.into_inner()).clone()
239    }
240
241    fn conversation_history_for_agent(&self, _agent_name: &str) -> Vec<Content> {
242        self.conversation_history()
243    }
244
245    fn append_to_history(&self, content: Content) {
246        self.history.write().unwrap_or_else(|e| e.into_inner()).push(content);
247    }
248}
249
250struct HistoryTrackingContext {
251    parent_ctx: Arc<dyn InvocationContext>,
252    session: HistoryTrackingSession,
253}
254
255impl HistoryTrackingContext {
256    fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
257        let session = HistoryTrackingSession::new(parent_ctx.clone());
258        Self { parent_ctx, session }
259    }
260
261    fn apply_event(&self, event: &Event) {
262        self.session.apply_event(event);
263    }
264}
265
266#[async_trait]
267impl adk_core::ReadonlyContext for HistoryTrackingContext {
268    fn invocation_id(&self) -> &str {
269        self.parent_ctx.invocation_id()
270    }
271
272    fn agent_name(&self) -> &str {
273        self.parent_ctx.agent_name()
274    }
275
276    fn user_id(&self) -> &str {
277        self.parent_ctx.user_id()
278    }
279
280    fn app_name(&self) -> &str {
281        self.parent_ctx.app_name()
282    }
283
284    fn session_id(&self) -> &str {
285        self.parent_ctx.session_id()
286    }
287
288    fn branch(&self) -> &str {
289        self.parent_ctx.branch()
290    }
291
292    fn user_content(&self) -> &Content {
293        self.parent_ctx.user_content()
294    }
295}
296
297#[async_trait]
298impl CallbackContext for HistoryTrackingContext {
299    fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
300        self.parent_ctx.artifacts()
301    }
302}
303
304#[async_trait]
305impl InvocationContext for HistoryTrackingContext {
306    fn agent(&self) -> Arc<dyn Agent> {
307        self.parent_ctx.agent()
308    }
309
310    fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
311        self.parent_ctx.memory()
312    }
313
314    fn session(&self) -> &dyn Session {
315        &self.session
316    }
317
318    fn run_config(&self) -> &adk_core::RunConfig {
319        self.parent_ctx.run_config()
320    }
321
322    fn end_invocation(&self) {
323        self.parent_ctx.end_invocation();
324    }
325
326    fn ended(&self) -> bool {
327        self.parent_ctx.ended()
328    }
329
330    fn user_scopes(&self) -> Vec<String> {
331        self.parent_ctx.user_scopes()
332    }
333
334    fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
335        self.parent_ctx.request_metadata()
336    }
337}
338
339#[async_trait]
340impl Agent for LoopAgent {
341    fn name(&self) -> &str {
342        &self.name
343    }
344
345    fn description(&self) -> &str {
346        &self.description
347    }
348
349    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
350        &self.sub_agents
351    }
352
353    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
354        let sub_agents = self.sub_agents.clone();
355        let max_iterations = self.max_iterations;
356        let before_callbacks = self.before_callbacks.clone();
357        let after_callbacks = self.after_callbacks.clone();
358        let agent_name = self.name.clone();
359        let run_ctx = super::skill_context::with_skill_injected_context(
360            ctx,
361            self.skills_index.as_ref(),
362            &self.skill_policy,
363            self.max_skill_chars,
364        );
365        let run_ctx = Arc::new(HistoryTrackingContext::new(run_ctx));
366
367        let s = stream! {
368            use futures::StreamExt;
369
370            // ===== BEFORE AGENT CALLBACKS =====
371            for callback in before_callbacks.as_ref() {
372                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
373                    Ok(Some(content)) => {
374                        let mut early_event = Event::new(run_ctx.invocation_id());
375                        early_event.author = agent_name.clone();
376                        early_event.llm_response.content = Some(content);
377                        yield Ok(early_event);
378
379                        for after_cb in after_callbacks.as_ref() {
380                            match after_cb(run_ctx.clone() as Arc<dyn CallbackContext>).await {
381                                Ok(Some(after_content)) => {
382                                    let mut after_event = Event::new(run_ctx.invocation_id());
383                                    after_event.author = agent_name.clone();
384                                    after_event.llm_response.content = Some(after_content);
385                                    yield Ok(after_event);
386                                    return;
387                                }
388                                Ok(None) => continue,
389                                Err(e) => { yield Err(e); return; }
390                            }
391                        }
392                        return;
393                    }
394                    Ok(None) => continue,
395                    Err(e) => { yield Err(e); return; }
396                }
397            }
398
399            let mut remaining = max_iterations;
400
401            loop {
402                let mut should_exit = false;
403
404                for agent in &sub_agents {
405                    let mut stream = agent.run(run_ctx.clone() as Arc<dyn InvocationContext>).await?;
406
407                    while let Some(result) = stream.next().await {
408                        match result {
409                            Ok(event) => {
410                                run_ctx.apply_event(&event);
411                                if event.actions.escalate {
412                                    should_exit = true;
413                                }
414                                yield Ok(event);
415                            }
416                            Err(e) => {
417                                yield Err(e);
418                                return;
419                            }
420                        }
421                    }
422
423                    if should_exit {
424                        break;
425                    }
426                }
427
428                if should_exit {
429                    break;
430                }
431
432                remaining -= 1;
433                if remaining == 0 {
434                    break;
435                }
436            }
437
438            // ===== AFTER AGENT CALLBACKS =====
439            for callback in after_callbacks.as_ref() {
440                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
441                    Ok(Some(content)) => {
442                        let mut after_event = Event::new(run_ctx.invocation_id());
443                        after_event.author = agent_name.clone();
444                        after_event.llm_response.content = Some(content);
445                        yield Ok(after_event);
446                        break;
447                    }
448                    Ok(None) => continue,
449                    Err(e) => { yield Err(e); return; }
450                }
451            }
452        };
453
454        Ok(Box::pin(s))
455    }
456}