use std::sync::Arc;
use crate::agent::Agent;
use crate::agent_session::AgentSession;
use crate::context::InvocationContext;
use crate::error::AgentError;
use crate::middleware::MiddlewareChain;
use crate::plugin::{Plugin, PluginManager};
use crate::router::AgentRegistry;
use crate::state::State;
pub struct Runner {
root_agent: Arc<dyn Agent>,
registry: AgentRegistry,
middleware: MiddlewareChain,
plugins: PluginManager,
state: State,
}
impl Runner {
pub fn new(root_agent: impl Agent + 'static) -> Self {
let agent = Arc::new(root_agent);
let mut registry = AgentRegistry::new();
Self::register_tree(&mut registry, agent.clone());
Self {
root_agent: agent,
registry,
middleware: MiddlewareChain::new(),
plugins: PluginManager::new(),
state: State::new(),
}
}
pub fn from_arc(root_agent: Arc<dyn Agent>) -> Self {
let mut registry = AgentRegistry::new();
Self::register_tree(&mut registry, root_agent.clone());
Self {
root_agent,
registry,
middleware: MiddlewareChain::new(),
plugins: PluginManager::new(),
state: State::new(),
}
}
pub fn with_middleware(mut self, mw: impl crate::middleware::Middleware + 'static) -> Self {
self.middleware.add(Arc::new(mw));
self
}
pub fn with_plugin(mut self, plugin: impl Plugin + 'static) -> Self {
self.plugins.add(Arc::new(plugin));
self
}
pub fn with_state(mut self, state: State) -> Self {
self.state = state;
self
}
pub fn register(&mut self, agent: Arc<dyn Agent>) {
self.registry.register(agent);
}
pub fn registry(&self) -> &AgentRegistry {
&self.registry
}
pub fn root_agent(&self) -> &dyn Agent {
self.root_agent.as_ref()
}
pub async fn run<F, Fut>(&self, connect_fn: F) -> Result<(), AgentError>
where
F: Fn(Arc<dyn Agent>) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<AgentSession, AgentError>> + Send,
{
let mut current_agent = self.root_agent.clone();
let runner_state = self.state.clone();
crate::telemetry::logging::log_agent_started(
current_agent.name(),
0, );
loop {
let agent_session = connect_fn(current_agent.clone()).await?;
agent_session.state().merge(&runner_state);
let mut ctx =
InvocationContext::with_middleware(agent_session.clone(), self.middleware.clone());
self.plugins.run_before_run(&ctx).await;
match current_agent.run_live(&mut ctx).await {
Ok(()) => {
self.plugins.run_after_run(&ctx).await;
runner_state.merge(agent_session.state());
break;
}
Err(AgentError::TransferRequested(target_name)) => {
let target = self
.registry
.resolve(&target_name)
.ok_or_else(|| AgentError::UnknownAgent(target_name.clone()))?;
crate::telemetry::logging::log_agent_transfer(
current_agent.name(),
&target_name,
);
crate::telemetry::metrics::record_agent_transfer(
current_agent.name(),
&target_name,
);
runner_state.merge(agent_session.state());
let _ = agent_session.disconnect().await;
current_agent = target;
continue;
}
Err(e) => {
runner_state.merge(agent_session.state());
let _ = agent_session.disconnect().await;
return Err(e);
}
}
}
Ok(())
}
fn register_tree(registry: &mut AgentRegistry, agent: Arc<dyn Agent>) {
registry.register(agent.clone());
for sub in agent.sub_agents() {
Self::register_tree(registry, sub);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::AgentError;
use async_trait::async_trait;
use rs_genai::session::{SessionHandle, SessionPhase, SessionState};
use std::sync::atomic::{AtomicU32, Ordering};
use tokio::sync::{broadcast, mpsc, watch};
struct NoopAgent {
name: String,
}
#[async_trait]
impl Agent for NoopAgent {
fn name(&self) -> &str {
&self.name
}
async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
Ok(())
}
}
struct TransferAgent {
name: String,
target: String,
}
#[async_trait]
impl Agent for TransferAgent {
fn name(&self) -> &str {
&self.name
}
async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
Err(AgentError::TransferRequested(self.target.clone()))
}
fn sub_agents(&self) -> Vec<Arc<dyn Agent>> {
vec![]
}
}
struct StateReaderAgent {
name: String,
key: String,
expected: String,
}
#[async_trait]
impl Agent for StateReaderAgent {
fn name(&self) -> &str {
&self.name
}
async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
let val = ctx.state().get::<String>(&self.key);
assert_eq!(val.as_deref(), Some(self.expected.as_str()));
Ok(())
}
}
struct FailingAgent;
#[async_trait]
impl Agent for FailingAgent {
fn name(&self) -> &str {
"failing"
}
async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
Err(AgentError::Other("boom".to_string()))
}
}
fn mock_session_handle() -> SessionHandle {
let (cmd_tx, _cmd_rx) = mpsc::channel(16);
let (evt_tx, _) = broadcast::channel(16);
let (phase_tx, phase_rx) = watch::channel(SessionPhase::Active);
let state = Arc::new(SessionState::new(phase_tx));
SessionHandle::new(cmd_tx, evt_tx, state, phase_rx)
}
fn mock_agent_session() -> AgentSession {
AgentSession::new(mock_session_handle())
}
#[tokio::test]
async fn runner_runs_single_agent() {
let agent = NoopAgent {
name: "root".to_string(),
};
let runner = Runner::new(agent);
let result = runner
.run(|_agent| async { Ok(mock_agent_session()) })
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn runner_handles_transfer() {
let target = Arc::new(NoopAgent {
name: "target".to_string(),
});
let root = TransferAgent {
name: "root".to_string(),
target: "target".to_string(),
};
let mut runner = Runner::new(root);
runner.register(target);
let connect_count = Arc::new(AtomicU32::new(0));
let count = connect_count.clone();
let result = runner
.run(move |_agent| {
let c = count.clone();
async move {
c.fetch_add(1, Ordering::SeqCst);
Ok(mock_agent_session())
}
})
.await;
assert!(result.is_ok());
assert_eq!(connect_count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn runner_preserves_state_across_transfer() {
let agent_b = Arc::new(StateReaderAgent {
name: "agent_b".to_string(),
key: "greeting".to_string(),
expected: "hello from A".to_string(),
});
struct SetAndTransferAgent;
#[async_trait]
impl Agent for SetAndTransferAgent {
fn name(&self) -> &str {
"agent_a"
}
async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
ctx.state().set("greeting", "hello from A");
Err(AgentError::TransferRequested("agent_b".to_string()))
}
}
let mut runner = Runner::new(SetAndTransferAgent);
runner.register(agent_b);
let result = runner
.run(|_agent| async { Ok(mock_agent_session()) })
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn runner_fails_on_unknown_transfer_target() {
let root = TransferAgent {
name: "root".to_string(),
target: "nonexistent".to_string(),
};
let runner = Runner::new(root);
let result = runner
.run(|_agent| async { Ok(mock_agent_session()) })
.await;
match result {
Err(AgentError::UnknownAgent(name)) => assert_eq!(name, "nonexistent"),
other => panic!("expected UnknownAgent, got: {:?}", other),
}
}
#[tokio::test]
async fn runner_propagates_errors() {
let runner = Runner::new(FailingAgent);
let result = runner
.run(|_agent| async { Ok(mock_agent_session()) })
.await;
match result {
Err(AgentError::Other(msg)) => assert_eq!(msg, "boom"),
other => panic!("expected Other error, got: {:?}", other),
}
}
#[tokio::test]
async fn runner_with_initial_state() {
struct StateCheckAgent;
#[async_trait]
impl Agent for StateCheckAgent {
fn name(&self) -> &str {
"checker"
}
async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
let val = ctx.state().get::<String>("initial_key");
assert_eq!(val.as_deref(), Some("initial_value"));
Ok(())
}
}
let initial_state = State::new();
initial_state.set("initial_key", "initial_value");
let runner = Runner::new(StateCheckAgent).with_state(initial_state);
let result = runner
.run(|_agent| async { Ok(mock_agent_session()) })
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn runner_auto_registers_sub_agents() {
struct ParentAgent;
#[async_trait]
impl Agent for ParentAgent {
fn name(&self) -> &str {
"parent"
}
async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
Ok(())
}
fn sub_agents(&self) -> Vec<Arc<dyn Agent>> {
vec![
Arc::new(NoopAgent {
name: "child_a".to_string(),
}),
Arc::new(NoopAgent {
name: "child_b".to_string(),
}),
]
}
}
let runner = Runner::new(ParentAgent);
assert!(runner.registry().resolve("parent").is_some());
assert!(runner.registry().resolve("child_a").is_some());
assert!(runner.registry().resolve("child_b").is_some());
}
}