Skip to main content

adk_agent/workflow/
loop_agent.rs

1#[cfg(feature = "skills")]
2use crate::skill_shim::load_skill_index;
3use crate::skill_shim::{SelectionPolicy, SkillIndex};
4use adk_core::{
5    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Content, Event, EventStream,
6    InvocationContext, ReadonlyContext, Result, Session, State,
7};
8use async_stream::stream;
9use async_trait::async_trait;
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12
13/// Default maximum iterations for LoopAgent when none is specified.
14/// Prevents infinite loops from consuming unbounded resources.
15pub const DEFAULT_LOOP_MAX_ITERATIONS: u32 = 1000;
16
17/// Loop agent executes sub-agents repeatedly for N iterations or until escalation
18pub struct LoopAgent {
19    name: String,
20    description: String,
21    sub_agents: Vec<Arc<dyn Agent>>,
22    max_iterations: u32,
23    skills_index: Option<Arc<SkillIndex>>,
24    skill_policy: SelectionPolicy,
25    max_skill_chars: usize,
26    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
27    after_callbacks: Arc<Vec<AfterAgentCallback>>,
28}
29
30impl LoopAgent {
31    /// Create a new loop agent with the given name and sub-agents.
32    pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
33        Self {
34            name: name.into(),
35            description: String::new(),
36            sub_agents,
37            max_iterations: DEFAULT_LOOP_MAX_ITERATIONS,
38            skills_index: None,
39            skill_policy: SelectionPolicy::default(),
40            max_skill_chars: 2000,
41            before_callbacks: Arc::new(Vec::new()),
42            after_callbacks: Arc::new(Vec::new()),
43        }
44    }
45
46    /// Set the agent description.
47    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
48        self.description = desc.into();
49        self
50    }
51
52    /// Set the maximum number of loop iterations.
53    pub fn with_max_iterations(mut self, max: u32) -> Self {
54        self.max_iterations = max;
55        self
56    }
57
58    /// Set a preloaded skills index for this agent.
59    #[cfg(feature = "skills")]
60    pub fn with_skills(mut self, index: SkillIndex) -> Self {
61        self.skills_index = Some(Arc::new(index));
62        self
63    }
64
65    /// Auto-load skills from `.skills/` in the current working directory.
66    #[cfg(feature = "skills")]
67    pub fn with_auto_skills(self) -> Result<Self> {
68        self.with_skills_from_root(".")
69    }
70
71    /// Auto-load skills from `.skills/` under a custom root directory.
72    #[cfg(feature = "skills")]
73    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
74        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
75        self.skills_index = Some(Arc::new(index));
76        Ok(self)
77    }
78
79    /// Customize skill selection behavior.
80    #[cfg(feature = "skills")]
81    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
82        self.skill_policy = policy;
83        self
84    }
85
86    /// Limit injected skill content length.
87    #[cfg(feature = "skills")]
88    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
89        self.max_skill_chars = max_chars;
90        self
91    }
92
93    /// Add a before-agent callback.
94    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
95        if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
96            callbacks.push(callback);
97        }
98        self
99    }
100
101    /// Add an after-agent callback.
102    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
103        if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
104            callbacks.push(callback);
105        }
106        self
107    }
108}
109
110struct HistoryTrackingSession {
111    parent_ctx: Arc<dyn InvocationContext>,
112    history: Arc<RwLock<Vec<Content>>>,
113    state: StateTrackingState,
114}
115
116struct StateTrackingState {
117    values: RwLock<HashMap<String, serde_json::Value>>,
118}
119
120impl StateTrackingState {
121    fn new(parent_ctx: &Arc<dyn InvocationContext>) -> Self {
122        Self { values: RwLock::new(parent_ctx.session().state().all()) }
123    }
124
125    fn apply_delta(&self, delta: &HashMap<String, serde_json::Value>) {
126        if delta.is_empty() {
127            return;
128        }
129
130        let mut values = self.values.write().unwrap_or_else(|e| e.into_inner());
131        for (key, value) in delta {
132            values.insert(key.clone(), value.clone());
133        }
134    }
135}
136
137impl State for StateTrackingState {
138    fn get(&self, key: &str) -> Option<serde_json::Value> {
139        self.values.read().unwrap_or_else(|e| e.into_inner()).get(key).cloned()
140    }
141
142    fn set(&mut self, key: String, value: serde_json::Value) {
143        if let Err(msg) = adk_core::validate_state_key(&key) {
144            tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
145            return;
146        }
147        self.values.write().unwrap_or_else(|e| e.into_inner()).insert(key, value);
148    }
149
150    fn all(&self) -> HashMap<String, serde_json::Value> {
151        self.values.read().unwrap_or_else(|e| e.into_inner()).clone()
152    }
153}
154
155impl HistoryTrackingSession {
156    fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
157        Self {
158            history: Arc::new(RwLock::new(parent_ctx.session().conversation_history())),
159            state: StateTrackingState::new(&parent_ctx),
160            parent_ctx,
161        }
162    }
163
164    fn apply_event(&self, event: &Event) {
165        if let Some(content) = &event.llm_response.content {
166            // Consolidate streaming chunks: if the last history entry has the
167            // same role, merge text into it instead of creating a new entry.
168            // This prevents N streaming chunks from becoming N separate Content
169            // entries that bloat context for subsequent agents.
170            let mut history = self.history.write().unwrap_or_else(|e| e.into_inner());
171
172            if event.llm_response.partial {
173                // Partial chunk — merge into last entry if same role
174                if let Some(last) = history.last_mut() {
175                    if last.role == content.role {
176                        for part in &content.parts {
177                            if let adk_core::Part::Text { text } = part {
178                                // Append text to the last Text part
179                                if let Some(adk_core::Part::Text { text: existing }) =
180                                    last.parts.last_mut()
181                                {
182                                    existing.push_str(text);
183                                } else {
184                                    last.parts.push(part.clone());
185                                }
186                            } else {
187                                last.parts.push(part.clone());
188                            }
189                        }
190                        return;
191                    }
192                }
193                // No matching last entry — start a new one
194                history.push(content.clone());
195            } else {
196                // Final event (partial=false) — append as-is.
197                // For non-streaming mode this carries the full content.
198                // For streaming mode the accumulated text is already in the
199                // last history entry from partial merges above, so the final
200                // chunk (which may carry the last fragment or be empty) is
201                // merged if same role, or appended if different.
202                if let Some(last) = history.last_mut() {
203                    if last.role == content.role && !content.parts.is_empty() {
204                        // Merge any remaining text from the final chunk
205                        for part in &content.parts {
206                            if let adk_core::Part::Text { text } = part {
207                                if let Some(adk_core::Part::Text { text: existing }) =
208                                    last.parts.last_mut()
209                                {
210                                    existing.push_str(text);
211                                } else {
212                                    last.parts.push(part.clone());
213                                }
214                            } else {
215                                last.parts.push(part.clone());
216                            }
217                        }
218                    } else if !content.parts.is_empty() {
219                        history.push(content.clone());
220                    }
221                } else {
222                    history.push(content.clone());
223                }
224            }
225        }
226        self.state.apply_delta(&event.actions.state_delta);
227    }
228}
229
230impl Session for HistoryTrackingSession {
231    fn id(&self) -> &str {
232        self.parent_ctx.session().id()
233    }
234
235    fn app_name(&self) -> &str {
236        self.parent_ctx.session().app_name()
237    }
238
239    fn user_id(&self) -> &str {
240        self.parent_ctx.session().user_id()
241    }
242
243    fn state(&self) -> &dyn State {
244        &self.state
245    }
246
247    fn conversation_history(&self) -> Vec<Content> {
248        self.history.read().unwrap_or_else(|e| e.into_inner()).clone()
249    }
250
251    fn conversation_history_for_agent(&self, _agent_name: &str) -> Vec<Content> {
252        self.conversation_history()
253    }
254
255    fn append_to_history(&self, content: Content) {
256        self.history.write().unwrap_or_else(|e| e.into_inner()).push(content);
257    }
258}
259
260struct HistoryTrackingContext {
261    parent_ctx: Arc<dyn InvocationContext>,
262    session: HistoryTrackingSession,
263}
264
265impl HistoryTrackingContext {
266    fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
267        let session = HistoryTrackingSession::new(parent_ctx.clone());
268        Self { parent_ctx, session }
269    }
270
271    fn apply_event(&self, event: &Event) {
272        self.session.apply_event(event);
273    }
274}
275
276#[async_trait]
277impl adk_core::ReadonlyContext for HistoryTrackingContext {
278    fn invocation_id(&self) -> &str {
279        self.parent_ctx.invocation_id()
280    }
281
282    fn agent_name(&self) -> &str {
283        self.parent_ctx.agent_name()
284    }
285
286    fn user_id(&self) -> &str {
287        self.parent_ctx.user_id()
288    }
289
290    fn app_name(&self) -> &str {
291        self.parent_ctx.app_name()
292    }
293
294    fn session_id(&self) -> &str {
295        self.parent_ctx.session_id()
296    }
297
298    fn branch(&self) -> &str {
299        self.parent_ctx.branch()
300    }
301
302    fn user_content(&self) -> &Content {
303        self.parent_ctx.user_content()
304    }
305}
306
307#[async_trait]
308impl CallbackContext for HistoryTrackingContext {
309    fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
310        self.parent_ctx.artifacts()
311    }
312}
313
314#[async_trait]
315impl InvocationContext for HistoryTrackingContext {
316    fn agent(&self) -> Arc<dyn Agent> {
317        self.parent_ctx.agent()
318    }
319
320    fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
321        self.parent_ctx.memory()
322    }
323
324    fn session(&self) -> &dyn Session {
325        &self.session
326    }
327
328    fn run_config(&self) -> &adk_core::RunConfig {
329        self.parent_ctx.run_config()
330    }
331
332    fn end_invocation(&self) {
333        self.parent_ctx.end_invocation();
334    }
335
336    fn ended(&self) -> bool {
337        self.parent_ctx.ended()
338    }
339
340    fn user_scopes(&self) -> Vec<String> {
341        self.parent_ctx.user_scopes()
342    }
343
344    fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
345        self.parent_ctx.request_metadata()
346    }
347}
348
349#[async_trait]
350impl Agent for LoopAgent {
351    fn name(&self) -> &str {
352        &self.name
353    }
354
355    fn description(&self) -> &str {
356        &self.description
357    }
358
359    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
360        &self.sub_agents
361    }
362
363    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
364        let sub_agents = self.sub_agents.clone();
365        let max_iterations = self.max_iterations;
366        let before_callbacks = self.before_callbacks.clone();
367        let after_callbacks = self.after_callbacks.clone();
368        let agent_name = self.name.clone();
369        let run_ctx = super::skill_context::with_skill_injected_context(
370            ctx,
371            self.skills_index.as_ref(),
372            &self.skill_policy,
373            self.max_skill_chars,
374        );
375        let run_ctx = Arc::new(HistoryTrackingContext::new(run_ctx));
376
377        let s = stream! {
378            use futures::StreamExt;
379
380            // ===== BEFORE AGENT CALLBACKS =====
381            for callback in before_callbacks.as_ref() {
382                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
383                    Ok(Some(content)) => {
384                        let mut early_event = Event::new(run_ctx.invocation_id());
385                        early_event.author = agent_name.clone();
386                        early_event.llm_response.content = Some(content);
387                        yield Ok(early_event);
388
389                        for after_cb in after_callbacks.as_ref() {
390                            match after_cb(run_ctx.clone() as Arc<dyn CallbackContext>).await {
391                                Ok(Some(after_content)) => {
392                                    let mut after_event = Event::new(run_ctx.invocation_id());
393                                    after_event.author = agent_name.clone();
394                                    after_event.llm_response.content = Some(after_content);
395                                    yield Ok(after_event);
396                                    return;
397                                }
398                                Ok(None) => continue,
399                                Err(e) => { yield Err(e); return; }
400                            }
401                        }
402                        return;
403                    }
404                    Ok(None) => continue,
405                    Err(e) => { yield Err(e); return; }
406                }
407            }
408
409            let mut remaining = max_iterations;
410
411            loop {
412                let mut should_exit = false;
413
414                for agent in &sub_agents {
415                    let mut stream = agent.run(run_ctx.clone() as Arc<dyn InvocationContext>).await?;
416
417                    while let Some(result) = stream.next().await {
418                        match result {
419                            Ok(event) => {
420                                run_ctx.apply_event(&event);
421                                if event.actions.escalate {
422                                    should_exit = true;
423                                }
424                                yield Ok(event);
425                            }
426                            Err(e) => {
427                                yield Err(e);
428                                return;
429                            }
430                        }
431                    }
432
433                    if should_exit {
434                        break;
435                    }
436                }
437
438                if should_exit {
439                    break;
440                }
441
442                remaining -= 1;
443                if remaining == 0 {
444                    break;
445                }
446            }
447
448            // ===== AFTER AGENT CALLBACKS =====
449            for callback in after_callbacks.as_ref() {
450                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
451                    Ok(Some(content)) => {
452                        let mut after_event = Event::new(run_ctx.invocation_id());
453                        after_event.author = agent_name.clone();
454                        after_event.llm_response.content = Some(content);
455                        yield Ok(after_event);
456                        break;
457                    }
458                    Ok(None) => continue,
459                    Err(e) => { yield Err(e); return; }
460                }
461            }
462        };
463
464        Ok(Box::pin(s))
465    }
466}