adk_rs/agents/
sequential_agent.rs1use std::sync::Arc;
5
6use async_stream::try_stream;
7use async_trait::async_trait;
8use futures::StreamExt;
9
10use crate::core::{Event, EventStream, InvocationContext, LlmResponse};
11use crate::error::{Error, Result};
12
13use crate::agents::base::BaseAgent;
14
15pub(crate) fn is_resumable(ctx: &InvocationContext) -> bool {
17 ctx.run_config
18 .resumability
19 .map(|r| r.is_resumable)
20 .unwrap_or(false)
21}
22
23pub(crate) fn invocation_paused(ctx: &InvocationContext) -> bool {
26 ctx.attributes
27 .lock()
28 .get("invocation.paused")
29 .and_then(serde_json::Value::as_bool)
30 .unwrap_or(false)
31}
32
33pub(crate) fn completed_sub_agents(ctx: &InvocationContext, author: &str) -> usize {
36 let sess = ctx.session.lock();
37 sess.events
38 .iter()
39 .rev()
40 .find(|e| {
41 e.invocation_id == ctx.invocation_id
42 && e.author == author
43 && e.actions.agent_state.is_some()
44 })
45 .and_then(|e| e.actions.agent_state.as_ref())
46 .and_then(|s| s.get("completed_sub_agents"))
47 .and_then(serde_json::Value::as_u64)
48 .unwrap_or(0) as usize
49}
50
51pub(crate) fn checkpoint_event(author: &str, invocation_id: &str, n: usize) -> Event {
53 let mut e = Event::new(author, LlmResponse::default());
54 e.invocation_id = invocation_id.to_string();
55 e.actions.agent_state = Some(serde_json::json!({ "completed_sub_agents": n }));
56 e
57}
58
59#[derive(Debug)]
61pub struct SequentialAgent {
62 name: String,
63 description: String,
64 sub_agents: Vec<Arc<dyn BaseAgent>>,
65}
66
67impl SequentialAgent {
68 pub fn new(
70 name: impl Into<String>,
71 description: impl Into<String>,
72 sub_agents: Vec<Arc<dyn BaseAgent>>,
73 ) -> Result<Self> {
74 if sub_agents.is_empty() {
75 return Err(Error::config(
76 "SequentialAgent requires at least one sub_agent",
77 ));
78 }
79 Ok(Self {
80 name: name.into(),
81 description: description.into(),
82 sub_agents,
83 })
84 }
85}
86
87#[async_trait]
88impl BaseAgent for SequentialAgent {
89 fn name(&self) -> &str {
90 &self.name
91 }
92 fn description(&self) -> &str {
93 &self.description
94 }
95 fn sub_agents(&self) -> &[Arc<dyn BaseAgent>] {
96 &self.sub_agents
97 }
98 async fn run(self: Arc<Self>, ctx: Arc<InvocationContext>) -> Result<EventStream<'static>> {
99 let me = self.clone();
100 let stream = try_stream! {
101 let resumable = is_resumable(&ctx);
102 let start_index = if resumable {
105 completed_sub_agents(&ctx, &me.name)
106 } else {
107 0
108 };
109 for (i, sub) in me.sub_agents.iter().enumerate().skip(start_index) {
110 if ctx.is_cancelled() {
111 return;
112 }
113 let mut s = Box::pin(sub.clone().run(ctx.clone()).await?);
114 while let Some(ev) = s.next().await {
115 let ev = ev?;
116 let escalate = ev.actions.escalate == Some(true);
118 yield ev;
119 if escalate {
120 return;
121 }
122 }
123 if invocation_paused(&ctx) {
127 return;
128 }
129 if resumable && i + 1 < me.sub_agents.len() {
130 yield checkpoint_event(&me.name, &ctx.invocation_id, i + 1);
131 }
132 }
133 };
134 Ok(Box::pin(stream))
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::agents::tests_support::{stub_agent, test_ctx};
142
143 #[tokio::test]
144 async fn empty_sub_agents_rejected() {
145 let err = SequentialAgent::new("seq", "d", vec![]).unwrap_err();
146 assert!(err.to_string().contains("at least one sub_agent"));
147 }
148
149 #[tokio::test]
150 async fn runs_sub_agents_in_declared_order() {
151 let a = stub_agent("a", &["a-msg"], false);
152 let b = stub_agent("b", &["b-msg"], false);
153 let seq = Arc::new(SequentialAgent::new("seq", "", vec![a, b]).unwrap());
154 let mut stream = seq.run(test_ctx()).await.unwrap();
155 let mut authors = Vec::new();
156 while let Some(ev) = stream.next().await {
157 authors.push(ev.unwrap().author);
158 }
159 assert_eq!(authors, vec!["a", "b"]);
160 }
161
162 #[tokio::test]
163 async fn stops_after_escalate() {
164 let a = stub_agent("a", &["a-msg"], true); let b = stub_agent("b", &["b-msg"], false);
166 let seq = Arc::new(SequentialAgent::new("seq", "", vec![a, b]).unwrap());
167 let mut stream = seq.run(test_ctx()).await.unwrap();
168 let mut authors = Vec::new();
169 while let Some(ev) = stream.next().await {
170 authors.push(ev.unwrap().author);
171 }
172 assert_eq!(authors, vec!["a"], "b should not have run after escalate");
173 }
174}