Skip to main content

ravenclaws/
policy.rs

1//! RavenClaws
2//!
3//! Every tool call is checked against the policy before execution.
4//! The policy defines allow-lists for commands, paths, hosts, and
5//! network targets. By default, everything is denied unless explicitly allowed.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ToolCall
11//!   │
12//!   ▼
13//! PolicyEngine::check()
14//!   │
15//!   ├── ShellPolicy  → command allow-list, flag restrictions
16//!   ├── PathPolicy   → read/write path allow-lists
17//!   ├── NetworkPolicy→ host/URL allow-list
18//!   └── GeneralPolicy→ category-based rules
19//!   │
20//!   ▼
21//! Allowed / Denied (with reason)
22
23use serde::{Deserialize, Serialize};
24use std::path::Path;
25use thiserror::Error;
26// ── Error types ────────────────────────────────────────────────────────────
27
28#[allow(dead_code)]
29#[derive(Error, Debug)]
30pub enum PolicyError {
31    #[error("Policy denied: {0}")]
32    Denied(String),
33
34    #[error("Invalid policy configuration: {0}")]
35    InvalidConfig(String),
36}
37
38// ── Policy types ───────────────────────────────────────────────────────────
39
40/// Policy decision
41#[allow(dead_code)]
42#[derive(Debug, Clone, PartialEq)]
43pub enum Decision {
44    /// Allow the operation
45    Allow,
46    /// Deny the operation with a reason
47    Deny(String),
48}
49
50#[allow(dead_code)]
51impl Decision {
52    pub fn is_allowed(&self) -> bool {
53        matches!(self, Decision::Allow)
54    }
55}
56
57/// Shell command policy
58#[allow(dead_code)]
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ShellPolicy {
61    /// If true, all shell commands are denied (default: false)
62    #[serde(default)]
63    pub deny_all: bool,
64    /// List of allowed command prefixes (e.g., ["echo", "ls", "cat", "git"])
65    #[serde(default)]
66    pub allowed_commands: Vec<String>,
67    /// List of denied command prefixes (takes precedence over allowed)
68    #[serde(default)]
69    pub denied_commands: Vec<String>,
70    /// Maximum command timeout in seconds
71    #[serde(default = "default_shell_timeout")]
72    pub max_timeout_secs: u64,
73    /// If true, allow commands that write to disk (install, rm, etc.)
74    #[serde(default)]
75    pub allow_write_commands: bool,
76}
77
78impl Default for ShellPolicy {
79    fn default() -> Self {
80        Self {
81            deny_all: false,
82            allowed_commands: vec![
83                "echo".to_string(),
84                "cat".to_string(),
85                "ls".to_string(),
86                "head".to_string(),
87                "tail".to_string(),
88                "wc".to_string(),
89                "grep".to_string(),
90                "find".to_string(),
91                "sort".to_string(),
92                "uniq".to_string(),
93                "cut".to_string(),
94                "which".to_string(),
95                "pwd".to_string(),
96                "date".to_string(),
97                "whoami".to_string(),
98                "uname".to_string(),
99                "env".to_string(),
100                "printenv".to_string(),
101                "git".to_string(),
102                "cargo".to_string(),
103                "rustc".to_string(),
104                "python3".to_string(),
105                "node".to_string(),
106            ],
107            denied_commands: vec![
108                "rm -rf /".to_string(),
109                "mkfs".to_string(),
110                "dd".to_string(),
111                "shutdown".to_string(),
112                "reboot".to_string(),
113                "halt".to_string(),
114                "poweroff".to_string(),
115            ],
116            max_timeout_secs: 60,
117            allow_write_commands: false,
118        }
119    }
120}
121
122/// Filesystem path policy
123#[allow(dead_code)]
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct PathPolicy {
126    /// List of allowed read path prefixes
127    #[serde(default)]
128    pub allowed_read_paths: Vec<String>,
129    /// List of allowed write path prefixes
130    #[serde(default)]
131    pub allowed_write_paths: Vec<String>,
132    /// List of denied path prefixes (takes precedence)
133    #[serde(default)]
134    pub denied_paths: Vec<String>,
135    /// Maximum file read size in bytes
136    #[serde(default = "default_max_read_bytes")]
137    pub max_read_bytes: usize,
138    /// Maximum file write size in bytes
139    #[serde(default = "default_max_write_bytes")]
140    pub max_write_bytes: usize,
141}
142
143impl Default for PathPolicy {
144    fn default() -> Self {
145        Self {
146            allowed_read_paths: vec![
147                "/tmp".to_string(),
148                "/var/tmp".to_string(),
149                "/home".to_string(),
150                "/workspace".to_string(),
151                ".".to_string(),
152            ],
153            allowed_write_paths: vec![
154                "/tmp".to_string(),
155                "/var/tmp".to_string(),
156                "/workspace".to_string(),
157                ".".to_string(),
158            ],
159            denied_paths: vec![
160                "/etc/shadow".to_string(),
161                "/etc/sudoers".to_string(),
162                "/etc/ssh".to_string(),
163                "/root".to_string(),
164            ],
165            max_read_bytes: 65536,
166            max_write_bytes: 1048576,
167        }
168    }
169}
170
171/// Network policy
172#[allow(dead_code)]
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct NetworkPolicy {
175    /// If true, all network access is denied
176    #[serde(default)]
177    pub deny_all: bool,
178    /// List of allowed hostname suffixes (e.g., ["github.com", "docs.rs"])
179    #[serde(default)]
180    pub allowed_hosts: Vec<String>,
181    /// List of denied hostname suffixes (takes precedence)
182    #[serde(default)]
183    pub denied_hosts: Vec<String>,
184    /// If true, allow connections to localhost/127.0.0.1
185    #[serde(default = "default_true")]
186    pub allow_localhost: bool,
187    /// If true, allow connections to private IP ranges
188    #[serde(default)]
189    pub allow_private_networks: bool,
190}
191
192impl Default for NetworkPolicy {
193    fn default() -> Self {
194        Self {
195            deny_all: false,
196            allowed_hosts: vec![
197                "github.com".to_string(),
198                "raw.githubusercontent.com".to_string(),
199                "docs.rs".to_string(),
200                "crates.io".to_string(),
201                "api.github.com".to_string(),
202                "google.com".to_string(),
203                "wikipedia.org".to_string(),
204                "stackoverflow.com".to_string(),
205                "rust-lang.org".to_string(),
206            ],
207            denied_hosts: vec![],
208            allow_localhost: true,
209            allow_private_networks: false,
210        }
211    }
212}
213
214/// Complete security policy
215#[allow(dead_code)]
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct SecurityPolicy {
218    /// Shell command policy
219    #[serde(default)]
220    pub shell: ShellPolicy,
221    /// Filesystem path policy
222    #[serde(default)]
223    pub path: PathPolicy,
224    /// Network policy
225    #[serde(default)]
226    pub network: NetworkPolicy,
227    /// If true, all tool calls require human approval
228    #[serde(default)]
229    pub require_approval_all: bool,
230    /// List of tool names that require approval
231    #[serde(default)]
232    pub require_approval_for: Vec<String>,
233}
234
235impl Default for SecurityPolicy {
236    fn default() -> Self {
237        Self {
238            shell: ShellPolicy::default(),
239            path: PathPolicy::default(),
240            network: NetworkPolicy::default(),
241            require_approval_all: false,
242            require_approval_for: vec!["shell_exec".to_string(), "write_file".to_string()],
243        }
244    }
245}
246
247// ── Policy engine ──────────────────────────────────────────────────────────
248
249/// The policy engine — checks tool calls against the security policy
250#[allow(dead_code)]
251pub struct PolicyEngine {
252    policy: SecurityPolicy,
253}
254
255#[allow(dead_code)]
256impl PolicyEngine {
257    /// Create a new policy engine with the given policy
258    pub fn new(policy: SecurityPolicy) -> Self {
259        Self { policy }
260    }
261
262    /// Create a policy engine with default (secure) settings
263    pub fn default_secure() -> Self {
264        Self {
265            policy: SecurityPolicy::default(),
266        }
267    }
268
269    /// Create a permissive policy engine (for development)
270    pub fn permissive() -> Self {
271        Self {
272            policy: SecurityPolicy {
273                require_approval_all: false,
274                require_approval_for: vec![],
275                shell: ShellPolicy {
276                    deny_all: false,
277                    allowed_commands: vec!["*".to_string()], // allow all
278                    denied_commands: vec![],
279                    max_timeout_secs: 300,
280                    allow_write_commands: true,
281                },
282                path: PathPolicy {
283                    allowed_read_paths: vec!["/".to_string()],
284                    allowed_write_paths: vec!["/tmp".to_string(), "/workspace".to_string()],
285                    denied_paths: vec![],
286                    max_read_bytes: 1048576,
287                    max_write_bytes: 10485760,
288                },
289                network: NetworkPolicy {
290                    deny_all: false,
291                    allowed_hosts: vec!["*".to_string()],
292                    denied_hosts: vec![],
293                    allow_localhost: true,
294                    allow_private_networks: true,
295                },
296            },
297        }
298    }
299
300    /// Check if a tool call is allowed
301    pub fn check_tool_call(&self, tool_name: &str, args: &serde_json::Value) -> Decision {
302        match tool_name {
303            "shell_exec" => self.check_shell_command(args),
304            "read_file" | "write_file" => self.check_file_operation(tool_name, args),
305            "web_fetch" => self.check_network_request(args),
306            _ => Decision::Allow, // Unknown tools are allowed by default (they'll be checked by the tool registry)
307        }
308    }
309
310    /// Check if a tool requires human approval
311    pub fn requires_approval(&self, tool_name: &str) -> bool {
312        if self.policy.require_approval_all {
313            return true;
314        }
315        self.policy
316            .require_approval_for
317            .contains(&tool_name.to_string())
318    }
319
320    /// Get the policy configuration
321    pub fn policy(&self) -> &SecurityPolicy {
322        &self.policy
323    }
324
325    // ── Internal check methods ──────────────────────────────────────────
326
327    fn check_shell_command(&self, args: &serde_json::Value) -> Decision {
328        let policy = &self.policy.shell;
329
330        if policy.deny_all {
331            return Decision::Deny("All shell commands are denied by policy".to_string());
332        }
333
334        let command = args.get("command").and_then(|v| v.as_str()).unwrap_or("");
335
336        if command.is_empty() {
337            return Decision::Deny("Empty command".to_string());
338        }
339
340        // Check denied commands first
341        for denied in &policy.denied_commands {
342            if command.contains(denied) {
343                return Decision::Deny(format!("Command contains denied pattern: '{}'", denied));
344            }
345        }
346
347        // Check timeout
348        if let Some(timeout) = args.get("timeout_secs").and_then(|v| v.as_u64()) {
349            if timeout > policy.max_timeout_secs {
350                return Decision::Deny(format!(
351                    "Timeout {}s exceeds maximum {}s",
352                    timeout, policy.max_timeout_secs
353                ));
354            }
355        }
356
357        // Check allowed commands
358        let first_word = command.split_whitespace().next().unwrap_or("");
359        let is_allowed = policy.allowed_commands.iter().any(|a| {
360            if a == "*" {
361                return true; // wildcard — allow all
362            }
363            first_word == a || command.starts_with(a)
364        });
365
366        if !is_allowed {
367            return Decision::Deny(format!(
368                "Command '{}' is not in the allowed list",
369                first_word
370            ));
371        }
372
373        Decision::Allow
374    }
375
376    fn check_file_operation(&self, tool_name: &str, args: &serde_json::Value) -> Decision {
377        let policy = &self.policy.path;
378        let path = args.get("path").and_then(|v| v.as_str()).unwrap_or("");
379
380        if path.is_empty() {
381            return Decision::Deny("Empty path".to_string());
382        }
383
384        // Resolve to absolute path
385        let abs_path = if Path::new(path).is_absolute() {
386            path.to_string()
387        } else {
388            match std::env::current_dir() {
389                Ok(cwd) => cwd.join(path).to_string_lossy().to_string(),
390                Err(_) => path.to_string(),
391            }
392        };
393
394        // Check denied paths
395        for denied in &policy.denied_paths {
396            if abs_path.starts_with(denied) || abs_path.contains(denied) {
397                return Decision::Deny(format!("Path '{}' is denied", path));
398            }
399        }
400
401        // Check allowed paths based on operation type
402        let allowed_paths = match tool_name {
403            "write_file" => &policy.allowed_write_paths,
404            _ => &policy.allowed_read_paths,
405        };
406
407        let is_allowed = allowed_paths.iter().any(|a| {
408            if a == "/" || a == "*" {
409                return true; // wildcard
410            }
411            abs_path.starts_with(a)
412        });
413
414        if !is_allowed {
415            return Decision::Deny(format!(
416                "Path '{}' is not in the allowed {} paths",
417                path,
418                if tool_name == "write_file" {
419                    "write"
420                } else {
421                    "read"
422                }
423            ));
424        }
425
426        // Check size limits for write operations
427        if tool_name == "write_file" {
428            if let Some(content) = args.get("content").and_then(|v| v.as_str()) {
429                if content.len() > policy.max_write_bytes {
430                    return Decision::Deny(format!(
431                        "Write size {} exceeds maximum {} bytes",
432                        content.len(),
433                        policy.max_write_bytes
434                    ));
435                }
436            }
437        }
438
439        Decision::Allow
440    }
441
442    fn check_network_request(&self, args: &serde_json::Value) -> Decision {
443        let policy = &self.policy.network;
444
445        if policy.deny_all {
446            return Decision::Deny("All network requests are denied by policy".to_string());
447        }
448
449        let url = args.get("url").and_then(|v| v.as_str()).unwrap_or("");
450
451        if url.is_empty() {
452            return Decision::Deny("Empty URL".to_string());
453        }
454
455        // Parse the URL
456        let parsed = match url::Url::parse(url) {
457            Ok(u) => u,
458            Err(e) => {
459                return Decision::Deny(format!("Invalid URL: {}", e));
460            }
461        };
462
463        let host = match parsed.host_str() {
464            Some(h) => h.to_string(),
465            None => return Decision::Deny("URL has no host".to_string()),
466        };
467
468        // Check localhost
469        if is_localhost(&host) {
470            if !policy.allow_localhost {
471                return Decision::Deny("Localhost connections are denied by policy".to_string());
472            }
473            return Decision::Allow;
474        }
475
476        // Check private networks
477        if is_private_ip(&host) && !policy.allow_private_networks {
478            return Decision::Deny("Private network connections are denied by policy".to_string());
479        }
480
481        // Check denied hosts
482        for denied in &policy.denied_hosts {
483            if host == *denied || host.ends_with(&format!(".{}", denied)) {
484                return Decision::Deny(format!("Host '{}' is denied", host));
485            }
486        }
487
488        // Check allowed hosts
489        let is_allowed = policy.allowed_hosts.iter().any(|a| {
490            if a == "*" {
491                return true; // wildcard
492            }
493            host == *a || host.ends_with(&format!(".{}", a))
494        });
495
496        if !is_allowed {
497            return Decision::Deny(format!("Host '{}' is not in the allowed hosts list", host));
498        }
499
500        Decision::Allow
501    }
502}
503
504// ── Prompt-injection defense ───────────────────────────────────────────────
505
506/// Result of an injection check on LLM output
507#[allow(dead_code)]
508#[derive(Debug, Clone, PartialEq)]
509pub enum InjectionVerdict {
510    /// Output appears safe
511    Clean,
512    /// Possible injection detected with a reason
513    Suspicious(String),
514}
515
516/// Detects prompt-injection attempts in LLM responses.
517///
518/// Two layers of defense:
519/// 1. **Instruction-boundary enforcement** — scans for patterns that indicate
520///    the LLM is trying to override its system instructions or produce
521///    unauthorized output.
522/// 2. **Output schema validation** — ensures structured output (tool calls,
523///    JSON responses) conforms to expected format.
524#[allow(dead_code)]
525#[derive(Debug, Clone)]
526pub struct InjectionDetector {
527    /// If true, instruction-boundary scanning is enabled
528    check_instruction_boundary: bool,
529    /// If true, output schema validation is enabled
530    check_output_schema: bool,
531    /// Custom patterns to flag as injection attempts
532    custom_patterns: Vec<String>,
533}
534
535#[allow(dead_code)]
536impl InjectionDetector {
537    /// Create a new injection detector with default settings
538    pub fn new() -> Self {
539        Self {
540            check_instruction_boundary: true,
541            check_output_schema: true,
542            custom_patterns: Vec::new(),
543        }
544    }
545
546    /// Create a detector with all checks disabled (permissive)
547    pub fn permissive() -> Self {
548        Self {
549            check_instruction_boundary: false,
550            check_output_schema: false,
551            custom_patterns: Vec::new(),
552        }
553    }
554
555    /// Enable or disable instruction-boundary checking
556    pub fn with_instruction_boundary(mut self, enabled: bool) -> Self {
557        self.check_instruction_boundary = enabled;
558        self
559    }
560
561    /// Enable or disable output schema checking
562    pub fn with_output_schema(mut self, enabled: bool) -> Self {
563        self.check_output_schema = enabled;
564        self
565    }
566
567    /// Add a custom injection pattern
568    pub fn with_custom_pattern(mut self, pattern: &str) -> Self {
569        self.custom_patterns.push(pattern.to_string());
570        self
571    }
572
573    /// Check LLM response content for injection attempts.
574    ///
575    /// Returns `InjectionVerdict::Clean` if the output appears safe,
576    /// or `InjectionVerdict::Suspicious(reason)` if injection is detected.
577    pub fn check(&self, content: &str) -> InjectionVerdict {
578        // Check instruction-boundary violations
579        if self.check_instruction_boundary {
580            if let Some(reason) = self.check_instruction_boundary_violations(content) {
581                return InjectionVerdict::Suspicious(reason);
582            }
583        }
584
585        // Check output schema
586        if self.check_output_schema {
587            if let Some(reason) = self.check_output_schema_violations(content) {
588                return InjectionVerdict::Suspicious(reason);
589            }
590        }
591
592        InjectionVerdict::Clean
593    }
594
595    /// Check for instruction-boundary violations — patterns where the LLM
596    /// attempts to override its system prompt or produce unauthorized output.
597    fn check_instruction_boundary_violations(&self, content: &str) -> Option<String> {
598        let content_lower = content.to_lowercase();
599
600        // Known prompt-injection / jailbreak patterns
601        let injection_patterns = [
602            // Direct instruction override attempts
603            ("ignore_previous", "Attempt to ignore previous instructions"),
604            ("ignore above", "Attempt to ignore previous instructions"),
605            (
606                "ignore all previous",
607                "Attempt to ignore previous instructions",
608            ),
609            (
610                "ignore the above",
611                "Attempt to ignore previous instructions",
612            ),
613            (
614                "ignore your previous",
615                "Attempt to ignore previous instructions",
616            ),
617            (
618                "ignore all instructions",
619                "Attempt to ignore all instructions",
620            ),
621            ("forget all previous", "Attempt to forget instructions"),
622            ("forget your instructions", "Attempt to forget instructions"),
623            ("disregard previous", "Attempt to disregard instructions"),
624            ("disregard all", "Attempt to disregard all instructions"),
625            // System prompt override
626            ("new system prompt", "Attempt to override system prompt"),
627            ("override system", "Attempt to override system prompt"),
628            ("you are now", "Suspicious role-switch pattern"),
629            ("from now on you are", "Suspicious role-switch pattern"),
630            ("act as a", "Suspicious role-switch pattern"),
631            ("pretend to be", "Suspicious role-switch pattern"),
632            // Delimiter-based injection
633            ("---begin override---", "Delimiter-based injection attempt"),
634            ("<system>", "Delimiter-based injection attempt"),
635            ("</system>", "Delimiter-based injection attempt"),
636            ("<override>", "Delimiter-based injection attempt"),
637            ("</override>", "Delimiter-based injection attempt"),
638            // Output manipulation
639            ("output your", "Attempt to extract system prompt"),
640            ("print your", "Attempt to extract system prompt"),
641            ("reveal your", "Attempt to extract system prompt"),
642            ("show your prompt", "Attempt to extract system prompt"),
643            ("show your instructions", "Attempt to extract system prompt"),
644            (
645                "what are your instructions",
646                "Attempt to extract system prompt",
647            ),
648            (
649                "what is your system prompt",
650                "Attempt to extract system prompt",
651            ),
652            // Role-playing jailbreaks
653            ("dan ", "Potential DAN jailbreak pattern"),
654            ("do anything now", "Potential jailbreak pattern"),
655            ("you have been released", "Potential jailbreak pattern"),
656            ("you are free", "Potential jailbreak pattern"),
657            ("no restrictions", "Potential jailbreak pattern"),
658            ("no rules", "Potential jailbreak pattern"),
659            ("no limitations", "Potential jailbreak pattern"),
660            ("no filtering", "Potential jailbreak pattern"),
661            ("no censorship", "Potential jailbreak pattern"),
662            // Token smuggling
663            ("base64", "Potential token smuggling"),
664            ("rot13", "Potential obfuscation attempt"),
665            ("caesar cipher", "Potential obfuscation attempt"),
666            ("encoded message", "Potential obfuscation attempt"),
667            ("decode this", "Potential obfuscation attempt"),
668            // Meta-instruction attacks
669            ("this is a test", "Suspicious meta-instruction pattern"),
670            (
671                "this is a security test",
672                "Suspicious meta-instruction pattern",
673            ),
674            ("this is a prompt", "Suspicious meta-instruction pattern"),
675            ("the user is lying", "Suspicious meta-instruction pattern"),
676            ("the user is testing", "Suspicious meta-instruction pattern"),
677            ("you must obey", "Suspicious imperative pattern"),
678            ("you will obey", "Suspicious imperative pattern"),
679            ("you are required", "Suspicious imperative pattern"),
680            ("you must respond", "Suspicious imperative pattern"),
681            ("respond with exactly", "Suspicious imperative pattern"),
682            ("say exactly", "Suspicious imperative pattern"),
683            ("repeat exactly", "Suspicious imperative pattern"),
684            ("repeat after me", "Suspicious imperative pattern"),
685            ("repeat the words", "Suspicious imperative pattern"),
686        ];
687
688        for (pattern, reason) in &injection_patterns {
689            if content_lower.contains(pattern) {
690                return Some(format!("{}: '{}'", reason, pattern));
691            }
692        }
693
694        // Check custom patterns
695        for pattern in &self.custom_patterns {
696            if content_lower.contains(&pattern.to_lowercase()) {
697                return Some(format!("Custom pattern matched: '{}'", pattern));
698            }
699        }
700
701        None
702    }
703
704    /// Check for output schema violations — ensures structured output
705    /// conforms to expected format.
706    fn check_output_schema_violations(&self, content: &str) -> Option<String> {
707        // If the content contains a TOOL_CALL: marker, validate the JSON args
708        if content.contains("TOOL_CALL:") {
709            // Find all ARGS: lines and validate JSON
710            for line in content.lines() {
711                let trimmed = line.trim();
712                if let Some(args_str) = trimmed.strip_prefix("ARGS:") {
713                    let args_str = args_str.trim();
714                    if !args_str.is_empty()
715                        && serde_json::from_str::<serde_json::Value>(args_str).is_err()
716                    {
717                        return Some(format!(
718                            "Invalid JSON in tool call arguments: '{}'",
719                            args_str
720                        ));
721                    }
722                }
723            }
724        }
725
726        // Check for unbalanced code blocks that could hide content
727        let open_blocks = content.matches("```").count();
728        #[allow(clippy::manual_is_multiple_of)]
729        if open_blocks % 2 != 0 {
730            return Some("Unbalanced code block delimiters".to_string());
731        }
732
733        // Check for extremely long content that might be a smuggling attempt
734        if content.len() > 100_000 {
735            return Some(format!(
736                "Response too long ({} chars), possible smuggling attempt",
737                content.len()
738            ));
739        }
740
741        None
742    }
743}
744
745impl Default for InjectionDetector {
746    fn default() -> Self {
747        Self::new()
748    }
749}
750
751// ── Helper functions ───────────────────────────────────────────────────────
752
753fn default_shell_timeout() -> u64 {
754    60
755}
756
757fn default_max_read_bytes() -> usize {
758    65536
759}
760
761fn default_max_write_bytes() -> usize {
762    1048576
763}
764
765fn default_true() -> bool {
766    true
767}
768
769fn is_localhost(host: &str) -> bool {
770    host == "localhost"
771        || host == "127.0.0.1"
772        || host == "::1"
773        || host == "0.0.0.0"
774        || host.starts_with("127.")
775}
776
777fn is_private_ip(host: &str) -> bool {
778    host == "10.0.0.1"
779        || host.starts_with("10.")
780        || host.starts_with("172.16.")
781        || host.starts_with("172.17.")
782        || host.starts_with("172.18.")
783        || host.starts_with("172.19.")
784        || host.starts_with("172.20.")
785        || host.starts_with("172.21.")
786        || host.starts_with("172.22.")
787        || host.starts_with("172.23.")
788        || host.starts_with("172.24.")
789        || host.starts_with("172.25.")
790        || host.starts_with("172.26.")
791        || host.starts_with("172.27.")
792        || host.starts_with("172.28.")
793        || host.starts_with("172.29.")
794        || host.starts_with("172.30.")
795        || host.starts_with("172.31.")
796        || host.starts_with("192.168.")
797}
798
799// ── Tests ──────────────────────────────────────────────────────────────────
800
801#[cfg(test)]
802mod tests {
803    use super::*;
804
805    #[test]
806    fn test_decision_allow() {
807        let d = Decision::Allow;
808        assert!(d.is_allowed());
809    }
810
811    #[test]
812    fn test_decision_deny() {
813        let d = Decision::Deny("test".to_string());
814        assert!(!d.is_allowed());
815    }
816
817    #[test]
818    fn test_default_policy_denies_unknown_command() {
819        let engine = PolicyEngine::default_secure();
820        let args = serde_json::json!({"command": "sudo rm -rf /"});
821        let decision = engine.check_shell_command(&args);
822        assert!(!decision.is_allowed());
823    }
824
825    #[test]
826    fn test_default_policy_allows_echo() {
827        let engine = PolicyEngine::default_secure();
828        let args = serde_json::json!({"command": "echo hello"});
829        let decision = engine.check_shell_command(&args);
830        assert!(decision.is_allowed());
831    }
832
833    #[test]
834    fn test_default_policy_allows_ls() {
835        let engine = PolicyEngine::default_secure();
836        let args = serde_json::json!({"command": "ls -la"});
837        let decision = engine.check_shell_command(&args);
838        assert!(decision.is_allowed());
839    }
840
841    #[test]
842    fn test_default_policy_denies_shutdown() {
843        let engine = PolicyEngine::default_secure();
844        let args = serde_json::json!({"command": "shutdown -h now"});
845        let decision = engine.check_shell_command(&args);
846        assert!(!decision.is_allowed());
847    }
848
849    #[test]
850    fn test_default_policy_denies_rm_rf_root() {
851        let engine = PolicyEngine::default_secure();
852        let args = serde_json::json!({"command": "rm -rf /"});
853        let decision = engine.check_shell_command(&args);
854        assert!(!decision.is_allowed());
855    }
856
857    #[test]
858    fn test_deny_all_shell() {
859        let policy = SecurityPolicy {
860            shell: ShellPolicy {
861                deny_all: true,
862                ..ShellPolicy::default()
863            },
864            ..SecurityPolicy::default()
865        };
866        let engine = PolicyEngine::new(policy);
867        let args = serde_json::json!({"command": "echo hello"});
868        let decision = engine.check_shell_command(&args);
869        assert!(!decision.is_allowed());
870    }
871
872    #[test]
873    fn test_timeout_exceeded() {
874        let engine = PolicyEngine::default_secure();
875        let args = serde_json::json!({"command": "echo hello", "timeout_secs": 999});
876        let decision = engine.check_shell_command(&args);
877        assert!(!decision.is_allowed());
878    }
879
880    #[test]
881    fn test_empty_command() {
882        let engine = PolicyEngine::default_secure();
883        let args = serde_json::json!({"command": ""});
884        let decision = engine.check_shell_command(&args);
885        assert!(!decision.is_allowed());
886    }
887
888    #[test]
889    fn test_permissive_allows_all() {
890        let engine = PolicyEngine::permissive();
891        let args = serde_json::json!({"command": "curl https://example.com"});
892        let decision = engine.check_shell_command(&args);
893        assert!(decision.is_allowed());
894    }
895
896    #[test]
897    fn test_path_read_allowed() {
898        let engine = PolicyEngine::default_secure();
899        let args = serde_json::json!({"path": "/tmp/test.txt"});
900        let decision = engine.check_file_operation("read_file", &args);
901        assert!(decision.is_allowed());
902    }
903
904    #[test]
905    fn test_path_write_allowed() {
906        let engine = PolicyEngine::default_secure();
907        let args = serde_json::json!({"path": "/tmp/test.txt", "content": "data"});
908        let decision = engine.check_file_operation("write_file", &args);
909        assert!(decision.is_allowed());
910    }
911
912    #[test]
913    fn test_path_denied() {
914        let engine = PolicyEngine::default_secure();
915        let args = serde_json::json!({"path": "/etc/shadow"});
916        let decision = engine.check_file_operation("read_file", &args);
917        assert!(!decision.is_allowed());
918    }
919
920    #[test]
921    fn test_path_denied_write() {
922        let engine = PolicyEngine::default_secure();
923        let args = serde_json::json!({"path": "/etc/shadow", "content": "data"});
924        let decision = engine.check_file_operation("write_file", &args);
925        assert!(!decision.is_allowed());
926    }
927
928    #[test]
929    fn test_empty_path() {
930        let engine = PolicyEngine::default_secure();
931        let args = serde_json::json!({"path": ""});
932        let decision = engine.check_file_operation("read_file", &args);
933        assert!(!decision.is_allowed());
934    }
935
936    #[test]
937    fn test_network_allowed_host() {
938        let engine = PolicyEngine::default_secure();
939        let args = serde_json::json!({"url": "https://github.com/egkristi/RavenClaws"});
940        let decision = engine.check_network_request(&args);
941        assert!(decision.is_allowed());
942    }
943
944    #[test]
945    fn test_network_denied_host() {
946        let engine = PolicyEngine::default_secure();
947        let args = serde_json::json!({"url": "https://evil.com/malware"});
948        let decision = engine.check_network_request(&args);
949        assert!(!decision.is_allowed());
950    }
951
952    #[test]
953    fn test_network_localhost_allowed() {
954        let engine = PolicyEngine::default_secure();
955        let args = serde_json::json!({"url": "http://localhost:11434/api/chat"});
956        let decision = engine.check_network_request(&args);
957        assert!(decision.is_allowed());
958    }
959
960    #[test]
961    fn test_network_deny_all() {
962        let policy = SecurityPolicy {
963            network: NetworkPolicy {
964                deny_all: true,
965                ..NetworkPolicy::default()
966            },
967            ..SecurityPolicy::default()
968        };
969        let engine = PolicyEngine::new(policy);
970        let args = serde_json::json!({"url": "https://github.com"});
971        let decision = engine.check_network_request(&args);
972        assert!(!decision.is_allowed());
973    }
974
975    #[test]
976    fn test_network_empty_url() {
977        let engine = PolicyEngine::default_secure();
978        let args = serde_json::json!({"url": ""});
979        let decision = engine.check_network_request(&args);
980        assert!(!decision.is_allowed());
981    }
982
983    #[test]
984    fn test_network_invalid_url() {
985        let engine = PolicyEngine::default_secure();
986        let args = serde_json::json!({"url": "not-a-url"});
987        let decision = engine.check_network_request(&args);
988        assert!(!decision.is_allowed());
989    }
990
991    #[test]
992    fn test_requires_approval_default() {
993        let engine = PolicyEngine::default_secure();
994        assert!(engine.requires_approval("shell_exec"));
995        assert!(engine.requires_approval("write_file"));
996        assert!(!engine.requires_approval("read_file"));
997        assert!(!engine.requires_approval("web_fetch"));
998    }
999
1000    #[test]
1001    fn test_requires_approval_all() {
1002        let policy = SecurityPolicy {
1003            require_approval_all: true,
1004            ..SecurityPolicy::default()
1005        };
1006        let engine = PolicyEngine::new(policy);
1007        assert!(engine.requires_approval("shell_exec"));
1008        assert!(engine.requires_approval("read_file"));
1009        assert!(engine.requires_approval("web_fetch"));
1010    }
1011
1012    #[test]
1013    fn test_check_tool_call_shell() {
1014        let engine = PolicyEngine::default_secure();
1015        let args = serde_json::json!({"command": "echo hello"});
1016        let decision = engine.check_tool_call("shell_exec", &args);
1017        assert!(decision.is_allowed());
1018    }
1019
1020    #[test]
1021    fn test_check_tool_call_read_file() {
1022        let engine = PolicyEngine::default_secure();
1023        let args = serde_json::json!({"path": "/tmp/test.txt"});
1024        let decision = engine.check_tool_call("read_file", &args);
1025        assert!(decision.is_allowed());
1026    }
1027
1028    #[test]
1029    fn test_check_tool_call_web_fetch() {
1030        let engine = PolicyEngine::default_secure();
1031        let args = serde_json::json!({"url": "https://github.com"});
1032        let decision = engine.check_tool_call("web_fetch", &args);
1033        assert!(decision.is_allowed());
1034    }
1035
1036    #[test]
1037    fn test_check_tool_call_unknown() {
1038        let engine = PolicyEngine::default_secure();
1039        let args = serde_json::json!({});
1040        let decision = engine.check_tool_call("unknown_tool", &args);
1041        assert!(decision.is_allowed());
1042    }
1043
1044    #[test]
1045    fn test_policy_error_denied() {
1046        let err = PolicyError::Denied("test".to_string());
1047        assert_eq!(format!("{}", err), "Policy denied: test");
1048    }
1049
1050    #[test]
1051    fn test_policy_error_invalid_config() {
1052        let err = PolicyError::InvalidConfig("bad config".to_string());
1053        assert_eq!(
1054            format!("{}", err),
1055            "Invalid policy configuration: bad config"
1056        );
1057    }
1058
1059    #[test]
1060    fn test_is_localhost() {
1061        assert!(is_localhost("localhost"));
1062        assert!(is_localhost("127.0.0.1"));
1063        assert!(is_localhost("::1"));
1064        assert!(is_localhost("0.0.0.0"));
1065        assert!(is_localhost("127.0.0.2"));
1066        assert!(!is_localhost("example.com"));
1067    }
1068
1069    #[test]
1070    fn test_is_private_ip() {
1071        assert!(is_private_ip("10.0.0.1"));
1072        assert!(is_private_ip("192.168.1.1"));
1073        assert!(is_private_ip("172.16.0.1"));
1074        assert!(!is_private_ip("8.8.8.8"));
1075        assert!(!is_private_ip("example.com"));
1076    }
1077
1078    #[test]
1079    fn test_shell_policy_default() {
1080        let policy = ShellPolicy::default();
1081        assert!(!policy.deny_all);
1082        assert!(policy.allowed_commands.contains(&"echo".to_string()));
1083        assert!(policy.denied_commands.contains(&"rm -rf /".to_string()));
1084    }
1085
1086    #[test]
1087    fn test_path_policy_default() {
1088        let policy = PathPolicy::default();
1089        assert!(policy.allowed_read_paths.contains(&"/tmp".to_string()));
1090        assert!(policy.allowed_write_paths.contains(&"/tmp".to_string()));
1091        assert!(policy.denied_paths.contains(&"/etc/shadow".to_string()));
1092    }
1093
1094    #[test]
1095    fn test_network_policy_default() {
1096        let policy = NetworkPolicy::default();
1097        assert!(!policy.deny_all);
1098        assert!(policy.allow_localhost);
1099        assert!(!policy.allow_private_networks);
1100    }
1101
1102    #[test]
1103    fn test_security_policy_default() {
1104        let policy = SecurityPolicy::default();
1105        assert!(!policy.require_approval_all);
1106        assert!(policy
1107            .require_approval_for
1108            .contains(&"shell_exec".to_string()));
1109    }
1110
1111    #[test]
1112    fn test_permissive_policy() {
1113        let engine = PolicyEngine::permissive();
1114        let policy = engine.policy();
1115        assert!(policy.shell.allowed_commands.contains(&"*".to_string()));
1116        assert!(policy.network.allowed_hosts.contains(&"*".to_string()));
1117        assert!(policy.network.allow_private_networks);
1118    }
1119}