Skip to main content

adk_core/
context.rs

1use crate::identity::{AdkIdentity, AppName, ExecutionIdentity, InvocationId, SessionId, UserId};
2use crate::{Agent, Result, types::Content};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::{BTreeSet, HashMap};
7use std::sync::Arc;
8
9#[async_trait]
10pub trait ReadonlyContext: Send + Sync {
11    fn invocation_id(&self) -> &str;
12    fn agent_name(&self) -> &str;
13    fn user_id(&self) -> &str;
14    fn app_name(&self) -> &str;
15    fn session_id(&self) -> &str;
16    fn branch(&self) -> &str;
17    fn user_content(&self) -> &Content;
18
19    /// Returns the application name as a typed [`AppName`].
20    ///
21    /// Parses the value returned by [`app_name()`](Self::app_name). Returns an
22    /// error if the raw string fails validation (empty, null bytes, or exceeds
23    /// the maximum length).
24    ///
25    /// # Errors
26    ///
27    /// Returns [`AdkError::Config`](crate::AdkError::Config) when the
28    /// underlying string is not a valid identifier.
29    fn try_app_name(&self) -> Result<AppName> {
30        Ok(AppName::try_from(self.app_name())?)
31    }
32
33    /// Returns the user identifier as a typed [`UserId`].
34    ///
35    /// Parses the value returned by [`user_id()`](Self::user_id). Returns an
36    /// error if the raw string fails validation.
37    ///
38    /// # Errors
39    ///
40    /// Returns [`AdkError::Config`](crate::AdkError::Config) when the
41    /// underlying string is not a valid identifier.
42    fn try_user_id(&self) -> Result<UserId> {
43        Ok(UserId::try_from(self.user_id())?)
44    }
45
46    /// Returns the session identifier as a typed [`SessionId`].
47    ///
48    /// Parses the value returned by [`session_id()`](Self::session_id).
49    /// Returns an error if the raw string fails validation.
50    ///
51    /// # Errors
52    ///
53    /// Returns [`AdkError::Config`](crate::AdkError::Config) when the
54    /// underlying string is not a valid identifier.
55    fn try_session_id(&self) -> Result<SessionId> {
56        Ok(SessionId::try_from(self.session_id())?)
57    }
58
59    /// Returns the invocation identifier as a typed [`InvocationId`].
60    ///
61    /// Parses the value returned by [`invocation_id()`](Self::invocation_id).
62    /// Returns an error if the raw string fails validation.
63    ///
64    /// # Errors
65    ///
66    /// Returns [`AdkError::Config`](crate::AdkError::Config) when the
67    /// underlying string is not a valid identifier.
68    fn try_invocation_id(&self) -> Result<InvocationId> {
69        Ok(InvocationId::try_from(self.invocation_id())?)
70    }
71
72    /// Returns the stable session-scoped [`AdkIdentity`] triple.
73    ///
74    /// Combines [`try_app_name()`](Self::try_app_name),
75    /// [`try_user_id()`](Self::try_user_id), and
76    /// [`try_session_id()`](Self::try_session_id) into a single composite
77    /// identity value.
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if any of the three constituent identifiers fail
82    /// validation.
83    fn try_identity(&self) -> Result<AdkIdentity> {
84        Ok(AdkIdentity {
85            app_name: self.try_app_name()?,
86            user_id: self.try_user_id()?,
87            session_id: self.try_session_id()?,
88        })
89    }
90
91    /// Returns the full per-invocation [`ExecutionIdentity`].
92    ///
93    /// Combines [`try_identity()`](Self::try_identity) with the invocation,
94    /// branch, and agent name from this context.
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if any of the four typed identifiers fail validation.
99    fn try_execution_identity(&self) -> Result<ExecutionIdentity> {
100        Ok(ExecutionIdentity {
101            adk: self.try_identity()?,
102            invocation_id: self.try_invocation_id()?,
103            branch: self.branch().to_string(),
104            agent_name: self.agent_name().to_string(),
105        })
106    }
107}
108
109// State management traits
110
111/// Maximum allowed length for state keys (256 bytes).
112pub const MAX_STATE_KEY_LEN: usize = 256;
113
114/// Validates a state key. Returns `Ok(())` if the key is safe, or an error message.
115///
116/// Rules:
117/// - Must not be empty
118/// - Must not exceed [`MAX_STATE_KEY_LEN`] bytes
119/// - Must not contain path separators (`/`, `\`) or `..`
120/// - Must not contain null bytes
121pub fn validate_state_key(key: &str) -> std::result::Result<(), &'static str> {
122    if key.is_empty() {
123        return Err("state key must not be empty");
124    }
125    if key.len() > MAX_STATE_KEY_LEN {
126        return Err("state key exceeds maximum length of 256 bytes");
127    }
128    if key.contains('/') || key.contains('\\') || key.contains("..") {
129        return Err("state key must not contain path separators or '..'");
130    }
131    if key.contains('\0') {
132        return Err("state key must not contain null bytes");
133    }
134    Ok(())
135}
136
137pub trait State: Send + Sync {
138    fn get(&self, key: &str) -> Option<Value>;
139    /// Set a state value. Implementations should call [`validate_state_key`] and
140    /// reject invalid keys (e.g., by logging a warning or panicking).
141    fn set(&mut self, key: String, value: Value);
142    fn all(&self) -> HashMap<String, Value>;
143}
144
145pub trait ReadonlyState: Send + Sync {
146    fn get(&self, key: &str) -> Option<Value>;
147    fn all(&self) -> HashMap<String, Value>;
148}
149
150// Session trait
151pub trait Session: Send + Sync {
152    fn id(&self) -> &str;
153    fn app_name(&self) -> &str;
154    fn user_id(&self) -> &str;
155    fn state(&self) -> &dyn State;
156    /// Returns the conversation history from this session as Content items
157    fn conversation_history(&self) -> Vec<Content>;
158    /// Returns conversation history filtered for a specific agent.
159    ///
160    /// When provided, events authored by other agents (not "user", not the
161    /// named agent, and not function/tool responses) are excluded. This
162    /// prevents a transferred sub-agent from seeing the parent's tool calls
163    /// mapped as "model" role, which would cause the LLM to think work is
164    /// already done.
165    ///
166    /// Default implementation delegates to [`conversation_history`](Self::conversation_history).
167    fn conversation_history_for_agent(&self, _agent_name: &str) -> Vec<Content> {
168        self.conversation_history()
169    }
170    /// Append content to conversation history (for sequential agent support)
171    fn append_to_history(&self, _content: Content) {
172        // Default no-op - implementations can override to track history
173    }
174
175    /// Returns the application name as a typed [`AppName`].
176    ///
177    /// Parses the value returned by [`app_name()`](Self::app_name). Returns an
178    /// error if the raw string fails validation (empty, null bytes, or exceeds
179    /// the maximum length).
180    ///
181    /// # Errors
182    ///
183    /// Returns [`AdkError::Config`](crate::AdkError::Config) when the
184    /// underlying string is not a valid identifier.
185    fn try_app_name(&self) -> Result<AppName> {
186        Ok(AppName::try_from(self.app_name())?)
187    }
188
189    /// Returns the user identifier as a typed [`UserId`].
190    ///
191    /// Parses the value returned by [`user_id()`](Self::user_id). Returns an
192    /// error if the raw string fails validation.
193    ///
194    /// # Errors
195    ///
196    /// Returns [`AdkError::Config`](crate::AdkError::Config) when the
197    /// underlying string is not a valid identifier.
198    fn try_user_id(&self) -> Result<UserId> {
199        Ok(UserId::try_from(self.user_id())?)
200    }
201
202    /// Returns the session identifier as a typed [`SessionId`].
203    ///
204    /// Parses the value returned by [`id()`](Self::id). Returns an error if
205    /// the raw string fails validation.
206    ///
207    /// # Errors
208    ///
209    /// Returns [`AdkError::Config`](crate::AdkError::Config) when the
210    /// underlying string is not a valid identifier.
211    fn try_session_id(&self) -> Result<SessionId> {
212        Ok(SessionId::try_from(self.id())?)
213    }
214
215    /// Returns the stable session-scoped [`AdkIdentity`] triple.
216    ///
217    /// Combines [`try_app_name()`](Self::try_app_name),
218    /// [`try_user_id()`](Self::try_user_id), and
219    /// [`try_session_id()`](Self::try_session_id) into a single composite
220    /// identity value.
221    ///
222    /// # Errors
223    ///
224    /// Returns an error if any of the three constituent identifiers fail
225    /// validation.
226    fn try_identity(&self) -> Result<AdkIdentity> {
227        Ok(AdkIdentity {
228            app_name: self.try_app_name()?,
229            user_id: self.try_user_id()?,
230            session_id: self.try_session_id()?,
231        })
232    }
233}
234
235/// Structured metadata about a completed tool execution.
236///
237/// Available via [`CallbackContext::tool_outcome()`] in after-tool callbacks,
238/// plugins, and telemetry hooks. Provides structured access to execution
239/// results without requiring JSON error parsing.
240///
241/// # Fields
242///
243/// - `tool_name` — Name of the tool that was executed.
244/// - `tool_args` — Arguments passed to the tool as a JSON value.
245/// - `success` — Whether the tool execution succeeded. Derived from the
246///   Rust `Result` / timeout path, never from JSON content inspection.
247/// - `duration` — Wall-clock duration of the tool execution.
248/// - `error_message` — Error message if the tool failed; `None` on success.
249/// - `attempt` — Retry attempt number (0 = first attempt, 1 = first retry, etc.).
250///   Always 0 when retries are not configured.
251#[derive(Debug, Clone)]
252pub struct ToolOutcome {
253    /// Name of the tool that was executed.
254    pub tool_name: String,
255    /// Arguments passed to the tool (JSON value).
256    pub tool_args: serde_json::Value,
257    /// Whether the tool execution succeeded.
258    pub success: bool,
259    /// Wall-clock duration of the tool execution.
260    pub duration: std::time::Duration,
261    /// Error message if the tool failed. `None` on success.
262    pub error_message: Option<String>,
263    /// Retry attempt number (0 = first attempt, 1 = first retry, etc.).
264    /// Always 0 when retries are not configured.
265    pub attempt: u32,
266}
267
268#[async_trait]
269pub trait CallbackContext: ReadonlyContext {
270    fn artifacts(&self) -> Option<Arc<dyn Artifacts>>;
271
272    /// Returns structured metadata about the most recent tool execution.
273    /// Available in after-tool callbacks and plugin hooks.
274    /// Returns `None` when not in a tool execution context.
275    fn tool_outcome(&self) -> Option<ToolOutcome> {
276        None // default for backward compatibility
277    }
278}
279
280#[async_trait]
281pub trait InvocationContext: CallbackContext {
282    fn agent(&self) -> Arc<dyn Agent>;
283    fn memory(&self) -> Option<Arc<dyn Memory>>;
284    fn session(&self) -> &dyn Session;
285    fn run_config(&self) -> &RunConfig;
286    fn end_invocation(&self);
287    fn ended(&self) -> bool;
288
289    /// Returns the scopes granted to the current user for this invocation.
290    ///
291    /// When a [`RequestContext`](crate::RequestContext) is present (set by the
292    /// server's auth middleware bridge), this returns the scopes from that
293    /// context. The default returns an empty vec (no scopes granted).
294    fn user_scopes(&self) -> Vec<String> {
295        vec![]
296    }
297
298    /// Returns the request metadata from the auth middleware bridge, if present.
299    ///
300    /// This provides access to custom key-value pairs extracted from the HTTP
301    /// request by the [`RequestContextExtractor`](crate::RequestContext).
302    fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
303        HashMap::new()
304    }
305}
306
307// Placeholder service traits
308#[async_trait]
309pub trait Artifacts: Send + Sync {
310    async fn save(&self, name: &str, data: &crate::Part) -> Result<i64>;
311    async fn load(&self, name: &str) -> Result<crate::Part>;
312    async fn list(&self) -> Result<Vec<String>>;
313}
314
315#[async_trait]
316pub trait Memory: Send + Sync {
317    async fn search(&self, query: &str) -> Result<Vec<MemoryEntry>>;
318
319    /// Verify backend connectivity.
320    ///
321    /// The default implementation succeeds, which is suitable for in-memory
322    /// implementations and adapters without an external dependency.
323    async fn health_check(&self) -> Result<()> {
324        Ok(())
325    }
326}
327
328#[derive(Debug, Clone)]
329pub struct MemoryEntry {
330    pub content: Content,
331    pub author: String,
332}
333
334/// Streaming mode for agent responses.
335/// Matches ADK Python/Go specification.
336#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
337pub enum StreamingMode {
338    /// No streaming; responses delivered as complete units.
339    /// Agent collects all chunks internally and yields a single final event.
340    None,
341    /// Server-Sent Events streaming; one-way streaming from server to client.
342    /// Agent yields each chunk as it arrives with stable event ID.
343    #[default]
344    SSE,
345    /// Bidirectional streaming; simultaneous communication in both directions.
346    /// Used for realtime audio/video agents.
347    Bidi,
348}
349
350/// Controls what parts of prior conversation history is received by llmagent
351#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
352pub enum IncludeContents {
353    /// The llmagent operates solely on its current turn (latest user input + any following agent events)
354    None,
355    /// Default - The llmagent receives the relevant conversation history
356    #[default]
357    Default,
358}
359
360/// Decision applied when a tool execution requires human confirmation.
361#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
362#[serde(rename_all = "snake_case")]
363pub enum ToolConfirmationDecision {
364    Approve,
365    Deny,
366}
367
368/// Policy defining which tools require human confirmation before execution.
369#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
370#[serde(rename_all = "snake_case")]
371pub enum ToolConfirmationPolicy {
372    /// No tool confirmation is required.
373    #[default]
374    Never,
375    /// Every tool call requires confirmation.
376    Always,
377    /// Only the listed tool names require confirmation.
378    PerTool(BTreeSet<String>),
379}
380
381impl ToolConfirmationPolicy {
382    /// Returns true when the given tool name must be confirmed before execution.
383    pub fn requires_confirmation(&self, tool_name: &str) -> bool {
384        match self {
385            Self::Never => false,
386            Self::Always => true,
387            Self::PerTool(tools) => tools.contains(tool_name),
388        }
389    }
390
391    /// Add one tool name to the confirmation policy (converts `Never` to `PerTool`).
392    pub fn with_tool(mut self, tool_name: impl Into<String>) -> Self {
393        let tool_name = tool_name.into();
394        match &mut self {
395            Self::Never => {
396                let mut tools = BTreeSet::new();
397                tools.insert(tool_name);
398                Self::PerTool(tools)
399            }
400            Self::Always => Self::Always,
401            Self::PerTool(tools) => {
402                tools.insert(tool_name);
403                self
404            }
405        }
406    }
407}
408
409/// Payload describing a tool call awaiting human confirmation.
410#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
411#[serde(rename_all = "camelCase")]
412pub struct ToolConfirmationRequest {
413    pub tool_name: String,
414    #[serde(skip_serializing_if = "Option::is_none")]
415    pub function_call_id: Option<String>,
416    pub args: Value,
417}
418
419#[derive(Debug, Clone)]
420pub struct RunConfig {
421    pub streaming_mode: StreamingMode,
422    /// Optional per-tool confirmation decisions for the current run.
423    /// Keys are tool names.
424    pub tool_confirmation_decisions: HashMap<String, ToolConfirmationDecision>,
425    /// Optional cached content name for automatic prompt caching.
426    /// When set by the runner's cache lifecycle manager, agents should attach
427    /// this name to their `GenerateContentConfig` so the LLM provider can
428    /// reuse cached system instructions and tool definitions.
429    pub cached_content: Option<String>,
430    /// Valid agent names this agent can transfer to (parent, peers, children).
431    /// Set by the runner when invoking agents in a multi-agent tree.
432    /// When non-empty, the `transfer_to_agent` tool is injected and validation
433    /// uses this list instead of only checking `sub_agents`.
434    pub transfer_targets: Vec<String>,
435    /// The name of the parent agent, if this agent was invoked via transfer.
436    /// Used by the agent to apply `disallow_transfer_to_parent` filtering.
437    pub parent_agent: Option<String>,
438}
439
440impl Default for RunConfig {
441    fn default() -> Self {
442        Self {
443            streaming_mode: StreamingMode::SSE,
444            tool_confirmation_decisions: HashMap::new(),
445            cached_content: None,
446            transfer_targets: Vec::new(),
447            parent_agent: None,
448        }
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    #[test]
457    fn test_run_config_default() {
458        let config = RunConfig::default();
459        assert_eq!(config.streaming_mode, StreamingMode::SSE);
460        assert!(config.tool_confirmation_decisions.is_empty());
461    }
462
463    #[test]
464    fn test_streaming_mode() {
465        assert_eq!(StreamingMode::SSE, StreamingMode::SSE);
466        assert_ne!(StreamingMode::SSE, StreamingMode::None);
467        assert_ne!(StreamingMode::None, StreamingMode::Bidi);
468    }
469
470    #[test]
471    fn test_tool_confirmation_policy() {
472        let policy = ToolConfirmationPolicy::default();
473        assert!(!policy.requires_confirmation("search"));
474
475        let policy = policy.with_tool("search");
476        assert!(policy.requires_confirmation("search"));
477        assert!(!policy.requires_confirmation("write_file"));
478
479        assert!(ToolConfirmationPolicy::Always.requires_confirmation("any_tool"));
480    }
481
482    #[test]
483    fn test_validate_state_key_valid() {
484        assert!(validate_state_key("user_name").is_ok());
485        assert!(validate_state_key("app:config").is_ok());
486        assert!(validate_state_key("temp:data").is_ok());
487        assert!(validate_state_key("a").is_ok());
488    }
489
490    #[test]
491    fn test_validate_state_key_empty() {
492        assert_eq!(validate_state_key(""), Err("state key must not be empty"));
493    }
494
495    #[test]
496    fn test_validate_state_key_too_long() {
497        let long_key = "a".repeat(MAX_STATE_KEY_LEN + 1);
498        assert!(validate_state_key(&long_key).is_err());
499    }
500
501    #[test]
502    fn test_validate_state_key_path_traversal() {
503        assert!(validate_state_key("../etc/passwd").is_err());
504        assert!(validate_state_key("foo/bar").is_err());
505        assert!(validate_state_key("foo\\bar").is_err());
506        assert!(validate_state_key("..").is_err());
507    }
508
509    #[test]
510    fn test_validate_state_key_null_byte() {
511        assert!(validate_state_key("foo\0bar").is_err());
512    }
513}