Skip to main content

enact_core/policy/
tool_policy.rs

1//! Tool Policy - Permissions and trust levels for tools
2
3use super::{PolicyAction, PolicyContext, PolicyDecision, PolicyEvaluator};
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7/// Tool trust level
8#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
9pub enum ToolTrustLevel {
10    /// Untrusted - requires explicit approval for each invocation
11    Untrusted = 0,
12    /// Low - limited permissions, sandboxed
13    Low = 1,
14    /// Medium - standard permissions
15    #[default]
16    Medium = 2,
17    /// High - elevated permissions
18    High = 3,
19    /// System - full system access (dangerous)
20    System = 4,
21}
22
23/// Tool permissions
24#[derive(Debug, Clone, Serialize, Deserialize, Default)]
25pub struct ToolPermissions {
26    /// Can access network
27    pub network_access: bool,
28    /// Can access filesystem
29    pub filesystem_access: bool,
30    /// Can access filesystem (write)
31    pub filesystem_write: bool,
32    /// Can access environment variables
33    pub env_access: bool,
34    /// Can execute subprocesses
35    pub subprocess_access: bool,
36    /// Can access PII data
37    pub pii_access: bool,
38    /// Allowed domains (if network_access is true)
39    pub allowed_domains: HashSet<String>,
40    /// Allowed paths (if filesystem_access is true)
41    pub allowed_paths: HashSet<String>,
42}
43
44impl ToolPermissions {
45    /// Create permissions for a sandboxed tool
46    pub fn sandboxed() -> Self {
47        Self {
48            network_access: false,
49            filesystem_access: false,
50            filesystem_write: false,
51            env_access: false,
52            subprocess_access: false,
53            pii_access: false,
54            allowed_domains: HashSet::new(),
55            allowed_paths: HashSet::new(),
56        }
57    }
58
59    /// Create permissions for a network tool
60    pub fn network_only() -> Self {
61        Self {
62            network_access: true,
63            filesystem_access: false,
64            filesystem_write: false,
65            env_access: false,
66            subprocess_access: false,
67            pii_access: false,
68            allowed_domains: HashSet::new(),
69            allowed_paths: HashSet::new(),
70        }
71    }
72
73    /// Create full permissions (dangerous)
74    pub fn full() -> Self {
75        Self {
76            network_access: true,
77            filesystem_access: true,
78            filesystem_write: true,
79            env_access: true,
80            subprocess_access: true,
81            pii_access: true,
82            allowed_domains: HashSet::new(),
83            allowed_paths: HashSet::new(),
84        }
85    }
86
87    /// Add an allowed domain
88    pub fn allow_domain(mut self, domain: impl Into<String>) -> Self {
89        self.allowed_domains.insert(domain.into());
90        self
91    }
92
93    /// Add an allowed path
94    pub fn allow_path(mut self, path: impl Into<String>) -> Self {
95        self.allowed_paths.insert(path.into());
96        self
97    }
98}
99
100/// Tool policy
101#[derive(Debug, Clone)]
102pub struct ToolPolicy {
103    /// Default permissions for all tools
104    pub default_permissions: ToolPermissions,
105    /// Tool-specific permissions (overrides default)
106    pub tool_permissions: std::collections::HashMap<String, ToolPermissions>,
107    /// Tool trust levels
108    pub tool_trust: std::collections::HashMap<String, ToolTrustLevel>,
109    /// Minimum trust level required
110    pub min_trust_level: ToolTrustLevel,
111    /// Blocked tools
112    pub blocked_tools: HashSet<String>,
113}
114
115impl Default for ToolPolicy {
116    fn default() -> Self {
117        Self {
118            default_permissions: ToolPermissions::sandboxed(),
119            tool_permissions: std::collections::HashMap::new(),
120            tool_trust: std::collections::HashMap::new(),
121            min_trust_level: ToolTrustLevel::Low,
122            blocked_tools: HashSet::new(),
123        }
124    }
125}
126
127impl ToolPolicy {
128    /// Create a new tool policy
129    pub fn new() -> Self {
130        Self::default()
131    }
132
133    /// Set default permissions
134    pub fn with_default_permissions(mut self, perms: ToolPermissions) -> Self {
135        self.default_permissions = perms;
136        self
137    }
138
139    /// Set permissions for a specific tool
140    pub fn set_tool_permissions(mut self, tool: impl Into<String>, perms: ToolPermissions) -> Self {
141        self.tool_permissions.insert(tool.into(), perms);
142        self
143    }
144
145    /// Set trust level for a specific tool
146    pub fn set_tool_trust(mut self, tool: impl Into<String>, level: ToolTrustLevel) -> Self {
147        self.tool_trust.insert(tool.into(), level);
148        self
149    }
150
151    /// Block a tool
152    pub fn block_tool(mut self, tool: impl Into<String>) -> Self {
153        self.blocked_tools.insert(tool.into());
154        self
155    }
156
157    /// Get permissions for a tool
158    pub fn get_permissions(&self, tool_name: &str) -> &ToolPermissions {
159        self.tool_permissions
160            .get(tool_name)
161            .unwrap_or(&self.default_permissions)
162    }
163
164    /// Get trust level for a tool
165    pub fn get_trust_level(&self, tool_name: &str) -> ToolTrustLevel {
166        self.tool_trust
167            .get(tool_name)
168            .copied()
169            .unwrap_or(ToolTrustLevel::Medium)
170    }
171}
172
173impl PolicyEvaluator for ToolPolicy {
174    fn evaluate(&self, context: &PolicyContext) -> PolicyDecision {
175        match &context.action {
176            PolicyAction::InvokeTool { tool_name } => {
177                // Check if tool is blocked
178                if self.blocked_tools.contains(tool_name) {
179                    return PolicyDecision::Deny {
180                        reason: format!("Tool '{}' is blocked by policy", tool_name),
181                    };
182                }
183
184                // Check trust level
185                let trust = self.get_trust_level(tool_name);
186                if trust < self.min_trust_level {
187                    return PolicyDecision::Deny {
188                        reason: format!(
189                            "Tool '{}' trust level {:?} is below minimum {:?}",
190                            tool_name, trust, self.min_trust_level
191                        ),
192                    };
193                }
194
195                PolicyDecision::Allow
196            }
197            _ => PolicyDecision::Allow,
198        }
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use std::collections::HashMap;
206
207    // ============ ToolTrustLevel Tests ============
208
209    #[test]
210    fn test_tool_trust_level_ordering() {
211        assert!(ToolTrustLevel::Untrusted < ToolTrustLevel::Low);
212        assert!(ToolTrustLevel::Low < ToolTrustLevel::Medium);
213        assert!(ToolTrustLevel::Medium < ToolTrustLevel::High);
214        assert!(ToolTrustLevel::High < ToolTrustLevel::System);
215    }
216
217    #[test]
218    fn test_tool_trust_level_default() {
219        assert_eq!(ToolTrustLevel::default(), ToolTrustLevel::Medium);
220    }
221
222    // ============ ToolPermissions Tests ============
223
224    #[test]
225    fn test_tool_permissions_default() {
226        let perms = ToolPermissions::default();
227        assert!(!perms.network_access);
228        assert!(!perms.filesystem_access);
229        assert!(!perms.filesystem_write);
230        assert!(!perms.env_access);
231        assert!(!perms.subprocess_access);
232        assert!(!perms.pii_access);
233    }
234
235    #[test]
236    fn test_tool_permissions_sandboxed() {
237        let perms = ToolPermissions::sandboxed();
238        assert!(!perms.network_access);
239        assert!(!perms.filesystem_access);
240        assert!(!perms.filesystem_write);
241        assert!(!perms.env_access);
242        assert!(!perms.subprocess_access);
243        assert!(!perms.pii_access);
244        assert!(perms.allowed_domains.is_empty());
245        assert!(perms.allowed_paths.is_empty());
246    }
247
248    #[test]
249    fn test_tool_permissions_network_only() {
250        let perms = ToolPermissions::network_only();
251        assert!(perms.network_access);
252        assert!(!perms.filesystem_access);
253        assert!(!perms.subprocess_access);
254    }
255
256    #[test]
257    fn test_tool_permissions_full() {
258        let perms = ToolPermissions::full();
259        assert!(perms.network_access);
260        assert!(perms.filesystem_access);
261        assert!(perms.filesystem_write);
262        assert!(perms.env_access);
263        assert!(perms.subprocess_access);
264        assert!(perms.pii_access);
265    }
266
267    #[test]
268    fn test_tool_permissions_allow_domain() {
269        let perms = ToolPermissions::network_only()
270            .allow_domain("api.example.com")
271            .allow_domain("cdn.example.com");
272
273        assert!(perms.allowed_domains.contains("api.example.com"));
274        assert!(perms.allowed_domains.contains("cdn.example.com"));
275        assert_eq!(perms.allowed_domains.len(), 2);
276    }
277
278    #[test]
279    fn test_tool_permissions_allow_path() {
280        let perms = ToolPermissions::sandboxed()
281            .allow_path("/tmp")
282            .allow_path("/var/data");
283
284        assert!(perms.allowed_paths.contains("/tmp"));
285        assert!(perms.allowed_paths.contains("/var/data"));
286    }
287
288    // ============ ToolPolicy Tests ============
289
290    #[test]
291    fn test_tool_policy_default() {
292        let policy = ToolPolicy::default();
293        assert_eq!(policy.min_trust_level, ToolTrustLevel::Low);
294        assert!(policy.blocked_tools.is_empty());
295    }
296
297    #[test]
298    fn test_tool_policy_with_default_permissions() {
299        let policy = ToolPolicy::new().with_default_permissions(ToolPermissions::network_only());
300        assert!(policy.default_permissions.network_access);
301    }
302
303    #[test]
304    fn test_tool_policy_set_tool_permissions() {
305        let policy =
306            ToolPolicy::new().set_tool_permissions("web_search", ToolPermissions::network_only());
307
308        let perms = policy.get_permissions("web_search");
309        assert!(perms.network_access);
310
311        // Other tools get default
312        let default_perms = policy.get_permissions("other_tool");
313        assert!(!default_perms.network_access);
314    }
315
316    #[test]
317    fn test_tool_policy_set_tool_trust() {
318        let policy = ToolPolicy::new()
319            .set_tool_trust("trusted_tool", ToolTrustLevel::High)
320            .set_tool_trust("untrusted_tool", ToolTrustLevel::Untrusted);
321
322        assert_eq!(policy.get_trust_level("trusted_tool"), ToolTrustLevel::High);
323        assert_eq!(
324            policy.get_trust_level("untrusted_tool"),
325            ToolTrustLevel::Untrusted
326        );
327        assert_eq!(
328            policy.get_trust_level("unknown_tool"),
329            ToolTrustLevel::Medium
330        ); // default
331    }
332
333    #[test]
334    fn test_tool_policy_block_tool() {
335        let policy = ToolPolicy::new()
336            .block_tool("dangerous_tool")
337            .block_tool("another_dangerous");
338
339        assert!(policy.blocked_tools.contains("dangerous_tool"));
340        assert!(policy.blocked_tools.contains("another_dangerous"));
341    }
342
343    // ============ PolicyEvaluator Tests ============
344
345    #[test]
346    fn test_tool_policy_evaluate_allowed() {
347        let policy = ToolPolicy::new();
348        let context = PolicyContext {
349            tenant_id: None,
350            user_id: None,
351            action: PolicyAction::InvokeTool {
352                tool_name: "safe_tool".to_string(),
353            },
354            metadata: HashMap::new(),
355        };
356
357        let decision = policy.evaluate(&context);
358        assert!(decision.is_allowed());
359    }
360
361    #[test]
362    fn test_tool_policy_evaluate_blocked() {
363        let policy = ToolPolicy::new().block_tool("blocked_tool");
364        let context = PolicyContext {
365            tenant_id: None,
366            user_id: None,
367            action: PolicyAction::InvokeTool {
368                tool_name: "blocked_tool".to_string(),
369            },
370            metadata: HashMap::new(),
371        };
372
373        let decision = policy.evaluate(&context);
374        assert!(decision.is_denied());
375    }
376
377    #[test]
378    fn test_tool_policy_evaluate_trust_level_denied() {
379        let mut policy = ToolPolicy::new();
380        policy.min_trust_level = ToolTrustLevel::High;
381        let policy = policy.set_tool_trust("low_trust", ToolTrustLevel::Low);
382
383        let context = PolicyContext {
384            tenant_id: None,
385            user_id: None,
386            action: PolicyAction::InvokeTool {
387                tool_name: "low_trust".to_string(),
388            },
389            metadata: HashMap::new(),
390        };
391
392        let decision = policy.evaluate(&context);
393        assert!(decision.is_denied());
394    }
395
396    #[test]
397    fn test_tool_policy_evaluate_non_tool_action_allowed() {
398        let policy = ToolPolicy::new().block_tool("some_tool");
399        let context = PolicyContext {
400            tenant_id: None,
401            user_id: None,
402            action: PolicyAction::LlmCall {
403                model: "gpt-4".to_string(),
404            },
405            metadata: HashMap::new(),
406        };
407
408        // Non-tool actions are allowed by ToolPolicy
409        let decision = policy.evaluate(&context);
410        assert!(decision.is_allowed());
411    }
412}