Skip to main content

rs_adk/agents/
loop_agent.rs

1//! LoopAgent — runs sub-agents repeatedly until max iterations or escalation.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use crate::agent::Agent;
8use crate::context::InvocationContext;
9use crate::error::AgentError;
10
11/// Runs sub-agents repeatedly until `max_iterations` is reached or escalation.
12///
13/// Each iteration runs all sub-agents sequentially. To break out of the loop
14/// early, a sub-agent can return `TransferRequested("__escalate")`. Other
15/// transfer requests are propagated as-is, stopping the loop with an error.
16pub struct LoopAgent {
17    name: String,
18    sub_agents: Vec<Arc<dyn Agent>>,
19    max_iterations: u32,
20}
21
22impl LoopAgent {
23    /// Create a new loop agent with the given name, sub-agents, and maximum
24    /// number of iterations.
25    pub fn new(
26        name: impl Into<String>,
27        sub_agents: Vec<Arc<dyn Agent>>,
28        max_iterations: u32,
29    ) -> Self {
30        Self {
31            name: name.into(),
32            sub_agents,
33            max_iterations,
34        }
35    }
36}
37
38#[async_trait]
39impl Agent for LoopAgent {
40    fn name(&self) -> &str {
41        &self.name
42    }
43
44    async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
45        for _iter in 0..self.max_iterations {
46            for sub in &self.sub_agents {
47                match sub.run_live(ctx).await {
48                    Ok(()) => {}
49                    Err(AgentError::TransferRequested(ref target)) if target == "__escalate" => {
50                        return Ok(());
51                    }
52                    Err(e) => return Err(e),
53                }
54            }
55        }
56        Ok(())
57    }
58
59    fn sub_agents(&self) -> Vec<Arc<dyn Agent>> {
60        self.sub_agents.clone()
61    }
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67    use crate::agent_session::{AgentSession, NoOpSessionWriter};
68    use crate::context::InvocationContext;
69    use crate::error::AgentError;
70    use std::sync::atomic::{AtomicU32, Ordering};
71    use std::sync::Arc;
72    use tokio::sync::broadcast;
73
74    /// Helper: create a test InvocationContext with a no-op session.
75    fn test_ctx() -> InvocationContext {
76        let (event_tx, _) = broadcast::channel(16);
77        let writer: Arc<dyn rs_genai::session::SessionWriter> = Arc::new(NoOpSessionWriter);
78        let agent_session = AgentSession::from_writer(writer, event_tx);
79        InvocationContext::new(agent_session)
80    }
81
82    /// A test agent that increments a counter each time it runs.
83    struct CounterAgent {
84        agent_name: String,
85        counter: Arc<AtomicU32>,
86    }
87
88    #[async_trait]
89    impl Agent for CounterAgent {
90        fn name(&self) -> &str {
91            &self.agent_name
92        }
93
94        async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
95            self.counter.fetch_add(1, Ordering::SeqCst);
96            Ok(())
97        }
98    }
99
100    /// A test agent that escalates after a certain number of invocations.
101    struct EscalateAfterAgent {
102        agent_name: String,
103        counter: Arc<AtomicU32>,
104        escalate_at: u32,
105    }
106
107    #[async_trait]
108    impl Agent for EscalateAfterAgent {
109        fn name(&self) -> &str {
110            &self.agent_name
111        }
112
113        async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
114            let count = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
115            if count >= self.escalate_at {
116                Err(AgentError::TransferRequested("__escalate".to_string()))
117            } else {
118                Ok(())
119            }
120        }
121    }
122
123    /// A test agent that returns a non-escalate transfer request after N invocations.
124    struct TransferAfterAgent {
125        agent_name: String,
126        counter: Arc<AtomicU32>,
127        transfer_at: u32,
128        target: String,
129    }
130
131    #[async_trait]
132    impl Agent for TransferAfterAgent {
133        fn name(&self) -> &str {
134            &self.agent_name
135        }
136
137        async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
138            let count = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
139            if count >= self.transfer_at {
140                Err(AgentError::TransferRequested(self.target.clone()))
141            } else {
142                Ok(())
143            }
144        }
145    }
146
147    #[tokio::test]
148    async fn loop_runs_max_iterations() {
149        let counter = Arc::new(AtomicU32::new(0));
150        let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(CounterAgent {
151            agent_name: "counter".into(),
152            counter: counter.clone(),
153        })];
154
155        let loop_agent = LoopAgent::new("loop", agents, 5);
156        let mut ctx = test_ctx();
157        loop_agent.run_live(&mut ctx).await.unwrap();
158
159        assert_eq!(counter.load(Ordering::SeqCst), 5);
160    }
161
162    #[tokio::test]
163    async fn loop_escalate_breaks_early() {
164        let counter = Arc::new(AtomicU32::new(0));
165        let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(EscalateAfterAgent {
166            agent_name: "escalator".into(),
167            counter: counter.clone(),
168            escalate_at: 3,
169        })];
170
171        let loop_agent = LoopAgent::new("loop", agents, 10);
172        let mut ctx = test_ctx();
173        // Should return Ok because __escalate is treated as a clean break.
174        loop_agent.run_live(&mut ctx).await.unwrap();
175
176        // Agent ran 3 times: iterations 1, 2, 3 (escalated on 3rd).
177        assert_eq!(counter.load(Ordering::SeqCst), 3);
178    }
179
180    #[tokio::test]
181    async fn loop_propagates_non_escalate_transfer() {
182        let counter = Arc::new(AtomicU32::new(0));
183        let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(TransferAfterAgent {
184            agent_name: "transferer".into(),
185            counter: counter.clone(),
186            transfer_at: 2,
187            target: "other_agent".into(),
188        })];
189
190        let loop_agent = LoopAgent::new("loop", agents, 10);
191        let mut ctx = test_ctx();
192        let result = loop_agent.run_live(&mut ctx).await;
193
194        match result {
195            Err(AgentError::TransferRequested(target)) => {
196                assert_eq!(target, "other_agent");
197            }
198            other => panic!("expected TransferRequested, got {:?}", other),
199        }
200
201        // Agent ran twice: first time Ok, second time TransferRequested.
202        assert_eq!(counter.load(Ordering::SeqCst), 2);
203    }
204
205    #[tokio::test]
206    async fn loop_zero_iterations() {
207        let counter = Arc::new(AtomicU32::new(0));
208        let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(CounterAgent {
209            agent_name: "counter".into(),
210            counter: counter.clone(),
211        })];
212
213        let loop_agent = LoopAgent::new("loop", agents, 0);
214        let mut ctx = test_ctx();
215        loop_agent.run_live(&mut ctx).await.unwrap();
216
217        assert_eq!(counter.load(Ordering::SeqCst), 0);
218    }
219
220    #[test]
221    fn loop_sub_agents_returns_children() {
222        let counter = Arc::new(AtomicU32::new(0));
223        let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(CounterAgent {
224            agent_name: "child".into(),
225            counter,
226        })];
227
228        let loop_agent = LoopAgent::new("loop", agents, 5);
229        assert_eq!(loop_agent.sub_agents().len(), 1);
230        assert_eq!(loop_agent.sub_agents()[0].name(), "child");
231    }
232}