1use crate::{Agent, Result, types::Content};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::{BTreeSet, HashMap};
6use std::sync::Arc;
7
8#[async_trait]
9pub trait ReadonlyContext: Send + Sync {
10 fn invocation_id(&self) -> &str;
11 fn agent_name(&self) -> &str;
12 fn user_id(&self) -> &str;
13 fn app_name(&self) -> &str;
14 fn session_id(&self) -> &str;
15 fn branch(&self) -> &str;
16 fn user_content(&self) -> &Content;
17}
18
19pub const MAX_STATE_KEY_LEN: usize = 256;
23
24pub fn validate_state_key(key: &str) -> std::result::Result<(), &'static str> {
32 if key.is_empty() {
33 return Err("state key must not be empty");
34 }
35 if key.len() > MAX_STATE_KEY_LEN {
36 return Err("state key exceeds maximum length of 256 bytes");
37 }
38 if key.contains('/') || key.contains('\\') || key.contains("..") {
39 return Err("state key must not contain path separators or '..'");
40 }
41 if key.contains('\0') {
42 return Err("state key must not contain null bytes");
43 }
44 Ok(())
45}
46
47pub trait State: Send + Sync {
48 fn get(&self, key: &str) -> Option<Value>;
49 fn set(&mut self, key: String, value: Value);
52 fn all(&self) -> HashMap<String, Value>;
53}
54
55pub trait ReadonlyState: Send + Sync {
56 fn get(&self, key: &str) -> Option<Value>;
57 fn all(&self) -> HashMap<String, Value>;
58}
59
60pub trait Session: Send + Sync {
62 fn id(&self) -> &str;
63 fn app_name(&self) -> &str;
64 fn user_id(&self) -> &str;
65 fn state(&self) -> &dyn State;
66 fn conversation_history(&self) -> Vec<Content>;
68 fn append_to_history(&self, _content: Content) {
70 }
72}
73
74#[async_trait]
75pub trait CallbackContext: ReadonlyContext {
76 fn artifacts(&self) -> Option<Arc<dyn Artifacts>>;
77}
78
79#[async_trait]
80pub trait InvocationContext: CallbackContext {
81 fn agent(&self) -> Arc<dyn Agent>;
82 fn memory(&self) -> Option<Arc<dyn Memory>>;
83 fn session(&self) -> &dyn Session;
84 fn run_config(&self) -> &RunConfig;
85 fn end_invocation(&self);
86 fn ended(&self) -> bool;
87}
88
89#[async_trait]
91pub trait Artifacts: Send + Sync {
92 async fn save(&self, name: &str, data: &crate::Part) -> Result<i64>;
93 async fn load(&self, name: &str) -> Result<crate::Part>;
94 async fn list(&self) -> Result<Vec<String>>;
95}
96
97#[async_trait]
98pub trait Memory: Send + Sync {
99 async fn search(&self, query: &str) -> Result<Vec<MemoryEntry>>;
100}
101
102#[derive(Debug, Clone)]
103pub struct MemoryEntry {
104 pub content: Content,
105 pub author: String,
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
111pub enum StreamingMode {
112 None,
115 #[default]
118 SSE,
119 Bidi,
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
126pub enum IncludeContents {
127 None,
129 #[default]
131 Default,
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
136#[serde(rename_all = "snake_case")]
137pub enum ToolConfirmationDecision {
138 Approve,
139 Deny,
140}
141
142#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
144#[serde(rename_all = "snake_case")]
145pub enum ToolConfirmationPolicy {
146 #[default]
148 Never,
149 Always,
151 PerTool(BTreeSet<String>),
153}
154
155impl ToolConfirmationPolicy {
156 pub fn requires_confirmation(&self, tool_name: &str) -> bool {
158 match self {
159 Self::Never => false,
160 Self::Always => true,
161 Self::PerTool(tools) => tools.contains(tool_name),
162 }
163 }
164
165 pub fn with_tool(mut self, tool_name: impl Into<String>) -> Self {
167 let tool_name = tool_name.into();
168 match &mut self {
169 Self::Never => {
170 let mut tools = BTreeSet::new();
171 tools.insert(tool_name);
172 Self::PerTool(tools)
173 }
174 Self::Always => Self::Always,
175 Self::PerTool(tools) => {
176 tools.insert(tool_name);
177 self
178 }
179 }
180 }
181}
182
183#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
185#[serde(rename_all = "camelCase")]
186pub struct ToolConfirmationRequest {
187 pub tool_name: String,
188 #[serde(skip_serializing_if = "Option::is_none")]
189 pub function_call_id: Option<String>,
190 pub args: Value,
191}
192
193#[derive(Debug, Clone)]
194pub struct RunConfig {
195 pub streaming_mode: StreamingMode,
196 pub tool_confirmation_decisions: HashMap<String, ToolConfirmationDecision>,
199}
200
201impl Default for RunConfig {
202 fn default() -> Self {
203 Self { streaming_mode: StreamingMode::SSE, tool_confirmation_decisions: HashMap::new() }
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_run_config_default() {
213 let config = RunConfig::default();
214 assert_eq!(config.streaming_mode, StreamingMode::SSE);
215 assert!(config.tool_confirmation_decisions.is_empty());
216 }
217
218 #[test]
219 fn test_streaming_mode() {
220 assert_eq!(StreamingMode::SSE, StreamingMode::SSE);
221 assert_ne!(StreamingMode::SSE, StreamingMode::None);
222 assert_ne!(StreamingMode::None, StreamingMode::Bidi);
223 }
224
225 #[test]
226 fn test_tool_confirmation_policy() {
227 let policy = ToolConfirmationPolicy::default();
228 assert!(!policy.requires_confirmation("search"));
229
230 let policy = policy.with_tool("search");
231 assert!(policy.requires_confirmation("search"));
232 assert!(!policy.requires_confirmation("write_file"));
233
234 assert!(ToolConfirmationPolicy::Always.requires_confirmation("any_tool"));
235 }
236
237 #[test]
238 fn test_validate_state_key_valid() {
239 assert!(validate_state_key("user_name").is_ok());
240 assert!(validate_state_key("app:config").is_ok());
241 assert!(validate_state_key("temp:data").is_ok());
242 assert!(validate_state_key("a").is_ok());
243 }
244
245 #[test]
246 fn test_validate_state_key_empty() {
247 assert_eq!(validate_state_key(""), Err("state key must not be empty"));
248 }
249
250 #[test]
251 fn test_validate_state_key_too_long() {
252 let long_key = "a".repeat(MAX_STATE_KEY_LEN + 1);
253 assert!(validate_state_key(&long_key).is_err());
254 }
255
256 #[test]
257 fn test_validate_state_key_path_traversal() {
258 assert!(validate_state_key("../etc/passwd").is_err());
259 assert!(validate_state_key("foo/bar").is_err());
260 assert!(validate_state_key("foo\\bar").is_err());
261 assert!(validate_state_key("..").is_err());
262 }
263
264 #[test]
265 fn test_validate_state_key_null_byte() {
266 assert!(validate_state_key("foo\0bar").is_err());
267 }
268}