Skip to main content

adk_rs/agents/
sequential_agent.rs

1//! [`SequentialAgent`] — run sub-agents one after another, in order. Each
2//! sub-agent sees the cumulative event history.
3
4use std::sync::Arc;
5
6use async_stream::try_stream;
7use async_trait::async_trait;
8use futures::StreamExt;
9
10use crate::core::{Event, EventStream, InvocationContext, LlmResponse};
11use crate::error::{Error, Result};
12
13use crate::agents::base::BaseAgent;
14
15/// True when [`crate::core::RunConfig::resumability`] enables resume.
16pub(crate) fn is_resumable(ctx: &InvocationContext) -> bool {
17    ctx.run_config
18        .resumability
19        .map(|r| r.is_resumable)
20        .unwrap_or(false)
21}
22
23/// True when a descendant agent paused the invocation (long-running tool,
24/// confirmation, or auth consent pending).
25pub(crate) fn invocation_paused(ctx: &InvocationContext) -> bool {
26    ctx.attributes
27        .lock()
28        .get("invocation.paused")
29        .and_then(serde_json::Value::as_bool)
30        .unwrap_or(false)
31}
32
33/// Latest checkpoint recorded by `author` within this invocation: how many
34/// sub-agents have completed.
35pub(crate) fn completed_sub_agents(ctx: &InvocationContext, author: &str) -> usize {
36    let sess = ctx.session.lock();
37    sess.events
38        .iter()
39        .rev()
40        .find(|e| {
41            e.invocation_id == ctx.invocation_id
42                && e.author == author
43                && e.actions.agent_state.is_some()
44        })
45        .and_then(|e| e.actions.agent_state.as_ref())
46        .and_then(|s| s.get("completed_sub_agents"))
47        .and_then(serde_json::Value::as_u64)
48        .unwrap_or(0) as usize
49}
50
51/// Build a checkpoint event recording that `n` sub-agents completed.
52pub(crate) fn checkpoint_event(author: &str, invocation_id: &str, n: usize) -> Event {
53    let mut e = Event::new(author, LlmResponse::default());
54    e.invocation_id = invocation_id.to_string();
55    e.actions.agent_state = Some(serde_json::json!({ "completed_sub_agents": n }));
56    e
57}
58
59/// Run sub-agents in declared order.
60#[derive(Debug)]
61pub struct SequentialAgent {
62    name: String,
63    description: String,
64    sub_agents: Vec<Arc<dyn BaseAgent>>,
65}
66
67impl SequentialAgent {
68    /// Construct.
69    pub fn new(
70        name: impl Into<String>,
71        description: impl Into<String>,
72        sub_agents: Vec<Arc<dyn BaseAgent>>,
73    ) -> Result<Self> {
74        if sub_agents.is_empty() {
75            return Err(Error::config(
76                "SequentialAgent requires at least one sub_agent",
77            ));
78        }
79        Ok(Self {
80            name: name.into(),
81            description: description.into(),
82            sub_agents,
83        })
84    }
85}
86
87#[async_trait]
88impl BaseAgent for SequentialAgent {
89    fn name(&self) -> &str {
90        &self.name
91    }
92    fn description(&self) -> &str {
93        &self.description
94    }
95    fn sub_agents(&self) -> &[Arc<dyn BaseAgent>] {
96        &self.sub_agents
97    }
98    async fn run(self: Arc<Self>, ctx: Arc<InvocationContext>) -> Result<EventStream<'static>> {
99        let me = self.clone();
100        let stream = try_stream! {
101            let resumable = is_resumable(&ctx);
102            // On an in-place resume (same invocation id), skip sub-agents
103            // that completed before the pause.
104            let start_index = if resumable {
105                completed_sub_agents(&ctx, &me.name)
106            } else {
107                0
108            };
109            for (i, sub) in me.sub_agents.iter().enumerate().skip(start_index) {
110                if ctx.is_cancelled() {
111                    return;
112                }
113                let mut s = Box::pin(sub.clone().run(ctx.clone()).await?);
114                while let Some(ev) = s.next().await {
115                    let ev = ev?;
116                    // If a sub-agent escalates, stop the sequence.
117                    let escalate = ev.actions.escalate == Some(true);
118                    yield ev;
119                    if escalate {
120                        return;
121                    }
122                }
123                // A paused sub-agent (HITL confirmation, auth consent,
124                // long-running tool) suspends the whole pipeline; resume
125                // re-enters at this index thanks to the checkpoint below.
126                if invocation_paused(&ctx) {
127                    return;
128                }
129                if resumable && i + 1 < me.sub_agents.len() {
130                    yield checkpoint_event(&me.name, &ctx.invocation_id, i + 1);
131                }
132            }
133        };
134        Ok(Box::pin(stream))
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use crate::agents::tests_support::{stub_agent, test_ctx};
142
143    #[tokio::test]
144    async fn empty_sub_agents_rejected() {
145        let err = SequentialAgent::new("seq", "d", vec![]).unwrap_err();
146        assert!(err.to_string().contains("at least one sub_agent"));
147    }
148
149    #[tokio::test]
150    async fn runs_sub_agents_in_declared_order() {
151        let a = stub_agent("a", &["a-msg"], false);
152        let b = stub_agent("b", &["b-msg"], false);
153        let seq = Arc::new(SequentialAgent::new("seq", "", vec![a, b]).unwrap());
154        let mut stream = seq.run(test_ctx()).await.unwrap();
155        let mut authors = Vec::new();
156        while let Some(ev) = stream.next().await {
157            authors.push(ev.unwrap().author);
158        }
159        assert_eq!(authors, vec!["a", "b"]);
160    }
161
162    #[tokio::test]
163    async fn stops_after_escalate() {
164        let a = stub_agent("a", &["a-msg"], true); // escalates
165        let b = stub_agent("b", &["b-msg"], false);
166        let seq = Arc::new(SequentialAgent::new("seq", "", vec![a, b]).unwrap());
167        let mut stream = seq.run(test_ctx()).await.unwrap();
168        let mut authors = Vec::new();
169        while let Some(ev) = stream.next().await {
170            authors.push(ev.unwrap().author);
171        }
172        assert_eq!(authors, vec!["a"], "b should not have run after escalate");
173    }
174}