1use super::{PolicyAction, PolicyContext, PolicyDecision, PolicyEvaluator};
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
9pub enum ToolTrustLevel {
10 Untrusted = 0,
12 Low = 1,
14 #[default]
16 Medium = 2,
17 High = 3,
19 System = 4,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize, Default)]
25pub struct ToolPermissions {
26 pub network_access: bool,
28 pub filesystem_access: bool,
30 pub filesystem_write: bool,
32 pub env_access: bool,
34 pub subprocess_access: bool,
36 pub pii_access: bool,
38 pub allowed_domains: HashSet<String>,
40 pub allowed_paths: HashSet<String>,
42}
43
44impl ToolPermissions {
45 pub fn sandboxed() -> Self {
47 Self {
48 network_access: false,
49 filesystem_access: false,
50 filesystem_write: false,
51 env_access: false,
52 subprocess_access: false,
53 pii_access: false,
54 allowed_domains: HashSet::new(),
55 allowed_paths: HashSet::new(),
56 }
57 }
58
59 pub fn network_only() -> Self {
61 Self {
62 network_access: true,
63 filesystem_access: false,
64 filesystem_write: false,
65 env_access: false,
66 subprocess_access: false,
67 pii_access: false,
68 allowed_domains: HashSet::new(),
69 allowed_paths: HashSet::new(),
70 }
71 }
72
73 pub fn full() -> Self {
75 Self {
76 network_access: true,
77 filesystem_access: true,
78 filesystem_write: true,
79 env_access: true,
80 subprocess_access: true,
81 pii_access: true,
82 allowed_domains: HashSet::new(),
83 allowed_paths: HashSet::new(),
84 }
85 }
86
87 pub fn allow_domain(mut self, domain: impl Into<String>) -> Self {
89 self.allowed_domains.insert(domain.into());
90 self
91 }
92
93 pub fn allow_path(mut self, path: impl Into<String>) -> Self {
95 self.allowed_paths.insert(path.into());
96 self
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct ToolPolicy {
103 pub default_permissions: ToolPermissions,
105 pub tool_permissions: std::collections::HashMap<String, ToolPermissions>,
107 pub tool_trust: std::collections::HashMap<String, ToolTrustLevel>,
109 pub min_trust_level: ToolTrustLevel,
111 pub blocked_tools: HashSet<String>,
113}
114
115impl Default for ToolPolicy {
116 fn default() -> Self {
117 Self {
118 default_permissions: ToolPermissions::sandboxed(),
119 tool_permissions: std::collections::HashMap::new(),
120 tool_trust: std::collections::HashMap::new(),
121 min_trust_level: ToolTrustLevel::Low,
122 blocked_tools: HashSet::new(),
123 }
124 }
125}
126
127impl ToolPolicy {
128 pub fn new() -> Self {
130 Self::default()
131 }
132
133 pub fn with_default_permissions(mut self, perms: ToolPermissions) -> Self {
135 self.default_permissions = perms;
136 self
137 }
138
139 pub fn set_tool_permissions(mut self, tool: impl Into<String>, perms: ToolPermissions) -> Self {
141 self.tool_permissions.insert(tool.into(), perms);
142 self
143 }
144
145 pub fn set_tool_trust(mut self, tool: impl Into<String>, level: ToolTrustLevel) -> Self {
147 self.tool_trust.insert(tool.into(), level);
148 self
149 }
150
151 pub fn block_tool(mut self, tool: impl Into<String>) -> Self {
153 self.blocked_tools.insert(tool.into());
154 self
155 }
156
157 pub fn get_permissions(&self, tool_name: &str) -> &ToolPermissions {
159 self.tool_permissions
160 .get(tool_name)
161 .unwrap_or(&self.default_permissions)
162 }
163
164 pub fn get_trust_level(&self, tool_name: &str) -> ToolTrustLevel {
166 self.tool_trust
167 .get(tool_name)
168 .copied()
169 .unwrap_or(ToolTrustLevel::Medium)
170 }
171}
172
173impl PolicyEvaluator for ToolPolicy {
174 fn evaluate(&self, context: &PolicyContext) -> PolicyDecision {
175 match &context.action {
176 PolicyAction::InvokeTool { tool_name } => {
177 if self.blocked_tools.contains(tool_name) {
179 return PolicyDecision::Deny {
180 reason: format!("Tool '{}' is blocked by policy", tool_name),
181 };
182 }
183
184 let trust = self.get_trust_level(tool_name);
186 if trust < self.min_trust_level {
187 return PolicyDecision::Deny {
188 reason: format!(
189 "Tool '{}' trust level {:?} is below minimum {:?}",
190 tool_name, trust, self.min_trust_level
191 ),
192 };
193 }
194
195 PolicyDecision::Allow
196 }
197 _ => PolicyDecision::Allow,
198 }
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use std::collections::HashMap;
206
207 #[test]
210 fn test_tool_trust_level_ordering() {
211 assert!(ToolTrustLevel::Untrusted < ToolTrustLevel::Low);
212 assert!(ToolTrustLevel::Low < ToolTrustLevel::Medium);
213 assert!(ToolTrustLevel::Medium < ToolTrustLevel::High);
214 assert!(ToolTrustLevel::High < ToolTrustLevel::System);
215 }
216
217 #[test]
218 fn test_tool_trust_level_default() {
219 assert_eq!(ToolTrustLevel::default(), ToolTrustLevel::Medium);
220 }
221
222 #[test]
225 fn test_tool_permissions_default() {
226 let perms = ToolPermissions::default();
227 assert!(!perms.network_access);
228 assert!(!perms.filesystem_access);
229 assert!(!perms.filesystem_write);
230 assert!(!perms.env_access);
231 assert!(!perms.subprocess_access);
232 assert!(!perms.pii_access);
233 }
234
235 #[test]
236 fn test_tool_permissions_sandboxed() {
237 let perms = ToolPermissions::sandboxed();
238 assert!(!perms.network_access);
239 assert!(!perms.filesystem_access);
240 assert!(!perms.filesystem_write);
241 assert!(!perms.env_access);
242 assert!(!perms.subprocess_access);
243 assert!(!perms.pii_access);
244 assert!(perms.allowed_domains.is_empty());
245 assert!(perms.allowed_paths.is_empty());
246 }
247
248 #[test]
249 fn test_tool_permissions_network_only() {
250 let perms = ToolPermissions::network_only();
251 assert!(perms.network_access);
252 assert!(!perms.filesystem_access);
253 assert!(!perms.subprocess_access);
254 }
255
256 #[test]
257 fn test_tool_permissions_full() {
258 let perms = ToolPermissions::full();
259 assert!(perms.network_access);
260 assert!(perms.filesystem_access);
261 assert!(perms.filesystem_write);
262 assert!(perms.env_access);
263 assert!(perms.subprocess_access);
264 assert!(perms.pii_access);
265 }
266
267 #[test]
268 fn test_tool_permissions_allow_domain() {
269 let perms = ToolPermissions::network_only()
270 .allow_domain("api.example.com")
271 .allow_domain("cdn.example.com");
272
273 assert!(perms.allowed_domains.contains("api.example.com"));
274 assert!(perms.allowed_domains.contains("cdn.example.com"));
275 assert_eq!(perms.allowed_domains.len(), 2);
276 }
277
278 #[test]
279 fn test_tool_permissions_allow_path() {
280 let perms = ToolPermissions::sandboxed()
281 .allow_path("/tmp")
282 .allow_path("/var/data");
283
284 assert!(perms.allowed_paths.contains("/tmp"));
285 assert!(perms.allowed_paths.contains("/var/data"));
286 }
287
288 #[test]
291 fn test_tool_policy_default() {
292 let policy = ToolPolicy::default();
293 assert_eq!(policy.min_trust_level, ToolTrustLevel::Low);
294 assert!(policy.blocked_tools.is_empty());
295 }
296
297 #[test]
298 fn test_tool_policy_with_default_permissions() {
299 let policy = ToolPolicy::new().with_default_permissions(ToolPermissions::network_only());
300 assert!(policy.default_permissions.network_access);
301 }
302
303 #[test]
304 fn test_tool_policy_set_tool_permissions() {
305 let policy =
306 ToolPolicy::new().set_tool_permissions("web_search", ToolPermissions::network_only());
307
308 let perms = policy.get_permissions("web_search");
309 assert!(perms.network_access);
310
311 let default_perms = policy.get_permissions("other_tool");
313 assert!(!default_perms.network_access);
314 }
315
316 #[test]
317 fn test_tool_policy_set_tool_trust() {
318 let policy = ToolPolicy::new()
319 .set_tool_trust("trusted_tool", ToolTrustLevel::High)
320 .set_tool_trust("untrusted_tool", ToolTrustLevel::Untrusted);
321
322 assert_eq!(policy.get_trust_level("trusted_tool"), ToolTrustLevel::High);
323 assert_eq!(
324 policy.get_trust_level("untrusted_tool"),
325 ToolTrustLevel::Untrusted
326 );
327 assert_eq!(
328 policy.get_trust_level("unknown_tool"),
329 ToolTrustLevel::Medium
330 ); }
332
333 #[test]
334 fn test_tool_policy_block_tool() {
335 let policy = ToolPolicy::new()
336 .block_tool("dangerous_tool")
337 .block_tool("another_dangerous");
338
339 assert!(policy.blocked_tools.contains("dangerous_tool"));
340 assert!(policy.blocked_tools.contains("another_dangerous"));
341 }
342
343 #[test]
346 fn test_tool_policy_evaluate_allowed() {
347 let policy = ToolPolicy::new();
348 let context = PolicyContext {
349 tenant_id: None,
350 user_id: None,
351 action: PolicyAction::InvokeTool {
352 tool_name: "safe_tool".to_string(),
353 },
354 metadata: HashMap::new(),
355 };
356
357 let decision = policy.evaluate(&context);
358 assert!(decision.is_allowed());
359 }
360
361 #[test]
362 fn test_tool_policy_evaluate_blocked() {
363 let policy = ToolPolicy::new().block_tool("blocked_tool");
364 let context = PolicyContext {
365 tenant_id: None,
366 user_id: None,
367 action: PolicyAction::InvokeTool {
368 tool_name: "blocked_tool".to_string(),
369 },
370 metadata: HashMap::new(),
371 };
372
373 let decision = policy.evaluate(&context);
374 assert!(decision.is_denied());
375 }
376
377 #[test]
378 fn test_tool_policy_evaluate_trust_level_denied() {
379 let mut policy = ToolPolicy::new();
380 policy.min_trust_level = ToolTrustLevel::High;
381 let policy = policy.set_tool_trust("low_trust", ToolTrustLevel::Low);
382
383 let context = PolicyContext {
384 tenant_id: None,
385 user_id: None,
386 action: PolicyAction::InvokeTool {
387 tool_name: "low_trust".to_string(),
388 },
389 metadata: HashMap::new(),
390 };
391
392 let decision = policy.evaluate(&context);
393 assert!(decision.is_denied());
394 }
395
396 #[test]
397 fn test_tool_policy_evaluate_non_tool_action_allowed() {
398 let policy = ToolPolicy::new().block_tool("some_tool");
399 let context = PolicyContext {
400 tenant_id: None,
401 user_id: None,
402 action: PolicyAction::LlmCall {
403 model: "gpt-4".to_string(),
404 },
405 metadata: HashMap::new(),
406 };
407
408 let decision = policy.evaluate(&context);
410 assert!(decision.is_allowed());
411 }
412}