1use std::sync::Arc;
8
9use crate::agent::Agent;
10use crate::agent_session::AgentSession;
11use crate::context::InvocationContext;
12use crate::error::AgentError;
13use crate::middleware::MiddlewareChain;
14use crate::plugin::{Plugin, PluginManager};
15use crate::router::AgentRegistry;
16use crate::state::State;
17
18pub struct Runner {
38 root_agent: Arc<dyn Agent>,
39 registry: AgentRegistry,
40 middleware: MiddlewareChain,
41 plugins: PluginManager,
42 state: State,
43}
44
45impl Runner {
46 pub fn new(root_agent: impl Agent + 'static) -> Self {
50 let agent = Arc::new(root_agent);
51 let mut registry = AgentRegistry::new();
52 Self::register_tree(&mut registry, agent.clone());
53 Self {
54 root_agent: agent,
55 registry,
56 middleware: MiddlewareChain::new(),
57 plugins: PluginManager::new(),
58 state: State::new(),
59 }
60 }
61
62 pub fn from_arc(root_agent: Arc<dyn Agent>) -> Self {
64 let mut registry = AgentRegistry::new();
65 Self::register_tree(&mut registry, root_agent.clone());
66 Self {
67 root_agent,
68 registry,
69 middleware: MiddlewareChain::new(),
70 plugins: PluginManager::new(),
71 state: State::new(),
72 }
73 }
74
75 pub fn with_middleware(mut self, mw: impl crate::middleware::Middleware + 'static) -> Self {
77 self.middleware.add(Arc::new(mw));
78 self
79 }
80
81 pub fn with_plugin(mut self, plugin: impl Plugin + 'static) -> Self {
83 self.plugins.add(Arc::new(plugin));
84 self
85 }
86
87 pub fn with_state(mut self, state: State) -> Self {
89 self.state = state;
90 self
91 }
92
93 pub fn register(&mut self, agent: Arc<dyn Agent>) {
95 self.registry.register(agent);
96 }
97
98 pub fn registry(&self) -> &AgentRegistry {
100 &self.registry
101 }
102
103 pub fn root_agent(&self) -> &dyn Agent {
105 self.root_agent.as_ref()
106 }
107
108 pub async fn run<F, Fut>(&self, connect_fn: F) -> Result<(), AgentError>
121 where
122 F: Fn(Arc<dyn Agent>) -> Fut + Send + Sync,
123 Fut: std::future::Future<Output = Result<AgentSession, AgentError>> + Send,
124 {
125 let mut current_agent = self.root_agent.clone();
126 let runner_state = self.state.clone();
127
128 crate::telemetry::logging::log_agent_started(
130 current_agent.name(),
131 0, );
133
134 loop {
135 let agent_session = connect_fn(current_agent.clone()).await?;
137
138 agent_session.state().merge(&runner_state);
140
141 let mut ctx =
143 InvocationContext::with_middleware(agent_session.clone(), self.middleware.clone());
144
145 self.plugins.run_before_run(&ctx).await;
147
148 match current_agent.run_live(&mut ctx).await {
150 Ok(()) => {
151 self.plugins.run_after_run(&ctx).await;
153 runner_state.merge(agent_session.state());
155 break;
156 }
157 Err(AgentError::TransferRequested(target_name)) => {
158 let target = self
160 .registry
161 .resolve(&target_name)
162 .ok_or_else(|| AgentError::UnknownAgent(target_name.clone()))?;
163
164 crate::telemetry::logging::log_agent_transfer(
165 current_agent.name(),
166 &target_name,
167 );
168 crate::telemetry::metrics::record_agent_transfer(
169 current_agent.name(),
170 &target_name,
171 );
172
173 runner_state.merge(agent_session.state());
175
176 let _ = agent_session.disconnect().await;
178
179 current_agent = target;
181 continue;
182 }
183 Err(e) => {
184 runner_state.merge(agent_session.state());
186 let _ = agent_session.disconnect().await;
187 return Err(e);
188 }
189 }
190 }
191
192 Ok(())
193 }
194
195 fn register_tree(registry: &mut AgentRegistry, agent: Arc<dyn Agent>) {
197 registry.register(agent.clone());
198 for sub in agent.sub_agents() {
199 Self::register_tree(registry, sub);
200 }
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use crate::error::AgentError;
208 use async_trait::async_trait;
209 use rs_genai::session::{SessionHandle, SessionPhase, SessionState};
210 use std::sync::atomic::{AtomicU32, Ordering};
211 use tokio::sync::{broadcast, mpsc, watch};
212
213 struct NoopAgent {
215 name: String,
216 }
217
218 #[async_trait]
219 impl Agent for NoopAgent {
220 fn name(&self) -> &str {
221 &self.name
222 }
223 async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
224 Ok(())
225 }
226 }
227
228 struct TransferAgent {
230 name: String,
231 target: String,
232 }
233
234 #[async_trait]
235 impl Agent for TransferAgent {
236 fn name(&self) -> &str {
237 &self.name
238 }
239 async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
240 Err(AgentError::TransferRequested(self.target.clone()))
241 }
242 fn sub_agents(&self) -> Vec<Arc<dyn Agent>> {
243 vec![]
244 }
245 }
246
247 struct StateReaderAgent {
249 name: String,
250 key: String,
251 expected: String,
252 }
253
254 #[async_trait]
255 impl Agent for StateReaderAgent {
256 fn name(&self) -> &str {
257 &self.name
258 }
259 async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
260 let val = ctx.state().get::<String>(&self.key);
261 assert_eq!(val.as_deref(), Some(self.expected.as_str()));
262 Ok(())
263 }
264 }
265
266 struct FailingAgent;
268
269 #[async_trait]
270 impl Agent for FailingAgent {
271 fn name(&self) -> &str {
272 "failing"
273 }
274 async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
275 Err(AgentError::Other("boom".to_string()))
276 }
277 }
278
279 fn mock_session_handle() -> SessionHandle {
280 let (cmd_tx, _cmd_rx) = mpsc::channel(16);
281 let (evt_tx, _) = broadcast::channel(16);
282 let (phase_tx, phase_rx) = watch::channel(SessionPhase::Active);
283 let state = Arc::new(SessionState::new(phase_tx));
284 SessionHandle::new(cmd_tx, evt_tx, state, phase_rx)
285 }
286
287 fn mock_agent_session() -> AgentSession {
288 AgentSession::new(mock_session_handle())
289 }
290
291 #[tokio::test]
292 async fn runner_runs_single_agent() {
293 let agent = NoopAgent {
294 name: "root".to_string(),
295 };
296 let runner = Runner::new(agent);
297
298 let result = runner
299 .run(|_agent| async { Ok(mock_agent_session()) })
300 .await;
301
302 assert!(result.is_ok());
303 }
304
305 #[tokio::test]
306 async fn runner_handles_transfer() {
307 let target = Arc::new(NoopAgent {
309 name: "target".to_string(),
310 });
311 let root = TransferAgent {
312 name: "root".to_string(),
313 target: "target".to_string(),
314 };
315
316 let mut runner = Runner::new(root);
317 runner.register(target);
319
320 let connect_count = Arc::new(AtomicU32::new(0));
321 let count = connect_count.clone();
322
323 let result = runner
324 .run(move |_agent| {
325 let c = count.clone();
326 async move {
327 c.fetch_add(1, Ordering::SeqCst);
328 Ok(mock_agent_session())
329 }
330 })
331 .await;
332
333 assert!(result.is_ok());
334 assert_eq!(connect_count.load(Ordering::SeqCst), 2);
336 }
337
338 #[tokio::test]
339 async fn runner_preserves_state_across_transfer() {
340 let agent_b = Arc::new(StateReaderAgent {
342 name: "agent_b".to_string(),
343 key: "greeting".to_string(),
344 expected: "hello from A".to_string(),
345 });
346
347 struct SetAndTransferAgent;
349 #[async_trait]
350 impl Agent for SetAndTransferAgent {
351 fn name(&self) -> &str {
352 "agent_a"
353 }
354 async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
355 ctx.state().set("greeting", "hello from A");
356 Err(AgentError::TransferRequested("agent_b".to_string()))
357 }
358 }
359
360 let mut runner = Runner::new(SetAndTransferAgent);
361 runner.register(agent_b);
362
363 let result = runner
364 .run(|_agent| async { Ok(mock_agent_session()) })
365 .await;
366
367 assert!(result.is_ok());
368 }
369
370 #[tokio::test]
371 async fn runner_fails_on_unknown_transfer_target() {
372 let root = TransferAgent {
373 name: "root".to_string(),
374 target: "nonexistent".to_string(),
375 };
376
377 let runner = Runner::new(root);
378
379 let result = runner
380 .run(|_agent| async { Ok(mock_agent_session()) })
381 .await;
382
383 match result {
384 Err(AgentError::UnknownAgent(name)) => assert_eq!(name, "nonexistent"),
385 other => panic!("expected UnknownAgent, got: {:?}", other),
386 }
387 }
388
389 #[tokio::test]
390 async fn runner_propagates_errors() {
391 let runner = Runner::new(FailingAgent);
392
393 let result = runner
394 .run(|_agent| async { Ok(mock_agent_session()) })
395 .await;
396
397 match result {
398 Err(AgentError::Other(msg)) => assert_eq!(msg, "boom"),
399 other => panic!("expected Other error, got: {:?}", other),
400 }
401 }
402
403 #[tokio::test]
404 async fn runner_with_initial_state() {
405 struct StateCheckAgent;
406 #[async_trait]
407 impl Agent for StateCheckAgent {
408 fn name(&self) -> &str {
409 "checker"
410 }
411 async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
412 let val = ctx.state().get::<String>("initial_key");
413 assert_eq!(val.as_deref(), Some("initial_value"));
414 Ok(())
415 }
416 }
417
418 let initial_state = State::new();
419 initial_state.set("initial_key", "initial_value");
420
421 let runner = Runner::new(StateCheckAgent).with_state(initial_state);
422
423 let result = runner
424 .run(|_agent| async { Ok(mock_agent_session()) })
425 .await;
426
427 assert!(result.is_ok());
428 }
429
430 #[tokio::test]
431 async fn runner_auto_registers_sub_agents() {
432 struct ParentAgent;
433 #[async_trait]
434 impl Agent for ParentAgent {
435 fn name(&self) -> &str {
436 "parent"
437 }
438 async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
439 Ok(())
440 }
441 fn sub_agents(&self) -> Vec<Arc<dyn Agent>> {
442 vec![
443 Arc::new(NoopAgent {
444 name: "child_a".to_string(),
445 }),
446 Arc::new(NoopAgent {
447 name: "child_b".to_string(),
448 }),
449 ]
450 }
451 }
452
453 let runner = Runner::new(ParentAgent);
454 assert!(runner.registry().resolve("parent").is_some());
455 assert!(runner.registry().resolve("child_a").is_some());
456 assert!(runner.registry().resolve("child_b").is_some());
457 }
458}