ai_agents_tools/security/
engine.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Instant;
4
5use parking_lot::RwLock;
6use tracing::debug;
7
8use super::config::*;
9use ai_agents_core::Result;
10
11#[derive(Debug, Default)]
12struct ToolCallTracker {
13 calls: HashMap<String, Vec<Instant>>,
14}
15
16impl ToolCallTracker {
17 fn record_call(&mut self, tool_id: &str) {
18 self.calls
19 .entry(tool_id.to_string())
20 .or_default()
21 .push(Instant::now());
22 }
23
24 fn get_calls_in_window(&self, tool_id: &str, window_seconds: u64) -> usize {
25 let now = Instant::now();
26 let window = std::time::Duration::from_secs(window_seconds);
27
28 self.calls
29 .get(tool_id)
30 .map(|calls| {
31 calls
32 .iter()
33 .filter(|t| now.duration_since(**t) < window)
34 .count()
35 })
36 .unwrap_or(0)
37 }
38
39 fn reset(&mut self) {
40 self.calls.clear();
41 }
42}
43
44#[derive(Debug)]
45pub struct ToolSecurityEngine {
46 config: ToolSecurityConfig,
47 tool_call_tracker: Arc<RwLock<ToolCallTracker>>,
48}
49
50impl ToolSecurityEngine {
51 pub fn new(config: ToolSecurityConfig) -> Self {
52 Self {
53 config,
54 tool_call_tracker: Arc::new(RwLock::new(ToolCallTracker::default())),
55 }
56 }
57
58 pub fn config(&self) -> &ToolSecurityConfig {
59 &self.config
60 }
61
62 pub async fn check_tool_execution(
63 &self,
64 tool_id: &str,
65 args: &serde_json::Value,
66 ) -> Result<SecurityCheckResult> {
67 if !self.config.enabled {
68 return Ok(SecurityCheckResult::Allow);
69 }
70
71 let tool_config = self.config.tools.get(tool_id);
72
73 if let Some(config) = tool_config {
74 if !config.enabled {
75 return Ok(SecurityCheckResult::Block {
76 reason: format!("Tool '{}' is disabled", tool_id),
77 });
78 }
79
80 if let Some(rate_limit) = config.rate_limit {
81 let calls = self
82 .tool_call_tracker
83 .read()
84 .get_calls_in_window(tool_id, 60);
85 if calls >= rate_limit as usize {
86 return Ok(SecurityCheckResult::Block {
87 reason: format!(
88 "Rate limit exceeded for tool '{}': {} calls per minute",
89 tool_id, rate_limit
90 ),
91 });
92 }
93 }
94
95 if let Some(url) = args.get("url").and_then(|u| u.as_str()) {
96 for blocked in &config.blocked_domains {
97 if url.contains(blocked) {
98 return Ok(SecurityCheckResult::Block {
99 reason: format!(
100 "Domain '{}' is blocked for tool '{}'",
101 blocked, tool_id
102 ),
103 });
104 }
105 }
106
107 if !config.allowed_domains.is_empty() {
108 let is_allowed = config.allowed_domains.iter().any(|d| url.contains(d));
109 if !is_allowed {
110 return Ok(SecurityCheckResult::Block {
111 reason: format!(
112 "URL domain not in allowed list for tool '{}'",
113 tool_id
114 ),
115 });
116 }
117 }
118 }
119
120 if let Some(path) = args.get("path").and_then(|p| p.as_str()) {
121 if !config.allowed_paths.is_empty() {
122 let is_allowed = config.allowed_paths.iter().any(|p| path.starts_with(p));
123 if !is_allowed {
124 return Ok(SecurityCheckResult::Block {
125 reason: format!("Path not in allowed list for tool '{}'", tool_id),
126 });
127 }
128 }
129 }
130
131 if config.require_confirmation {
132 let message = config
133 .confirmation_message
134 .clone()
135 .unwrap_or_else(|| format!("Confirm execution of tool '{}'?", tool_id));
136 return Ok(SecurityCheckResult::RequireConfirmation { message });
137 }
138 }
139
140 self.tool_call_tracker.write().record_call(tool_id);
141 debug!(tool_id = %tool_id, "Tool execution allowed");
142
143 Ok(SecurityCheckResult::Allow)
144 }
145
146 pub fn get_tool_timeout(&self, tool_id: &str) -> u64 {
147 self.config
148 .tools
149 .get(tool_id)
150 .and_then(|c| c.timeout_ms)
151 .unwrap_or(self.config.default_timeout_ms)
152 }
153
154 pub fn reset_session(&self) {
155 self.tool_call_tracker.write().reset();
156 }
157}
158
159impl Default for ToolSecurityEngine {
160 fn default() -> Self {
161 Self::new(ToolSecurityConfig::default())
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[test]
170 fn test_default_engine() {
171 let engine = ToolSecurityEngine::default();
172 assert!(!engine.config().enabled);
173 }
174
175 #[tokio::test]
176 async fn test_tool_domain_blocking() {
177 let mut config = ToolSecurityConfig::default();
178 config.enabled = true;
179
180 let mut http_config = ToolPolicyConfig::default();
181 http_config.blocked_domains = vec!["evil.com".to_string()];
182 config.tools.insert("http".to_string(), http_config);
183
184 let engine = ToolSecurityEngine::new(config);
185
186 let args = serde_json::json!({"url": "https://evil.com/api"});
187 let result = engine.check_tool_execution("http", &args).await.unwrap();
188 assert!(result.is_blocked());
189
190 let args = serde_json::json!({"url": "https://good.com/api"});
191 let result = engine.check_tool_execution("http", &args).await.unwrap();
192 assert!(result.is_allowed());
193 }
194
195 #[tokio::test]
196 async fn test_tool_allowed_domains() {
197 let mut config = ToolSecurityConfig::default();
198 config.enabled = true;
199
200 let mut http_config = ToolPolicyConfig::default();
201 http_config.allowed_domains = vec!["api.example.com".to_string()];
202 config.tools.insert("http".to_string(), http_config);
203
204 let engine = ToolSecurityEngine::new(config);
205
206 let args = serde_json::json!({"url": "https://api.example.com/v1"});
207 let result = engine.check_tool_execution("http", &args).await.unwrap();
208 assert!(result.is_allowed());
209
210 let args = serde_json::json!({"url": "https://other.com/api"});
211 let result = engine.check_tool_execution("http", &args).await.unwrap();
212 assert!(result.is_blocked());
213 }
214
215 #[tokio::test]
216 async fn test_tool_disabled() {
217 let mut config = ToolSecurityConfig::default();
218 config.enabled = true;
219
220 let mut tool_config = ToolPolicyConfig::default();
221 tool_config.enabled = false;
222 config.tools.insert("dangerous".to_string(), tool_config);
223
224 let engine = ToolSecurityEngine::new(config);
225
226 let result = engine
227 .check_tool_execution("dangerous", &serde_json::json!({}))
228 .await
229 .unwrap();
230 assert!(result.is_blocked());
231 }
232
233 #[tokio::test]
234 async fn test_tool_confirmation_required() {
235 let mut config = ToolSecurityConfig::default();
236 config.enabled = true;
237
238 let mut tool_config = ToolPolicyConfig::default();
239 tool_config.require_confirmation = true;
240 tool_config.confirmation_message = Some("Are you sure?".to_string());
241 config.tools.insert("delete".to_string(), tool_config);
242
243 let engine = ToolSecurityEngine::new(config);
244
245 let result = engine
246 .check_tool_execution("delete", &serde_json::json!({}))
247 .await
248 .unwrap();
249
250 match result {
251 SecurityCheckResult::RequireConfirmation { message } => {
252 assert_eq!(message, "Are you sure?");
253 }
254 _ => panic!("Expected RequireConfirmation"),
255 }
256 }
257
258 #[test]
259 fn test_get_tool_timeout() {
260 let mut config = ToolSecurityConfig::default();
261 config.default_timeout_ms = 5000;
262
263 let mut tool_config = ToolPolicyConfig::default();
264 tool_config.timeout_ms = Some(10000);
265 config.tools.insert("slow".to_string(), tool_config);
266
267 let engine = ToolSecurityEngine::new(config);
268
269 assert_eq!(engine.get_tool_timeout("slow"), 10000);
270 assert_eq!(engine.get_tool_timeout("other"), 5000);
271 }
272
273 #[tokio::test]
274 async fn test_path_restrictions() {
275 let mut config = ToolSecurityConfig::default();
276 config.enabled = true;
277
278 let mut tool_config = ToolPolicyConfig::default();
279 tool_config.allowed_paths = vec!["/tmp/".to_string(), "/home/user/".to_string()];
280 config.tools.insert("file_write".to_string(), tool_config);
281
282 let engine = ToolSecurityEngine::new(config);
283
284 let args = serde_json::json!({"path": "/tmp/test.txt"});
285 let result = engine
286 .check_tool_execution("file_write", &args)
287 .await
288 .unwrap();
289 assert!(result.is_allowed());
290
291 let args = serde_json::json!({"path": "/etc/passwd"});
292 let result = engine
293 .check_tool_execution("file_write", &args)
294 .await
295 .unwrap();
296 assert!(result.is_blocked());
297 }
298}