use std::sync::Arc;
use async_trait::async_trait;
use crate::agent::Agent;
use crate::context::InvocationContext;
use crate::error::AgentError;
pub struct LoopAgent {
name: String,
sub_agents: Vec<Arc<dyn Agent>>,
max_iterations: u32,
}
impl LoopAgent {
pub fn new(
name: impl Into<String>,
sub_agents: Vec<Arc<dyn Agent>>,
max_iterations: u32,
) -> Self {
Self {
name: name.into(),
sub_agents,
max_iterations,
}
}
}
#[async_trait]
impl Agent for LoopAgent {
fn name(&self) -> &str {
&self.name
}
async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
for _iter in 0..self.max_iterations {
for sub in &self.sub_agents {
match sub.run_live(ctx).await {
Ok(()) => {}
Err(AgentError::TransferRequested(ref target)) if target == "__escalate" => {
return Ok(());
}
Err(e) => return Err(e),
}
}
}
Ok(())
}
fn sub_agents(&self) -> Vec<Arc<dyn Agent>> {
self.sub_agents.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent_session::{AgentSession, NoOpSessionWriter};
use crate::context::InvocationContext;
use crate::error::AgentError;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use tokio::sync::broadcast;
fn test_ctx() -> InvocationContext {
let (event_tx, _) = broadcast::channel(16);
let writer: Arc<dyn rs_genai::session::SessionWriter> = Arc::new(NoOpSessionWriter);
let agent_session = AgentSession::from_writer(writer, event_tx);
InvocationContext::new(agent_session)
}
struct CounterAgent {
agent_name: String,
counter: Arc<AtomicU32>,
}
#[async_trait]
impl Agent for CounterAgent {
fn name(&self) -> &str {
&self.agent_name
}
async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
self.counter.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
struct EscalateAfterAgent {
agent_name: String,
counter: Arc<AtomicU32>,
escalate_at: u32,
}
#[async_trait]
impl Agent for EscalateAfterAgent {
fn name(&self) -> &str {
&self.agent_name
}
async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
let count = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
if count >= self.escalate_at {
Err(AgentError::TransferRequested("__escalate".to_string()))
} else {
Ok(())
}
}
}
struct TransferAfterAgent {
agent_name: String,
counter: Arc<AtomicU32>,
transfer_at: u32,
target: String,
}
#[async_trait]
impl Agent for TransferAfterAgent {
fn name(&self) -> &str {
&self.agent_name
}
async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
let count = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
if count >= self.transfer_at {
Err(AgentError::TransferRequested(self.target.clone()))
} else {
Ok(())
}
}
}
#[tokio::test]
async fn loop_runs_max_iterations() {
let counter = Arc::new(AtomicU32::new(0));
let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(CounterAgent {
agent_name: "counter".into(),
counter: counter.clone(),
})];
let loop_agent = LoopAgent::new("loop", agents, 5);
let mut ctx = test_ctx();
loop_agent.run_live(&mut ctx).await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 5);
}
#[tokio::test]
async fn loop_escalate_breaks_early() {
let counter = Arc::new(AtomicU32::new(0));
let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(EscalateAfterAgent {
agent_name: "escalator".into(),
counter: counter.clone(),
escalate_at: 3,
})];
let loop_agent = LoopAgent::new("loop", agents, 10);
let mut ctx = test_ctx();
loop_agent.run_live(&mut ctx).await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn loop_propagates_non_escalate_transfer() {
let counter = Arc::new(AtomicU32::new(0));
let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(TransferAfterAgent {
agent_name: "transferer".into(),
counter: counter.clone(),
transfer_at: 2,
target: "other_agent".into(),
})];
let loop_agent = LoopAgent::new("loop", agents, 10);
let mut ctx = test_ctx();
let result = loop_agent.run_live(&mut ctx).await;
match result {
Err(AgentError::TransferRequested(target)) => {
assert_eq!(target, "other_agent");
}
other => panic!("expected TransferRequested, got {:?}", other),
}
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn loop_zero_iterations() {
let counter = Arc::new(AtomicU32::new(0));
let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(CounterAgent {
agent_name: "counter".into(),
counter: counter.clone(),
})];
let loop_agent = LoopAgent::new("loop", agents, 0);
let mut ctx = test_ctx();
loop_agent.run_live(&mut ctx).await.unwrap();
assert_eq!(counter.load(Ordering::SeqCst), 0);
}
#[test]
fn loop_sub_agents_returns_children() {
let counter = Arc::new(AtomicU32::new(0));
let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(CounterAgent {
agent_name: "child".into(),
counter,
})];
let loop_agent = LoopAgent::new("loop", agents, 5);
assert_eq!(loop_agent.sub_agents().len(), 1);
assert_eq!(loop_agent.sub_agents()[0].name(), "child");
}
}