Skip to main content

rs_adk/
runner.rs

1//! Runner — orchestrates agent execution across Gemini Live sessions.
2//!
3//! Handles the full lifecycle: connect → run agent → handle transfer → reconnect → repeat.
4//! Transfer is complex in Gemini Live because tools are fixed at setup — changing agents
5//! means changing sessions (unlike traditional ADK where tools change per-call).
6
7use std::sync::Arc;
8
9use crate::agent::Agent;
10use crate::agent_session::AgentSession;
11use crate::context::InvocationContext;
12use crate::error::AgentError;
13use crate::middleware::MiddlewareChain;
14use crate::plugin::{Plugin, PluginManager};
15use crate::router::AgentRegistry;
16use crate::state::State;
17
18/// Orchestrates agent execution across Gemini Live sessions.
19///
20/// Handles the full lifecycle: connect → run → transfer → reconnect → repeat.
21/// Transfer is complex in Gemini Live because tools are fixed at setup —
22/// changing agents means changing sessions.
23///
24/// # Example
25///
26/// ```ignore
27/// let runner = Runner::new(root_agent);
28///
29/// runner.run(|agent| async move {
30///     let config = SessionConfig::new(&api_key)
31///         .model(GeminiModel::GeminiLive2_5FlashNativeAudio);
32///     // Add agent's tools to config
33///     let session = connect(config, TransportConfig::default()).await?;
34///     Ok(AgentSession::new(session))
35/// }).await?;
36/// ```
37pub struct Runner {
38    root_agent: Arc<dyn Agent>,
39    registry: AgentRegistry,
40    middleware: MiddlewareChain,
41    plugins: PluginManager,
42    state: State,
43}
44
45impl Runner {
46    /// Create a new Runner with a root agent.
47    ///
48    /// Automatically registers the root agent and all sub-agents recursively.
49    pub fn new(root_agent: impl Agent + 'static) -> Self {
50        let agent = Arc::new(root_agent);
51        let mut registry = AgentRegistry::new();
52        Self::register_tree(&mut registry, agent.clone());
53        Self {
54            root_agent: agent,
55            registry,
56            middleware: MiddlewareChain::new(),
57            plugins: PluginManager::new(),
58            state: State::new(),
59        }
60    }
61
62    /// Create a Runner from an already-Arc'd agent.
63    pub fn from_arc(root_agent: Arc<dyn Agent>) -> Self {
64        let mut registry = AgentRegistry::new();
65        Self::register_tree(&mut registry, root_agent.clone());
66        Self {
67            root_agent,
68            registry,
69            middleware: MiddlewareChain::new(),
70            plugins: PluginManager::new(),
71            state: State::new(),
72        }
73    }
74
75    /// Add middleware to the runner (applied to all agent invocations).
76    pub fn with_middleware(mut self, mw: impl crate::middleware::Middleware + 'static) -> Self {
77        self.middleware.add(Arc::new(mw));
78        self
79    }
80
81    /// Add a plugin to the runner.
82    pub fn with_plugin(mut self, plugin: impl Plugin + 'static) -> Self {
83        self.plugins.add(Arc::new(plugin));
84        self
85    }
86
87    /// Set initial state (available to all agents).
88    pub fn with_state(mut self, state: State) -> Self {
89        self.state = state;
90        self
91    }
92
93    /// Manually register an additional agent (useful for cross-tree transfers).
94    pub fn register(&mut self, agent: Arc<dyn Agent>) {
95        self.registry.register(agent);
96    }
97
98    /// Access the agent registry.
99    pub fn registry(&self) -> &AgentRegistry {
100        &self.registry
101    }
102
103    /// Access the root agent.
104    pub fn root_agent(&self) -> &dyn Agent {
105        self.root_agent.as_ref()
106    }
107
108    /// Run the agent lifecycle. Handles transfers automatically.
109    ///
110    /// `connect_fn` is a factory that creates a new AgentSession for a given agent.
111    /// This allows the Runner to reconnect with different configs on agent transfer
112    /// (different tools/instructions → different Gemini Live session).
113    ///
114    /// The Runner will:
115    /// 1. Call `connect_fn` with the current agent
116    /// 2. Run `agent.run_live()` on the resulting session
117    /// 3. If `TransferRequested` is returned, resolve the target agent,
118    ///    disconnect, preserve state, and loop back to step 1
119    /// 4. If the agent completes normally, return Ok(())
120    pub async fn run<F, Fut>(&self, connect_fn: F) -> Result<(), AgentError>
121    where
122        F: Fn(Arc<dyn Agent>) -> Fut + Send + Sync,
123        Fut: std::future::Future<Output = Result<AgentSession, AgentError>> + Send,
124    {
125        let mut current_agent = self.root_agent.clone();
126        let runner_state = self.state.clone();
127
128        // Telemetry
129        crate::telemetry::logging::log_agent_started(
130            current_agent.name(),
131            0, // runner doesn't have tools
132        );
133
134        loop {
135            // Connect with current agent's config
136            let agent_session = connect_fn(current_agent.clone()).await?;
137
138            // Merge runner state into session state
139            agent_session.state().merge(&runner_state);
140
141            // Create invocation context with runner's middleware
142            let mut ctx =
143                InvocationContext::with_middleware(agent_session.clone(), self.middleware.clone());
144
145            // Run before_run plugins
146            self.plugins.run_before_run(&ctx).await;
147
148            // Run the agent
149            match current_agent.run_live(&mut ctx).await {
150                Ok(()) => {
151                    // Run after_run plugins
152                    self.plugins.run_after_run(&ctx).await;
153                    // Agent completed normally — preserve state and return
154                    runner_state.merge(agent_session.state());
155                    break;
156                }
157                Err(AgentError::TransferRequested(target_name)) => {
158                    // Resolve target agent
159                    let target = self
160                        .registry
161                        .resolve(&target_name)
162                        .ok_or_else(|| AgentError::UnknownAgent(target_name.clone()))?;
163
164                    crate::telemetry::logging::log_agent_transfer(
165                        current_agent.name(),
166                        &target_name,
167                    );
168                    crate::telemetry::metrics::record_agent_transfer(
169                        current_agent.name(),
170                        &target_name,
171                    );
172
173                    // Preserve state across transfer
174                    runner_state.merge(agent_session.state());
175
176                    // Disconnect current session
177                    let _ = agent_session.disconnect().await;
178
179                    // Switch to target agent
180                    current_agent = target;
181                    continue;
182                }
183                Err(e) => {
184                    // Other error — preserve state and propagate
185                    runner_state.merge(agent_session.state());
186                    let _ = agent_session.disconnect().await;
187                    return Err(e);
188                }
189            }
190        }
191
192        Ok(())
193    }
194
195    /// Recursively register an agent and all its sub-agents.
196    fn register_tree(registry: &mut AgentRegistry, agent: Arc<dyn Agent>) {
197        registry.register(agent.clone());
198        for sub in agent.sub_agents() {
199            Self::register_tree(registry, sub);
200        }
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use crate::error::AgentError;
208    use async_trait::async_trait;
209    use rs_genai::session::{SessionHandle, SessionPhase, SessionState};
210    use std::sync::atomic::{AtomicU32, Ordering};
211    use tokio::sync::{broadcast, mpsc, watch};
212
213    // Mock agent that completes immediately
214    struct NoopAgent {
215        name: String,
216    }
217
218    #[async_trait]
219    impl Agent for NoopAgent {
220        fn name(&self) -> &str {
221            &self.name
222        }
223        async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
224            Ok(())
225        }
226    }
227
228    // Mock agent that requests transfer
229    struct TransferAgent {
230        name: String,
231        target: String,
232    }
233
234    #[async_trait]
235    impl Agent for TransferAgent {
236        fn name(&self) -> &str {
237            &self.name
238        }
239        async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
240            Err(AgentError::TransferRequested(self.target.clone()))
241        }
242        fn sub_agents(&self) -> Vec<Arc<dyn Agent>> {
243            vec![]
244        }
245    }
246
247    // Mock agent that reads state
248    struct StateReaderAgent {
249        name: String,
250        key: String,
251        expected: String,
252    }
253
254    #[async_trait]
255    impl Agent for StateReaderAgent {
256        fn name(&self) -> &str {
257            &self.name
258        }
259        async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
260            let val = ctx.state().get::<String>(&self.key);
261            assert_eq!(val.as_deref(), Some(self.expected.as_str()));
262            Ok(())
263        }
264    }
265
266    // Mock agent that fails
267    struct FailingAgent;
268
269    #[async_trait]
270    impl Agent for FailingAgent {
271        fn name(&self) -> &str {
272            "failing"
273        }
274        async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
275            Err(AgentError::Other("boom".to_string()))
276        }
277    }
278
279    fn mock_session_handle() -> SessionHandle {
280        let (cmd_tx, _cmd_rx) = mpsc::channel(16);
281        let (evt_tx, _) = broadcast::channel(16);
282        let (phase_tx, phase_rx) = watch::channel(SessionPhase::Active);
283        let state = Arc::new(SessionState::new(phase_tx));
284        SessionHandle::new(cmd_tx, evt_tx, state, phase_rx)
285    }
286
287    fn mock_agent_session() -> AgentSession {
288        AgentSession::new(mock_session_handle())
289    }
290
291    #[tokio::test]
292    async fn runner_runs_single_agent() {
293        let agent = NoopAgent {
294            name: "root".to_string(),
295        };
296        let runner = Runner::new(agent);
297
298        let result = runner
299            .run(|_agent| async { Ok(mock_agent_session()) })
300            .await;
301
302        assert!(result.is_ok());
303    }
304
305    #[tokio::test]
306    async fn runner_handles_transfer() {
307        // Root agent transfers to "target"
308        let target = Arc::new(NoopAgent {
309            name: "target".to_string(),
310        });
311        let root = TransferAgent {
312            name: "root".to_string(),
313            target: "target".to_string(),
314        };
315
316        let mut runner = Runner::new(root);
317        // Register the target agent manually since TransferAgent doesn't declare sub_agents
318        runner.register(target);
319
320        let connect_count = Arc::new(AtomicU32::new(0));
321        let count = connect_count.clone();
322
323        let result = runner
324            .run(move |_agent| {
325                let c = count.clone();
326                async move {
327                    c.fetch_add(1, Ordering::SeqCst);
328                    Ok(mock_agent_session())
329                }
330            })
331            .await;
332
333        assert!(result.is_ok());
334        // Should have connected twice: once for root, once for target
335        assert_eq!(connect_count.load(Ordering::SeqCst), 2);
336    }
337
338    #[tokio::test]
339    async fn runner_preserves_state_across_transfer() {
340        // Agent A sets state, transfers to B, B reads state
341        let agent_b = Arc::new(StateReaderAgent {
342            name: "agent_b".to_string(),
343            key: "greeting".to_string(),
344            expected: "hello from A".to_string(),
345        });
346
347        // Agent A: sets state, then transfers
348        struct SetAndTransferAgent;
349        #[async_trait]
350        impl Agent for SetAndTransferAgent {
351            fn name(&self) -> &str {
352                "agent_a"
353            }
354            async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
355                ctx.state().set("greeting", "hello from A");
356                Err(AgentError::TransferRequested("agent_b".to_string()))
357            }
358        }
359
360        let mut runner = Runner::new(SetAndTransferAgent);
361        runner.register(agent_b);
362
363        let result = runner
364            .run(|_agent| async { Ok(mock_agent_session()) })
365            .await;
366
367        assert!(result.is_ok());
368    }
369
370    #[tokio::test]
371    async fn runner_fails_on_unknown_transfer_target() {
372        let root = TransferAgent {
373            name: "root".to_string(),
374            target: "nonexistent".to_string(),
375        };
376
377        let runner = Runner::new(root);
378
379        let result = runner
380            .run(|_agent| async { Ok(mock_agent_session()) })
381            .await;
382
383        match result {
384            Err(AgentError::UnknownAgent(name)) => assert_eq!(name, "nonexistent"),
385            other => panic!("expected UnknownAgent, got: {:?}", other),
386        }
387    }
388
389    #[tokio::test]
390    async fn runner_propagates_errors() {
391        let runner = Runner::new(FailingAgent);
392
393        let result = runner
394            .run(|_agent| async { Ok(mock_agent_session()) })
395            .await;
396
397        match result {
398            Err(AgentError::Other(msg)) => assert_eq!(msg, "boom"),
399            other => panic!("expected Other error, got: {:?}", other),
400        }
401    }
402
403    #[tokio::test]
404    async fn runner_with_initial_state() {
405        struct StateCheckAgent;
406        #[async_trait]
407        impl Agent for StateCheckAgent {
408            fn name(&self) -> &str {
409                "checker"
410            }
411            async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
412                let val = ctx.state().get::<String>("initial_key");
413                assert_eq!(val.as_deref(), Some("initial_value"));
414                Ok(())
415            }
416        }
417
418        let initial_state = State::new();
419        initial_state.set("initial_key", "initial_value");
420
421        let runner = Runner::new(StateCheckAgent).with_state(initial_state);
422
423        let result = runner
424            .run(|_agent| async { Ok(mock_agent_session()) })
425            .await;
426
427        assert!(result.is_ok());
428    }
429
430    #[tokio::test]
431    async fn runner_auto_registers_sub_agents() {
432        struct ParentAgent;
433        #[async_trait]
434        impl Agent for ParentAgent {
435            fn name(&self) -> &str {
436                "parent"
437            }
438            async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
439                Ok(())
440            }
441            fn sub_agents(&self) -> Vec<Arc<dyn Agent>> {
442                vec![
443                    Arc::new(NoopAgent {
444                        name: "child_a".to_string(),
445                    }),
446                    Arc::new(NoopAgent {
447                        name: "child_b".to_string(),
448                    }),
449                ]
450            }
451        }
452
453        let runner = Runner::new(ParentAgent);
454        assert!(runner.registry().resolve("parent").is_some());
455        assert!(runner.registry().resolve("child_a").is_some());
456        assert!(runner.registry().resolve("child_b").is_some());
457    }
458}