rs_adk/agents/
loop_agent.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use crate::agent::Agent;
8use crate::context::InvocationContext;
9use crate::error::AgentError;
10
11pub struct LoopAgent {
17 name: String,
18 sub_agents: Vec<Arc<dyn Agent>>,
19 max_iterations: u32,
20}
21
22impl LoopAgent {
23 pub fn new(
26 name: impl Into<String>,
27 sub_agents: Vec<Arc<dyn Agent>>,
28 max_iterations: u32,
29 ) -> Self {
30 Self {
31 name: name.into(),
32 sub_agents,
33 max_iterations,
34 }
35 }
36}
37
38#[async_trait]
39impl Agent for LoopAgent {
40 fn name(&self) -> &str {
41 &self.name
42 }
43
44 async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
45 for _iter in 0..self.max_iterations {
46 for sub in &self.sub_agents {
47 match sub.run_live(ctx).await {
48 Ok(()) => {}
49 Err(AgentError::TransferRequested(ref target)) if target == "__escalate" => {
50 return Ok(());
51 }
52 Err(e) => return Err(e),
53 }
54 }
55 }
56 Ok(())
57 }
58
59 fn sub_agents(&self) -> Vec<Arc<dyn Agent>> {
60 self.sub_agents.clone()
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67 use crate::agent_session::{AgentSession, NoOpSessionWriter};
68 use crate::context::InvocationContext;
69 use crate::error::AgentError;
70 use std::sync::atomic::{AtomicU32, Ordering};
71 use std::sync::Arc;
72 use tokio::sync::broadcast;
73
74 fn test_ctx() -> InvocationContext {
76 let (event_tx, _) = broadcast::channel(16);
77 let writer: Arc<dyn rs_genai::session::SessionWriter> = Arc::new(NoOpSessionWriter);
78 let agent_session = AgentSession::from_writer(writer, event_tx);
79 InvocationContext::new(agent_session)
80 }
81
82 struct CounterAgent {
84 agent_name: String,
85 counter: Arc<AtomicU32>,
86 }
87
88 #[async_trait]
89 impl Agent for CounterAgent {
90 fn name(&self) -> &str {
91 &self.agent_name
92 }
93
94 async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
95 self.counter.fetch_add(1, Ordering::SeqCst);
96 Ok(())
97 }
98 }
99
100 struct EscalateAfterAgent {
102 agent_name: String,
103 counter: Arc<AtomicU32>,
104 escalate_at: u32,
105 }
106
107 #[async_trait]
108 impl Agent for EscalateAfterAgent {
109 fn name(&self) -> &str {
110 &self.agent_name
111 }
112
113 async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
114 let count = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
115 if count >= self.escalate_at {
116 Err(AgentError::TransferRequested("__escalate".to_string()))
117 } else {
118 Ok(())
119 }
120 }
121 }
122
123 struct TransferAfterAgent {
125 agent_name: String,
126 counter: Arc<AtomicU32>,
127 transfer_at: u32,
128 target: String,
129 }
130
131 #[async_trait]
132 impl Agent for TransferAfterAgent {
133 fn name(&self) -> &str {
134 &self.agent_name
135 }
136
137 async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
138 let count = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
139 if count >= self.transfer_at {
140 Err(AgentError::TransferRequested(self.target.clone()))
141 } else {
142 Ok(())
143 }
144 }
145 }
146
147 #[tokio::test]
148 async fn loop_runs_max_iterations() {
149 let counter = Arc::new(AtomicU32::new(0));
150 let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(CounterAgent {
151 agent_name: "counter".into(),
152 counter: counter.clone(),
153 })];
154
155 let loop_agent = LoopAgent::new("loop", agents, 5);
156 let mut ctx = test_ctx();
157 loop_agent.run_live(&mut ctx).await.unwrap();
158
159 assert_eq!(counter.load(Ordering::SeqCst), 5);
160 }
161
162 #[tokio::test]
163 async fn loop_escalate_breaks_early() {
164 let counter = Arc::new(AtomicU32::new(0));
165 let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(EscalateAfterAgent {
166 agent_name: "escalator".into(),
167 counter: counter.clone(),
168 escalate_at: 3,
169 })];
170
171 let loop_agent = LoopAgent::new("loop", agents, 10);
172 let mut ctx = test_ctx();
173 loop_agent.run_live(&mut ctx).await.unwrap();
175
176 assert_eq!(counter.load(Ordering::SeqCst), 3);
178 }
179
180 #[tokio::test]
181 async fn loop_propagates_non_escalate_transfer() {
182 let counter = Arc::new(AtomicU32::new(0));
183 let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(TransferAfterAgent {
184 agent_name: "transferer".into(),
185 counter: counter.clone(),
186 transfer_at: 2,
187 target: "other_agent".into(),
188 })];
189
190 let loop_agent = LoopAgent::new("loop", agents, 10);
191 let mut ctx = test_ctx();
192 let result = loop_agent.run_live(&mut ctx).await;
193
194 match result {
195 Err(AgentError::TransferRequested(target)) => {
196 assert_eq!(target, "other_agent");
197 }
198 other => panic!("expected TransferRequested, got {:?}", other),
199 }
200
201 assert_eq!(counter.load(Ordering::SeqCst), 2);
203 }
204
205 #[tokio::test]
206 async fn loop_zero_iterations() {
207 let counter = Arc::new(AtomicU32::new(0));
208 let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(CounterAgent {
209 agent_name: "counter".into(),
210 counter: counter.clone(),
211 })];
212
213 let loop_agent = LoopAgent::new("loop", agents, 0);
214 let mut ctx = test_ctx();
215 loop_agent.run_live(&mut ctx).await.unwrap();
216
217 assert_eq!(counter.load(Ordering::SeqCst), 0);
218 }
219
220 #[test]
221 fn loop_sub_agents_returns_children() {
222 let counter = Arc::new(AtomicU32::new(0));
223 let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(CounterAgent {
224 agent_name: "child".into(),
225 counter,
226 })];
227
228 let loop_agent = LoopAgent::new("loop", agents, 5);
229 assert_eq!(loop_agent.sub_agents().len(), 1);
230 assert_eq!(loop_agent.sub_agents()[0].name(), "child");
231 }
232}