Skip to main content

rs_adk/
context.rs

1//! InvocationContext — the session state container flowing through agent execution.
2//!
3//! Also provides `CallbackContext` and `ToolContext` wrappers for richer
4//! access patterns in callbacks and tool execution.
5
6use std::sync::Arc;
7
8use tokio::sync::broadcast;
9
10use crate::agent_session::AgentSession;
11use crate::confirmation::ToolConfirmation;
12use crate::events::EventActions;
13use crate::middleware::MiddlewareChain;
14use crate::run_config::RunConfig;
15
16/// Events emitted by agents during live execution.
17/// Wraps SessionEvent (Layer 0) and adds agent-specific events.
18/// No duplicate variants — use AgentEvent::Session(_) for wire-level events.
19#[derive(Debug, Clone)]
20pub enum AgentEvent {
21    /// Passthrough of wire-level session events (text, audio, turn lifecycle).
22    Session(rs_genai::session::SessionEvent),
23    /// Agent lifecycle.
24    /// An agent has started execution.
25    AgentStarted {
26        /// Name of the agent that started.
27        name: String,
28    },
29    /// An agent has completed execution.
30    AgentCompleted {
31        /// Name of the agent that completed.
32        name: String,
33    },
34    /// A tool call has started execution.
35    ToolCallStarted {
36        /// Tool name.
37        name: String,
38        /// Tool call arguments.
39        args: serde_json::Value,
40    },
41    /// A tool call completed successfully.
42    ToolCallCompleted {
43        /// Tool name.
44        name: String,
45        /// The tool's return value.
46        result: serde_json::Value,
47        /// How long the tool call took.
48        duration: std::time::Duration,
49    },
50    /// A tool call failed.
51    ToolCallFailed {
52        /// Tool name.
53        name: String,
54        /// Error description.
55        error: String,
56    },
57    /// A streaming tool yielded an intermediate value.
58    StreamingToolYield {
59        /// Tool name.
60        name: String,
61        /// The yielded value.
62        value: serde_json::Value,
63    },
64    /// An agent transferred control to another agent.
65    AgentTransfer {
66        /// Source agent name.
67        from: String,
68        /// Target agent name.
69        to: String,
70    },
71    /// A state key was changed.
72    StateChanged {
73        /// The key that was modified.
74        key: String,
75    },
76    /// A loop agent completed an iteration.
77    LoopIteration {
78        /// The zero-based iteration number.
79        iteration: u32,
80    },
81    /// An agent timed out.
82    Timeout,
83    /// A route agent selected a branch.
84    RouteSelected {
85        /// The name of the selected agent.
86        agent_name: String,
87    },
88    /// A fallback agent activated a fallback branch.
89    FallbackActivated {
90        /// The name of the fallback agent that was activated.
91        agent_name: String,
92    },
93}
94
95/// The context object that flows through agent execution.
96/// Holds everything a running agent needs.
97///
98/// Note: State is accessed via agent_session.state() — single source of truth.
99pub struct InvocationContext {
100    /// AgentSession wraps SessionHandle with fan-out + middleware.
101    /// Replaces LiveSender — sends go directly through SessionHandle (one hop).
102    pub agent_session: AgentSession,
103
104    /// Event bus — agents emit events here, application code subscribes.
105    pub event_tx: broadcast::Sender<AgentEvent>,
106
107    /// Middleware chain for lifecycle hooks.
108    pub middleware: MiddlewareChain,
109
110    /// Configuration for this run.
111    pub run_config: RunConfig,
112
113    /// Session ID for session-aware runs.
114    pub session_id: Option<String>,
115
116    /// Optional artifact service for this invocation.
117    pub artifact_service: Option<Arc<dyn crate::artifacts::ArtifactService>>,
118    /// Optional memory service for this invocation.
119    pub memory_service: Option<Arc<dyn crate::memory::MemoryService>>,
120    /// Optional session service for this invocation.
121    pub session_service: Option<Arc<dyn crate::session::SessionService>>,
122}
123
124impl InvocationContext {
125    /// Create a new invocation context with an empty middleware chain.
126    pub fn new(agent_session: AgentSession) -> Self {
127        let (event_tx, _) = broadcast::channel(256);
128        Self {
129            agent_session,
130            event_tx,
131            middleware: MiddlewareChain::new(),
132            run_config: RunConfig::default(),
133            session_id: None,
134            artifact_service: None,
135            memory_service: None,
136            session_service: None,
137        }
138    }
139
140    /// Create a new invocation context with a pre-configured middleware chain.
141    pub fn with_middleware(agent_session: AgentSession, middleware: MiddlewareChain) -> Self {
142        let (event_tx, _) = broadcast::channel(256);
143        Self {
144            agent_session,
145            event_tx,
146            middleware,
147            run_config: RunConfig::default(),
148            session_id: None,
149            artifact_service: None,
150            memory_service: None,
151            session_service: None,
152        }
153    }
154
155    /// Emit an event to all subscribers.
156    pub fn emit(&self, event: AgentEvent) {
157        let _ = self.event_tx.send(event);
158    }
159
160    /// Subscribe to agent events.
161    pub fn subscribe(&self) -> broadcast::Receiver<AgentEvent> {
162        self.event_tx.subscribe()
163    }
164
165    /// Convenience: access the state container.
166    pub fn state(&self) -> &crate::state::State {
167        self.agent_session.state()
168    }
169
170    /// Set the artifact service for this invocation.
171    pub fn with_artifact_service(
172        mut self,
173        service: Arc<dyn crate::artifacts::ArtifactService>,
174    ) -> Self {
175        self.artifact_service = Some(service);
176        self
177    }
178
179    /// Set the memory service for this invocation.
180    pub fn with_memory_service(mut self, service: Arc<dyn crate::memory::MemoryService>) -> Self {
181        self.memory_service = Some(service);
182        self
183    }
184
185    /// Set the session service for this invocation.
186    pub fn with_session_service(
187        mut self,
188        service: Arc<dyn crate::session::SessionService>,
189    ) -> Self {
190        self.session_service = Some(service);
191        self
192    }
193}
194
195// ── CallbackContext ────────────────────────────────────────────────────────
196
197/// Rich context for callbacks — provides access to state, artifacts, memory,
198/// and event actions for mutation.
199pub struct CallbackContext<'a> {
200    ctx: &'a InvocationContext,
201    /// Event actions that the callback can populate (e.g., state_delta, transfer).
202    pub event_actions: EventActions,
203}
204
205impl<'a> CallbackContext<'a> {
206    /// Create a new callback context wrapping an invocation context.
207    pub fn new(ctx: &'a InvocationContext) -> Self {
208        Self {
209            ctx,
210            event_actions: EventActions::default(),
211        }
212    }
213
214    /// Access the state container.
215    pub fn state(&self) -> &crate::state::State {
216        self.ctx.state()
217    }
218
219    /// Get the invocation context's session ID, if any.
220    pub fn session_id(&self) -> Option<&str> {
221        self.ctx.session_id.as_deref()
222    }
223
224    /// Access the underlying invocation context.
225    pub fn invocation_context(&self) -> &InvocationContext {
226        self.ctx
227    }
228}
229
230// ── ToolContext ─────────────────────────────────────────────────────────────
231
232/// Extended context for tool execution — adds function call ID and confirmation.
233pub struct ToolContext<'a> {
234    /// The underlying callback context (provides state, event_actions, etc.).
235    pub callback: CallbackContext<'a>,
236    /// The ID of the function call being executed.
237    pub function_call_id: Option<String>,
238    /// User confirmation for this tool call, if applicable.
239    pub confirmation: Option<ToolConfirmation>,
240}
241
242impl<'a> ToolContext<'a> {
243    /// Create a new tool context.
244    pub fn new(ctx: &'a InvocationContext, function_call_id: Option<String>) -> Self {
245        Self {
246            callback: CallbackContext::new(ctx),
247            function_call_id,
248            confirmation: None,
249        }
250    }
251
252    /// Access the state container.
253    pub fn state(&self) -> &crate::state::State {
254        self.callback.state()
255    }
256
257    /// Set the confirmation for this tool call.
258    pub fn with_confirmation(mut self, confirmation: ToolConfirmation) -> Self {
259        self.confirmation = Some(confirmation);
260        self
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn agent_event_is_send_and_clone() {
270        fn assert_send_clone<T: Send + Clone>() {}
271        assert_send_clone::<AgentEvent>();
272    }
273
274    #[test]
275    fn invocation_context_has_default_run_config() {
276        use std::sync::Arc;
277        use tokio::sync::broadcast;
278
279        let (evt_tx, _) = broadcast::channel(16);
280        let writer: Arc<dyn rs_genai::session::SessionWriter> =
281            Arc::new(crate::test_helpers::MockWriter);
282        let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
283        let ctx = InvocationContext::new(session);
284
285        assert_eq!(ctx.run_config.max_llm_calls, 500);
286        assert!(ctx.session_id.is_none());
287    }
288
289    #[test]
290    fn callback_context_state_access() {
291        use std::sync::Arc;
292        use tokio::sync::broadcast;
293
294        let (evt_tx, _) = broadcast::channel(16);
295        let writer: Arc<dyn rs_genai::session::SessionWriter> =
296            Arc::new(crate::test_helpers::MockWriter);
297        let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
298        let ctx = InvocationContext::new(session);
299        ctx.state().set("key", "value");
300
301        let cb_ctx = CallbackContext::new(&ctx);
302        assert_eq!(
303            cb_ctx.state().get::<String>("key"),
304            Some("value".to_string())
305        );
306    }
307
308    #[test]
309    fn tool_context_wraps_callback_context() {
310        use std::sync::Arc;
311        use tokio::sync::broadcast;
312
313        let (evt_tx, _) = broadcast::channel(16);
314        let writer: Arc<dyn rs_genai::session::SessionWriter> =
315            Arc::new(crate::test_helpers::MockWriter);
316        let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
317        let ctx = InvocationContext::new(session);
318        ctx.state().set("x", 42);
319
320        let tool_ctx = ToolContext::new(&ctx, Some("call-1".to_string()));
321        assert_eq!(tool_ctx.state().get::<i32>("x"), Some(42));
322        assert_eq!(tool_ctx.function_call_id.as_deref(), Some("call-1"));
323        assert!(tool_ctx.confirmation.is_none());
324    }
325
326    #[test]
327    fn tool_context_with_confirmation() {
328        use std::sync::Arc;
329        use tokio::sync::broadcast;
330
331        let (evt_tx, _) = broadcast::channel(16);
332        let writer: Arc<dyn rs_genai::session::SessionWriter> =
333            Arc::new(crate::test_helpers::MockWriter);
334        let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx);
335        let ctx = InvocationContext::new(session);
336
337        let tool_ctx =
338            ToolContext::new(&ctx, None).with_confirmation(ToolConfirmation::confirmed());
339        assert!(tool_ctx.confirmation.as_ref().unwrap().confirmed);
340    }
341}