1use serde::{Deserialize, Serialize};
24use std::path::Path;
25use thiserror::Error;
26#[allow(dead_code)]
29#[derive(Error, Debug)]
30pub enum PolicyError {
31 #[error("Policy denied: {0}")]
32 Denied(String),
33
34 #[error("Invalid policy configuration: {0}")]
35 InvalidConfig(String),
36}
37
38#[allow(dead_code)]
42#[derive(Debug, Clone, PartialEq)]
43pub enum Decision {
44 Allow,
46 Deny(String),
48}
49
50#[allow(dead_code)]
51impl Decision {
52 pub fn is_allowed(&self) -> bool {
53 matches!(self, Decision::Allow)
54 }
55}
56
57#[allow(dead_code)]
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ShellPolicy {
61 #[serde(default)]
63 pub deny_all: bool,
64 #[serde(default)]
66 pub allowed_commands: Vec<String>,
67 #[serde(default)]
69 pub denied_commands: Vec<String>,
70 #[serde(default = "default_shell_timeout")]
72 pub max_timeout_secs: u64,
73 #[serde(default)]
75 pub allow_write_commands: bool,
76}
77
78impl Default for ShellPolicy {
79 fn default() -> Self {
80 Self {
81 deny_all: false,
82 allowed_commands: vec![
83 "echo".to_string(),
84 "cat".to_string(),
85 "ls".to_string(),
86 "head".to_string(),
87 "tail".to_string(),
88 "wc".to_string(),
89 "grep".to_string(),
90 "find".to_string(),
91 "sort".to_string(),
92 "uniq".to_string(),
93 "cut".to_string(),
94 "which".to_string(),
95 "pwd".to_string(),
96 "date".to_string(),
97 "whoami".to_string(),
98 "uname".to_string(),
99 "env".to_string(),
100 "printenv".to_string(),
101 "git".to_string(),
102 "cargo".to_string(),
103 "rustc".to_string(),
104 "python3".to_string(),
105 "node".to_string(),
106 ],
107 denied_commands: vec![
108 "rm -rf /".to_string(),
109 "mkfs".to_string(),
110 "dd".to_string(),
111 "shutdown".to_string(),
112 "reboot".to_string(),
113 "halt".to_string(),
114 "poweroff".to_string(),
115 ],
116 max_timeout_secs: 60,
117 allow_write_commands: false,
118 }
119 }
120}
121
122#[allow(dead_code)]
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct PathPolicy {
126 #[serde(default)]
128 pub allowed_read_paths: Vec<String>,
129 #[serde(default)]
131 pub allowed_write_paths: Vec<String>,
132 #[serde(default)]
134 pub denied_paths: Vec<String>,
135 #[serde(default = "default_max_read_bytes")]
137 pub max_read_bytes: usize,
138 #[serde(default = "default_max_write_bytes")]
140 pub max_write_bytes: usize,
141}
142
143impl Default for PathPolicy {
144 fn default() -> Self {
145 Self {
146 allowed_read_paths: vec![
147 "/tmp".to_string(),
148 "/var/tmp".to_string(),
149 "/home".to_string(),
150 "/workspace".to_string(),
151 ".".to_string(),
152 ],
153 allowed_write_paths: vec![
154 "/tmp".to_string(),
155 "/var/tmp".to_string(),
156 "/workspace".to_string(),
157 ".".to_string(),
158 ],
159 denied_paths: vec![
160 "/etc/shadow".to_string(),
161 "/etc/sudoers".to_string(),
162 "/etc/ssh".to_string(),
163 "/root".to_string(),
164 ],
165 max_read_bytes: 65536,
166 max_write_bytes: 1048576,
167 }
168 }
169}
170
171#[allow(dead_code)]
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct NetworkPolicy {
175 #[serde(default)]
177 pub deny_all: bool,
178 #[serde(default)]
180 pub allowed_hosts: Vec<String>,
181 #[serde(default)]
183 pub denied_hosts: Vec<String>,
184 #[serde(default = "default_true")]
186 pub allow_localhost: bool,
187 #[serde(default)]
189 pub allow_private_networks: bool,
190}
191
192impl Default for NetworkPolicy {
193 fn default() -> Self {
194 Self {
195 deny_all: false,
196 allowed_hosts: vec![
197 "github.com".to_string(),
198 "raw.githubusercontent.com".to_string(),
199 "docs.rs".to_string(),
200 "crates.io".to_string(),
201 "api.github.com".to_string(),
202 "google.com".to_string(),
203 "wikipedia.org".to_string(),
204 "stackoverflow.com".to_string(),
205 "rust-lang.org".to_string(),
206 ],
207 denied_hosts: vec![],
208 allow_localhost: true,
209 allow_private_networks: false,
210 }
211 }
212}
213
214#[allow(dead_code)]
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct SecurityPolicy {
218 #[serde(default)]
220 pub shell: ShellPolicy,
221 #[serde(default)]
223 pub path: PathPolicy,
224 #[serde(default)]
226 pub network: NetworkPolicy,
227 #[serde(default)]
229 pub require_approval_all: bool,
230 #[serde(default)]
232 pub require_approval_for: Vec<String>,
233}
234
235impl Default for SecurityPolicy {
236 fn default() -> Self {
237 Self {
238 shell: ShellPolicy::default(),
239 path: PathPolicy::default(),
240 network: NetworkPolicy::default(),
241 require_approval_all: false,
242 require_approval_for: vec!["shell_exec".to_string(), "write_file".to_string()],
243 }
244 }
245}
246
247#[allow(dead_code)]
251pub struct PolicyEngine {
252 policy: SecurityPolicy,
253}
254
255#[allow(dead_code)]
256impl PolicyEngine {
257 pub fn new(policy: SecurityPolicy) -> Self {
259 Self { policy }
260 }
261
262 pub fn default_secure() -> Self {
264 Self {
265 policy: SecurityPolicy::default(),
266 }
267 }
268
269 pub fn permissive() -> Self {
271 Self {
272 policy: SecurityPolicy {
273 require_approval_all: false,
274 require_approval_for: vec![],
275 shell: ShellPolicy {
276 deny_all: false,
277 allowed_commands: vec!["*".to_string()], denied_commands: vec![],
279 max_timeout_secs: 300,
280 allow_write_commands: true,
281 },
282 path: PathPolicy {
283 allowed_read_paths: vec!["/".to_string()],
284 allowed_write_paths: vec!["/tmp".to_string(), "/workspace".to_string()],
285 denied_paths: vec![],
286 max_read_bytes: 1048576,
287 max_write_bytes: 10485760,
288 },
289 network: NetworkPolicy {
290 deny_all: false,
291 allowed_hosts: vec!["*".to_string()],
292 denied_hosts: vec![],
293 allow_localhost: true,
294 allow_private_networks: true,
295 },
296 },
297 }
298 }
299
300 pub fn check_tool_call(&self, tool_name: &str, args: &serde_json::Value) -> Decision {
302 match tool_name {
303 "shell_exec" => self.check_shell_command(args),
304 "read_file" | "write_file" => self.check_file_operation(tool_name, args),
305 "web_fetch" => self.check_network_request(args),
306 _ => Decision::Allow, }
308 }
309
310 pub fn requires_approval(&self, tool_name: &str) -> bool {
312 if self.policy.require_approval_all {
313 return true;
314 }
315 self.policy
316 .require_approval_for
317 .contains(&tool_name.to_string())
318 }
319
320 pub fn policy(&self) -> &SecurityPolicy {
322 &self.policy
323 }
324
325 fn check_shell_command(&self, args: &serde_json::Value) -> Decision {
328 let policy = &self.policy.shell;
329
330 if policy.deny_all {
331 return Decision::Deny("All shell commands are denied by policy".to_string());
332 }
333
334 let command = args.get("command").and_then(|v| v.as_str()).unwrap_or("");
335
336 if command.is_empty() {
337 return Decision::Deny("Empty command".to_string());
338 }
339
340 for denied in &policy.denied_commands {
342 if command.contains(denied) {
343 return Decision::Deny(format!("Command contains denied pattern: '{}'", denied));
344 }
345 }
346
347 if let Some(timeout) = args.get("timeout_secs").and_then(|v| v.as_u64()) {
349 if timeout > policy.max_timeout_secs {
350 return Decision::Deny(format!(
351 "Timeout {}s exceeds maximum {}s",
352 timeout, policy.max_timeout_secs
353 ));
354 }
355 }
356
357 let first_word = command.split_whitespace().next().unwrap_or("");
359 let is_allowed = policy.allowed_commands.iter().any(|a| {
360 if a == "*" {
361 return true; }
363 first_word == a || command.starts_with(a)
364 });
365
366 if !is_allowed {
367 return Decision::Deny(format!(
368 "Command '{}' is not in the allowed list",
369 first_word
370 ));
371 }
372
373 Decision::Allow
374 }
375
376 fn check_file_operation(&self, tool_name: &str, args: &serde_json::Value) -> Decision {
377 let policy = &self.policy.path;
378 let path = args.get("path").and_then(|v| v.as_str()).unwrap_or("");
379
380 if path.is_empty() {
381 return Decision::Deny("Empty path".to_string());
382 }
383
384 let abs_path = if Path::new(path).is_absolute() {
386 path.to_string()
387 } else {
388 match std::env::current_dir() {
389 Ok(cwd) => cwd.join(path).to_string_lossy().to_string(),
390 Err(_) => path.to_string(),
391 }
392 };
393
394 for denied in &policy.denied_paths {
396 if abs_path.starts_with(denied) || abs_path.contains(denied) {
397 return Decision::Deny(format!("Path '{}' is denied", path));
398 }
399 }
400
401 let allowed_paths = match tool_name {
403 "write_file" => &policy.allowed_write_paths,
404 _ => &policy.allowed_read_paths,
405 };
406
407 let is_allowed = allowed_paths.iter().any(|a| {
408 if a == "/" || a == "*" {
409 return true; }
411 abs_path.starts_with(a)
412 });
413
414 if !is_allowed {
415 return Decision::Deny(format!(
416 "Path '{}' is not in the allowed {} paths",
417 path,
418 if tool_name == "write_file" {
419 "write"
420 } else {
421 "read"
422 }
423 ));
424 }
425
426 if tool_name == "write_file" {
428 if let Some(content) = args.get("content").and_then(|v| v.as_str()) {
429 if content.len() > policy.max_write_bytes {
430 return Decision::Deny(format!(
431 "Write size {} exceeds maximum {} bytes",
432 content.len(),
433 policy.max_write_bytes
434 ));
435 }
436 }
437 }
438
439 Decision::Allow
440 }
441
442 fn check_network_request(&self, args: &serde_json::Value) -> Decision {
443 let policy = &self.policy.network;
444
445 if policy.deny_all {
446 return Decision::Deny("All network requests are denied by policy".to_string());
447 }
448
449 let url = args.get("url").and_then(|v| v.as_str()).unwrap_or("");
450
451 if url.is_empty() {
452 return Decision::Deny("Empty URL".to_string());
453 }
454
455 let parsed = match url::Url::parse(url) {
457 Ok(u) => u,
458 Err(e) => {
459 return Decision::Deny(format!("Invalid URL: {}", e));
460 }
461 };
462
463 let host = match parsed.host_str() {
464 Some(h) => h.to_string(),
465 None => return Decision::Deny("URL has no host".to_string()),
466 };
467
468 if is_localhost(&host) {
470 if !policy.allow_localhost {
471 return Decision::Deny("Localhost connections are denied by policy".to_string());
472 }
473 return Decision::Allow;
474 }
475
476 if is_private_ip(&host) && !policy.allow_private_networks {
478 return Decision::Deny("Private network connections are denied by policy".to_string());
479 }
480
481 for denied in &policy.denied_hosts {
483 if host == *denied || host.ends_with(&format!(".{}", denied)) {
484 return Decision::Deny(format!("Host '{}' is denied", host));
485 }
486 }
487
488 let is_allowed = policy.allowed_hosts.iter().any(|a| {
490 if a == "*" {
491 return true; }
493 host == *a || host.ends_with(&format!(".{}", a))
494 });
495
496 if !is_allowed {
497 return Decision::Deny(format!("Host '{}' is not in the allowed hosts list", host));
498 }
499
500 Decision::Allow
501 }
502}
503
504#[allow(dead_code)]
508#[derive(Debug, Clone, PartialEq)]
509pub enum InjectionVerdict {
510 Clean,
512 Suspicious(String),
514}
515
516#[allow(dead_code)]
525#[derive(Debug, Clone)]
526pub struct InjectionDetector {
527 check_instruction_boundary: bool,
529 check_output_schema: bool,
531 custom_patterns: Vec<String>,
533}
534
535#[allow(dead_code)]
536impl InjectionDetector {
537 pub fn new() -> Self {
539 Self {
540 check_instruction_boundary: true,
541 check_output_schema: true,
542 custom_patterns: Vec::new(),
543 }
544 }
545
546 pub fn permissive() -> Self {
548 Self {
549 check_instruction_boundary: false,
550 check_output_schema: false,
551 custom_patterns: Vec::new(),
552 }
553 }
554
555 pub fn with_instruction_boundary(mut self, enabled: bool) -> Self {
557 self.check_instruction_boundary = enabled;
558 self
559 }
560
561 pub fn with_output_schema(mut self, enabled: bool) -> Self {
563 self.check_output_schema = enabled;
564 self
565 }
566
567 pub fn with_custom_pattern(mut self, pattern: &str) -> Self {
569 self.custom_patterns.push(pattern.to_string());
570 self
571 }
572
573 pub fn check(&self, content: &str) -> InjectionVerdict {
578 if self.check_instruction_boundary {
580 if let Some(reason) = self.check_instruction_boundary_violations(content) {
581 return InjectionVerdict::Suspicious(reason);
582 }
583 }
584
585 if self.check_output_schema {
587 if let Some(reason) = self.check_output_schema_violations(content) {
588 return InjectionVerdict::Suspicious(reason);
589 }
590 }
591
592 InjectionVerdict::Clean
593 }
594
595 fn check_instruction_boundary_violations(&self, content: &str) -> Option<String> {
598 let content_lower = content.to_lowercase();
599
600 let injection_patterns = [
602 ("ignore_previous", "Attempt to ignore previous instructions"),
604 ("ignore above", "Attempt to ignore previous instructions"),
605 (
606 "ignore all previous",
607 "Attempt to ignore previous instructions",
608 ),
609 (
610 "ignore the above",
611 "Attempt to ignore previous instructions",
612 ),
613 (
614 "ignore your previous",
615 "Attempt to ignore previous instructions",
616 ),
617 (
618 "ignore all instructions",
619 "Attempt to ignore all instructions",
620 ),
621 ("forget all previous", "Attempt to forget instructions"),
622 ("forget your instructions", "Attempt to forget instructions"),
623 ("disregard previous", "Attempt to disregard instructions"),
624 ("disregard all", "Attempt to disregard all instructions"),
625 ("new system prompt", "Attempt to override system prompt"),
627 ("override system", "Attempt to override system prompt"),
628 ("you are now", "Suspicious role-switch pattern"),
629 ("from now on you are", "Suspicious role-switch pattern"),
630 ("act as a", "Suspicious role-switch pattern"),
631 ("pretend to be", "Suspicious role-switch pattern"),
632 ("---begin override---", "Delimiter-based injection attempt"),
634 ("<system>", "Delimiter-based injection attempt"),
635 ("</system>", "Delimiter-based injection attempt"),
636 ("<override>", "Delimiter-based injection attempt"),
637 ("</override>", "Delimiter-based injection attempt"),
638 ("output your", "Attempt to extract system prompt"),
640 ("print your", "Attempt to extract system prompt"),
641 ("reveal your", "Attempt to extract system prompt"),
642 ("show your prompt", "Attempt to extract system prompt"),
643 ("show your instructions", "Attempt to extract system prompt"),
644 (
645 "what are your instructions",
646 "Attempt to extract system prompt",
647 ),
648 (
649 "what is your system prompt",
650 "Attempt to extract system prompt",
651 ),
652 ("dan ", "Potential DAN jailbreak pattern"),
654 ("do anything now", "Potential jailbreak pattern"),
655 ("you have been released", "Potential jailbreak pattern"),
656 ("you are free", "Potential jailbreak pattern"),
657 ("no restrictions", "Potential jailbreak pattern"),
658 ("no rules", "Potential jailbreak pattern"),
659 ("no limitations", "Potential jailbreak pattern"),
660 ("no filtering", "Potential jailbreak pattern"),
661 ("no censorship", "Potential jailbreak pattern"),
662 ("base64", "Potential token smuggling"),
664 ("rot13", "Potential obfuscation attempt"),
665 ("caesar cipher", "Potential obfuscation attempt"),
666 ("encoded message", "Potential obfuscation attempt"),
667 ("decode this", "Potential obfuscation attempt"),
668 ("this is a test", "Suspicious meta-instruction pattern"),
670 (
671 "this is a security test",
672 "Suspicious meta-instruction pattern",
673 ),
674 ("this is a prompt", "Suspicious meta-instruction pattern"),
675 ("the user is lying", "Suspicious meta-instruction pattern"),
676 ("the user is testing", "Suspicious meta-instruction pattern"),
677 ("you must obey", "Suspicious imperative pattern"),
678 ("you will obey", "Suspicious imperative pattern"),
679 ("you are required", "Suspicious imperative pattern"),
680 ("you must respond", "Suspicious imperative pattern"),
681 ("respond with exactly", "Suspicious imperative pattern"),
682 ("say exactly", "Suspicious imperative pattern"),
683 ("repeat exactly", "Suspicious imperative pattern"),
684 ("repeat after me", "Suspicious imperative pattern"),
685 ("repeat the words", "Suspicious imperative pattern"),
686 ];
687
688 for (pattern, reason) in &injection_patterns {
689 if content_lower.contains(pattern) {
690 return Some(format!("{}: '{}'", reason, pattern));
691 }
692 }
693
694 for pattern in &self.custom_patterns {
696 if content_lower.contains(&pattern.to_lowercase()) {
697 return Some(format!("Custom pattern matched: '{}'", pattern));
698 }
699 }
700
701 None
702 }
703
704 fn check_output_schema_violations(&self, content: &str) -> Option<String> {
707 if content.contains("TOOL_CALL:") {
709 for line in content.lines() {
711 let trimmed = line.trim();
712 if let Some(args_str) = trimmed.strip_prefix("ARGS:") {
713 let args_str = args_str.trim();
714 if !args_str.is_empty()
715 && serde_json::from_str::<serde_json::Value>(args_str).is_err()
716 {
717 return Some(format!(
718 "Invalid JSON in tool call arguments: '{}'",
719 args_str
720 ));
721 }
722 }
723 }
724 }
725
726 let open_blocks = content.matches("```").count();
728 #[allow(clippy::manual_is_multiple_of)]
729 if open_blocks % 2 != 0 {
730 return Some("Unbalanced code block delimiters".to_string());
731 }
732
733 if content.len() > 100_000 {
735 return Some(format!(
736 "Response too long ({} chars), possible smuggling attempt",
737 content.len()
738 ));
739 }
740
741 None
742 }
743}
744
745impl Default for InjectionDetector {
746 fn default() -> Self {
747 Self::new()
748 }
749}
750
751fn default_shell_timeout() -> u64 {
754 60
755}
756
757fn default_max_read_bytes() -> usize {
758 65536
759}
760
761fn default_max_write_bytes() -> usize {
762 1048576
763}
764
765fn default_true() -> bool {
766 true
767}
768
769fn is_localhost(host: &str) -> bool {
770 host == "localhost"
771 || host == "127.0.0.1"
772 || host == "::1"
773 || host == "0.0.0.0"
774 || host.starts_with("127.")
775}
776
777fn is_private_ip(host: &str) -> bool {
778 host == "10.0.0.1"
779 || host.starts_with("10.")
780 || host.starts_with("172.16.")
781 || host.starts_with("172.17.")
782 || host.starts_with("172.18.")
783 || host.starts_with("172.19.")
784 || host.starts_with("172.20.")
785 || host.starts_with("172.21.")
786 || host.starts_with("172.22.")
787 || host.starts_with("172.23.")
788 || host.starts_with("172.24.")
789 || host.starts_with("172.25.")
790 || host.starts_with("172.26.")
791 || host.starts_with("172.27.")
792 || host.starts_with("172.28.")
793 || host.starts_with("172.29.")
794 || host.starts_with("172.30.")
795 || host.starts_with("172.31.")
796 || host.starts_with("192.168.")
797}
798
799#[cfg(test)]
802mod tests {
803 use super::*;
804
805 #[test]
806 fn test_decision_allow() {
807 let d = Decision::Allow;
808 assert!(d.is_allowed());
809 }
810
811 #[test]
812 fn test_decision_deny() {
813 let d = Decision::Deny("test".to_string());
814 assert!(!d.is_allowed());
815 }
816
817 #[test]
818 fn test_default_policy_denies_unknown_command() {
819 let engine = PolicyEngine::default_secure();
820 let args = serde_json::json!({"command": "sudo rm -rf /"});
821 let decision = engine.check_shell_command(&args);
822 assert!(!decision.is_allowed());
823 }
824
825 #[test]
826 fn test_default_policy_allows_echo() {
827 let engine = PolicyEngine::default_secure();
828 let args = serde_json::json!({"command": "echo hello"});
829 let decision = engine.check_shell_command(&args);
830 assert!(decision.is_allowed());
831 }
832
833 #[test]
834 fn test_default_policy_allows_ls() {
835 let engine = PolicyEngine::default_secure();
836 let args = serde_json::json!({"command": "ls -la"});
837 let decision = engine.check_shell_command(&args);
838 assert!(decision.is_allowed());
839 }
840
841 #[test]
842 fn test_default_policy_denies_shutdown() {
843 let engine = PolicyEngine::default_secure();
844 let args = serde_json::json!({"command": "shutdown -h now"});
845 let decision = engine.check_shell_command(&args);
846 assert!(!decision.is_allowed());
847 }
848
849 #[test]
850 fn test_default_policy_denies_rm_rf_root() {
851 let engine = PolicyEngine::default_secure();
852 let args = serde_json::json!({"command": "rm -rf /"});
853 let decision = engine.check_shell_command(&args);
854 assert!(!decision.is_allowed());
855 }
856
857 #[test]
858 fn test_deny_all_shell() {
859 let policy = SecurityPolicy {
860 shell: ShellPolicy {
861 deny_all: true,
862 ..ShellPolicy::default()
863 },
864 ..SecurityPolicy::default()
865 };
866 let engine = PolicyEngine::new(policy);
867 let args = serde_json::json!({"command": "echo hello"});
868 let decision = engine.check_shell_command(&args);
869 assert!(!decision.is_allowed());
870 }
871
872 #[test]
873 fn test_timeout_exceeded() {
874 let engine = PolicyEngine::default_secure();
875 let args = serde_json::json!({"command": "echo hello", "timeout_secs": 999});
876 let decision = engine.check_shell_command(&args);
877 assert!(!decision.is_allowed());
878 }
879
880 #[test]
881 fn test_empty_command() {
882 let engine = PolicyEngine::default_secure();
883 let args = serde_json::json!({"command": ""});
884 let decision = engine.check_shell_command(&args);
885 assert!(!decision.is_allowed());
886 }
887
888 #[test]
889 fn test_permissive_allows_all() {
890 let engine = PolicyEngine::permissive();
891 let args = serde_json::json!({"command": "curl https://example.com"});
892 let decision = engine.check_shell_command(&args);
893 assert!(decision.is_allowed());
894 }
895
896 #[test]
897 fn test_path_read_allowed() {
898 let engine = PolicyEngine::default_secure();
899 let args = serde_json::json!({"path": "/tmp/test.txt"});
900 let decision = engine.check_file_operation("read_file", &args);
901 assert!(decision.is_allowed());
902 }
903
904 #[test]
905 fn test_path_write_allowed() {
906 let engine = PolicyEngine::default_secure();
907 let args = serde_json::json!({"path": "/tmp/test.txt", "content": "data"});
908 let decision = engine.check_file_operation("write_file", &args);
909 assert!(decision.is_allowed());
910 }
911
912 #[test]
913 fn test_path_denied() {
914 let engine = PolicyEngine::default_secure();
915 let args = serde_json::json!({"path": "/etc/shadow"});
916 let decision = engine.check_file_operation("read_file", &args);
917 assert!(!decision.is_allowed());
918 }
919
920 #[test]
921 fn test_path_denied_write() {
922 let engine = PolicyEngine::default_secure();
923 let args = serde_json::json!({"path": "/etc/shadow", "content": "data"});
924 let decision = engine.check_file_operation("write_file", &args);
925 assert!(!decision.is_allowed());
926 }
927
928 #[test]
929 fn test_empty_path() {
930 let engine = PolicyEngine::default_secure();
931 let args = serde_json::json!({"path": ""});
932 let decision = engine.check_file_operation("read_file", &args);
933 assert!(!decision.is_allowed());
934 }
935
936 #[test]
937 fn test_network_allowed_host() {
938 let engine = PolicyEngine::default_secure();
939 let args = serde_json::json!({"url": "https://github.com/egkristi/RavenClaws"});
940 let decision = engine.check_network_request(&args);
941 assert!(decision.is_allowed());
942 }
943
944 #[test]
945 fn test_network_denied_host() {
946 let engine = PolicyEngine::default_secure();
947 let args = serde_json::json!({"url": "https://evil.com/malware"});
948 let decision = engine.check_network_request(&args);
949 assert!(!decision.is_allowed());
950 }
951
952 #[test]
953 fn test_network_localhost_allowed() {
954 let engine = PolicyEngine::default_secure();
955 let args = serde_json::json!({"url": "http://localhost:11434/api/chat"});
956 let decision = engine.check_network_request(&args);
957 assert!(decision.is_allowed());
958 }
959
960 #[test]
961 fn test_network_deny_all() {
962 let policy = SecurityPolicy {
963 network: NetworkPolicy {
964 deny_all: true,
965 ..NetworkPolicy::default()
966 },
967 ..SecurityPolicy::default()
968 };
969 let engine = PolicyEngine::new(policy);
970 let args = serde_json::json!({"url": "https://github.com"});
971 let decision = engine.check_network_request(&args);
972 assert!(!decision.is_allowed());
973 }
974
975 #[test]
976 fn test_network_empty_url() {
977 let engine = PolicyEngine::default_secure();
978 let args = serde_json::json!({"url": ""});
979 let decision = engine.check_network_request(&args);
980 assert!(!decision.is_allowed());
981 }
982
983 #[test]
984 fn test_network_invalid_url() {
985 let engine = PolicyEngine::default_secure();
986 let args = serde_json::json!({"url": "not-a-url"});
987 let decision = engine.check_network_request(&args);
988 assert!(!decision.is_allowed());
989 }
990
991 #[test]
992 fn test_requires_approval_default() {
993 let engine = PolicyEngine::default_secure();
994 assert!(engine.requires_approval("shell_exec"));
995 assert!(engine.requires_approval("write_file"));
996 assert!(!engine.requires_approval("read_file"));
997 assert!(!engine.requires_approval("web_fetch"));
998 }
999
1000 #[test]
1001 fn test_requires_approval_all() {
1002 let policy = SecurityPolicy {
1003 require_approval_all: true,
1004 ..SecurityPolicy::default()
1005 };
1006 let engine = PolicyEngine::new(policy);
1007 assert!(engine.requires_approval("shell_exec"));
1008 assert!(engine.requires_approval("read_file"));
1009 assert!(engine.requires_approval("web_fetch"));
1010 }
1011
1012 #[test]
1013 fn test_check_tool_call_shell() {
1014 let engine = PolicyEngine::default_secure();
1015 let args = serde_json::json!({"command": "echo hello"});
1016 let decision = engine.check_tool_call("shell_exec", &args);
1017 assert!(decision.is_allowed());
1018 }
1019
1020 #[test]
1021 fn test_check_tool_call_read_file() {
1022 let engine = PolicyEngine::default_secure();
1023 let args = serde_json::json!({"path": "/tmp/test.txt"});
1024 let decision = engine.check_tool_call("read_file", &args);
1025 assert!(decision.is_allowed());
1026 }
1027
1028 #[test]
1029 fn test_check_tool_call_web_fetch() {
1030 let engine = PolicyEngine::default_secure();
1031 let args = serde_json::json!({"url": "https://github.com"});
1032 let decision = engine.check_tool_call("web_fetch", &args);
1033 assert!(decision.is_allowed());
1034 }
1035
1036 #[test]
1037 fn test_check_tool_call_unknown() {
1038 let engine = PolicyEngine::default_secure();
1039 let args = serde_json::json!({});
1040 let decision = engine.check_tool_call("unknown_tool", &args);
1041 assert!(decision.is_allowed());
1042 }
1043
1044 #[test]
1045 fn test_policy_error_denied() {
1046 let err = PolicyError::Denied("test".to_string());
1047 assert_eq!(format!("{}", err), "Policy denied: test");
1048 }
1049
1050 #[test]
1051 fn test_policy_error_invalid_config() {
1052 let err = PolicyError::InvalidConfig("bad config".to_string());
1053 assert_eq!(
1054 format!("{}", err),
1055 "Invalid policy configuration: bad config"
1056 );
1057 }
1058
1059 #[test]
1060 fn test_is_localhost() {
1061 assert!(is_localhost("localhost"));
1062 assert!(is_localhost("127.0.0.1"));
1063 assert!(is_localhost("::1"));
1064 assert!(is_localhost("0.0.0.0"));
1065 assert!(is_localhost("127.0.0.2"));
1066 assert!(!is_localhost("example.com"));
1067 }
1068
1069 #[test]
1070 fn test_is_private_ip() {
1071 assert!(is_private_ip("10.0.0.1"));
1072 assert!(is_private_ip("192.168.1.1"));
1073 assert!(is_private_ip("172.16.0.1"));
1074 assert!(!is_private_ip("8.8.8.8"));
1075 assert!(!is_private_ip("example.com"));
1076 }
1077
1078 #[test]
1079 fn test_shell_policy_default() {
1080 let policy = ShellPolicy::default();
1081 assert!(!policy.deny_all);
1082 assert!(policy.allowed_commands.contains(&"echo".to_string()));
1083 assert!(policy.denied_commands.contains(&"rm -rf /".to_string()));
1084 }
1085
1086 #[test]
1087 fn test_path_policy_default() {
1088 let policy = PathPolicy::default();
1089 assert!(policy.allowed_read_paths.contains(&"/tmp".to_string()));
1090 assert!(policy.allowed_write_paths.contains(&"/tmp".to_string()));
1091 assert!(policy.denied_paths.contains(&"/etc/shadow".to_string()));
1092 }
1093
1094 #[test]
1095 fn test_network_policy_default() {
1096 let policy = NetworkPolicy::default();
1097 assert!(!policy.deny_all);
1098 assert!(policy.allow_localhost);
1099 assert!(!policy.allow_private_networks);
1100 }
1101
1102 #[test]
1103 fn test_security_policy_default() {
1104 let policy = SecurityPolicy::default();
1105 assert!(!policy.require_approval_all);
1106 assert!(policy
1107 .require_approval_for
1108 .contains(&"shell_exec".to_string()));
1109 }
1110
1111 #[test]
1112 fn test_permissive_policy() {
1113 let engine = PolicyEngine::permissive();
1114 let policy = engine.policy();
1115 assert!(policy.shell.allowed_commands.contains(&"*".to_string()));
1116 assert!(policy.network.allowed_hosts.contains(&"*".to_string()));
1117 assert!(policy.network.allow_private_networks);
1118 }
1119}