Skip to main content

ai_agents_tools/security/
engine.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Instant;
4
5use parking_lot::RwLock;
6use tracing::debug;
7
8use super::config::*;
9use ai_agents_core::Result;
10
11#[derive(Debug, Default)]
12struct ToolCallTracker {
13    calls: HashMap<String, Vec<Instant>>,
14}
15
16impl ToolCallTracker {
17    fn record_call(&mut self, tool_id: &str) {
18        self.calls
19            .entry(tool_id.to_string())
20            .or_default()
21            .push(Instant::now());
22    }
23
24    fn get_calls_in_window(&self, tool_id: &str, window_seconds: u64) -> usize {
25        let now = Instant::now();
26        let window = std::time::Duration::from_secs(window_seconds);
27
28        self.calls
29            .get(tool_id)
30            .map(|calls| {
31                calls
32                    .iter()
33                    .filter(|t| now.duration_since(**t) < window)
34                    .count()
35            })
36            .unwrap_or(0)
37    }
38
39    fn reset(&mut self) {
40        self.calls.clear();
41    }
42}
43
44#[derive(Debug)]
45pub struct ToolSecurityEngine {
46    config: ToolSecurityConfig,
47    tool_call_tracker: Arc<RwLock<ToolCallTracker>>,
48}
49
50impl ToolSecurityEngine {
51    pub fn new(config: ToolSecurityConfig) -> Self {
52        Self {
53            config,
54            tool_call_tracker: Arc::new(RwLock::new(ToolCallTracker::default())),
55        }
56    }
57
58    pub fn config(&self) -> &ToolSecurityConfig {
59        &self.config
60    }
61
62    pub async fn check_tool_execution(
63        &self,
64        tool_id: &str,
65        args: &serde_json::Value,
66    ) -> Result<SecurityCheckResult> {
67        if !self.config.enabled {
68            return Ok(SecurityCheckResult::Allow);
69        }
70
71        let tool_config = self.config.tools.get(tool_id);
72
73        if let Some(config) = tool_config {
74            if !config.enabled {
75                return Ok(SecurityCheckResult::Block {
76                    reason: format!("Tool '{}' is disabled", tool_id),
77                });
78            }
79
80            if let Some(rate_limit) = config.rate_limit {
81                let calls = self
82                    .tool_call_tracker
83                    .read()
84                    .get_calls_in_window(tool_id, 60);
85                if calls >= rate_limit as usize {
86                    return Ok(SecurityCheckResult::Block {
87                        reason: format!(
88                            "Rate limit exceeded for tool '{}': {} calls per minute",
89                            tool_id, rate_limit
90                        ),
91                    });
92                }
93            }
94
95            if let Some(url) = args.get("url").and_then(|u| u.as_str()) {
96                for blocked in &config.blocked_domains {
97                    if url.contains(blocked) {
98                        return Ok(SecurityCheckResult::Block {
99                            reason: format!(
100                                "Domain '{}' is blocked for tool '{}'",
101                                blocked, tool_id
102                            ),
103                        });
104                    }
105                }
106
107                if !config.allowed_domains.is_empty() {
108                    let is_allowed = config.allowed_domains.iter().any(|d| url.contains(d));
109                    if !is_allowed {
110                        return Ok(SecurityCheckResult::Block {
111                            reason: format!(
112                                "URL domain not in allowed list for tool '{}'",
113                                tool_id
114                            ),
115                        });
116                    }
117                }
118            }
119
120            if let Some(path) = args.get("path").and_then(|p| p.as_str()) {
121                if !config.allowed_paths.is_empty() {
122                    let is_allowed = config.allowed_paths.iter().any(|p| path.starts_with(p));
123                    if !is_allowed {
124                        return Ok(SecurityCheckResult::Block {
125                            reason: format!("Path not in allowed list for tool '{}'", tool_id),
126                        });
127                    }
128                }
129            }
130
131            if config.require_confirmation {
132                let message = config
133                    .confirmation_message
134                    .clone()
135                    .unwrap_or_else(|| format!("Confirm execution of tool '{}'?", tool_id));
136                return Ok(SecurityCheckResult::RequireConfirmation { message });
137            }
138        }
139
140        self.tool_call_tracker.write().record_call(tool_id);
141        debug!(tool_id = %tool_id, "Tool execution allowed");
142
143        Ok(SecurityCheckResult::Allow)
144    }
145
146    pub fn get_tool_timeout(&self, tool_id: &str) -> u64 {
147        self.config
148            .tools
149            .get(tool_id)
150            .and_then(|c| c.timeout_ms)
151            .unwrap_or(self.config.default_timeout_ms)
152    }
153
154    pub fn reset_session(&self) {
155        self.tool_call_tracker.write().reset();
156    }
157}
158
159impl Default for ToolSecurityEngine {
160    fn default() -> Self {
161        Self::new(ToolSecurityConfig::default())
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn test_default_engine() {
171        let engine = ToolSecurityEngine::default();
172        assert!(!engine.config().enabled);
173    }
174
175    #[tokio::test]
176    async fn test_tool_domain_blocking() {
177        let mut config = ToolSecurityConfig::default();
178        config.enabled = true;
179
180        let mut http_config = ToolPolicyConfig::default();
181        http_config.blocked_domains = vec!["evil.com".to_string()];
182        config.tools.insert("http".to_string(), http_config);
183
184        let engine = ToolSecurityEngine::new(config);
185
186        let args = serde_json::json!({"url": "https://evil.com/api"});
187        let result = engine.check_tool_execution("http", &args).await.unwrap();
188        assert!(result.is_blocked());
189
190        let args = serde_json::json!({"url": "https://good.com/api"});
191        let result = engine.check_tool_execution("http", &args).await.unwrap();
192        assert!(result.is_allowed());
193    }
194
195    #[tokio::test]
196    async fn test_tool_allowed_domains() {
197        let mut config = ToolSecurityConfig::default();
198        config.enabled = true;
199
200        let mut http_config = ToolPolicyConfig::default();
201        http_config.allowed_domains = vec!["api.example.com".to_string()];
202        config.tools.insert("http".to_string(), http_config);
203
204        let engine = ToolSecurityEngine::new(config);
205
206        let args = serde_json::json!({"url": "https://api.example.com/v1"});
207        let result = engine.check_tool_execution("http", &args).await.unwrap();
208        assert!(result.is_allowed());
209
210        let args = serde_json::json!({"url": "https://other.com/api"});
211        let result = engine.check_tool_execution("http", &args).await.unwrap();
212        assert!(result.is_blocked());
213    }
214
215    #[tokio::test]
216    async fn test_tool_disabled() {
217        let mut config = ToolSecurityConfig::default();
218        config.enabled = true;
219
220        let mut tool_config = ToolPolicyConfig::default();
221        tool_config.enabled = false;
222        config.tools.insert("dangerous".to_string(), tool_config);
223
224        let engine = ToolSecurityEngine::new(config);
225
226        let result = engine
227            .check_tool_execution("dangerous", &serde_json::json!({}))
228            .await
229            .unwrap();
230        assert!(result.is_blocked());
231    }
232
233    #[tokio::test]
234    async fn test_tool_confirmation_required() {
235        let mut config = ToolSecurityConfig::default();
236        config.enabled = true;
237
238        let mut tool_config = ToolPolicyConfig::default();
239        tool_config.require_confirmation = true;
240        tool_config.confirmation_message = Some("Are you sure?".to_string());
241        config.tools.insert("delete".to_string(), tool_config);
242
243        let engine = ToolSecurityEngine::new(config);
244
245        let result = engine
246            .check_tool_execution("delete", &serde_json::json!({}))
247            .await
248            .unwrap();
249
250        match result {
251            SecurityCheckResult::RequireConfirmation { message } => {
252                assert_eq!(message, "Are you sure?");
253            }
254            _ => panic!("Expected RequireConfirmation"),
255        }
256    }
257
258    #[test]
259    fn test_get_tool_timeout() {
260        let mut config = ToolSecurityConfig::default();
261        config.default_timeout_ms = 5000;
262
263        let mut tool_config = ToolPolicyConfig::default();
264        tool_config.timeout_ms = Some(10000);
265        config.tools.insert("slow".to_string(), tool_config);
266
267        let engine = ToolSecurityEngine::new(config);
268
269        assert_eq!(engine.get_tool_timeout("slow"), 10000);
270        assert_eq!(engine.get_tool_timeout("other"), 5000);
271    }
272
273    #[tokio::test]
274    async fn test_path_restrictions() {
275        let mut config = ToolSecurityConfig::default();
276        config.enabled = true;
277
278        let mut tool_config = ToolPolicyConfig::default();
279        tool_config.allowed_paths = vec!["/tmp/".to_string(), "/home/user/".to_string()];
280        config.tools.insert("file_write".to_string(), tool_config);
281
282        let engine = ToolSecurityEngine::new(config);
283
284        let args = serde_json::json!({"path": "/tmp/test.txt"});
285        let result = engine
286            .check_tool_execution("file_write", &args)
287            .await
288            .unwrap();
289        assert!(result.is_allowed());
290
291        let args = serde_json::json!({"path": "/etc/passwd"});
292        let result = engine
293            .check_tool_execution("file_write", &args)
294            .await
295            .unwrap();
296        assert!(result.is_blocked());
297    }
298}