adk_core/context.rs
1use crate::identity::{AdkIdentity, AppName, ExecutionIdentity, InvocationId, SessionId, UserId};
2use crate::{AdkError, Agent, Result, types::Content};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::{BTreeSet, HashMap};
7use std::sync::Arc;
8
9#[async_trait]
10pub trait ReadonlyContext: Send + Sync {
11 fn invocation_id(&self) -> &str;
12 fn agent_name(&self) -> &str;
13 fn user_id(&self) -> &str;
14 fn app_name(&self) -> &str;
15 fn session_id(&self) -> &str;
16 fn branch(&self) -> &str;
17 fn user_content(&self) -> &Content;
18
19 /// Returns the application name as a typed [`AppName`].
20 ///
21 /// Parses the value returned by [`app_name()`](Self::app_name). Returns an
22 /// error if the raw string fails validation (empty, null bytes, or exceeds
23 /// the maximum length).
24 ///
25 /// # Errors
26 ///
27 /// Returns an error when the
28 /// underlying string is not a valid identifier.
29 fn try_app_name(&self) -> Result<AppName> {
30 Ok(AppName::try_from(self.app_name())?)
31 }
32
33 /// Returns the user identifier as a typed [`UserId`].
34 ///
35 /// Parses the value returned by [`user_id()`](Self::user_id). Returns an
36 /// error if the raw string fails validation.
37 ///
38 /// # Errors
39 ///
40 /// Returns an error when the
41 /// underlying string is not a valid identifier.
42 fn try_user_id(&self) -> Result<UserId> {
43 Ok(UserId::try_from(self.user_id())?)
44 }
45
46 /// Returns the session identifier as a typed [`SessionId`].
47 ///
48 /// Parses the value returned by [`session_id()`](Self::session_id).
49 /// Returns an error if the raw string fails validation.
50 ///
51 /// # Errors
52 ///
53 /// Returns an error when the
54 /// underlying string is not a valid identifier.
55 fn try_session_id(&self) -> Result<SessionId> {
56 Ok(SessionId::try_from(self.session_id())?)
57 }
58
59 /// Returns the invocation identifier as a typed [`InvocationId`].
60 ///
61 /// Parses the value returned by [`invocation_id()`](Self::invocation_id).
62 /// Returns an error if the raw string fails validation.
63 ///
64 /// # Errors
65 ///
66 /// Returns an error when the
67 /// underlying string is not a valid identifier.
68 fn try_invocation_id(&self) -> Result<InvocationId> {
69 Ok(InvocationId::try_from(self.invocation_id())?)
70 }
71
72 /// Returns the stable session-scoped [`AdkIdentity`] triple.
73 ///
74 /// Combines [`try_app_name()`](Self::try_app_name),
75 /// [`try_user_id()`](Self::try_user_id), and
76 /// [`try_session_id()`](Self::try_session_id) into a single composite
77 /// identity value.
78 ///
79 /// # Errors
80 ///
81 /// Returns an error if any of the three constituent identifiers fail
82 /// validation.
83 fn try_identity(&self) -> Result<AdkIdentity> {
84 Ok(AdkIdentity {
85 app_name: self.try_app_name()?,
86 user_id: self.try_user_id()?,
87 session_id: self.try_session_id()?,
88 })
89 }
90
91 /// Returns the full per-invocation [`ExecutionIdentity`].
92 ///
93 /// Combines [`try_identity()`](Self::try_identity) with the invocation,
94 /// branch, and agent name from this context.
95 ///
96 /// # Errors
97 ///
98 /// Returns an error if any of the four typed identifiers fail validation.
99 fn try_execution_identity(&self) -> Result<ExecutionIdentity> {
100 Ok(ExecutionIdentity {
101 adk: self.try_identity()?,
102 invocation_id: self.try_invocation_id()?,
103 branch: self.branch().to_string(),
104 agent_name: self.agent_name().to_string(),
105 })
106 }
107}
108
109// State management traits
110
111/// Maximum allowed length for state keys (256 bytes).
112pub const MAX_STATE_KEY_LEN: usize = 256;
113
114/// Validates a state key. Returns `Ok(())` if the key is safe, or an error message.
115///
116/// Rules:
117/// - Must not be empty
118/// - Must not exceed [`MAX_STATE_KEY_LEN`] bytes
119/// - Must not contain path separators (`/`, `\`) or `..`
120/// - Must not contain null bytes
121pub fn validate_state_key(key: &str) -> std::result::Result<(), &'static str> {
122 if key.is_empty() {
123 return Err("state key must not be empty");
124 }
125 if key.len() > MAX_STATE_KEY_LEN {
126 return Err("state key exceeds maximum length of 256 bytes");
127 }
128 if key.contains('/') || key.contains('\\') || key.contains("..") {
129 return Err("state key must not contain path separators or '..'");
130 }
131 if key.contains('\0') {
132 return Err("state key must not contain null bytes");
133 }
134 Ok(())
135}
136
137pub trait State: Send + Sync {
138 fn get(&self, key: &str) -> Option<Value>;
139 /// Set a state value. Implementations should call [`validate_state_key`] and
140 /// reject invalid keys (e.g., by logging a warning or panicking).
141 fn set(&mut self, key: String, value: Value);
142 fn all(&self) -> HashMap<String, Value>;
143}
144
145pub trait ReadonlyState: Send + Sync {
146 fn get(&self, key: &str) -> Option<Value>;
147 fn all(&self) -> HashMap<String, Value>;
148}
149
150// Session trait
151pub trait Session: Send + Sync {
152 fn id(&self) -> &str;
153 fn app_name(&self) -> &str;
154 fn user_id(&self) -> &str;
155 fn state(&self) -> &dyn State;
156 /// Returns the conversation history from this session as Content items
157 fn conversation_history(&self) -> Vec<Content>;
158 /// Returns conversation history filtered for a specific agent.
159 ///
160 /// When provided, events authored by other agents (not "user", not the
161 /// named agent, and not function/tool responses) are excluded. This
162 /// prevents a transferred sub-agent from seeing the parent's tool calls
163 /// mapped as "model" role, which would cause the LLM to think work is
164 /// already done.
165 ///
166 /// Default implementation delegates to [`conversation_history`](Self::conversation_history).
167 fn conversation_history_for_agent(&self, _agent_name: &str) -> Vec<Content> {
168 self.conversation_history()
169 }
170 /// Append content to conversation history (for sequential agent support)
171 fn append_to_history(&self, _content: Content) {
172 // Default no-op - implementations can override to track history
173 }
174
175 /// Returns the application name as a typed [`AppName`].
176 ///
177 /// Parses the value returned by [`app_name()`](Self::app_name). Returns an
178 /// error if the raw string fails validation (empty, null bytes, or exceeds
179 /// the maximum length).
180 ///
181 /// # Errors
182 ///
183 /// Returns an error when the
184 /// underlying string is not a valid identifier.
185 fn try_app_name(&self) -> Result<AppName> {
186 Ok(AppName::try_from(self.app_name())?)
187 }
188
189 /// Returns the user identifier as a typed [`UserId`].
190 ///
191 /// Parses the value returned by [`user_id()`](Self::user_id). Returns an
192 /// error if the raw string fails validation.
193 ///
194 /// # Errors
195 ///
196 /// Returns an error when the
197 /// underlying string is not a valid identifier.
198 fn try_user_id(&self) -> Result<UserId> {
199 Ok(UserId::try_from(self.user_id())?)
200 }
201
202 /// Returns the session identifier as a typed [`SessionId`].
203 ///
204 /// Parses the value returned by [`id()`](Self::id). Returns an error if
205 /// the raw string fails validation.
206 ///
207 /// # Errors
208 ///
209 /// Returns an error when the
210 /// underlying string is not a valid identifier.
211 fn try_session_id(&self) -> Result<SessionId> {
212 Ok(SessionId::try_from(self.id())?)
213 }
214
215 /// Returns the stable session-scoped [`AdkIdentity`] triple.
216 ///
217 /// Combines [`try_app_name()`](Self::try_app_name),
218 /// [`try_user_id()`](Self::try_user_id), and
219 /// [`try_session_id()`](Self::try_session_id) into a single composite
220 /// identity value.
221 ///
222 /// # Errors
223 ///
224 /// Returns an error if any of the three constituent identifiers fail
225 /// validation.
226 fn try_identity(&self) -> Result<AdkIdentity> {
227 Ok(AdkIdentity {
228 app_name: self.try_app_name()?,
229 user_id: self.try_user_id()?,
230 session_id: self.try_session_id()?,
231 })
232 }
233}
234
235/// Structured metadata about a completed tool execution.
236///
237/// Available via [`CallbackContext::tool_outcome()`] in after-tool callbacks,
238/// plugins, and telemetry hooks. Provides structured access to execution
239/// results without requiring JSON error parsing.
240///
241/// # Fields
242///
243/// - `tool_name` — Name of the tool that was executed.
244/// - `tool_args` — Arguments passed to the tool as a JSON value.
245/// - `success` — Whether the tool execution succeeded. Derived from the
246/// Rust `Result` / timeout path, never from JSON content inspection.
247/// - `duration` — Wall-clock duration of the tool execution.
248/// - `error_message` — Error message if the tool failed; `None` on success.
249/// - `attempt` — Retry attempt number (0 = first attempt, 1 = first retry, etc.).
250/// Always 0 when retries are not configured.
251#[derive(Debug, Clone)]
252pub struct ToolOutcome {
253 /// Name of the tool that was executed.
254 pub tool_name: String,
255 /// Arguments passed to the tool (JSON value).
256 pub tool_args: serde_json::Value,
257 /// Whether the tool execution succeeded.
258 pub success: bool,
259 /// Wall-clock duration of the tool execution.
260 pub duration: std::time::Duration,
261 /// Error message if the tool failed. `None` on success.
262 pub error_message: Option<String>,
263 /// Retry attempt number (0 = first attempt, 1 = first retry, etc.).
264 /// Always 0 when retries are not configured.
265 pub attempt: u32,
266}
267
268#[async_trait]
269pub trait CallbackContext: ReadonlyContext {
270 fn artifacts(&self) -> Option<Arc<dyn Artifacts>>;
271
272 /// Returns structured metadata about the most recent tool execution.
273 /// Available in after-tool callbacks and plugin hooks.
274 /// Returns `None` when not in a tool execution context.
275 fn tool_outcome(&self) -> Option<ToolOutcome> {
276 None // default for backward compatibility
277 }
278
279 /// Returns the name of the tool about to be executed.
280 /// Available in before-tool and after-tool callback contexts.
281 fn tool_name(&self) -> Option<&str> {
282 None
283 }
284
285 /// Returns the input arguments for the tool about to be executed.
286 /// Available in before-tool and after-tool callback contexts.
287 fn tool_input(&self) -> Option<&serde_json::Value> {
288 None
289 }
290
291 /// Returns the shared state for parallel agent coordination.
292 /// Returns `None` when not running inside a `ParallelAgent` with shared state enabled.
293 fn shared_state(&self) -> Option<Arc<crate::SharedState>> {
294 None
295 }
296}
297
298/// Wraps a [`CallbackContext`] to inject tool name and input for before-tool
299/// and after-tool callbacks.
300///
301/// Used by the agent runtime to provide tool context to `BeforeToolCallback`
302/// and `AfterToolCallback` invocations.
303///
304/// # Example
305///
306/// ```rust,ignore
307/// let tool_ctx = Arc::new(ToolCallbackContext::new(
308/// ctx.clone(),
309/// "search".to_string(),
310/// serde_json::json!({"query": "hello"}),
311/// ));
312/// callback(tool_ctx as Arc<dyn CallbackContext>).await;
313/// ```
314pub struct ToolCallbackContext {
315 /// The inner callback context to delegate to.
316 pub inner: Arc<dyn CallbackContext>,
317 /// The name of the tool being executed.
318 pub tool_name: String,
319 /// The input arguments for the tool being executed.
320 pub tool_input: serde_json::Value,
321}
322
323impl ToolCallbackContext {
324 /// Creates a new `ToolCallbackContext` wrapping the given inner context.
325 pub fn new(
326 inner: Arc<dyn CallbackContext>,
327 tool_name: String,
328 tool_input: serde_json::Value,
329 ) -> Self {
330 Self { inner, tool_name, tool_input }
331 }
332}
333
334#[async_trait]
335impl ReadonlyContext for ToolCallbackContext {
336 fn invocation_id(&self) -> &str {
337 self.inner.invocation_id()
338 }
339
340 fn agent_name(&self) -> &str {
341 self.inner.agent_name()
342 }
343
344 fn user_id(&self) -> &str {
345 self.inner.user_id()
346 }
347
348 fn app_name(&self) -> &str {
349 self.inner.app_name()
350 }
351
352 fn session_id(&self) -> &str {
353 self.inner.session_id()
354 }
355
356 fn branch(&self) -> &str {
357 self.inner.branch()
358 }
359
360 fn user_content(&self) -> &Content {
361 self.inner.user_content()
362 }
363}
364
365#[async_trait]
366impl CallbackContext for ToolCallbackContext {
367 fn artifacts(&self) -> Option<Arc<dyn Artifacts>> {
368 self.inner.artifacts()
369 }
370
371 fn tool_outcome(&self) -> Option<ToolOutcome> {
372 self.inner.tool_outcome()
373 }
374
375 fn tool_name(&self) -> Option<&str> {
376 Some(&self.tool_name)
377 }
378
379 fn tool_input(&self) -> Option<&serde_json::Value> {
380 Some(&self.tool_input)
381 }
382
383 fn shared_state(&self) -> Option<Arc<crate::SharedState>> {
384 self.inner.shared_state()
385 }
386}
387
388#[async_trait]
389pub trait InvocationContext: CallbackContext {
390 fn agent(&self) -> Arc<dyn Agent>;
391 fn memory(&self) -> Option<Arc<dyn Memory>>;
392 fn session(&self) -> &dyn Session;
393 fn run_config(&self) -> &RunConfig;
394 fn end_invocation(&self);
395 fn ended(&self) -> bool;
396
397 /// Returns the scopes granted to the current user for this invocation.
398 ///
399 /// When a [`RequestContext`](crate::RequestContext) is present (set by the
400 /// server's auth middleware bridge), this returns the scopes from that
401 /// context. The default returns an empty vec (no scopes granted).
402 fn user_scopes(&self) -> Vec<String> {
403 vec![]
404 }
405
406 /// Returns the request metadata from the auth middleware bridge, if present.
407 ///
408 /// This provides access to custom key-value pairs extracted from the HTTP
409 /// request by the [`RequestContextExtractor`](crate::RequestContext).
410 fn request_metadata(&self) -> HashMap<String, serde_json::Value> {
411 HashMap::new()
412 }
413}
414
415// Placeholder service traits
416#[async_trait]
417pub trait Artifacts: Send + Sync {
418 async fn save(&self, name: &str, data: &crate::Part) -> Result<i64>;
419 async fn load(&self, name: &str) -> Result<crate::Part>;
420 async fn list(&self) -> Result<Vec<String>>;
421}
422
423#[async_trait]
424pub trait Memory: Send + Sync {
425 async fn search(&self, query: &str) -> Result<Vec<MemoryEntry>>;
426
427 /// Verify backend connectivity.
428 ///
429 /// The default implementation succeeds, which is suitable for in-memory
430 /// implementations and adapters without an external dependency.
431 async fn health_check(&self) -> Result<()> {
432 Ok(())
433 }
434
435 /// Add a single memory entry.
436 ///
437 /// The default implementation returns an "not implemented" error, which is
438 /// suitable for read-only memory backends.
439 async fn add(&self, entry: MemoryEntry) -> Result<()> {
440 let _ = entry;
441 Err(AdkError::memory("add not implemented"))
442 }
443
444 /// Delete entries matching a query. Returns count of deleted entries.
445 ///
446 /// The default implementation returns an "not implemented" error, which is
447 /// suitable for read-only memory backends.
448 async fn delete(&self, query: &str) -> Result<u64> {
449 let _ = query;
450 Err(AdkError::memory("delete not implemented"))
451 }
452}
453
454#[derive(Debug, Clone)]
455pub struct MemoryEntry {
456 pub content: Content,
457 pub author: String,
458}
459
460/// Streaming mode for agent responses.
461/// Matches ADK Python/Go specification.
462#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
463pub enum StreamingMode {
464 /// No streaming; responses delivered as complete units.
465 /// Agent collects all chunks internally and yields a single final event.
466 None,
467 /// Server-Sent Events streaming; one-way streaming from server to client.
468 /// Agent yields each chunk as it arrives with stable event ID.
469 #[default]
470 SSE,
471 /// Bidirectional streaming; simultaneous communication in both directions.
472 /// Used for realtime audio/video agents.
473 Bidi,
474}
475
476/// Controls what parts of prior conversation history is received by llmagent
477#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
478pub enum IncludeContents {
479 /// The llmagent operates solely on its current turn (latest user input + any following agent events)
480 None,
481 /// Default - The llmagent receives the relevant conversation history
482 #[default]
483 Default,
484}
485
486/// Decision applied when a tool execution requires human confirmation.
487#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
488#[serde(rename_all = "snake_case")]
489pub enum ToolConfirmationDecision {
490 Approve,
491 Deny,
492}
493
494/// Policy defining which tools require human confirmation before execution.
495#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
496#[serde(rename_all = "snake_case")]
497pub enum ToolConfirmationPolicy {
498 /// No tool confirmation is required.
499 #[default]
500 Never,
501 /// Every tool call requires confirmation.
502 Always,
503 /// Only the listed tool names require confirmation.
504 PerTool(BTreeSet<String>),
505}
506
507impl ToolConfirmationPolicy {
508 /// Returns true when the given tool name must be confirmed before execution.
509 pub fn requires_confirmation(&self, tool_name: &str) -> bool {
510 match self {
511 Self::Never => false,
512 Self::Always => true,
513 Self::PerTool(tools) => tools.contains(tool_name),
514 }
515 }
516
517 /// Add one tool name to the confirmation policy (converts `Never` to `PerTool`).
518 pub fn with_tool(mut self, tool_name: impl Into<String>) -> Self {
519 let tool_name = tool_name.into();
520 match &mut self {
521 Self::Never => {
522 let mut tools = BTreeSet::new();
523 tools.insert(tool_name);
524 Self::PerTool(tools)
525 }
526 Self::Always => Self::Always,
527 Self::PerTool(tools) => {
528 tools.insert(tool_name);
529 self
530 }
531 }
532 }
533}
534
535/// Payload describing a tool call awaiting human confirmation.
536#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
537#[serde(rename_all = "camelCase")]
538pub struct ToolConfirmationRequest {
539 pub tool_name: String,
540 #[serde(skip_serializing_if = "Option::is_none")]
541 pub function_call_id: Option<String>,
542 pub args: Value,
543}
544
545#[derive(Debug, Clone)]
546pub struct RunConfig {
547 pub streaming_mode: StreamingMode,
548 /// Optional per-tool confirmation decisions for the current run.
549 /// Keys are tool names.
550 pub tool_confirmation_decisions: HashMap<String, ToolConfirmationDecision>,
551 /// Optional cached content name for automatic prompt caching.
552 /// When set by the runner's cache lifecycle manager, agents should attach
553 /// this name to their `GenerateContentConfig` so the LLM provider can
554 /// reuse cached system instructions and tool definitions.
555 pub cached_content: Option<String>,
556 /// Valid agent names this agent can transfer to (parent, peers, children).
557 /// Set by the runner when invoking agents in a multi-agent tree.
558 /// When non-empty, the `transfer_to_agent` tool is injected and validation
559 /// uses this list instead of only checking `sub_agents`.
560 pub transfer_targets: Vec<String>,
561 /// The name of the parent agent, if this agent was invoked via transfer.
562 /// Used by the agent to apply `disallow_transfer_to_parent` filtering.
563 pub parent_agent: Option<String>,
564 /// Enable automatic prompt caching for all providers that support it.
565 ///
566 /// When `true` (the default), the runner enables provider-level caching:
567 /// - Anthropic: sets `prompt_caching = true` on the config
568 /// - Bedrock: sets `prompt_caching = Some(BedrockCacheConfig::default())`
569 /// - OpenAI / DeepSeek: no action needed (caching is automatic)
570 /// - Gemini: handled separately via `ContextCacheConfig`
571 pub auto_cache: bool,
572}
573
574impl Default for RunConfig {
575 fn default() -> Self {
576 Self {
577 streaming_mode: StreamingMode::SSE,
578 tool_confirmation_decisions: HashMap::new(),
579 cached_content: None,
580 transfer_targets: Vec::new(),
581 parent_agent: None,
582 auto_cache: true,
583 }
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590
591 #[test]
592 fn test_run_config_default() {
593 let config = RunConfig::default();
594 assert_eq!(config.streaming_mode, StreamingMode::SSE);
595 assert!(config.tool_confirmation_decisions.is_empty());
596 }
597
598 #[test]
599 fn test_streaming_mode() {
600 assert_eq!(StreamingMode::SSE, StreamingMode::SSE);
601 assert_ne!(StreamingMode::SSE, StreamingMode::None);
602 assert_ne!(StreamingMode::None, StreamingMode::Bidi);
603 }
604
605 #[test]
606 fn test_tool_confirmation_policy() {
607 let policy = ToolConfirmationPolicy::default();
608 assert!(!policy.requires_confirmation("search"));
609
610 let policy = policy.with_tool("search");
611 assert!(policy.requires_confirmation("search"));
612 assert!(!policy.requires_confirmation("write_file"));
613
614 assert!(ToolConfirmationPolicy::Always.requires_confirmation("any_tool"));
615 }
616
617 #[test]
618 fn test_validate_state_key_valid() {
619 assert!(validate_state_key("user_name").is_ok());
620 assert!(validate_state_key("app:config").is_ok());
621 assert!(validate_state_key("temp:data").is_ok());
622 assert!(validate_state_key("a").is_ok());
623 }
624
625 #[test]
626 fn test_validate_state_key_empty() {
627 assert_eq!(validate_state_key(""), Err("state key must not be empty"));
628 }
629
630 #[test]
631 fn test_validate_state_key_too_long() {
632 let long_key = "a".repeat(MAX_STATE_KEY_LEN + 1);
633 assert!(validate_state_key(&long_key).is_err());
634 }
635
636 #[test]
637 fn test_validate_state_key_path_traversal() {
638 assert!(validate_state_key("../etc/passwd").is_err());
639 assert!(validate_state_key("foo/bar").is_err());
640 assert!(validate_state_key("foo\\bar").is_err());
641 assert!(validate_state_key("..").is_err());
642 }
643
644 #[test]
645 fn test_validate_state_key_null_byte() {
646 assert!(validate_state_key("foo\0bar").is_err());
647 }
648}