rs_adk/agents/
sequential.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use crate::agent::Agent;
8use crate::context::InvocationContext;
9use crate::error::AgentError;
10
11pub struct SequentialAgent {
17 name: String,
18 sub_agents: Vec<Arc<dyn Agent>>,
19}
20
21impl SequentialAgent {
22 pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
24 Self {
25 name: name.into(),
26 sub_agents,
27 }
28 }
29}
30
31#[async_trait]
32impl Agent for SequentialAgent {
33 fn name(&self) -> &str {
34 &self.name
35 }
36
37 async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
38 for sub in &self.sub_agents {
39 sub.run_live(ctx).await?;
40 }
41 Ok(())
42 }
43
44 fn sub_agents(&self) -> Vec<Arc<dyn Agent>> {
45 self.sub_agents.clone()
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use super::*;
52 use crate::agent_session::AgentSession;
53 use crate::context::InvocationContext;
54 use crate::error::AgentError;
55 use std::sync::Arc;
56 use tokio::sync::broadcast;
57
58 use crate::agent_session::NoOpSessionWriter;
59
60 fn test_ctx() -> InvocationContext {
62 let (event_tx, _) = broadcast::channel(16);
63 let writer: Arc<dyn rs_genai::session::SessionWriter> = Arc::new(NoOpSessionWriter);
64 let agent_session = AgentSession::from_writer(writer, event_tx);
65 InvocationContext::new(agent_session)
66 }
67
68 struct AppendAgent {
70 agent_name: String,
71 log: Arc<parking_lot::Mutex<Vec<String>>>,
72 }
73
74 #[async_trait]
75 impl Agent for AppendAgent {
76 fn name(&self) -> &str {
77 &self.agent_name
78 }
79
80 async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
81 self.log.lock().push(self.agent_name.clone());
82 Ok(())
83 }
84 }
85
86 struct FailAgent {
88 agent_name: String,
89 log: Arc<parking_lot::Mutex<Vec<String>>>,
90 }
91
92 #[async_trait]
93 impl Agent for FailAgent {
94 fn name(&self) -> &str {
95 &self.agent_name
96 }
97
98 async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
99 self.log.lock().push(self.agent_name.clone());
100 Err(AgentError::Other("fail".to_string()))
101 }
102 }
103
104 struct TransferAgent {
106 agent_name: String,
107 target: String,
108 log: Arc<parking_lot::Mutex<Vec<String>>>,
109 }
110
111 #[async_trait]
112 impl Agent for TransferAgent {
113 fn name(&self) -> &str {
114 &self.agent_name
115 }
116
117 async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
118 self.log.lock().push(self.agent_name.clone());
119 Err(AgentError::TransferRequested(self.target.clone()))
120 }
121 }
122
123 #[tokio::test]
124 async fn sequential_runs_all_in_order() {
125 let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
126 let agents: Vec<Arc<dyn Agent>> = vec![
127 Arc::new(AppendAgent {
128 agent_name: "a".into(),
129 log: log.clone(),
130 }),
131 Arc::new(AppendAgent {
132 agent_name: "b".into(),
133 log: log.clone(),
134 }),
135 Arc::new(AppendAgent {
136 agent_name: "c".into(),
137 log: log.clone(),
138 }),
139 ];
140
141 let seq = SequentialAgent::new("seq", agents);
142 let mut ctx = test_ctx();
143 seq.run_live(&mut ctx).await.unwrap();
144
145 let entries = log.lock().clone();
146 assert_eq!(entries, vec!["a", "b", "c"]);
147 }
148
149 #[tokio::test]
150 async fn sequential_stops_on_error() {
151 let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
152 let agents: Vec<Arc<dyn Agent>> = vec![
153 Arc::new(AppendAgent {
154 agent_name: "a".into(),
155 log: log.clone(),
156 }),
157 Arc::new(FailAgent {
158 agent_name: "b".into(),
159 log: log.clone(),
160 }),
161 Arc::new(AppendAgent {
162 agent_name: "c".into(),
163 log: log.clone(),
164 }),
165 ];
166
167 let seq = SequentialAgent::new("seq", agents);
168 let mut ctx = test_ctx();
169 let result = seq.run_live(&mut ctx).await;
170
171 assert!(result.is_err());
172 let entries = log.lock().clone();
173 assert_eq!(entries, vec!["a", "b"]); }
175
176 #[tokio::test]
177 async fn sequential_propagates_transfer() {
178 let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
179 let agents: Vec<Arc<dyn Agent>> = vec![
180 Arc::new(AppendAgent {
181 agent_name: "a".into(),
182 log: log.clone(),
183 }),
184 Arc::new(TransferAgent {
185 agent_name: "b".into(),
186 target: "target_agent".into(),
187 log: log.clone(),
188 }),
189 Arc::new(AppendAgent {
190 agent_name: "c".into(),
191 log: log.clone(),
192 }),
193 ];
194
195 let seq = SequentialAgent::new("seq", agents);
196 let mut ctx = test_ctx();
197 let result = seq.run_live(&mut ctx).await;
198
199 match result {
200 Err(AgentError::TransferRequested(target)) => {
201 assert_eq!(target, "target_agent");
202 }
203 other => panic!("expected TransferRequested, got {:?}", other),
204 }
205 let entries = log.lock().clone();
206 assert_eq!(entries, vec!["a", "b"]); }
208
209 #[tokio::test]
210 async fn sequential_empty_succeeds() {
211 let seq = SequentialAgent::new("empty", vec![]);
212 let mut ctx = test_ctx();
213 seq.run_live(&mut ctx).await.unwrap();
214 }
215
216 #[test]
217 fn sequential_sub_agents_returns_children() {
218 let log = Arc::new(parking_lot::Mutex::new(Vec::new()));
219 let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(AppendAgent {
220 agent_name: "child".into(),
221 log,
222 })];
223
224 let seq = SequentialAgent::new("seq", agents);
225 assert_eq!(seq.sub_agents().len(), 1);
226 assert_eq!(seq.sub_agents()[0].name(), "child");
227 }
228}