Skip to main content

battlecommand_forge/
sandbox.rs

1//! Sandboxed command execution with timeouts and environment isolation.
2
3use std::io::Read as IoRead;
4use std::path::Path;
5use std::process::{Command, Stdio};
6use std::time::{Duration, Instant};
7
8const DEFAULT_TIMEOUT_SECS: u64 = 30;
9
10/// Allowlist of env-var names passed to subprocesses. Anything not on
11/// this list (or matching one of the prefix patterns below) is stripped
12/// before exec. The previous substring blocklist missed several real
13/// leak surfaces (`OLLAMA_HOST`, `DATABASE_URL`, `KUBECONFIG`,
14/// `AWS_ACCESS_KEY_ID`, `SSH_AUTH_SOCK`, …); allowlist semantics close
15/// the class.
16const ALLOWED_ENV_NAMES: &[&str] = &[
17    // Universal essentials
18    "PATH",
19    "HOME",
20    "USER",
21    "SHELL",
22    "LANG",
23    "TZ",
24    "TERM",
25    "TMPDIR",
26    "TMP",
27    "TEMP",
28    // Python/venv
29    "VIRTUAL_ENV",
30    "PYTHONUNBUFFERED",
31    "PYTHONDONTWRITEBYTECODE",
32    // Windows essentials (without these, child processes fail to start)
33    "USERNAME",
34    "USERPROFILE",
35    "HOMEDRIVE",
36    "HOMEPATH",
37    "SYSTEMROOT",
38    "SYSTEMDRIVE",
39    "WINDIR",
40    "COMSPEC",
41    "PROCESSOR_ARCHITECTURE",
42    "PROCESSOR_IDENTIFIER",
43    "NUMBER_OF_PROCESSORS",
44    "OS",
45    "PATHEXT",
46];
47
48/// Env-var name prefixes that are always allowed (locale family).
49const ALLOWED_ENV_PREFIXES: &[&str] = &["LC_"];
50
51fn is_allowed_env_name(name: &str) -> bool {
52    let upper = name.to_ascii_uppercase();
53    if ALLOWED_ENV_NAMES.iter().any(|n| *n == upper) {
54        return true;
55    }
56    ALLOWED_ENV_PREFIXES.iter().any(|p| upper.starts_with(p))
57}
58
59/// Result of running an external tool.
60#[derive(Debug)]
61pub struct ToolResult {
62    pub success: bool,
63    pub stdout: String,
64    pub stderr: String,
65    pub timed_out: bool,
66}
67
68/// Run a command in a sandboxed environment.
69/// - Working directory set to `cwd`
70/// - Timeout enforced (default 30s)
71/// - Sensitive env vars stripped (any *API_KEY, *SECRET, *TOKEN, etc.)
72pub fn run_tool(program: &str, args: &[&str], cwd: &Path) -> ToolResult {
73    run_tool_with_timeout(program, args, cwd, DEFAULT_TIMEOUT_SECS)
74}
75
76pub fn run_tool_with_timeout(
77    program: &str,
78    args: &[&str],
79    cwd: &Path,
80    timeout_secs: u64,
81) -> ToolResult {
82    // Check if tool exists
83    if !tool_exists(program) {
84        return ToolResult {
85            success: false,
86            stdout: String::new(),
87            stderr: format!("{} not found in PATH", program),
88            timed_out: false,
89        };
90    }
91
92    let mut cmd = Command::new(program);
93    cmd.args(args)
94        .current_dir(cwd)
95        .stdout(Stdio::piped())
96        .stderr(Stdio::piped());
97
98    // Strip env to allowlist. We snapshot the parent env, env_clear() the
99    // child, then re-add only the entries that pass `is_allowed_env_name`.
100    // Anything else (API keys, tokens, OLLAMA_HOST, DATABASE_URL,
101    // KUBECONFIG, SSH_AUTH_SOCK, …) is dropped.
102    let allowed: Vec<(String, String)> = std::env::vars()
103        .filter(|(k, _)| is_allowed_env_name(k))
104        .collect();
105    cmd.env_clear();
106    for (k, v) in allowed {
107        cmd.env(k, v);
108    }
109
110    let mut child = match cmd.spawn() {
111        Ok(c) => c,
112        Err(e) => {
113            return ToolResult {
114                success: false,
115                stdout: String::new(),
116                stderr: format!("Failed to spawn {}: {}", program, e),
117                timed_out: false,
118            }
119        }
120    };
121
122    // Poll with timeout enforcement
123    let deadline = Instant::now() + Duration::from_secs(timeout_secs);
124    loop {
125        match child.try_wait() {
126            Ok(Some(status)) => {
127                let mut stdout = String::new();
128                let mut stderr = String::new();
129                if let Some(mut out) = child.stdout.take() {
130                    let _ = out.read_to_string(&mut stdout);
131                }
132                if let Some(mut err) = child.stderr.take() {
133                    let _ = err.read_to_string(&mut stderr);
134                }
135                return ToolResult {
136                    success: status.success(),
137                    stdout,
138                    stderr,
139                    timed_out: false,
140                };
141            }
142            Ok(None) => {
143                if Instant::now() >= deadline {
144                    let _ = child.kill();
145                    let _ = child.wait();
146                    return ToolResult {
147                        success: false,
148                        stdout: String::new(),
149                        stderr: format!("Killed: timed out after {}s", timeout_secs),
150                        timed_out: true,
151                    };
152                }
153                std::thread::sleep(Duration::from_millis(200));
154            }
155            Err(e) => {
156                return ToolResult {
157                    success: false,
158                    stdout: String::new(),
159                    stderr: format!("Failed to wait for {}: {}", program, e),
160                    timed_out: false,
161                }
162            }
163        }
164    }
165}
166
167/// Run a tool with network access denied (macOS sandbox-exec).
168/// Falls back to normal execution on non-macOS or if sandbox-exec unavailable.
169pub fn run_tool_sandboxed(
170    program: &str,
171    args: &[&str],
172    cwd: &Path,
173    timeout_secs: u64,
174    deny_network: bool,
175) -> ToolResult {
176    if !deny_network || !cfg!(target_os = "macos") || !tool_exists("sandbox-exec") {
177        return run_tool_with_timeout(program, args, cwd, timeout_secs);
178    }
179    let profile = "(version 1)\n(allow default)\n(deny network*)";
180    let mut sbox_args = vec!["-p", profile, program];
181    sbox_args.extend_from_slice(args);
182    run_tool_with_timeout("sandbox-exec", &sbox_args, cwd, timeout_secs)
183}
184
185/// Validate that a relative path stays within `root`.
186///
187/// Rejects: parent-dir traversal (`..` as a path component, not as a
188/// substring — so `file..py` is allowed), absolute paths, drive-letter
189/// prefixes, leading separators, null bytes, and (when `root` exists)
190/// any path whose canonical form escapes `root` via a planted symlink.
191pub fn validate_path_within(root: &Path, relative: &str) -> Result<std::path::PathBuf, String> {
192    use std::path::Component;
193
194    let trimmed = relative.trim();
195    if trimmed.is_empty() {
196        return Err("Empty path".to_string());
197    }
198    if trimmed.contains('\0') {
199        return Err(format!("Null byte in path: {:?}", trimmed));
200    }
201    if trimmed.starts_with('/') || trimmed.starts_with('\\') {
202        return Err(format!("Absolute path rejected: {}", trimmed));
203    }
204    // Windows drive letter, e.g. "C:foo" or "C:\foo", on any platform.
205    let bytes = trimmed.as_bytes();
206    if bytes.len() >= 2 && bytes[1] == b':' && bytes[0].is_ascii_alphabetic() {
207        return Err(format!("Drive-letter path rejected: {}", trimmed));
208    }
209
210    let rel_path = Path::new(trimmed);
211    if rel_path.is_absolute() {
212        return Err(format!("Absolute path rejected: {}", trimmed));
213    }
214    for comp in rel_path.components() {
215        match comp {
216            Component::ParentDir => {
217                return Err(format!("Parent-dir traversal rejected: {}", trimmed));
218            }
219            Component::Prefix(_) | Component::RootDir => {
220                return Err(format!("Absolute or prefixed path rejected: {}", trimmed));
221            }
222            _ => {}
223        }
224    }
225
226    let joined = root.join(rel_path);
227
228    // If root doesn't exist on the filesystem, no symlink can have been
229    // planted inside it — the lexical check above is sufficient.
230    let canon_root = match std::fs::canonicalize(root) {
231        Ok(c) => c,
232        Err(_) => return Ok(joined),
233    };
234
235    // For write targets that don't exist yet, walk up to the deepest
236    // existing ancestor and canonicalize that. starts_with on the
237    // canonical ancestor catches a symlink anywhere in the chain.
238    let canon_joined = match std::fs::canonicalize(&joined) {
239        Ok(c) => c,
240        Err(_) => {
241            let mut probe = joined.clone();
242            loop {
243                match probe.parent() {
244                    Some(parent) if parent != probe => {
245                        probe = parent.to_path_buf();
246                        if let Ok(c) = std::fs::canonicalize(&probe) {
247                            break c;
248                        }
249                    }
250                    _ => return Ok(joined),
251                }
252            }
253        }
254    };
255
256    if !canon_joined.starts_with(&canon_root) {
257        return Err(format!(
258            "Path escapes root directory (symlink?): {}",
259            trimmed
260        ));
261    }
262
263    Ok(joined)
264}
265
266/// Check if a tool is available (in PATH or as absolute path).
267/// Cross-platform: uses `which` on Unix, `where` on Windows.
268pub fn tool_exists(program: &str) -> bool {
269    let path = std::path::Path::new(program);
270    // Absolute or relative path with separators — check if file exists
271    if path.is_absolute() || program.contains(std::path::MAIN_SEPARATOR) {
272        return path.exists();
273    }
274    // PATH lookup: `which` on Unix, `where` on Windows
275    let lookup = if cfg!(windows) { "where" } else { "which" };
276    Command::new(lookup)
277        .arg(program)
278        .output()
279        .map(|o| o.status.success())
280        .unwrap_or(false)
281}
282
283/// Parsed lint issue from an external tool.
284#[derive(Debug, Clone)]
285pub struct LintIssue {
286    pub file: String,
287    pub line: Option<u32>,
288    pub column: Option<u32>,
289    pub severity: String,
290    pub message: String,
291    pub rule: String,
292}
293
294impl std::fmt::Display for LintIssue {
295    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296        if let Some(line) = self.line {
297            write!(
298                f,
299                "{}:{}: [{}] {}",
300                self.file, line, self.rule, self.message
301            )
302        } else {
303            write!(f, "{}: [{}] {}", self.file, self.rule, self.message)
304        }
305    }
306}
307
308/// Parse ruff output into structured lint issues.
309pub fn parse_ruff_output(output: &str) -> Vec<LintIssue> {
310    let mut issues = Vec::new();
311    for line in output.lines() {
312        // Format: file.py:10:5: E501 Line too long
313        let parts: Vec<&str> = line.splitn(4, ':').collect();
314        if parts.len() >= 4 {
315            let file = parts[0].trim().to_string();
316            let line_num = parts[1].trim().parse().ok();
317            let col = parts[2].trim().parse().ok();
318            let rest = parts[3].trim();
319            let (rule, msg) = rest.split_once(' ').unwrap_or(("", rest));
320            issues.push(LintIssue {
321                file,
322                line: line_num,
323                column: col,
324                severity: "warning".into(),
325                message: msg.to_string(),
326                rule: rule.to_string(),
327            });
328        }
329    }
330    issues
331}
332
333/// Parse pytest output to extract pass/fail counts.
334#[derive(Debug)]
335pub struct TestResult {
336    pub passed: u32,
337    pub failed: u32,
338    pub errors: u32,
339    pub output: String,
340}
341
342pub fn parse_pytest_output(stdout: &str, stderr: &str) -> TestResult {
343    let combined = format!("{}\n{}", stdout, stderr);
344    let mut passed = 0u32;
345    let mut failed = 0u32;
346    let mut errors = 0u32;
347
348    for line in combined.lines() {
349        let lower = line.to_lowercase();
350        let trimmed_lower = lower.trim();
351
352        // Only parse pytest summary lines — not arbitrary error messages.
353        // Pytest summary format: "5 passed, 2 failed, 1 error in 3.2s"
354        // or "===== 5 passed =====" or "5 passed in 0.5s"
355        // Key: must contain "passed" or end with "failed" context, AND look like a summary.
356
357        // Method 1: "N word" pairs — safe, precise extraction
358        // Matches: "5 passed", "2 failed", "1 error" as adjacent words
359        let words: Vec<&str> = line.split_whitespace().collect();
360        for pair in words.windows(2) {
361            if let Ok(n) = pair[0]
362                .trim_matches(|c: char| c == ',' || c == '=')
363                .parse::<u32>()
364            {
365                let what = pair[1].to_lowercase();
366                // Only match pytest keywords, with reasonable bounds (< 10000 tests)
367                if n < 10000 {
368                    if what.starts_with("passed") && passed == 0 {
369                        passed = n;
370                    } else if what.starts_with("failed") && failed == 0 {
371                        failed = n;
372                    } else if what.starts_with("error") && errors == 0 {
373                        errors = n;
374                    }
375                }
376            }
377        }
378
379        // Method 2: count individual FAILED/ERROR lines from pytest -q short output
380        // Format: "FAILED tests/test_foo.py::test_bar - AssertionError..."
381        // Format: "ERROR tests/test_foo.py::test_bar - ..."
382        if trimmed_lower.starts_with("failed ") && trimmed_lower.contains("::") {
383            failed += 1;
384        }
385        if trimmed_lower.starts_with("error ") && trimmed_lower.contains("::") {
386            errors += 1;
387        }
388
389        // Method 3: pytest -q final line: "3 passed" or "3 passed."
390        if trimmed_lower.ends_with("passed") || trimmed_lower.ends_with("passed.") {
391            if let Some(n_str) = trimmed_lower.split_whitespace().next() {
392                if let Ok(n) = n_str.parse::<u32>() {
393                    if passed == 0 && n < 10000 {
394                        passed = n;
395                    }
396                }
397            }
398        }
399    }
400
401    TestResult {
402        passed,
403        failed,
404        errors,
405        output: combined,
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412
413    #[test]
414    fn test_parse_ruff_output() {
415        let output =
416            "app/main.py:10:5: E501 Line too long (120 > 88)\napp/main.py:25:1: F401 Unused import";
417        let issues = parse_ruff_output(output);
418        assert_eq!(issues.len(), 2);
419        assert_eq!(issues[0].line, Some(10));
420        assert_eq!(issues[0].rule, "E501");
421        assert_eq!(issues[1].rule, "F401");
422    }
423
424    #[test]
425    fn test_parse_pytest_summary_line() {
426        let stdout = "5 passed, 2 failed, 1 error in 3.2s";
427        let result = parse_pytest_output(stdout, "");
428        assert_eq!(result.passed, 5);
429        assert_eq!(result.failed, 2);
430        assert_eq!(result.errors, 1);
431    }
432
433    #[test]
434    fn test_parse_pytest_equals_format() {
435        let stdout = "===== 3 passed in 0.5s =====";
436        let result = parse_pytest_output(stdout, "");
437        assert_eq!(result.passed, 3);
438    }
439
440    #[test]
441    fn test_parse_pytest_q_format() {
442        let stdout = "3 passed";
443        let result = parse_pytest_output(stdout, "");
444        assert_eq!(result.passed, 3);
445    }
446
447    #[test]
448    fn test_tool_exists() {
449        // These should exist on any unix system
450        assert!(tool_exists("ls"));
451        assert!(!tool_exists("nonexistent_tool_xyz_12345"));
452    }
453
454    #[test]
455    fn test_validate_path_safe() {
456        let root = std::path::Path::new("/tmp/test_root");
457        assert!(validate_path_within(root, "app/main.py").is_ok());
458        assert!(validate_path_within(root, "tests/test_foo.py").is_ok());
459        assert!(validate_path_within(root, "README.md").is_ok());
460    }
461
462    #[test]
463    fn test_validate_path_traversal_blocked() {
464        let root = std::path::Path::new("/tmp/test_root");
465        assert!(validate_path_within(root, "../etc/passwd").is_err());
466        assert!(validate_path_within(root, "app/../../etc/shadow").is_err());
467        assert!(validate_path_within(root, "/etc/passwd").is_err());
468        assert!(validate_path_within(root, "\\windows\\system32").is_err());
469        assert!(validate_path_within(root, "file\0.py").is_err());
470    }
471
472    #[test]
473    fn test_env_var_stripping() {
474        // Set a test API key, run env via sandbox, verify it's stripped
475        unsafe {
476            std::env::set_var("TEST_API_KEY", "secret123");
477        }
478        let result = run_tool("env", &[], std::path::Path::new("/tmp"));
479        assert!(
480            !result.stdout.contains("secret123"),
481            "API key leaked to subprocess!"
482        );
483        unsafe {
484            std::env::remove_var("TEST_API_KEY");
485        }
486    }
487
488    #[test]
489    fn test_timeout_kills_process() {
490        let result = run_tool_with_timeout("sleep", &["30"], std::path::Path::new("/tmp"), 2);
491        assert!(result.timed_out, "Process should have timed out");
492        assert!(!result.success);
493    }
494
495    // Regression: the previous validator rejected `..` as a substring,
496    // false-positive-blocking legitimate filenames.
497    #[test]
498    fn test_validate_path_dotdot_in_filename_allowed() {
499        let root = std::path::Path::new("/tmp/test_root");
500        assert!(validate_path_within(root, "file..py").is_ok());
501        assert!(validate_path_within(root, "a..b.txt").is_ok());
502        assert!(validate_path_within(root, "my.backup..tar").is_ok());
503    }
504
505    // Regression: prior validator missed Windows drive prefixes on Linux
506    // (`C:\foo` slipped through `contains("..")`).
507    #[test]
508    fn test_validate_path_windows_drive_rejected() {
509        let root = std::path::Path::new("/tmp/test_root");
510        assert!(validate_path_within(root, "C:\\windows\\system32").is_err());
511        assert!(validate_path_within(root, "D:foo").is_err());
512    }
513
514    // Regression: a planted symlink inside root pointing outside should
515    // be detected via canonicalize. Unix-only because we use the unix
516    // symlink primitive.
517    #[cfg(unix)]
518    #[test]
519    fn test_validate_path_symlink_escape_rejected() {
520        use std::os::unix::fs::symlink;
521
522        let root = std::env::temp_dir().join(format!("bcf-symlink-{}", std::process::id()));
523        std::fs::create_dir_all(&root).unwrap();
524        let link = root.join("escape");
525        let _ = std::fs::remove_file(&link);
526        symlink("/etc", &link).unwrap();
527
528        // Reading the symlinked target via `escape/passwd` would canonicalize
529        // outside root — that must fail.
530        let result = validate_path_within(&root, "escape/passwd");
531        assert!(
532            result.is_err(),
533            "planted-symlink escape should be rejected, got Ok({:?})",
534            result
535        );
536
537        let _ = std::fs::remove_file(&link);
538        let _ = std::fs::remove_dir(&root);
539    }
540
541    // Allowlist drift detector — fails loudly if someone removes an
542    // expected-allowed env or un-strips a sensitive one.
543    #[test]
544    fn test_env_allowlist_keeps_essentials() {
545        assert!(is_allowed_env_name("PATH"));
546        assert!(is_allowed_env_name("HOME"));
547        assert!(is_allowed_env_name("USER"));
548        assert!(is_allowed_env_name("VIRTUAL_ENV"));
549        assert!(is_allowed_env_name("LC_ALL"));
550        assert!(is_allowed_env_name("LC_CTYPE"));
551        // Case-insensitive
552        assert!(is_allowed_env_name("path"));
553    }
554
555    #[test]
556    fn test_env_allowlist_blocks_known_secrets() {
557        // Provider API keys
558        assert!(!is_allowed_env_name("ANTHROPIC_API_KEY"));
559        assert!(!is_allowed_env_name("XAI_API_KEY"));
560        assert!(!is_allowed_env_name("BRAVE_API_KEY"));
561        assert!(!is_allowed_env_name("OPENAI_API_KEY"));
562        assert!(!is_allowed_env_name("GH_TOKEN"));
563        assert!(!is_allowed_env_name("GITHUB_TOKEN"));
564        assert!(!is_allowed_env_name("HF_TOKEN"));
565        // Cloud creds (the substring blocklist missed these)
566        assert!(!is_allowed_env_name("AWS_ACCESS_KEY_ID"));
567        assert!(!is_allowed_env_name("AWS_SECRET_ACCESS_KEY"));
568        // Network pivots / metadata
569        assert!(!is_allowed_env_name("OLLAMA_HOST"));
570        assert!(!is_allowed_env_name("KUBECONFIG"));
571        assert!(!is_allowed_env_name("SSH_AUTH_SOCK"));
572        // Database URLs (often embed credentials)
573        assert!(!is_allowed_env_name("DATABASE_URL"));
574        assert!(!is_allowed_env_name("POSTGRES_URL"));
575        assert!(!is_allowed_env_name("REDIS_URL"));
576    }
577}