Skip to main content

rs_adk/agents/
sequential.rs

1//! SequentialAgent — runs sub-agents one after another.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use crate::agent::Agent;
8use crate::context::InvocationContext;
9use crate::error::AgentError;
10
11/// Runs sub-agents in sequential order.
12///
13/// Each sub-agent runs to completion before the next starts.
14/// If any sub-agent returns an error (including `TransferRequested`), execution stops
15/// and the error is propagated to the caller.
16pub struct SequentialAgent {
17    name: String,
18    sub_agents: Vec<Arc<dyn Agent>>,
19}
20
21impl SequentialAgent {
22    /// Create a new sequential agent with the given name and ordered sub-agents.
23    pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
24        Self {
25            name: name.into(),
26            sub_agents,
27        }
28    }
29}
30
31#[async_trait]
32impl Agent for SequentialAgent {
33    fn name(&self) -> &str {
34        &self.name
35    }
36
37    async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
38        for sub in &self.sub_agents {
39            sub.run_live(ctx).await?;
40        }
41        Ok(())
42    }
43
44    fn sub_agents(&self) -> Vec<Arc<dyn Agent>> {
45        self.sub_agents.clone()
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52    use crate::agent_session::AgentSession;
53    use crate::context::InvocationContext;
54    use crate::error::AgentError;
55    use std::sync::Arc;
56    use tokio::sync::broadcast;
57
58    use crate::agent_session::NoOpSessionWriter;
59
60    /// Helper: create a test InvocationContext with a no-op session.
61    fn test_ctx() -> InvocationContext {
62        let (event_tx, _) = broadcast::channel(16);
63        let writer: Arc<dyn rs_genai::session::SessionWriter> = Arc::new(NoOpSessionWriter);
64        let agent_session = AgentSession::from_writer(writer, event_tx);
65        InvocationContext::new(agent_session)
66    }
67
68    /// A test agent that appends its name to a shared Vec when run.
69    struct AppendAgent {
70        agent_name: String,
71        log: Arc<parking_lot::Mutex<Vec<String>>>,
72    }
73
74    #[async_trait]
75    impl Agent for AppendAgent {
76        fn name(&self) -> &str {
77            &self.agent_name
78        }
79
80        async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
81            self.log.lock().push(self.agent_name.clone());
82            Ok(())
83        }
84    }
85
86    /// A test agent that fails with an error.
87    struct FailAgent {
88        agent_name: String,
89        log: Arc<parking_lot::Mutex<Vec<String>>>,
90    }
91
92    #[async_trait]
93    impl Agent for FailAgent {
94        fn name(&self) -> &str {
95            &self.agent_name
96        }
97
98        async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
99            self.log.lock().push(self.agent_name.clone());
100            Err(AgentError::Other("fail".to_string()))
101        }
102    }
103
104    /// A test agent that returns TransferRequested.
105    struct TransferAgent {
106        agent_name: String,
107        target: String,
108        log: Arc<parking_lot::Mutex<Vec<String>>>,
109    }
110
111    #[async_trait]
112    impl Agent for TransferAgent {
113        fn name(&self) -> &str {
114            &self.agent_name
115        }
116
117        async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
118            self.log.lock().push(self.agent_name.clone());
119            Err(AgentError::TransferRequested(self.target.clone()))
120        }
121    }
122
123    #[tokio::test]
124    async fn sequential_runs_all_in_order() {
125        let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
126        let agents: Vec<Arc<dyn Agent>> = vec![
127            Arc::new(AppendAgent {
128                agent_name: "a".into(),
129                log: log.clone(),
130            }),
131            Arc::new(AppendAgent {
132                agent_name: "b".into(),
133                log: log.clone(),
134            }),
135            Arc::new(AppendAgent {
136                agent_name: "c".into(),
137                log: log.clone(),
138            }),
139        ];
140
141        let seq = SequentialAgent::new("seq", agents);
142        let mut ctx = test_ctx();
143        seq.run_live(&mut ctx).await.unwrap();
144
145        let entries = log.lock().clone();
146        assert_eq!(entries, vec!["a", "b", "c"]);
147    }
148
149    #[tokio::test]
150    async fn sequential_stops_on_error() {
151        let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
152        let agents: Vec<Arc<dyn Agent>> = vec![
153            Arc::new(AppendAgent {
154                agent_name: "a".into(),
155                log: log.clone(),
156            }),
157            Arc::new(FailAgent {
158                agent_name: "b".into(),
159                log: log.clone(),
160            }),
161            Arc::new(AppendAgent {
162                agent_name: "c".into(),
163                log: log.clone(),
164            }),
165        ];
166
167        let seq = SequentialAgent::new("seq", agents);
168        let mut ctx = test_ctx();
169        let result = seq.run_live(&mut ctx).await;
170
171        assert!(result.is_err());
172        let entries = log.lock().clone();
173        assert_eq!(entries, vec!["a", "b"]); // c never ran
174    }
175
176    #[tokio::test]
177    async fn sequential_propagates_transfer() {
178        let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
179        let agents: Vec<Arc<dyn Agent>> = vec![
180            Arc::new(AppendAgent {
181                agent_name: "a".into(),
182                log: log.clone(),
183            }),
184            Arc::new(TransferAgent {
185                agent_name: "b".into(),
186                target: "target_agent".into(),
187                log: log.clone(),
188            }),
189            Arc::new(AppendAgent {
190                agent_name: "c".into(),
191                log: log.clone(),
192            }),
193        ];
194
195        let seq = SequentialAgent::new("seq", agents);
196        let mut ctx = test_ctx();
197        let result = seq.run_live(&mut ctx).await;
198
199        match result {
200            Err(AgentError::TransferRequested(target)) => {
201                assert_eq!(target, "target_agent");
202            }
203            other => panic!("expected TransferRequested, got {:?}", other),
204        }
205        let entries = log.lock().clone();
206        assert_eq!(entries, vec!["a", "b"]); // c never ran
207    }
208
209    #[tokio::test]
210    async fn sequential_empty_succeeds() {
211        let seq = SequentialAgent::new("empty", vec![]);
212        let mut ctx = test_ctx();
213        seq.run_live(&mut ctx).await.unwrap();
214    }
215
216    #[test]
217    fn sequential_sub_agents_returns_children() {
218        let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
219        let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(AppendAgent {
220            agent_name: "child".into(),
221            log,
222        })];
223
224        let seq = SequentialAgent::new("seq", agents);
225        assert_eq!(seq.sub_agents().len(), 1);
226        assert_eq!(seq.sub_agents()[0].name(), "child");
227    }
228}