adk_agent/workflow/
loop_agent.rs1use adk_core::{
2 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Content, Event, EventStream,
3 InvocationContext, ReadonlyContext, Result, Session, State,
4};
5use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index};
6use async_stream::stream;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11pub const DEFAULT_LOOP_MAX_ITERATIONS: u32 = 1000;
14
15pub struct LoopAgent {
17 name: String,
18 description: String,
19 sub_agents: Vec<Arc<dyn Agent>>,
20 max_iterations: u32,
21 skills_index: Option<Arc<SkillIndex>>,
22 skill_policy: SelectionPolicy,
23 max_skill_chars: usize,
24 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
25 after_callbacks: Arc<Vec<AfterAgentCallback>>,
26}
27
28impl LoopAgent {
29 pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
30 Self {
31 name: name.into(),
32 description: String::new(),
33 sub_agents,
34 max_iterations: DEFAULT_LOOP_MAX_ITERATIONS,
35 skills_index: None,
36 skill_policy: SelectionPolicy::default(),
37 max_skill_chars: 2000,
38 before_callbacks: Arc::new(Vec::new()),
39 after_callbacks: Arc::new(Vec::new()),
40 }
41 }
42
43 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
44 self.description = desc.into();
45 self
46 }
47
48 pub fn with_max_iterations(mut self, max: u32) -> Self {
49 self.max_iterations = max;
50 self
51 }
52
53 pub fn with_skills(mut self, index: SkillIndex) -> Self {
54 self.skills_index = Some(Arc::new(index));
55 self
56 }
57
58 pub fn with_auto_skills(self) -> Result<Self> {
59 self.with_skills_from_root(".")
60 }
61
62 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
63 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
64 self.skills_index = Some(Arc::new(index));
65 Ok(self)
66 }
67
68 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
69 self.skill_policy = policy;
70 self
71 }
72
73 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
74 self.max_skill_chars = max_chars;
75 self
76 }
77
78 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
79 if let Some(callbacks) = Arc::get_mut(&mut self.before_callbacks) {
80 callbacks.push(callback);
81 }
82 self
83 }
84
85 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
86 if let Some(callbacks) = Arc::get_mut(&mut self.after_callbacks) {
87 callbacks.push(callback);
88 }
89 self
90 }
91}
92
93struct HistoryTrackingSession {
94 parent_ctx: Arc<dyn InvocationContext>,
95 history: Arc<RwLock<Vec<Content>>>,
96 state: StateTrackingState,
97}
98
99struct StateTrackingState {
100 values: RwLock<HashMap<String, serde_json::Value>>,
101}
102
103impl StateTrackingState {
104 fn new(parent_ctx: &Arc<dyn InvocationContext>) -> Self {
105 Self { values: RwLock::new(parent_ctx.session().state().all()) }
106 }
107
108 fn apply_delta(&self, delta: &HashMap<String, serde_json::Value>) {
109 if delta.is_empty() {
110 return;
111 }
112
113 let mut values = self.values.write().unwrap_or_else(|e| e.into_inner());
114 for (key, value) in delta {
115 values.insert(key.clone(), value.clone());
116 }
117 }
118}
119
120impl State for StateTrackingState {
121 fn get(&self, key: &str) -> Option<serde_json::Value> {
122 self.values.read().unwrap_or_else(|e| e.into_inner()).get(key).cloned()
123 }
124
125 fn set(&mut self, key: String, value: serde_json::Value) {
126 if let Err(msg) = adk_core::validate_state_key(&key) {
127 tracing::warn!(key = %key, "rejecting invalid state key: {msg}");
128 return;
129 }
130 self.values.write().unwrap_or_else(|e| e.into_inner()).insert(key, value);
131 }
132
133 fn all(&self) -> HashMap<String, serde_json::Value> {
134 self.values.read().unwrap_or_else(|e| e.into_inner()).clone()
135 }
136}
137
138impl HistoryTrackingSession {
139 fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
140 Self {
141 history: Arc::new(RwLock::new(parent_ctx.session().conversation_history())),
142 state: StateTrackingState::new(&parent_ctx),
143 parent_ctx,
144 }
145 }
146
147 fn apply_event(&self, event: &Event) {
148 if let Some(content) = &event.llm_response.content {
149 self.append_to_history(content.clone());
150 }
151 self.state.apply_delta(&event.actions.state_delta);
152 }
153}
154
155impl Session for HistoryTrackingSession {
156 fn id(&self) -> &str {
157 self.parent_ctx.session().id()
158 }
159
160 fn app_name(&self) -> &str {
161 self.parent_ctx.session().app_name()
162 }
163
164 fn user_id(&self) -> &str {
165 self.parent_ctx.session().user_id()
166 }
167
168 fn state(&self) -> &dyn State {
169 &self.state
170 }
171
172 fn conversation_history(&self) -> Vec<Content> {
173 self.history.read().unwrap_or_else(|e| e.into_inner()).clone()
174 }
175
176 fn conversation_history_for_agent(&self, _agent_name: &str) -> Vec<Content> {
177 self.conversation_history()
178 }
179
180 fn append_to_history(&self, content: Content) {
181 self.history.write().unwrap_or_else(|e| e.into_inner()).push(content);
182 }
183}
184
185struct HistoryTrackingContext {
186 parent_ctx: Arc<dyn InvocationContext>,
187 session: HistoryTrackingSession,
188}
189
190impl HistoryTrackingContext {
191 fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
192 let session = HistoryTrackingSession::new(parent_ctx.clone());
193 Self { parent_ctx, session }
194 }
195
196 fn apply_event(&self, event: &Event) {
197 self.session.apply_event(event);
198 }
199}
200
201#[async_trait]
202impl adk_core::ReadonlyContext for HistoryTrackingContext {
203 fn invocation_id(&self) -> &str {
204 self.parent_ctx.invocation_id()
205 }
206
207 fn agent_name(&self) -> &str {
208 self.parent_ctx.agent_name()
209 }
210
211 fn user_id(&self) -> &str {
212 self.parent_ctx.user_id()
213 }
214
215 fn app_name(&self) -> &str {
216 self.parent_ctx.app_name()
217 }
218
219 fn session_id(&self) -> &str {
220 self.parent_ctx.session_id()
221 }
222
223 fn branch(&self) -> &str {
224 self.parent_ctx.branch()
225 }
226
227 fn user_content(&self) -> &Content {
228 self.parent_ctx.user_content()
229 }
230}
231
232#[async_trait]
233impl CallbackContext for HistoryTrackingContext {
234 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
235 self.parent_ctx.artifacts()
236 }
237}
238
239#[async_trait]
240impl InvocationContext for HistoryTrackingContext {
241 fn agent(&self) -> Arc<dyn Agent> {
242 self.parent_ctx.agent()
243 }
244
245 fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
246 self.parent_ctx.memory()
247 }
248
249 fn session(&self) -> &dyn Session {
250 &self.session
251 }
252
253 fn run_config(&self) -> &adk_core::RunConfig {
254 self.parent_ctx.run_config()
255 }
256
257 fn end_invocation(&self) {
258 self.parent_ctx.end_invocation();
259 }
260
261 fn ended(&self) -> bool {
262 self.parent_ctx.ended()
263 }
264
265 fn user_scopes(&self) -> Vec<String> {
266 self.parent_ctx.user_scopes()
267 }
268
269 fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
270 self.parent_ctx.request_metadata()
271 }
272}
273
274#[async_trait]
275impl Agent for LoopAgent {
276 fn name(&self) -> &str {
277 &self.name
278 }
279
280 fn description(&self) -> &str {
281 &self.description
282 }
283
284 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
285 &self.sub_agents
286 }
287
288 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
289 let sub_agents = self.sub_agents.clone();
290 let max_iterations = self.max_iterations;
291 let before_callbacks = self.before_callbacks.clone();
292 let after_callbacks = self.after_callbacks.clone();
293 let agent_name = self.name.clone();
294 let run_ctx = super::skill_context::with_skill_injected_context(
295 ctx,
296 self.skills_index.as_ref(),
297 &self.skill_policy,
298 self.max_skill_chars,
299 );
300 let run_ctx = Arc::new(HistoryTrackingContext::new(run_ctx));
301
302 let s = stream! {
303 use futures::StreamExt;
304
305 for callback in before_callbacks.as_ref() {
307 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
308 Ok(Some(content)) => {
309 let mut early_event = Event::new(run_ctx.invocation_id());
310 early_event.author = agent_name.clone();
311 early_event.llm_response.content = Some(content);
312 yield Ok(early_event);
313
314 for after_cb in after_callbacks.as_ref() {
315 match after_cb(run_ctx.clone() as Arc<dyn CallbackContext>).await {
316 Ok(Some(after_content)) => {
317 let mut after_event = Event::new(run_ctx.invocation_id());
318 after_event.author = agent_name.clone();
319 after_event.llm_response.content = Some(after_content);
320 yield Ok(after_event);
321 return;
322 }
323 Ok(None) => continue,
324 Err(e) => { yield Err(e); return; }
325 }
326 }
327 return;
328 }
329 Ok(None) => continue,
330 Err(e) => { yield Err(e); return; }
331 }
332 }
333
334 let mut remaining = max_iterations;
335
336 loop {
337 let mut should_exit = false;
338
339 for agent in &sub_agents {
340 let mut stream = agent.run(run_ctx.clone() as Arc<dyn InvocationContext>).await?;
341
342 while let Some(result) = stream.next().await {
343 match result {
344 Ok(event) => {
345 run_ctx.apply_event(&event);
346 if event.actions.escalate {
347 should_exit = true;
348 }
349 yield Ok(event);
350 }
351 Err(e) => {
352 yield Err(e);
353 return;
354 }
355 }
356 }
357
358 if should_exit {
359 break;
360 }
361 }
362
363 if should_exit {
364 break;
365 }
366
367 remaining -= 1;
368 if remaining == 0 {
369 break;
370 }
371 }
372
373 for callback in after_callbacks.as_ref() {
375 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
376 Ok(Some(content)) => {
377 let mut after_event = Event::new(run_ctx.invocation_id());
378 after_event.author = agent_name.clone();
379 after_event.llm_response.content = Some(content);
380 yield Ok(after_event);
381 break;
382 }
383 Ok(None) => continue,
384 Err(e) => { yield Err(e); return; }
385 }
386 }
387 };
388
389 Ok(Box::pin(s))
390 }
391}