adk_agent/workflow/
loop_agent.rs1use adk_core::{
2 AfterAgentCallback, Agent, BeforeAgentCallback, CallbackContext, Content, Event, EventStream,
3 InvocationContext, ReadonlyContext, Result, Session, State,
4};
5use adk_skill::{SelectionPolicy, SkillIndex, load_skill_index};
6use async_stream::stream;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11pub const DEFAULT_LOOP_MAX_ITERATIONS: u32 = 1000;
14
15pub struct LoopAgent {
17 name: String,
18 description: String,
19 sub_agents: Vec<Arc<dyn Agent>>,
20 max_iterations: u32,
21 skills_index: Option<Arc<SkillIndex>>,
22 skill_policy: SelectionPolicy,
23 max_skill_chars: usize,
24 before_callbacks: Arc<Vec<BeforeAgentCallback>>,
25 after_callbacks: Arc<Vec<AfterAgentCallback>>,
26}
27
28impl LoopAgent {
29 pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
30 Self {
31 name: name.into(),
32 description: String::new(),
33 sub_agents,
34 max_iterations: DEFAULT_LOOP_MAX_ITERATIONS,
35 skills_index: None,
36 skill_policy: SelectionPolicy::default(),
37 max_skill_chars: 2000,
38 before_callbacks: Arc::new(Vec::new()),
39 after_callbacks: Arc::new(Vec::new()),
40 }
41 }
42
43 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
44 self.description = desc.into();
45 self
46 }
47
48 pub fn with_max_iterations(mut self, max: u32) -> Self {
49 self.max_iterations = max;
50 self
51 }
52
53 pub fn with_skills(mut self, index: SkillIndex) -> Self {
54 self.skills_index = Some(Arc::new(index));
55 self
56 }
57
58 pub fn with_auto_skills(self) -> Result<Self> {
59 self.with_skills_from_root(".")
60 }
61
62 pub fn with_skills_from_root(mut self, root: impl AsRef<std::path::Path>) -> Result<Self> {
63 let index = load_skill_index(root).map_err(|e| adk_core::AdkError::Agent(e.to_string()))?;
64 self.skills_index = Some(Arc::new(index));
65 Ok(self)
66 }
67
68 pub fn with_skill_policy(mut self, policy: SelectionPolicy) -> Self {
69 self.skill_policy = policy;
70 self
71 }
72
73 pub fn with_skill_budget(mut self, max_chars: usize) -> Self {
74 self.max_skill_chars = max_chars;
75 self
76 }
77
78 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
79 Arc::get_mut(&mut self.before_callbacks)
80 .expect("before_callbacks not yet shared")
81 .push(callback);
82 self
83 }
84
85 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
86 Arc::get_mut(&mut self.after_callbacks)
87 .expect("after_callbacks not yet shared")
88 .push(callback);
89 self
90 }
91}
92
93struct HistoryTrackingSession {
94 parent_ctx: Arc<dyn InvocationContext>,
95 history: Arc<RwLock<Vec<Content>>>,
96 state: StateTrackingState,
97}
98
99struct StateTrackingState {
100 values: RwLock<HashMap<String, serde_json::Value>>,
101}
102
103impl StateTrackingState {
104 fn new(parent_ctx: &Arc<dyn InvocationContext>) -> Self {
105 Self { values: RwLock::new(parent_ctx.session().state().all()) }
106 }
107
108 fn apply_delta(&self, delta: &HashMap<String, serde_json::Value>) {
109 if delta.is_empty() {
110 return;
111 }
112
113 let mut values = self.values.write().unwrap_or_else(|e| e.into_inner());
114 for (key, value) in delta {
115 values.insert(key.clone(), value.clone());
116 }
117 }
118}
119
120impl State for StateTrackingState {
121 fn get(&self, key: &str) -> Option<serde_json::Value> {
122 self.values.read().unwrap_or_else(|e| e.into_inner()).get(key).cloned()
123 }
124
125 fn set(&mut self, key: String, value: serde_json::Value) {
126 self.values.write().unwrap_or_else(|e| e.into_inner()).insert(key, value);
127 }
128
129 fn all(&self) -> HashMap<String, serde_json::Value> {
130 self.values.read().unwrap_or_else(|e| e.into_inner()).clone()
131 }
132}
133
134impl HistoryTrackingSession {
135 fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
136 Self {
137 history: Arc::new(RwLock::new(parent_ctx.session().conversation_history())),
138 state: StateTrackingState::new(&parent_ctx),
139 parent_ctx,
140 }
141 }
142
143 fn apply_event(&self, event: &Event) {
144 if let Some(content) = &event.llm_response.content {
145 self.append_to_history(content.clone());
146 }
147 self.state.apply_delta(&event.actions.state_delta);
148 }
149}
150
151impl Session for HistoryTrackingSession {
152 fn id(&self) -> &str {
153 self.parent_ctx.session().id()
154 }
155
156 fn app_name(&self) -> &str {
157 self.parent_ctx.session().app_name()
158 }
159
160 fn user_id(&self) -> &str {
161 self.parent_ctx.session().user_id()
162 }
163
164 fn state(&self) -> &dyn State {
165 &self.state
166 }
167
168 fn conversation_history(&self) -> Vec<Content> {
169 self.history.read().unwrap_or_else(|e| e.into_inner()).clone()
170 }
171
172 fn conversation_history_for_agent(&self, _agent_name: &str) -> Vec<Content> {
173 self.conversation_history()
174 }
175
176 fn append_to_history(&self, content: Content) {
177 self.history.write().unwrap_or_else(|e| e.into_inner()).push(content);
178 }
179}
180
181struct HistoryTrackingContext {
182 parent_ctx: Arc<dyn InvocationContext>,
183 session: HistoryTrackingSession,
184}
185
186impl HistoryTrackingContext {
187 fn new(parent_ctx: Arc<dyn InvocationContext>) -> Self {
188 let session = HistoryTrackingSession::new(parent_ctx.clone());
189 Self { parent_ctx, session }
190 }
191
192 fn apply_event(&self, event: &Event) {
193 self.session.apply_event(event);
194 }
195}
196
197#[async_trait]
198impl adk_core::ReadonlyContext for HistoryTrackingContext {
199 fn invocation_id(&self) -> &str {
200 self.parent_ctx.invocation_id()
201 }
202
203 fn agent_name(&self) -> &str {
204 self.parent_ctx.agent_name()
205 }
206
207 fn user_id(&self) -> &str {
208 self.parent_ctx.user_id()
209 }
210
211 fn app_name(&self) -> &str {
212 self.parent_ctx.app_name()
213 }
214
215 fn session_id(&self) -> &str {
216 self.parent_ctx.session_id()
217 }
218
219 fn branch(&self) -> &str {
220 self.parent_ctx.branch()
221 }
222
223 fn user_content(&self) -> &Content {
224 self.parent_ctx.user_content()
225 }
226}
227
228#[async_trait]
229impl CallbackContext for HistoryTrackingContext {
230 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
231 self.parent_ctx.artifacts()
232 }
233}
234
235#[async_trait]
236impl InvocationContext for HistoryTrackingContext {
237 fn agent(&self) -> Arc<dyn Agent> {
238 self.parent_ctx.agent()
239 }
240
241 fn memory(&self) -> Option<Arc<dyn adk_core::Memory>> {
242 self.parent_ctx.memory()
243 }
244
245 fn session(&self) -> &dyn Session {
246 &self.session
247 }
248
249 fn run_config(&self) -> &adk_core::RunConfig {
250 self.parent_ctx.run_config()
251 }
252
253 fn end_invocation(&self) {
254 self.parent_ctx.end_invocation();
255 }
256
257 fn ended(&self) -> bool {
258 self.parent_ctx.ended()
259 }
260
261 fn user_scopes(&self) -> Vec<String> {
262 self.parent_ctx.user_scopes()
263 }
264
265 fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
266 self.parent_ctx.request_metadata()
267 }
268}
269
270#[async_trait]
271impl Agent for LoopAgent {
272 fn name(&self) -> &str {
273 &self.name
274 }
275
276 fn description(&self) -> &str {
277 &self.description
278 }
279
280 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
281 &self.sub_agents
282 }
283
284 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
285 let sub_agents = self.sub_agents.clone();
286 let max_iterations = self.max_iterations;
287 let before_callbacks = self.before_callbacks.clone();
288 let after_callbacks = self.after_callbacks.clone();
289 let agent_name = self.name.clone();
290 let run_ctx = super::skill_context::with_skill_injected_context(
291 ctx,
292 self.skills_index.as_ref(),
293 &self.skill_policy,
294 self.max_skill_chars,
295 );
296 let run_ctx = Arc::new(HistoryTrackingContext::new(run_ctx));
297
298 let s = stream! {
299 use futures::StreamExt;
300
301 for callback in before_callbacks.as_ref() {
303 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
304 Ok(Some(content)) => {
305 let mut early_event = Event::new(run_ctx.invocation_id());
306 early_event.author = agent_name.clone();
307 early_event.llm_response.content = Some(content);
308 yield Ok(early_event);
309
310 for after_cb in after_callbacks.as_ref() {
311 match after_cb(run_ctx.clone() as Arc<dyn CallbackContext>).await {
312 Ok(Some(after_content)) => {
313 let mut after_event = Event::new(run_ctx.invocation_id());
314 after_event.author = agent_name.clone();
315 after_event.llm_response.content = Some(after_content);
316 yield Ok(after_event);
317 return;
318 }
319 Ok(None) => continue,
320 Err(e) => { yield Err(e); return; }
321 }
322 }
323 return;
324 }
325 Ok(None) => continue,
326 Err(e) => { yield Err(e); return; }
327 }
328 }
329
330 let mut remaining = max_iterations;
331
332 loop {
333 let mut should_exit = false;
334
335 for agent in &sub_agents {
336 let mut stream = agent.run(run_ctx.clone() as Arc<dyn InvocationContext>).await?;
337
338 while let Some(result) = stream.next().await {
339 match result {
340 Ok(event) => {
341 run_ctx.apply_event(&event);
342 if event.actions.escalate {
343 should_exit = true;
344 }
345 yield Ok(event);
346 }
347 Err(e) => {
348 yield Err(e);
349 return;
350 }
351 }
352 }
353
354 if should_exit {
355 break;
356 }
357 }
358
359 if should_exit {
360 break;
361 }
362
363 remaining -= 1;
364 if remaining == 0 {
365 break;
366 }
367 }
368
369 for callback in after_callbacks.as_ref() {
371 match callback(run_ctx.clone() as Arc<dyn CallbackContext>).await {
372 Ok(Some(content)) => {
373 let mut after_event = Event::new(run_ctx.invocation_id());
374 after_event.author = agent_name.clone();
375 after_event.llm_response.content = Some(content);
376 yield Ok(after_event);
377 break;
378 }
379 Ok(None) => continue,
380 Err(e) => { yield Err(e); return; }
381 }
382 }
383 };
384
385 Ok(Box::pin(s))
386 }
387}