adk_agent/workflow/
loop_agent.rs

1use adk_core::{
2    AfterAgentCallback, Agent, BeforeAgentCallback, EventStream, InvocationContext, Result,
3};
4use async_stream::stream;
5use async_trait::async_trait;
6use std::sync::Arc;
7
8/// Loop agent executes sub-agents repeatedly for N iterations or until escalation
9pub struct LoopAgent {
10    name: String,
11    description: String,
12    sub_agents: Vec<Arc<dyn Agent>>,
13    max_iterations: Option<u32>,
14    before_callbacks: Vec<BeforeAgentCallback>,
15    after_callbacks: Vec<AfterAgentCallback>,
16}
17
18impl LoopAgent {
19    pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
20        Self {
21            name: name.into(),
22            description: String::new(),
23            sub_agents,
24            max_iterations: None,
25            before_callbacks: Vec::new(),
26            after_callbacks: Vec::new(),
27        }
28    }
29
30    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
31        self.description = desc.into();
32        self
33    }
34
35    pub fn with_max_iterations(mut self, max: u32) -> Self {
36        self.max_iterations = Some(max);
37        self
38    }
39
40    pub fn before_callback(mut self, callback: BeforeAgentCallback) -> Self {
41        self.before_callbacks.push(callback);
42        self
43    }
44
45    pub fn after_callback(mut self, callback: AfterAgentCallback) -> Self {
46        self.after_callbacks.push(callback);
47        self
48    }
49}
50
51#[async_trait]
52impl Agent for LoopAgent {
53    fn name(&self) -> &str {
54        &self.name
55    }
56
57    fn description(&self) -> &str {
58        &self.description
59    }
60
61    fn sub_agents(&self) -> &[Arc<dyn Agent>] {
62        &self.sub_agents
63    }
64
65    async fn run(&self, ctx: Arc<dyn InvocationContext>) -> Result<EventStream> {
66        let sub_agents = self.sub_agents.clone();
67        let max_iterations = self.max_iterations;
68
69        let s = stream! {
70            use futures::StreamExt;
71
72            let mut count = max_iterations;
73
74            loop {
75                let mut should_exit = false;
76
77                for agent in &sub_agents {
78                    let mut stream = agent.run(ctx.clone()).await?;
79
80                    while let Some(result) = stream.next().await {
81                        match result {
82                            Ok(event) => {
83                                if event.actions.escalate {
84                                    should_exit = true;
85                                }
86                                yield Ok(event);
87                            }
88                            Err(e) => {
89                                yield Err(e);
90                                return;
91                            }
92                        }
93                    }
94
95                    if should_exit {
96                        return;
97                    }
98                }
99
100                if let Some(ref mut c) = count {
101                    *c -= 1;
102                    if *c == 0 {
103                        return;
104                    }
105                }
106            }
107        };
108
109        Ok(Box::pin(s))
110    }
111}