Skip to main content

adk_agent/workflow/
loop_agent.rs

1use adk_core::{
2    AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Content, Event, EventStream,
3    InvocationContext, ReadonlyContext, Result, Session, State,
4};
5use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index};
6use async_stream::stream;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11/// Default maximum iterations for LoopAgent when none is specified.
12/// Prevents infinite loops from consuming unbounded resources.
13pub const DEFAULT_LOOP_MAX_ITERATIONS: u32 = 1000;
14
15/// Loop agent executes sub-agents repeatedly for N iterations or until escalation
16pub struct LoopAgent {
17    name: String,
18    description: String,
19    sub_agents: Vec<Arc<dyn Agent>>,
20    max_iterations: u32,
21    skills_index: Option<Arc<SkillIndex>>,
22    skill_policy: SelectionPolicy,
23    max_skill_chars: usize,
24    before_callbacks: Arc<Vec<BeforeAgentCallback>>,
25    after_callbacks: Arc<Vec<AfterAgentCallback>>,
26}
27
28impl LoopAgent {
29    pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
30        Self {
31            name: name.into(),
32            description: String::new(),
33            sub_agents,
34            max_iterations: DEFAULT_LOOP_MAX_ITERATIONS,
35            skills_index: None,
36            skill_policy: SelectionPolicy::default(),
37            max_skill_chars: 2000,
38            before_callbacks: Arc::new(Vec::new()),
39            after_callbacks: Arc::new(Vec::new()),
40        }
41    }
42
43    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
44        self.description = desc.into();
45        self
46    }
47
48    pub fn with_max_iterations(mut self, max: u32) -> Self {
49        self.max_iterations = max;
50        self
51    }
52
53    pub fn with_skills(mut self, index: SkillIndex) -> Self {
54        self.skills_index = Some(Arc::new(index));
55        self
56    }
57
58    pub fn with_auto_skills(self) -> Result<Self> {
59        self.with_skills_from_root(".")
60    }
61
62    pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
63        let index = load_skill_index(root).map_err(|e| adk_core::AdkError::Agent(e.to_string()))?;
64        self.skills_index = Some(Arc::new(index));
65        Ok(self)
66    }
67
68    pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
69        self.skill_policy = policy;
70        self
71    }
72
73    pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
74        self.max_skill_chars = max_chars;
75        self
76    }
77
78    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
79        Arc::get_mut(&mut self.before_callbacks)
80            .expect("before_callbacks not yet shared")
81            .push(callback);
82        self
83    }
84
85    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
86        Arc::get_mut(&mut self.after_callbacks)
87            .expect("after_callbacks not yet shared")
88            .push(callback);
89        self
90    }
91}
92
93struct HistoryTrackingSession {
94    parent_ctx: Arc<dyn InvocationContext>,
95    history: Arc<RwLock<Vec<Content>>>,
96    state: StateTrackingState,
97}
98
99struct StateTrackingState {
100    values: RwLock<HashMap<String, serde_json::Value>>,
101}
102
103impl StateTrackingState {
104    fn new(parent_ctx: &Arc<dyn InvocationContext>) -> Self {
105        Self { values: RwLock::new(parent_ctx.session().state().all()) }
106    }
107
108    fn apply_delta(&self, delta: &HashMap<String, serde_json::Value>) {
109        if delta.is_empty() {
110            return;
111        }
112
113        let mut values = self.values.write().unwrap_or_else(|e| e.into_inner());
114        for (key, value) in delta {
115            values.insert(key.clone(), value.clone());
116        }
117    }
118}
119
120impl State for StateTrackingState {
121    fn get(&self, key: &str) -> Option<serde_json::Value> {
122        self.values.read().unwrap_or_else(|e| e.into_inner()).get(key).cloned()
123    }
124
125    fn set(&mut self, key: String, value: serde_json::Value) {
126        self.values.write().unwrap_or_else(|e| e.into_inner()).insert(key, value);
127    }
128
129    fn all(&self) -> HashMap<String, serde_json::Value> {
130        self.values.read().unwrap_or_else(|e| e.into_inner()).clone()
131    }
132}
133
134impl HistoryTrackingSession {
135    fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
136        Self {
137            history: Arc::new(RwLock::new(parent_ctx.session().conversation_history())),
138            state: StateTrackingState::new(&parent_ctx),
139            parent_ctx,
140        }
141    }
142
143    fn apply_event(&self, event: &Event) {
144        if let Some(content) = &event.llm_response.content {
145            self.append_to_history(content.clone());
146        }
147        self.state.apply_delta(&event.actions.state_delta);
148    }
149}
150
151impl Session for HistoryTrackingSession {
152    fn id(&self) -> &str {
153        self.parent_ctx.session().id()
154    }
155
156    fn app_name(&self) -> &str {
157        self.parent_ctx.session().app_name()
158    }
159
160    fn user_id(&self) -> &str {
161        self.parent_ctx.session().user_id()
162    }
163
164    fn state(&self) -> &dyn State {
165        &self.state
166    }
167
168    fn conversation_history(&self) -> Vec<Content> {
169        self.history.read().unwrap_or_else(|e| e.into_inner()).clone()
170    }
171
172    fn conversation_history_for_agent(&self, _agent_name: &str) -> Vec<Content> {
173        self.conversation_history()
174    }
175
176    fn append_to_history(&self, content: Content) {
177        self.history.write().unwrap_or_else(|e| e.into_inner()).push(content);
178    }
179}
180
181struct HistoryTrackingContext {
182    parent_ctx: Arc<dyn InvocationContext>,
183    session: HistoryTrackingSession,
184}
185
186impl HistoryTrackingContext {
187    fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
188        let session = HistoryTrackingSession::new(parent_ctx.clone());
189        Self { parent_ctx, session }
190    }
191
192    fn apply_event(&self, event: &Event) {
193        self.session.apply_event(event);
194    }
195}
196
197#[async_trait]
198impl adk_core::ReadonlyContext for HistoryTrackingContext {
199    fn invocation_id(&self) -> &str {
200        self.parent_ctx.invocation_id()
201    }
202
203    fn agent_name(&self) -> &str {
204        self.parent_ctx.agent_name()
205    }
206
207    fn user_id(&self) -> &str {
208        self.parent_ctx.user_id()
209    }
210
211    fn app_name(&self) -> &str {
212        self.parent_ctx.app_name()
213    }
214
215    fn session_id(&self) -> &str {
216        self.parent_ctx.session_id()
217    }
218
219    fn branch(&self) -> &str {
220        self.parent_ctx.branch()
221    }
222
223    fn user_content(&self) -> &Content {
224        self.parent_ctx.user_content()
225    }
226}
227
228#[async_trait]
229impl CallbackContext for HistoryTrackingContext {
230    fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
231        self.parent_ctx.artifacts()
232    }
233}
234
235#[async_trait]
236impl InvocationContext for HistoryTrackingContext {
237    fn agent(&self) -> Arc<dyn Agent> {
238        self.parent_ctx.agent()
239    }
240
241    fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
242        self.parent_ctx.memory()
243    }
244
245    fn session(&self) -> &dyn Session {
246        &self.session
247    }
248
249    fn run_config(&self) -> &adk_core::RunConfig {
250        self.parent_ctx.run_config()
251    }
252
253    fn end_invocation(&self) {
254        self.parent_ctx.end_invocation();
255    }
256
257    fn ended(&self) -> bool {
258        self.parent_ctx.ended()
259    }
260
261    fn user_scopes(&self) -> Vec<String> {
262        self.parent_ctx.user_scopes()
263    }
264
265    fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
266        self.parent_ctx.request_metadata()
267    }
268}
269
270#[async_trait]
271impl Agent for LoopAgent {
272    fn name(&self) -> &str {
273        &self.name
274    }
275
276    fn description(&self) -> &str {
277        &self.description
278    }
279
280    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
281        &self.sub_agents
282    }
283
284    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
285        let sub_agents = self.sub_agents.clone();
286        let max_iterations = self.max_iterations;
287        let before_callbacks = self.before_callbacks.clone();
288        let after_callbacks = self.after_callbacks.clone();
289        let agent_name = self.name.clone();
290        let run_ctx = super::skill_context::with_skill_injected_context(
291            ctx,
292            self.skills_index.as_ref(),
293            &self.skill_policy,
294            self.max_skill_chars,
295        );
296        let run_ctx = Arc::new(HistoryTrackingContext::new(run_ctx));
297
298        let s = stream! {
299            use futures::StreamExt;
300
301            // ===== BEFORE AGENT CALLBACKS =====
302            for callback in before_callbacks.as_ref() {
303                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
304                    Ok(Some(content)) => {
305                        let mut early_event = Event::new(run_ctx.invocation_id());
306                        early_event.author = agent_name.clone();
307                        early_event.llm_response.content = Some(content);
308                        yield Ok(early_event);
309
310                        for after_cb in after_callbacks.as_ref() {
311                            match after_cb(run_ctx.clone() as Arc<dyn CallbackContext>).await {
312                                Ok(Some(after_content)) => {
313                                    let mut after_event = Event::new(run_ctx.invocation_id());
314                                    after_event.author = agent_name.clone();
315                                    after_event.llm_response.content = Some(after_content);
316                                    yield Ok(after_event);
317                                    return;
318                                }
319                                Ok(None) => continue,
320                                Err(e) => { yield Err(e); return; }
321                            }
322                        }
323                        return;
324                    }
325                    Ok(None) => continue,
326                    Err(e) => { yield Err(e); return; }
327                }
328            }
329
330            let mut remaining = max_iterations;
331
332            loop {
333                let mut should_exit = false;
334
335                for agent in &sub_agents {
336                    let mut stream = agent.run(run_ctx.clone() as Arc<dyn InvocationContext>).await?;
337
338                    while let Some(result) = stream.next().await {
339                        match result {
340                            Ok(event) => {
341                                run_ctx.apply_event(&event);
342                                if event.actions.escalate {
343                                    should_exit = true;
344                                }
345                                yield Ok(event);
346                            }
347                            Err(e) => {
348                                yield Err(e);
349                                return;
350                            }
351                        }
352                    }
353
354                    if should_exit {
355                        break;
356                    }
357                }
358
359                if should_exit {
360                    break;
361                }
362
363                remaining -= 1;
364                if remaining == 0 {
365                    break;
366                }
367            }
368
369            // ===== AFTER AGENT CALLBACKS =====
370            for callback in after_callbacks.as_ref() {
371                match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
372                    Ok(Some(content)) => {
373                        let mut after_event = Event::new(run_ctx.invocation_id());
374                        after_event.author = agent_name.clone();
375                        after_event.llm_response.content = Some(content);
376                        yield Ok(after_event);
377                        break;
378                    }
379                    Ok(None) => continue,
380                    Err(e) => { yield Err(e); return; }
381                }
382            }
383        };
384
385        Ok(Box::pin(s))
386    }
387}