adk_agent/workflow/
loop_agent.rs1use adk_core::{
2 AfterAgentCallback, Agent, BeforeAgentCallback, EventStream, InvocationContext, Result,
3};
4use async_stream::stream;
5use async_trait::async_trait;
6use std::sync::Arc;
7
8pub struct LoopAgent {
10 name: String,
11 description: String,
12 sub_agents: Vec<Arc<dyn Agent>>,
13 max_iterations: Option<u32>,
14 before_callbacks: Vec<BeforeAgentCallback>,
15 after_callbacks: Vec<AfterAgentCallback>,
16}
17
18impl LoopAgent {
19 pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
20 Self {
21 name: name.into(),
22 description: String::new(),
23 sub_agents,
24 max_iterations: None,
25 before_callbacks: Vec::new(),
26 after_callbacks: Vec::new(),
27 }
28 }
29
30 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
31 self.description = desc.into();
32 self
33 }
34
35 pub fn with_max_iterations(mut self, max: u32) -> Self {
36 self.max_iterations = Some(max);
37 self
38 }
39
40 pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
41 self.before_callbacks.push(callback);
42 self
43 }
44
45 pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
46 self.after_callbacks.push(callback);
47 self
48 }
49}
50
51#[async_trait]
52impl Agent for LoopAgent {
53 fn name(&self) -> &str {
54 &self.name
55 }
56
57 fn description(&self) -> &str {
58 &self.description
59 }
60
61 fn sub_agents(&self) -> &[Arc<dyn Agent>] {
62 &self.sub_agents
63 }
64
65 async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
66 let sub_agents = self.sub_agents.clone();
67 let max_iterations = self.max_iterations;
68
69 let s = stream! {
70 use futures::StreamExt;
71
72 let mut count = max_iterations;
73
74 loop {
75 let mut should_exit = false;
76
77 for agent in &sub_agents {
78 let mut stream = agent.run(ctx.clone()).await?;
79
80 while let Some(result) = stream.next().await {
81 match result {
82 Ok(event) => {
83 if event.actions.escalate {
84 should_exit = true;
85 }
86 yield Ok(event);
87 }
88 Err(e) => {
89 yield Err(e);
90 return;
91 }
92 }
93 }
94
95 if should_exit {
96 return;
97 }
98 }
99
100 if let Some(ref mut c) = count {
101 *c -= 1;
102 if *c == 0 {
103 return;
104 }
105 }
106 }
107 };
108
109 Ok(Box::pin(s))
110 }
111}