Skip to main content

adk_agent/workflow/
loop_agent.rs

1use adk_core::{
2    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Content, Event, EventStream,
3    InvocationContext, ReadonlyContext, Result, Session, State,
4};
5use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index};
6use async_stream::stream;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11/// Default maximum iterations for LoopAgent when none is specified.
12/// Prevents infinite loops from consuming unbounded resources.
13pub const DEFAULT_LOOP_MAX_ITERATIONS: u32 = 1000;
14
15/// Loop agent executes sub-agents repeatedly for N iterations or until escalation
16pub struct LoopAgent {
17    name: String,
18    description: String,
19    sub_agents: Vec<Arc<dyn Agent>>,
20    max_iterations: u32,
21    skills_index: Option<Arc<SkillIndex>>,
22    skill_policy: SelectionPolicy,
23    max_skill_chars: usize,
24    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
25    after_callbacks: Arc<Vec<AfterAgentCallback>>,
26}
27
28impl LoopAgent {
29    pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
30        Self {
31            name: name.into(),
32            description: String::new(),
33            sub_agents,
34            max_iterations: DEFAULT_LOOP_MAX_ITERATIONS,
35            skills_index: None,
36            skill_policy: SelectionPolicy::default(),
37            max_skill_chars: 2000,
38            before_callbacks: Arc::new(Vec::new()),
39            after_callbacks: Arc::new(Vec::new()),
40        }
41    }
42
43    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
44        self.description = desc.into();
45        self
46    }
47
48    pub fn with_max_iterations(mut self, max: u32) -> Self {
49        self.max_iterations = max;
50        self
51    }
52
53    pub fn with_skills(mut self, index: SkillIndex) -> Self {
54        self.skills_index = Some(Arc::new(index));
55        self
56    }
57
58    pub fn with_auto_skills(self) -> Result<Self> {
59        self.with_skills_from_root(".")
60    }
61
62    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
63        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
64        self.skills_index = Some(Arc::new(index));
65        Ok(self)
66    }
67
68    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
69        self.skill_policy = policy;
70        self
71    }
72
73    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
74        self.max_skill_chars = max_chars;
75        self
76    }
77
78    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
79        if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
80            callbacks.push(callback);
81        }
82        self
83    }
84
85    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
86        if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
87            callbacks.push(callback);
88        }
89        self
90    }
91}
92
93struct HistoryTrackingSession {
94    parent_ctx: Arc<dyn InvocationContext>,
95    history: Arc<RwLock<Vec<Content>>>,
96    state: StateTrackingState,
97}
98
99struct StateTrackingState {
100    values: RwLock<HashMap<String, serde_json::Value>>,
101}
102
103impl StateTrackingState {
104    fn new(parent_ctx: &Arc<dyn InvocationContext>) -> Self {
105        Self { values: RwLock::new(parent_ctx.session().state().all()) }
106    }
107
108    fn apply_delta(&self, delta: &HashMap<String, serde_json::Value>) {
109        if delta.is_empty() {
110            return;
111        }
112
113        let mut values = self.values.write().unwrap_or_else(|e| e.into_inner());
114        for (key, value) in delta {
115            values.insert(key.clone(), value.clone());
116        }
117    }
118}
119
120impl State for StateTrackingState {
121    fn get(&self, key: &str) -> Option<serde_json::Value> {
122        self.values.read().unwrap_or_else(|e| e.into_inner()).get(key).cloned()
123    }
124
125    fn set(&mut self, key: String, value: serde_json::Value) {
126        if let Err(msg) = adk_core::validate_state_key(&key) {
127            tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
128            return;
129        }
130        self.values.write().unwrap_or_else(|e| e.into_inner()).insert(key, value);
131    }
132
133    fn all(&self) -> HashMap<String, serde_json::Value> {
134        self.values.read().unwrap_or_else(|e| e.into_inner()).clone()
135    }
136}
137
138impl HistoryTrackingSession {
139    fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
140        Self {
141            history: Arc::new(RwLock::new(parent_ctx.session().conversation_history())),
142            state: StateTrackingState::new(&parent_ctx),
143            parent_ctx,
144        }
145    }
146
147    fn apply_event(&self, event: &Event) {
148        if let Some(content) = &event.llm_response.content {
149            self.append_to_history(content.clone());
150        }
151        self.state.apply_delta(&event.actions.state_delta);
152    }
153}
154
155impl Session for HistoryTrackingSession {
156    fn id(&self) -> &str {
157        self.parent_ctx.session().id()
158    }
159
160    fn app_name(&self) -> &str {
161        self.parent_ctx.session().app_name()
162    }
163
164    fn user_id(&self) -> &str {
165        self.parent_ctx.session().user_id()
166    }
167
168    fn state(&self) -> &dyn State {
169        &self.state
170    }
171
172    fn conversation_history(&self) -> Vec<Content> {
173        self.history.read().unwrap_or_else(|e| e.into_inner()).clone()
174    }
175
176    fn conversation_history_for_agent(&self, _agent_name: &str) -> Vec<Content> {
177        self.conversation_history()
178    }
179
180    fn append_to_history(&self, content: Content) {
181        self.history.write().unwrap_or_else(|e| e.into_inner()).push(content);
182    }
183}
184
185struct HistoryTrackingContext {
186    parent_ctx: Arc<dyn InvocationContext>,
187    session: HistoryTrackingSession,
188}
189
190impl HistoryTrackingContext {
191    fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
192        let session = HistoryTrackingSession::new(parent_ctx.clone());
193        Self { parent_ctx, session }
194    }
195
196    fn apply_event(&self, event: &Event) {
197        self.session.apply_event(event);
198    }
199}
200
201#[async_trait]
202impl adk_core::ReadonlyContext for HistoryTrackingContext {
203    fn invocation_id(&self) -> &str {
204        self.parent_ctx.invocation_id()
205    }
206
207    fn agent_name(&self) -> &str {
208        self.parent_ctx.agent_name()
209    }
210
211    fn user_id(&self) -> &str {
212        self.parent_ctx.user_id()
213    }
214
215    fn app_name(&self) -> &str {
216        self.parent_ctx.app_name()
217    }
218
219    fn session_id(&self) -> &str {
220        self.parent_ctx.session_id()
221    }
222
223    fn branch(&self) -> &str {
224        self.parent_ctx.branch()
225    }
226
227    fn user_content(&self) -> &Content {
228        self.parent_ctx.user_content()
229    }
230}
231
232#[async_trait]
233impl CallbackContext for HistoryTrackingContext {
234    fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
235        self.parent_ctx.artifacts()
236    }
237}
238
239#[async_trait]
240impl InvocationContext for HistoryTrackingContext {
241    fn agent(&self) -> Arc<dyn Agent> {
242        self.parent_ctx.agent()
243    }
244
245    fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
246        self.parent_ctx.memory()
247    }
248
249    fn session(&self) -> &dyn Session {
250        &self.session
251    }
252
253    fn run_config(&self) -> &adk_core::RunConfig {
254        self.parent_ctx.run_config()
255    }
256
257    fn end_invocation(&self) {
258        self.parent_ctx.end_invocation();
259    }
260
261    fn ended(&self) -> bool {
262        self.parent_ctx.ended()
263    }
264
265    fn user_scopes(&self) -> Vec<String> {
266        self.parent_ctx.user_scopes()
267    }
268
269    fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
270        self.parent_ctx.request_metadata()
271    }
272}
273
274#[async_trait]
275impl Agent for LoopAgent {
276    fn name(&self) -> &str {
277        &self.name
278    }
279
280    fn description(&self) -> &str {
281        &self.description
282    }
283
284    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
285        &self.sub_agents
286    }
287
288    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
289        let sub_agents = self.sub_agents.clone();
290        let max_iterations = self.max_iterations;
291        let before_callbacks = self.before_callbacks.clone();
292        let after_callbacks = self.after_callbacks.clone();
293        let agent_name = self.name.clone();
294        let run_ctx = super::skill_context::with_skill_injected_context(
295            ctx,
296            self.skills_index.as_ref(),
297            &self.skill_policy,
298            self.max_skill_chars,
299        );
300        let run_ctx = Arc::new(HistoryTrackingContext::new(run_ctx));
301
302        let s = stream! {
303            use futures::StreamExt;
304
305            // ===== BEFORE AGENT CALLBACKS =====
306            for callback in before_callbacks.as_ref() {
307                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
308                    Ok(Some(content)) => {
309                        let mut early_event = Event::new(run_ctx.invocation_id());
310                        early_event.author = agent_name.clone();
311                        early_event.llm_response.content = Some(content);
312                        yield Ok(early_event);
313
314                        for after_cb in after_callbacks.as_ref() {
315                            match after_cb(run_ctx.clone() as Arc<dyn CallbackContext>).await {
316                                Ok(Some(after_content)) => {
317                                    let mut after_event = Event::new(run_ctx.invocation_id());
318                                    after_event.author = agent_name.clone();
319                                    after_event.llm_response.content = Some(after_content);
320                                    yield Ok(after_event);
321                                    return;
322                                }
323                                Ok(None) => continue,
324                                Err(e) => { yield Err(e); return; }
325                            }
326                        }
327                        return;
328                    }
329                    Ok(None) => continue,
330                    Err(e) => { yield Err(e); return; }
331                }
332            }
333
334            let mut remaining = max_iterations;
335
336            loop {
337                let mut should_exit = false;
338
339                for agent in &sub_agents {
340                    let mut stream = agent.run(run_ctx.clone() as Arc<dyn InvocationContext>).await?;
341
342                    while let Some(result) = stream.next().await {
343                        match result {
344                            Ok(event) => {
345                                run_ctx.apply_event(&event);
346                                if event.actions.escalate {
347                                    should_exit = true;
348                                }
349                                yield Ok(event);
350                            }
351                            Err(e) => {
352                                yield Err(e);
353                                return;
354                            }
355                        }
356                    }
357
358                    if should_exit {
359                        break;
360                    }
361                }
362
363                if should_exit {
364                    break;
365                }
366
367                remaining -= 1;
368                if remaining == 0 {
369                    break;
370                }
371            }
372
373            // ===== AFTER AGENT CALLBACKS =====
374            for callback in after_callbacks.as_ref() {
375                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
376                    Ok(Some(content)) => {
377                        let mut after_event = Event::new(run_ctx.invocation_id());
378                        after_event.author = agent_name.clone();
379                        after_event.llm_response.content = Some(content);
380                        yield Ok(after_event);
381                        break;
382                    }
383                    Ok(None) => continue,
384                    Err(e) => { yield Err(e); return; }
385                }
386            }
387        };
388
389        Ok(Box::pin(s))
390    }
391}