Skip to main content

adk_core/
context.rs

1use crate::identity::{AdkIdentity, AppName, ExecutionIdentity, InvocationId, SessionId, UserId};
2use crate::{AdkError, Agent, Result, types::Content};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::{BTreeSet, HashMap};
7use std::sync::Arc;
8
9#[async_trait]
10pub trait ReadonlyContext: Send + Sync {
11    fn invocation_id(&self) -> &str;
12    fn agent_name(&self) -> &str;
13    fn user_id(&self) -> &str;
14    fn app_name(&self) -> &str;
15    fn session_id(&self) -> &str;
16    fn branch(&self) -> &str;
17    fn user_content(&self) -> &Content;
18
19    /// Returns the application name as a typed [`AppName`].
20    ///
21    /// Parses the value returned by [`app_name()`](Self::app_name). Returns an
22    /// error if the raw string fails validation (empty, null bytes, or exceeds
23    /// the maximum length).
24    ///
25    /// # Errors
26    ///
27    /// Returns an error when the
28    /// underlying string is not a valid identifier.
29    fn try_app_name(&self) -> Result<AppName> {
30        Ok(AppName::try_from(self.app_name())?)
31    }
32
33    /// Returns the user identifier as a typed [`UserId`].
34    ///
35    /// Parses the value returned by [`user_id()`](Self::user_id). Returns an
36    /// error if the raw string fails validation.
37    ///
38    /// # Errors
39    ///
40    /// Returns an error when the
41    /// underlying string is not a valid identifier.
42    fn try_user_id(&self) -> Result<UserId> {
43        Ok(UserId::try_from(self.user_id())?)
44    }
45
46    /// Returns the session identifier as a typed [`SessionId`].
47    ///
48    /// Parses the value returned by [`session_id()`](Self::session_id).
49    /// Returns an error if the raw string fails validation.
50    ///
51    /// # Errors
52    ///
53    /// Returns an error when the
54    /// underlying string is not a valid identifier.
55    fn try_session_id(&self) -> Result<SessionId> {
56        Ok(SessionId::try_from(self.session_id())?)
57    }
58
59    /// Returns the invocation identifier as a typed [`InvocationId`].
60    ///
61    /// Parses the value returned by [`invocation_id()`](Self::invocation_id).
62    /// Returns an error if the raw string fails validation.
63    ///
64    /// # Errors
65    ///
66    /// Returns an error when the
67    /// underlying string is not a valid identifier.
68    fn try_invocation_id(&self) -> Result<InvocationId> {
69        Ok(InvocationId::try_from(self.invocation_id())?)
70    }
71
72    /// Returns the stable session-scoped [`AdkIdentity`] triple.
73    ///
74    /// Combines [`try_app_name()`](Self::try_app_name),
75    /// [`try_user_id()`](Self::try_user_id), and
76    /// [`try_session_id()`](Self::try_session_id) into a single composite
77    /// identity value.
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if any of the three constituent identifiers fail
82    /// validation.
83    fn try_identity(&self) -> Result<AdkIdentity> {
84        Ok(AdkIdentity {
85            app_name: self.try_app_name()?,
86            user_id: self.try_user_id()?,
87            session_id: self.try_session_id()?,
88        })
89    }
90
91    /// Returns the full per-invocation [`ExecutionIdentity`].
92    ///
93    /// Combines [`try_identity()`](Self::try_identity) with the invocation,
94    /// branch, and agent name from this context.
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if any of the four typed identifiers fail validation.
99    fn try_execution_identity(&self) -> Result<ExecutionIdentity> {
100        Ok(ExecutionIdentity {
101            adk: self.try_identity()?,
102            invocation_id: self.try_invocation_id()?,
103            branch: self.branch().to_string(),
104            agent_name: self.agent_name().to_string(),
105        })
106    }
107}
108
109// State management traits
110
111/// Maximum allowed length for state keys (256 bytes).
112pub const MAX_STATE_KEY_LEN: usize = 256;
113
114/// Validates a state key. Returns `Ok(())` if the key is safe, or an error message.
115///
116/// Rules:
117/// - Must not be empty
118/// - Must not exceed [`MAX_STATE_KEY_LEN`] bytes
119/// - Must not contain path separators (`/`, `\`) or `..`
120/// - Must not contain null bytes
121pub fn validate_state_key(key: &str) -> std::result::Result<(), &'static str> {
122    if key.is_empty() {
123        return Err("state key must not be empty");
124    }
125    if key.len() > MAX_STATE_KEY_LEN {
126        return Err("state key exceeds maximum length of 256 bytes");
127    }
128    if key.contains('/') || key.contains('\\') || key.contains("..") {
129        return Err("state key must not contain path separators or '..'");
130    }
131    if key.contains('\0') {
132        return Err("state key must not contain null bytes");
133    }
134    Ok(())
135}
136
137pub trait State: Send + Sync {
138    fn get(&self, key: &str) -> Option<Value>;
139    /// Set a state value. Implementations should call [`validate_state_key`] and
140    /// reject invalid keys (e.g., by logging a warning or panicking).
141    fn set(&mut self, key: String, value: Value);
142    fn all(&self) -> HashMap<String, Value>;
143}
144
145pub trait ReadonlyState: Send + Sync {
146    fn get(&self, key: &str) -> Option<Value>;
147    fn all(&self) -> HashMap<String, Value>;
148}
149
150// Session trait
151pub trait Session: Send + Sync {
152    fn id(&self) -> &str;
153    fn app_name(&self) -> &str;
154    fn user_id(&self) -> &str;
155    fn state(&self) -> &dyn State;
156    /// Returns the conversation history from this session as Content items
157    fn conversation_history(&self) -> Vec<Content>;
158    /// Returns conversation history filtered for a specific agent.
159    ///
160    /// When provided, events authored by other agents (not "user", not the
161    /// named agent, and not function/tool responses) are excluded. This
162    /// prevents a transferred sub-agent from seeing the parent's tool calls
163    /// mapped as "model" role, which would cause the LLM to think work is
164    /// already done.
165    ///
166    /// Default implementation delegates to [`conversation_history`](Self::conversation_history).
167    fn conversation_history_for_agent(&self, _agent_name: &str) -> Vec<Content> {
168        self.conversation_history()
169    }
170    /// Append content to conversation history (for sequential agent support)
171    fn append_to_history(&self, _content: Content) {
172        // Default no-op - implementations can override to track history
173    }
174
175    /// Returns the application name as a typed [`AppName`].
176    ///
177    /// Parses the value returned by [`app_name()`](Self::app_name). Returns an
178    /// error if the raw string fails validation (empty, null bytes, or exceeds
179    /// the maximum length).
180    ///
181    /// # Errors
182    ///
183    /// Returns an error when the
184    /// underlying string is not a valid identifier.
185    fn try_app_name(&self) -> Result<AppName> {
186        Ok(AppName::try_from(self.app_name())?)
187    }
188
189    /// Returns the user identifier as a typed [`UserId`].
190    ///
191    /// Parses the value returned by [`user_id()`](Self::user_id). Returns an
192    /// error if the raw string fails validation.
193    ///
194    /// # Errors
195    ///
196    /// Returns an error when the
197    /// underlying string is not a valid identifier.
198    fn try_user_id(&self) -> Result<UserId> {
199        Ok(UserId::try_from(self.user_id())?)
200    }
201
202    /// Returns the session identifier as a typed [`SessionId`].
203    ///
204    /// Parses the value returned by [`id()`](Self::id). Returns an error if
205    /// the raw string fails validation.
206    ///
207    /// # Errors
208    ///
209    /// Returns an error when the
210    /// underlying string is not a valid identifier.
211    fn try_session_id(&self) -> Result<SessionId> {
212        Ok(SessionId::try_from(self.id())?)
213    }
214
215    /// Returns the stable session-scoped [`AdkIdentity`] triple.
216    ///
217    /// Combines [`try_app_name()`](Self::try_app_name),
218    /// [`try_user_id()`](Self::try_user_id), and
219    /// [`try_session_id()`](Self::try_session_id) into a single composite
220    /// identity value.
221    ///
222    /// # Errors
223    ///
224    /// Returns an error if any of the three constituent identifiers fail
225    /// validation.
226    fn try_identity(&self) -> Result<AdkIdentity> {
227        Ok(AdkIdentity {
228            app_name: self.try_app_name()?,
229            user_id: self.try_user_id()?,
230            session_id: self.try_session_id()?,
231        })
232    }
233}
234
235/// Structured metadata about a completed tool execution.
236///
237/// Available via [`CallbackContext::tool_outcome()`] in after-tool callbacks,
238/// plugins, and telemetry hooks. Provides structured access to execution
239/// results without requiring JSON error parsing.
240///
241/// # Fields
242///
243/// - `tool_name` — Name of the tool that was executed.
244/// - `tool_args` — Arguments passed to the tool as a JSON value.
245/// - `success` — Whether the tool execution succeeded. Derived from the
246///   Rust `Result` / timeout path, never from JSON content inspection.
247/// - `duration` — Wall-clock duration of the tool execution.
248/// - `error_message` — Error message if the tool failed; `None` on success.
249/// - `attempt` — Retry attempt number (0 = first attempt, 1 = first retry, etc.).
250///   Always 0 when retries are not configured.
251#[derive(Debug, Clone)]
252pub struct ToolOutcome {
253    /// Name of the tool that was executed.
254    pub tool_name: String,
255    /// Arguments passed to the tool (JSON value).
256    pub tool_args: serde_json::Value,
257    /// Whether the tool execution succeeded.
258    pub success: bool,
259    /// Wall-clock duration of the tool execution.
260    pub duration: std::time::Duration,
261    /// Error message if the tool failed. `None` on success.
262    pub error_message: Option<String>,
263    /// Retry attempt number (0 = first attempt, 1 = first retry, etc.).
264    /// Always 0 when retries are not configured.
265    pub attempt: u32,
266}
267
268#[async_trait]
269pub trait CallbackContext: ReadonlyContext {
270    fn artifacts(&self) -> Option<Arc<dyn Artifacts>>;
271
272    /// Returns structured metadata about the most recent tool execution.
273    /// Available in after-tool callbacks and plugin hooks.
274    /// Returns `None` when not in a tool execution context.
275    fn tool_outcome(&self) -> Option<ToolOutcome> {
276        None // default for backward compatibility
277    }
278
279    /// Returns the name of the tool about to be executed.
280    /// Available in before-tool and after-tool callback contexts.
281    fn tool_name(&self) -> Option<&str> {
282        None
283    }
284
285    /// Returns the input arguments for the tool about to be executed.
286    /// Available in before-tool and after-tool callback contexts.
287    fn tool_input(&self) -> Option<&serde_json::Value> {
288        None
289    }
290
291    /// Returns the shared state for parallel agent coordination.
292    /// Returns `None` when not running inside a `ParallelAgent` with shared state enabled.
293    fn shared_state(&self) -> Option<Arc<crate::SharedState>> {
294        None
295    }
296}
297
298/// Wraps a [`CallbackContext`] to inject tool name and input for before-tool
299/// and after-tool callbacks.
300///
301/// Used by the agent runtime to provide tool context to `BeforeToolCallback`
302/// and `AfterToolCallback` invocations.
303///
304/// # Example
305///
306/// ```rust,ignore
307/// let tool_ctx = Arc::new(ToolCallbackContext::new(
308///     ctx.clone(),
309///     "search".to_string(),
310///     serde_json::json!({"query": "hello"}),
311/// ));
312/// callback(tool_ctx as Arc<dyn CallbackContext>).await;
313/// ```
314pub struct ToolCallbackContext {
315    /// The inner callback context to delegate to.
316    pub inner: Arc<dyn CallbackContext>,
317    /// The name of the tool being executed.
318    pub tool_name: String,
319    /// The input arguments for the tool being executed.
320    pub tool_input: serde_json::Value,
321}
322
323impl ToolCallbackContext {
324    /// Creates a new `ToolCallbackContext` wrapping the given inner context.
325    pub fn new(
326        inner: Arc<dyn CallbackContext>,
327        tool_name: String,
328        tool_input: serde_json::Value,
329    ) -> Self {
330        Self { inner, tool_name, tool_input }
331    }
332}
333
334#[async_trait]
335impl ReadonlyContext for ToolCallbackContext {
336    fn invocation_id(&self) -> &str {
337        self.inner.invocation_id()
338    }
339
340    fn agent_name(&self) -> &str {
341        self.inner.agent_name()
342    }
343
344    fn user_id(&self) -> &str {
345        self.inner.user_id()
346    }
347
348    fn app_name(&self) -> &str {
349        self.inner.app_name()
350    }
351
352    fn session_id(&self) -> &str {
353        self.inner.session_id()
354    }
355
356    fn branch(&self) -> &str {
357        self.inner.branch()
358    }
359
360    fn user_content(&self) -> &Content {
361        self.inner.user_content()
362    }
363}
364
365#[async_trait]
366impl CallbackContext for ToolCallbackContext {
367    fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
368        self.inner.artifacts()
369    }
370
371    fn tool_outcome(&self) -> Option<ToolOutcome> {
372        self.inner.tool_outcome()
373    }
374
375    fn tool_name(&self) -> Option<&str> {
376        Some(&self.tool_name)
377    }
378
379    fn tool_input(&self) -> Option<&serde_json::Value> {
380        Some(&self.tool_input)
381    }
382
383    fn shared_state(&self) -> Option<Arc<crate::SharedState>> {
384        self.inner.shared_state()
385    }
386}
387
388#[async_trait]
389pub trait InvocationContext: CallbackContext {
390    fn agent(&self) -> Arc<dyn Agent>;
391    fn memory(&self) -> Option<Arc<dyn Memory>>;
392    fn session(&self) -> &dyn Session;
393    fn run_config(&self) -> &RunConfig;
394    fn end_invocation(&self);
395    fn ended(&self) -> bool;
396
397    /// Returns the scopes granted to the current user for this invocation.
398    ///
399    /// When a [`RequestContext`](crate::RequestContext) is present (set by the
400    /// server's auth middleware bridge), this returns the scopes from that
401    /// context. The default returns an empty vec (no scopes granted).
402    fn user_scopes(&self) -> Vec<String> {
403        vec![]
404    }
405
406    /// Returns the request metadata from the auth middleware bridge, if present.
407    ///
408    /// This provides access to custom key-value pairs extracted from the HTTP
409    /// request by the [`RequestContextExtractor`](crate::RequestContext).
410    fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
411        HashMap::new()
412    }
413}
414
415// Placeholder service traits
416#[async_trait]
417pub trait Artifacts: Send + Sync {
418    async fn save(&self, name: &str, data: &crate::Part) -> Result<i64>;
419    async fn load(&self, name: &str) -> Result<crate::Part>;
420    async fn list(&self) -> Result<Vec<String>>;
421}
422
423#[async_trait]
424pub trait Memory: Send + Sync {
425    async fn search(&self, query: &str) -> Result<Vec<MemoryEntry>>;
426
427    /// Verify backend connectivity.
428    ///
429    /// The default implementation succeeds, which is suitable for in-memory
430    /// implementations and adapters without an external dependency.
431    async fn health_check(&self) -> Result<()> {
432        Ok(())
433    }
434
435    /// Add a single memory entry.
436    ///
437    /// The default implementation returns an "not implemented" error, which is
438    /// suitable for read-only memory backends.
439    async fn add(&self, entry: MemoryEntry) -> Result<()> {
440        let _ = entry;
441        Err(AdkError::memory("add not implemented"))
442    }
443
444    /// Delete entries matching a query. Returns count of deleted entries.
445    ///
446    /// The default implementation returns an "not implemented" error, which is
447    /// suitable for read-only memory backends.
448    async fn delete(&self, query: &str) -> Result<u64> {
449        let _ = query;
450        Err(AdkError::memory("delete not implemented"))
451    }
452}
453
454#[derive(Debug, Clone)]
455pub struct MemoryEntry {
456    pub content: Content,
457    pub author: String,
458}
459
460/// Streaming mode for agent responses.
461/// Matches ADK Python/Go specification.
462#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
463pub enum StreamingMode {
464    /// No streaming; responses delivered as complete units.
465    /// Agent collects all chunks internally and yields a single final event.
466    None,
467    /// Server-Sent Events streaming; one-way streaming from server to client.
468    /// Agent yields each chunk as it arrives with stable event ID.
469    #[default]
470    SSE,
471    /// Bidirectional streaming; simultaneous communication in both directions.
472    /// Used for realtime audio/video agents.
473    Bidi,
474}
475
476/// Controls what parts of prior conversation history is received by llmagent
477#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
478pub enum IncludeContents {
479    /// The llmagent operates solely on its current turn (latest user input + any following agent events)
480    None,
481    /// Default - The llmagent receives the relevant conversation history
482    #[default]
483    Default,
484}
485
486/// Decision applied when a tool execution requires human confirmation.
487#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
488#[serde(rename_all = "snake_case")]
489pub enum ToolConfirmationDecision {
490    Approve,
491    Deny,
492}
493
494/// Policy defining which tools require human confirmation before execution.
495#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
496#[serde(rename_all = "snake_case")]
497pub enum ToolConfirmationPolicy {
498    /// No tool confirmation is required.
499    #[default]
500    Never,
501    /// Every tool call requires confirmation.
502    Always,
503    /// Only the listed tool names require confirmation.
504    PerTool(BTreeSet<String>),
505}
506
507impl ToolConfirmationPolicy {
508    /// Returns true when the given tool name must be confirmed before execution.
509    pub fn requires_confirmation(&self, tool_name: &str) -> bool {
510        match self {
511            Self::Never => false,
512            Self::Always => true,
513            Self::PerTool(tools) => tools.contains(tool_name),
514        }
515    }
516
517    /// Add one tool name to the confirmation policy (converts `Never` to `PerTool`).
518    pub fn with_tool(mut self, tool_name: impl Into<String>) -> Self {
519        let tool_name = tool_name.into();
520        match &mut self {
521            Self::Never => {
522                let mut tools = BTreeSet::new();
523                tools.insert(tool_name);
524                Self::PerTool(tools)
525            }
526            Self::Always => Self::Always,
527            Self::PerTool(tools) => {
528                tools.insert(tool_name);
529                self
530            }
531        }
532    }
533}
534
535/// Payload describing a tool call awaiting human confirmation.
536#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
537#[serde(rename_all = "camelCase")]
538pub struct ToolConfirmationRequest {
539    pub tool_name: String,
540    #[serde(skip_serializing_if = "Option::is_none")]
541    pub function_call_id: Option<String>,
542    pub args: Value,
543}
544
545#[derive(Debug, Clone)]
546pub struct RunConfig {
547    pub streaming_mode: StreamingMode,
548    /// Optional per-tool confirmation decisions for the current run.
549    /// Keys are tool names.
550    pub tool_confirmation_decisions: HashMap<String, ToolConfirmationDecision>,
551    /// Optional cached content name for automatic prompt caching.
552    /// When set by the runner's cache lifecycle manager, agents should attach
553    /// this name to their `GenerateContentConfig` so the LLM provider can
554    /// reuse cached system instructions and tool definitions.
555    pub cached_content: Option<String>,
556    /// Valid agent names this agent can transfer to (parent, peers, children).
557    /// Set by the runner when invoking agents in a multi-agent tree.
558    /// When non-empty, the `transfer_to_agent` tool is injected and validation
559    /// uses this list instead of only checking `sub_agents`.
560    pub transfer_targets: Vec<String>,
561    /// The name of the parent agent, if this agent was invoked via transfer.
562    /// Used by the agent to apply `disallow_transfer_to_parent` filtering.
563    pub parent_agent: Option<String>,
564    /// Enable automatic prompt caching for all providers that support it.
565    ///
566    /// When `true` (the default), the runner enables provider-level caching:
567    /// - Anthropic: sets `prompt_caching = true` on the config
568    /// - Bedrock: sets `prompt_caching = Some(BedrockCacheConfig::default())`
569    /// - OpenAI / DeepSeek: no action needed (caching is automatic)
570    /// - Gemini: handled separately via `ContextCacheConfig`
571    pub auto_cache: bool,
572}
573
574impl Default for RunConfig {
575    fn default() -> Self {
576        Self {
577            streaming_mode: StreamingMode::SSE,
578            tool_confirmation_decisions: HashMap::new(),
579            cached_content: None,
580            transfer_targets: Vec::new(),
581            parent_agent: None,
582            auto_cache: true,
583        }
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590
591    #[test]
592    fn test_run_config_default() {
593        let config = RunConfig::default();
594        assert_eq!(config.streaming_mode, StreamingMode::SSE);
595        assert!(config.tool_confirmation_decisions.is_empty());
596    }
597
598    #[test]
599    fn test_streaming_mode() {
600        assert_eq!(StreamingMode::SSE, StreamingMode::SSE);
601        assert_ne!(StreamingMode::SSE, StreamingMode::None);
602        assert_ne!(StreamingMode::None, StreamingMode::Bidi);
603    }
604
605    #[test]
606    fn test_tool_confirmation_policy() {
607        let policy = ToolConfirmationPolicy::default();
608        assert!(!policy.requires_confirmation("search"));
609
610        let policy = policy.with_tool("search");
611        assert!(policy.requires_confirmation("search"));
612        assert!(!policy.requires_confirmation("write_file"));
613
614        assert!(ToolConfirmationPolicy::Always.requires_confirmation("any_tool"));
615    }
616
617    #[test]
618    fn test_validate_state_key_valid() {
619        assert!(validate_state_key("user_name").is_ok());
620        assert!(validate_state_key("app:config").is_ok());
621        assert!(validate_state_key("temp:data").is_ok());
622        assert!(validate_state_key("a").is_ok());
623    }
624
625    #[test]
626    fn test_validate_state_key_empty() {
627        assert_eq!(validate_state_key(""), Err("state key must not be empty"));
628    }
629
630    #[test]
631    fn test_validate_state_key_too_long() {
632        let long_key = "a".repeat(MAX_STATE_KEY_LEN + 1);
633        assert!(validate_state_key(&long_key).is_err());
634    }
635
636    #[test]
637    fn test_validate_state_key_path_traversal() {
638        assert!(validate_state_key("../etc/passwd").is_err());
639        assert!(validate_state_key("foo/bar").is_err());
640        assert!(validate_state_key("foo\\bar").is_err());
641        assert!(validate_state_key("..").is_err());
642    }
643
644    #[test]
645    fn test_validate_state_key_null_byte() {
646        assert!(validate_state_key("foo\0bar").is_err());
647    }
648}