Skip to main content

adk_core/
context.rs

1use crate::identity::{AdkIdentity, AppName, ExecutionIdentity, InvocationId, SessionId, UserId};
2use crate::{AdkError, Agent, Result, types::Content};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::{BTreeSet, HashMap};
7use std::sync::Arc;
8
9#[async_trait]
10pub trait ReadonlyContext: Send + Sync {
11    fn invocation_id(&self) -> &str;
12    fn agent_name(&self) -> &str;
13    fn user_id(&self) -> &str;
14    fn app_name(&self) -> &str;
15    fn session_id(&self) -> &str;
16    fn branch(&self) -> &str;
17    fn user_content(&self) -> &Content;
18
19    /// Returns the application name as a typed [`AppName`].
20    ///
21    /// Parses the value returned by [`app_name()`](Self::app_name). Returns an
22    /// error if the raw string fails validation (empty, null bytes, or exceeds
23    /// the maximum length).
24    ///
25    /// # Errors
26    ///
27    /// Returns an error when the
28    /// underlying string is not a valid identifier.
29    fn try_app_name(&self) -> Result<AppName> {
30        Ok(AppName::try_from(self.app_name())?)
31    }
32
33    /// Returns the user identifier as a typed [`UserId`].
34    ///
35    /// Parses the value returned by [`user_id()`](Self::user_id). Returns an
36    /// error if the raw string fails validation.
37    ///
38    /// # Errors
39    ///
40    /// Returns an error when the
41    /// underlying string is not a valid identifier.
42    fn try_user_id(&self) -> Result<UserId> {
43        Ok(UserId::try_from(self.user_id())?)
44    }
45
46    /// Returns the session identifier as a typed [`SessionId`].
47    ///
48    /// Parses the value returned by [`session_id()`](Self::session_id).
49    /// Returns an error if the raw string fails validation.
50    ///
51    /// # Errors
52    ///
53    /// Returns an error when the
54    /// underlying string is not a valid identifier.
55    fn try_session_id(&self) -> Result<SessionId> {
56        Ok(SessionId::try_from(self.session_id())?)
57    }
58
59    /// Returns the invocation identifier as a typed [`InvocationId`].
60    ///
61    /// Parses the value returned by [`invocation_id()`](Self::invocation_id).
62    /// Returns an error if the raw string fails validation.
63    ///
64    /// # Errors
65    ///
66    /// Returns an error when the
67    /// underlying string is not a valid identifier.
68    fn try_invocation_id(&self) -> Result<InvocationId> {
69        Ok(InvocationId::try_from(self.invocation_id())?)
70    }
71
72    /// Returns the stable session-scoped [`AdkIdentity`] triple.
73    ///
74    /// Combines [`try_app_name()`](Self::try_app_name),
75    /// [`try_user_id()`](Self::try_user_id), and
76    /// [`try_session_id()`](Self::try_session_id) into a single composite
77    /// identity value.
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if any of the three constituent identifiers fail
82    /// validation.
83    fn try_identity(&self) -> Result<AdkIdentity> {
84        Ok(AdkIdentity {
85            app_name: self.try_app_name()?,
86            user_id: self.try_user_id()?,
87            session_id: self.try_session_id()?,
88        })
89    }
90
91    /// Returns the full per-invocation [`ExecutionIdentity`].
92    ///
93    /// Combines [`try_identity()`](Self::try_identity) with the invocation,
94    /// branch, and agent name from this context.
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if any of the four typed identifiers fail validation.
99    fn try_execution_identity(&self) -> Result<ExecutionIdentity> {
100        Ok(ExecutionIdentity {
101            adk: self.try_identity()?,
102            invocation_id: self.try_invocation_id()?,
103            branch: self.branch().to_string(),
104            agent_name: self.agent_name().to_string(),
105        })
106    }
107}
108
109// State management traits
110
111/// Maximum allowed length for state keys (256 bytes).
112pub const MAX_STATE_KEY_LEN: usize = 256;
113
114/// Validates a state key. Returns `Ok(())` if the key is safe, or an error message.
115///
116/// Rules:
117/// - Must not be empty
118/// - Must not exceed [`MAX_STATE_KEY_LEN`] bytes
119/// - Must not contain path separators (`/`, `\`) or `..`
120/// - Must not contain null bytes
121pub fn validate_state_key(key: &str) -> std::result::Result<(), &'static str> {
122    if key.is_empty() {
123        return Err("state key must not be empty");
124    }
125    if key.len() > MAX_STATE_KEY_LEN {
126        return Err("state key exceeds maximum length of 256 bytes");
127    }
128    if key.contains('/') || key.contains('\\') || key.contains("..") {
129        return Err("state key must not contain path separators or '..'");
130    }
131    if key.contains('\0') {
132        return Err("state key must not contain null bytes");
133    }
134    Ok(())
135}
136
137pub trait State: Send + Sync {
138    fn get(&self, key: &str) -> Option<Value>;
139    /// Set a state value. Implementations should call [`validate_state_key`] and
140    /// reject invalid keys (e.g., by logging a warning or panicking).
141    fn set(&mut self, key: String, value: Value);
142    fn all(&self) -> HashMap<String, Value>;
143}
144
145pub trait ReadonlyState: Send + Sync {
146    fn get(&self, key: &str) -> Option<Value>;
147    fn all(&self) -> HashMap<String, Value>;
148}
149
150// Session trait
151pub trait Session: Send + Sync {
152    fn id(&self) -> &str;
153    fn app_name(&self) -> &str;
154    fn user_id(&self) -> &str;
155    fn state(&self) -> &dyn State;
156    /// Returns the conversation history from this session as Content items
157    fn conversation_history(&self) -> Vec<Content>;
158    /// Returns conversation history filtered for a specific agent.
159    ///
160    /// When provided, events authored by other agents (not "user", not the
161    /// named agent, and not function/tool responses) are excluded. This
162    /// prevents a transferred sub-agent from seeing the parent's tool calls
163    /// mapped as "model" role, which would cause the LLM to think work is
164    /// already done.
165    ///
166    /// Default implementation delegates to [`conversation_history`](Self::conversation_history).
167    fn conversation_history_for_agent(&self, _agent_name: &str) -> Vec<Content> {
168        self.conversation_history()
169    }
170    /// Append content to conversation history (for sequential agent support)
171    fn append_to_history(&self, _content: Content) {
172        // Default no-op - implementations can override to track history
173    }
174
175    /// Returns the application name as a typed [`AppName`].
176    ///
177    /// Parses the value returned by [`app_name()`](Self::app_name). Returns an
178    /// error if the raw string fails validation (empty, null bytes, or exceeds
179    /// the maximum length).
180    ///
181    /// # Errors
182    ///
183    /// Returns an error when the
184    /// underlying string is not a valid identifier.
185    fn try_app_name(&self) -> Result<AppName> {
186        Ok(AppName::try_from(self.app_name())?)
187    }
188
189    /// Returns the user identifier as a typed [`UserId`].
190    ///
191    /// Parses the value returned by [`user_id()`](Self::user_id). Returns an
192    /// error if the raw string fails validation.
193    ///
194    /// # Errors
195    ///
196    /// Returns an error when the
197    /// underlying string is not a valid identifier.
198    fn try_user_id(&self) -> Result<UserId> {
199        Ok(UserId::try_from(self.user_id())?)
200    }
201
202    /// Returns the session identifier as a typed [`SessionId`].
203    ///
204    /// Parses the value returned by [`id()`](Self::id). Returns an error if
205    /// the raw string fails validation.
206    ///
207    /// # Errors
208    ///
209    /// Returns an error when the
210    /// underlying string is not a valid identifier.
211    fn try_session_id(&self) -> Result<SessionId> {
212        Ok(SessionId::try_from(self.id())?)
213    }
214
215    /// Returns the stable session-scoped [`AdkIdentity`] triple.
216    ///
217    /// Combines [`try_app_name()`](Self::try_app_name),
218    /// [`try_user_id()`](Self::try_user_id), and
219    /// [`try_session_id()`](Self::try_session_id) into a single composite
220    /// identity value.
221    ///
222    /// # Errors
223    ///
224    /// Returns an error if any of the three constituent identifiers fail
225    /// validation.
226    fn try_identity(&self) -> Result<AdkIdentity> {
227        Ok(AdkIdentity {
228            app_name: self.try_app_name()?,
229            user_id: self.try_user_id()?,
230            session_id: self.try_session_id()?,
231        })
232    }
233}
234
235/// Structured metadata about a completed tool execution.
236///
237/// Available via [`CallbackContext::tool_outcome()`] in after-tool callbacks,
238/// plugins, and telemetry hooks. Provides structured access to execution
239/// results without requiring JSON error parsing.
240///
241/// # Fields
242///
243/// - `tool_name` — Name of the tool that was executed.
244/// - `tool_args` — Arguments passed to the tool as a JSON value.
245/// - `success` — Whether the tool execution succeeded. Derived from the
246///   Rust `Result` / timeout path, never from JSON content inspection.
247/// - `duration` — Wall-clock duration of the tool execution.
248/// - `error_message` — Error message if the tool failed; `None` on success.
249/// - `attempt` — Retry attempt number (0 = first attempt, 1 = first retry, etc.).
250///   Always 0 when retries are not configured.
251#[derive(Debug, Clone)]
252pub struct ToolOutcome {
253    /// Name of the tool that was executed.
254    pub tool_name: String,
255    /// Arguments passed to the tool (JSON value).
256    pub tool_args: serde_json::Value,
257    /// Whether the tool execution succeeded.
258    pub success: bool,
259    /// Wall-clock duration of the tool execution.
260    pub duration: std::time::Duration,
261    /// Error message if the tool failed. `None` on success.
262    pub error_message: Option<String>,
263    /// Retry attempt number (0 = first attempt, 1 = first retry, etc.).
264    /// Always 0 when retries are not configured.
265    pub attempt: u32,
266}
267
268#[async_trait]
269pub trait CallbackContext: ReadonlyContext {
270    fn artifacts(&self) -> Option<Arc<dyn Artifacts>>;
271
272    /// Returns structured metadata about the most recent tool execution.
273    /// Available in after-tool callbacks and plugin hooks.
274    /// Returns `None` when not in a tool execution context.
275    fn tool_outcome(&self) -> Option<ToolOutcome> {
276        None // default for backward compatibility
277    }
278
279    /// Returns the name of the tool about to be executed.
280    /// Available in before-tool and after-tool callback contexts.
281    fn tool_name(&self) -> Option<&str> {
282        None
283    }
284
285    /// Returns the input arguments for the tool about to be executed.
286    /// Available in before-tool and after-tool callback contexts.
287    fn tool_input(&self) -> Option<&serde_json::Value> {
288        None
289    }
290
291    /// Returns the shared state for parallel agent coordination.
292    /// Returns `None` when not running inside a `ParallelAgent` with shared state enabled.
293    fn shared_state(&self) -> Option<Arc<crate::SharedState>> {
294        None
295    }
296}
297
298/// Wraps a [`CallbackContext`] to inject tool name and input for before-tool
299/// and after-tool callbacks.
300///
301/// Used by the agent runtime to provide tool context to `BeforeToolCallback`
302/// and `AfterToolCallback` invocations.
303///
304/// # Example
305///
306/// ```rust,ignore
307/// let tool_ctx = Arc::new(ToolCallbackContext::new(
308///     ctx.clone(),
309///     "search".to_string(),
310///     serde_json::json!({"query": "hello"}),
311/// ));
312/// callback(tool_ctx as Arc<dyn CallbackContext>).await;
313/// ```
314pub struct ToolCallbackContext {
315    /// The inner callback context to delegate to.
316    pub inner: Arc<dyn CallbackContext>,
317    /// The name of the tool being executed.
318    pub tool_name: String,
319    /// The input arguments for the tool being executed.
320    pub tool_input: serde_json::Value,
321}
322
323impl ToolCallbackContext {
324    /// Creates a new `ToolCallbackContext` wrapping the given inner context.
325    pub fn new(
326        inner: Arc<dyn CallbackContext>,
327        tool_name: String,
328        tool_input: serde_json::Value,
329    ) -> Self {
330        Self { inner, tool_name, tool_input }
331    }
332}
333
334#[async_trait]
335impl ReadonlyContext for ToolCallbackContext {
336    fn invocation_id(&self) -> &str {
337        self.inner.invocation_id()
338    }
339
340    fn agent_name(&self) -> &str {
341        self.inner.agent_name()
342    }
343
344    fn user_id(&self) -> &str {
345        self.inner.user_id()
346    }
347
348    fn app_name(&self) -> &str {
349        self.inner.app_name()
350    }
351
352    fn session_id(&self) -> &str {
353        self.inner.session_id()
354    }
355
356    fn branch(&self) -> &str {
357        self.inner.branch()
358    }
359
360    fn user_content(&self) -> &Content {
361        self.inner.user_content()
362    }
363}
364
365#[async_trait]
366impl CallbackContext for ToolCallbackContext {
367    fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
368        self.inner.artifacts()
369    }
370
371    fn tool_outcome(&self) -> Option<ToolOutcome> {
372        self.inner.tool_outcome()
373    }
374
375    fn tool_name(&self) -> Option<&str> {
376        Some(&self.tool_name)
377    }
378
379    fn tool_input(&self) -> Option<&serde_json::Value> {
380        Some(&self.tool_input)
381    }
382
383    fn shared_state(&self) -> Option<Arc<crate::SharedState>> {
384        self.inner.shared_state()
385    }
386}
387
388#[async_trait]
389pub trait InvocationContext: CallbackContext {
390    fn agent(&self) -> Arc<dyn Agent>;
391    fn memory(&self) -> Option<Arc<dyn Memory>>;
392    fn session(&self) -> &dyn Session;
393    fn run_config(&self) -> &RunConfig;
394    fn end_invocation(&self);
395    fn ended(&self) -> bool;
396
397    /// Returns the scopes granted to the current user for this invocation.
398    ///
399    /// When a [`RequestContext`](crate::RequestContext) is present (set by the
400    /// server's auth middleware bridge), this returns the scopes from that
401    /// context. The default returns an empty vec (no scopes granted).
402    fn user_scopes(&self) -> Vec<String> {
403        vec![]
404    }
405
406    /// Returns the request metadata from the auth middleware bridge, if present.
407    ///
408    /// This provides access to custom key-value pairs extracted from the HTTP
409    /// request by the [`RequestContextExtractor`](crate::RequestContext).
410    fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
411        HashMap::new()
412    }
413
414    /// Retrieve a secret by name from the configured secret provider.
415    ///
416    /// Returns `Ok(Some(value))` when a provider is configured and the secret
417    /// exists, `Ok(None)` when no provider is configured, or an error on
418    /// provider failure. The default returns `Ok(None)`.
419    async fn get_secret(&self, _name: &str) -> Result<Option<String>> {
420        Ok(None)
421    }
422}
423
424// Placeholder service traits
425#[async_trait]
426pub trait Artifacts: Send + Sync {
427    async fn save(&self, name: &str, data: &crate::Part) -> Result<i64>;
428    async fn load(&self, name: &str) -> Result<crate::Part>;
429    async fn list(&self) -> Result<Vec<String>>;
430}
431
432#[async_trait]
433pub trait Memory: Send + Sync {
434    async fn search(&self, query: &str) -> Result<Vec<MemoryEntry>>;
435
436    /// Verify backend connectivity.
437    ///
438    /// The default implementation succeeds, which is suitable for in-memory
439    /// implementations and adapters without an external dependency.
440    async fn health_check(&self) -> Result<()> {
441        Ok(())
442    }
443
444    /// Add a single memory entry.
445    ///
446    /// The default implementation returns an "not implemented" error, which is
447    /// suitable for read-only memory backends.
448    async fn add(&self, entry: MemoryEntry) -> Result<()> {
449        let _ = entry;
450        Err(AdkError::memory("add not implemented"))
451    }
452
453    /// Delete entries matching a query. Returns count of deleted entries.
454    ///
455    /// The default implementation returns an "not implemented" error, which is
456    /// suitable for read-only memory backends.
457    async fn delete(&self, query: &str) -> Result<u64> {
458        let _ = query;
459        Err(AdkError::memory("delete not implemented"))
460    }
461
462    /// Search for memories within a specific project.
463    /// Returns global entries + entries for the given project.
464    /// Default delegates to `search` (global-only results).
465    async fn search_in_project(&self, query: &str, project_id: &str) -> Result<Vec<MemoryEntry>> {
466        let _ = project_id;
467        self.search(query).await
468    }
469
470    /// Add a memory entry scoped to a specific project.
471    /// Default delegates to `add` (global entry).
472    async fn add_to_project(&self, entry: MemoryEntry, project_id: &str) -> Result<()> {
473        let _ = project_id;
474        self.add(entry).await
475    }
476}
477
478/// Trait for retrieving secrets at runtime.
479///
480/// This is the core-level abstraction used by [`ToolContext::get_secret`] and
481/// [`InvocationContext::get_secret`]. Concrete implementations (e.g., AWS
482/// Secrets Manager, Azure Key Vault, GCP Secret Manager) live in `adk-auth`
483/// behind feature flags and implement this trait via the `SecretProvider`
484/// adapter.
485///
486/// # Example
487///
488/// ```rust,ignore
489/// use adk_core::SecretService;
490///
491/// struct EnvSecretService;
492///
493/// #[async_trait::async_trait]
494/// impl SecretService for EnvSecretService {
495///     async fn get_secret(&self, name: &str) -> adk_core::Result<String> {
496///         std::env::var(name).map_err(|_| adk_core::AdkError::not_found(
497///             format!("secret '{name}' not found in environment"),
498///         ))
499///     }
500/// }
501/// ```
502#[async_trait]
503pub trait SecretService: Send + Sync {
504    /// Retrieve a secret value by name.
505    ///
506    /// Returns the secret string on success, or an [`AdkError`] on failure.
507    async fn get_secret(&self, name: &str) -> Result<String>;
508}
509
510#[derive(Debug, Clone)]
511pub struct MemoryEntry {
512    pub content: Content,
513    pub author: String,
514}
515
516/// Streaming mode for agent responses.
517/// Matches ADK Python/Go specification.
518#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
519pub enum StreamingMode {
520    /// No streaming; responses delivered as complete units.
521    /// Agent collects all chunks internally and yields a single final event.
522    None,
523    /// Server-Sent Events streaming; one-way streaming from server to client.
524    /// Agent yields each chunk as it arrives with stable event ID.
525    #[default]
526    SSE,
527    /// Bidirectional streaming; simultaneous communication in both directions.
528    /// Used for realtime audio/video agents.
529    Bidi,
530}
531
532/// Controls what parts of prior conversation history is received by llmagent
533#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
534pub enum IncludeContents {
535    /// The llmagent operates solely on its current turn (latest user input + any following agent events)
536    None,
537    /// Default - The llmagent receives the relevant conversation history
538    #[default]
539    Default,
540}
541
542/// Decision applied when a tool execution requires human confirmation.
543#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
544#[serde(rename_all = "snake_case")]
545pub enum ToolConfirmationDecision {
546    Approve,
547    Deny,
548}
549
550/// Policy defining which tools require human confirmation before execution.
551#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
552#[serde(rename_all = "snake_case")]
553pub enum ToolConfirmationPolicy {
554    /// No tool confirmation is required.
555    #[default]
556    Never,
557    /// Every tool call requires confirmation.
558    Always,
559    /// Only the listed tool names require confirmation.
560    PerTool(BTreeSet<String>),
561}
562
563impl ToolConfirmationPolicy {
564    /// Returns true when the given tool name must be confirmed before execution.
565    pub fn requires_confirmation(&self, tool_name: &str) -> bool {
566        match self {
567            Self::Never => false,
568            Self::Always => true,
569            Self::PerTool(tools) => tools.contains(tool_name),
570        }
571    }
572
573    /// Add one tool name to the confirmation policy (converts `Never` to `PerTool`).
574    pub fn with_tool(mut self, tool_name: impl Into<String>) -> Self {
575        let tool_name = tool_name.into();
576        match &mut self {
577            Self::Never => {
578                let mut tools = BTreeSet::new();
579                tools.insert(tool_name);
580                Self::PerTool(tools)
581            }
582            Self::Always => Self::Always,
583            Self::PerTool(tools) => {
584                tools.insert(tool_name);
585                self
586            }
587        }
588    }
589}
590
591/// Payload describing a tool call awaiting human confirmation.
592#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
593#[serde(rename_all = "camelCase")]
594pub struct ToolConfirmationRequest {
595    pub tool_name: String,
596    #[serde(skip_serializing_if = "Option::is_none")]
597    pub function_call_id: Option<String>,
598    pub args: Value,
599}
600
601#[derive(Debug, Clone)]
602pub struct RunConfig {
603    pub streaming_mode: StreamingMode,
604    /// Optional per-tool confirmation decisions for the current run.
605    /// Keys are tool names.
606    pub tool_confirmation_decisions: HashMap<String, ToolConfirmationDecision>,
607    /// Optional cached content name for automatic prompt caching.
608    /// When set by the runner's cache lifecycle manager, agents should attach
609    /// this name to their `GenerateContentConfig` so the LLM provider can
610    /// reuse cached system instructions and tool definitions.
611    pub cached_content: Option<String>,
612    /// Valid agent names this agent can transfer to (parent, peers, children).
613    /// Set by the runner when invoking agents in a multi-agent tree.
614    /// When non-empty, the `transfer_to_agent` tool is injected and validation
615    /// uses this list instead of only checking `sub_agents`.
616    pub transfer_targets: Vec<String>,
617    /// The name of the parent agent, if this agent was invoked via transfer.
618    /// Used by the agent to apply `disallow_transfer_to_parent` filtering.
619    pub parent_agent: Option<String>,
620    /// Enable automatic prompt caching for all providers that support it.
621    ///
622    /// When `true` (the default), the runner enables provider-level caching:
623    /// - Anthropic: sets `prompt_caching = true` on the config
624    /// - Bedrock: sets `prompt_caching = Some(BedrockCacheConfig::default())`
625    /// - OpenAI / DeepSeek: no action needed (caching is automatic)
626    /// - Gemini: handled separately via `ContextCacheConfig`
627    pub auto_cache: bool,
628    /// Maximum number of recent persisted events to load at the start of a run.
629    ///
630    /// `None` preserves the previous behavior and loads the full session
631    /// history. Set this for chat surfaces that already summarize older turns
632    /// and need predictable startup latency.
633    pub history_max_events: Option<usize>,
634    /// Maximum number of tool calls to execute concurrently for parallel/auto
635    /// tool dispatch. `None` allows all eligible tool calls to run together.
636    pub max_tool_concurrency: Option<usize>,
637    /// Whether tracing spans may include full request, response, and tool
638    /// payloads when the `record-payloads` crate feature is enabled.
639    pub record_payloads: bool,
640    /// Maximum serialized bytes recorded for tracing payload fields when full
641    /// payload recording is disabled.
642    pub trace_payload_max_bytes: usize,
643}
644
645impl Default for RunConfig {
646    fn default() -> Self {
647        Self {
648            streaming_mode: StreamingMode::SSE,
649            tool_confirmation_decisions: HashMap::new(),
650            cached_content: None,
651            transfer_targets: Vec::new(),
652            parent_agent: None,
653            auto_cache: true,
654            history_max_events: None,
655            max_tool_concurrency: None,
656            record_payloads: false,
657            trace_payload_max_bytes: 2048,
658        }
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665
666    #[test]
667    fn test_run_config_default() {
668        let config = RunConfig::default();
669        assert_eq!(config.streaming_mode, StreamingMode::SSE);
670        assert_eq!(config.history_max_events, None);
671        assert_eq!(config.max_tool_concurrency, None);
672        assert!(!config.record_payloads);
673        assert_eq!(config.trace_payload_max_bytes, 2048);
674        assert!(config.tool_confirmation_decisions.is_empty());
675    }
676
677    #[test]
678    fn test_streaming_mode() {
679        assert_eq!(StreamingMode::SSE, StreamingMode::SSE);
680        assert_ne!(StreamingMode::SSE, StreamingMode::None);
681        assert_ne!(StreamingMode::None, StreamingMode::Bidi);
682    }
683
684    #[test]
685    fn test_tool_confirmation_policy() {
686        let policy = ToolConfirmationPolicy::default();
687        assert!(!policy.requires_confirmation("search"));
688
689        let policy = policy.with_tool("search");
690        assert!(policy.requires_confirmation("search"));
691        assert!(!policy.requires_confirmation("write_file"));
692
693        assert!(ToolConfirmationPolicy::Always.requires_confirmation("any_tool"));
694    }
695
696    #[test]
697    fn test_validate_state_key_valid() {
698        assert!(validate_state_key("user_name").is_ok());
699        assert!(validate_state_key("app:config").is_ok());
700        assert!(validate_state_key("temp:data").is_ok());
701        assert!(validate_state_key("a").is_ok());
702    }
703
704    #[test]
705    fn test_validate_state_key_empty() {
706        assert_eq!(validate_state_key(""), Err("state key must not be empty"));
707    }
708
709    #[test]
710    fn test_validate_state_key_too_long() {
711        let long_key = "a".repeat(MAX_STATE_KEY_LEN + 1);
712        assert!(validate_state_key(&long_key).is_err());
713    }
714
715    #[test]
716    fn test_validate_state_key_path_traversal() {
717        assert!(validate_state_key("../etc/passwd").is_err());
718        assert!(validate_state_key("foo/bar").is_err());
719        assert!(validate_state_key("foo\\bar").is_err());
720        assert!(validate_state_key("..").is_err());
721    }
722
723    #[test]
724    fn test_validate_state_key_null_byte() {
725        assert!(validate_state_key("foo\0bar").is_err());
726    }
727}