Skip to main content

adk_core/
context.rs

1use crate::{Agent, Result, types::Content};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::{BTreeSet, HashMap};
6use std::sync::Arc;
7
8#[async_trait]
9pub trait ReadonlyContext: Send + Sync {
10    fn invocation_id(&self) -> &str;
11    fn agent_name(&self) -> &str;
12    fn user_id(&self) -> &str;
13    fn app_name(&self) -> &str;
14    fn session_id(&self) -> &str;
15    fn branch(&self) -> &str;
16    fn user_content(&self) -> &Content;
17}
18
19// State management traits
20
21/// Maximum allowed length for state keys (256 bytes).
22pub const MAX_STATE_KEY_LEN: usize = 256;
23
24/// Validates a state key. Returns `Ok(())` if the key is safe, or an error message.
25///
26/// Rules:
27/// - Must not be empty
28/// - Must not exceed [`MAX_STATE_KEY_LEN`] bytes
29/// - Must not contain path separators (`/`, `\`) or `..`
30/// - Must not contain null bytes
31pub fn validate_state_key(key: &str) -> std::result::Result<(), &'static str> {
32    if key.is_empty() {
33        return Err("state key must not be empty");
34    }
35    if key.len() > MAX_STATE_KEY_LEN {
36        return Err("state key exceeds maximum length of 256 bytes");
37    }
38    if key.contains('/') || key.contains('\\') || key.contains("..") {
39        return Err("state key must not contain path separators or '..'");
40    }
41    if key.contains('\0') {
42        return Err("state key must not contain null bytes");
43    }
44    Ok(())
45}
46
47pub trait State: Send + Sync {
48    fn get(&self, key: &str) -> Option<Value>;
49    /// Set a state value. Implementations should call [`validate_state_key`] and
50    /// reject invalid keys (e.g., by logging a warning or panicking).
51    fn set(&mut self, key: String, value: Value);
52    fn all(&self) -> HashMap<String, Value>;
53}
54
55pub trait ReadonlyState: Send + Sync {
56    fn get(&self, key: &str) -> Option<Value>;
57    fn all(&self) -> HashMap<String, Value>;
58}
59
60// Session trait
61pub trait Session: Send + Sync {
62    fn id(&self) -> &str;
63    fn app_name(&self) -> &str;
64    fn user_id(&self) -> &str;
65    fn state(&self) -> &dyn State;
66    /// Returns the conversation history from this session as Content items
67    fn conversation_history(&self) -> Vec<Content>;
68    /// Append content to conversation history (for sequential agent support)
69    fn append_to_history(&self, _content: Content) {
70        // Default no-op - implementations can override to track history
71    }
72}
73
74#[async_trait]
75pub trait CallbackContext: ReadonlyContext {
76    fn artifacts(&self) -> Option<Arc<dyn Artifacts>>;
77}
78
79#[async_trait]
80pub trait InvocationContext: CallbackContext {
81    fn agent(&self) -> Arc<dyn Agent>;
82    fn memory(&self) -> Option<Arc<dyn Memory>>;
83    fn session(&self) -> &dyn Session;
84    fn run_config(&self) -> &RunConfig;
85    fn end_invocation(&self);
86    fn ended(&self) -> bool;
87}
88
89// Placeholder service traits
90#[async_trait]
91pub trait Artifacts: Send + Sync {
92    async fn save(&self, name: &str, data: &crate::Part) -> Result<i64>;
93    async fn load(&self, name: &str) -> Result<crate::Part>;
94    async fn list(&self) -> Result<Vec<String>>;
95}
96
97#[async_trait]
98pub trait Memory: Send + Sync {
99    async fn search(&self, query: &str) -> Result<Vec<MemoryEntry>>;
100}
101
102#[derive(Debug, Clone)]
103pub struct MemoryEntry {
104    pub content: Content,
105    pub author: String,
106}
107
108/// Streaming mode for agent responses.
109/// Matches ADK Python/Go specification.
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
111pub enum StreamingMode {
112    /// No streaming; responses delivered as complete units.
113    /// Agent collects all chunks internally and yields a single final event.
114    None,
115    /// Server-Sent Events streaming; one-way streaming from server to client.
116    /// Agent yields each chunk as it arrives with stable event ID.
117    #[default]
118    SSE,
119    /// Bidirectional streaming; simultaneous communication in both directions.
120    /// Used for realtime audio/video agents.
121    Bidi,
122}
123
124/// Controls what parts of prior conversation history is received by llmagent
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
126pub enum IncludeContents {
127    /// The llmagent operates solely on its current turn (latest user input + any following agent events)
128    None,
129    /// Default - The llmagent receives the relevant conversation history
130    #[default]
131    Default,
132}
133
134/// Decision applied when a tool execution requires human confirmation.
135#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
136#[serde(rename_all = "snake_case")]
137pub enum ToolConfirmationDecision {
138    Approve,
139    Deny,
140}
141
142/// Policy defining which tools require human confirmation before execution.
143#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
144#[serde(rename_all = "snake_case")]
145pub enum ToolConfirmationPolicy {
146    /// No tool confirmation is required.
147    #[default]
148    Never,
149    /// Every tool call requires confirmation.
150    Always,
151    /// Only the listed tool names require confirmation.
152    PerTool(BTreeSet<String>),
153}
154
155impl ToolConfirmationPolicy {
156    /// Returns true when the given tool name must be confirmed before execution.
157    pub fn requires_confirmation(&self, tool_name: &str) -> bool {
158        match self {
159            Self::Never => false,
160            Self::Always => true,
161            Self::PerTool(tools) => tools.contains(tool_name),
162        }
163    }
164
165    /// Add one tool name to the confirmation policy (converts `Never` to `PerTool`).
166    pub fn with_tool(mut self, tool_name: impl Into<String>) -> Self {
167        let tool_name = tool_name.into();
168        match &mut self {
169            Self::Never => {
170                let mut tools = BTreeSet::new();
171                tools.insert(tool_name);
172                Self::PerTool(tools)
173            }
174            Self::Always => Self::Always,
175            Self::PerTool(tools) => {
176                tools.insert(tool_name);
177                self
178            }
179        }
180    }
181}
182
183/// Payload describing a tool call awaiting human confirmation.
184#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
185#[serde(rename_all = "camelCase")]
186pub struct ToolConfirmationRequest {
187    pub tool_name: String,
188    #[serde(skip_serializing_if = "Option::is_none")]
189    pub function_call_id: Option<String>,
190    pub args: Value,
191}
192
193#[derive(Debug, Clone)]
194pub struct RunConfig {
195    pub streaming_mode: StreamingMode,
196    /// Optional per-tool confirmation decisions for the current run.
197    /// Keys are tool names.
198    pub tool_confirmation_decisions: HashMap<String, ToolConfirmationDecision>,
199}
200
201impl Default for RunConfig {
202    fn default() -> Self {
203        Self { streaming_mode: StreamingMode::SSE, tool_confirmation_decisions: HashMap::new() }
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_run_config_default() {
213        let config = RunConfig::default();
214        assert_eq!(config.streaming_mode, StreamingMode::SSE);
215        assert!(config.tool_confirmation_decisions.is_empty());
216    }
217
218    #[test]
219    fn test_streaming_mode() {
220        assert_eq!(StreamingMode::SSE, StreamingMode::SSE);
221        assert_ne!(StreamingMode::SSE, StreamingMode::None);
222        assert_ne!(StreamingMode::None, StreamingMode::Bidi);
223    }
224
225    #[test]
226    fn test_tool_confirmation_policy() {
227        let policy = ToolConfirmationPolicy::default();
228        assert!(!policy.requires_confirmation("search"));
229
230        let policy = policy.with_tool("search");
231        assert!(policy.requires_confirmation("search"));
232        assert!(!policy.requires_confirmation("write_file"));
233
234        assert!(ToolConfirmationPolicy::Always.requires_confirmation("any_tool"));
235    }
236
237    #[test]
238    fn test_validate_state_key_valid() {
239        assert!(validate_state_key("user_name").is_ok());
240        assert!(validate_state_key("app:config").is_ok());
241        assert!(validate_state_key("temp:data").is_ok());
242        assert!(validate_state_key("a").is_ok());
243    }
244
245    #[test]
246    fn test_validate_state_key_empty() {
247        assert_eq!(validate_state_key(""), Err("state key must not be empty"));
248    }
249
250    #[test]
251    fn test_validate_state_key_too_long() {
252        let long_key = "a".repeat(MAX_STATE_KEY_LEN + 1);
253        assert!(validate_state_key(&long_key).is_err());
254    }
255
256    #[test]
257    fn test_validate_state_key_path_traversal() {
258        assert!(validate_state_key("../etc/passwd").is_err());
259        assert!(validate_state_key("foo/bar").is_err());
260        assert!(validate_state_key("foo\\bar").is_err());
261        assert!(validate_state_key("..").is_err());
262    }
263
264    #[test]
265    fn test_validate_state_key_null_byte() {
266        assert!(validate_state_key("foo\0bar").is_err());
267    }
268}