Skip to main content

lean_ctx/core/
sandbox.rs

1use std::collections::HashMap;
2use std::process::Command;
3
4#[derive(Debug, Clone)]
5pub struct SandboxResult {
6    pub stdout: String,
7    pub stderr: String,
8    pub exit_code: i32,
9    pub language: String,
10    pub duration_ms: u64,
11}
12
13const TIMEOUT_SECS: u64 = 30;
14const MAX_OUTPUT_BYTES: usize = 32_768;
15
16pub fn execute(language: &str, code: &str, timeout_secs: Option<u64>) -> SandboxResult {
17    let timeout = timeout_secs.unwrap_or(TIMEOUT_SECS);
18    let start = std::time::Instant::now();
19
20    let Some(runtime) = resolve_runtime(language) else {
21        return SandboxResult {
22                stdout: String::new(),
23                stderr: format!("Unsupported language: {language}. Supported: javascript, typescript, python, shell, ruby, go, rust, php, perl, r, elixir"),
24                exit_code: 1,
25                language: language.to_string(),
26                duration_ms: 0,
27            };
28    };
29
30    let result = if runtime.needs_temp_file {
31        execute_with_file(&runtime, code, timeout)
32    } else {
33        execute_with_stdin(&runtime, code, timeout)
34    };
35
36    let duration_ms = start.elapsed().as_millis() as u64;
37
38    match result {
39        Ok((stdout, stderr, code)) => SandboxResult {
40            stdout: truncate_output(&stdout),
41            stderr: truncate_smart(&stderr, 2048),
42            exit_code: code,
43            language: language.to_string(),
44            duration_ms,
45        },
46        Err(e) => SandboxResult {
47            stdout: String::new(),
48            stderr: format!("Execution error: {e}"),
49            exit_code: 1,
50            language: language.to_string(),
51            duration_ms,
52        },
53    }
54}
55
56pub fn batch_execute(items: &[(String, String)]) -> Vec<SandboxResult> {
57    items
58        .iter()
59        .map(|(lang, code)| execute(lang, code, None))
60        .collect()
61}
62
63struct RuntimeConfig {
64    command: String,
65    args: Vec<String>,
66    needs_temp_file: bool,
67    file_extension: String,
68    env: HashMap<String, String>,
69}
70
71fn resolve_runtime(language: &str) -> Option<RuntimeConfig> {
72    let lang = language.to_lowercase();
73    let lang = lang.as_str();
74
75    match lang {
76        "javascript" | "js" | "node" => Some(RuntimeConfig {
77            command: find_binary(&["bun", "node"])?,
78            args: vec!["-e".to_string()],
79            needs_temp_file: false,
80            file_extension: "js".to_string(),
81            env: HashMap::new(),
82        }),
83        "typescript" | "ts" => Some(RuntimeConfig {
84            command: find_binary(&["bun", "npx"])?,
85            args: if which_exists("bun") {
86                vec!["-e".to_string()]
87            } else {
88                vec!["tsx".to_string(), "-e".to_string()]
89            },
90            needs_temp_file: false,
91            file_extension: "ts".to_string(),
92            env: HashMap::new(),
93        }),
94        "python" | "py" => Some(RuntimeConfig {
95            command: find_binary(&["python3", "python"])?,
96            args: vec!["-c".to_string()],
97            needs_temp_file: false,
98            file_extension: "py".to_string(),
99            env: HashMap::from([("PYTHONDONTWRITEBYTECODE".into(), "1".into())]),
100        }),
101        "shell" | "bash" | "sh" => {
102            #[cfg(target_os = "windows")]
103            {
104                Some(RuntimeConfig {
105                    command: "cmd".to_string(),
106                    args: vec!["/C".to_string()],
107                    needs_temp_file: false,
108                    file_extension: "bat".to_string(),
109                    env: HashMap::new(),
110                })
111            }
112            #[cfg(not(target_os = "windows"))]
113            {
114                Some(RuntimeConfig {
115                    command: find_binary(&["bash", "sh"])?,
116                    args: vec!["-c".to_string()],
117                    needs_temp_file: false,
118                    file_extension: "sh".to_string(),
119                    env: HashMap::new(),
120                })
121            }
122        }
123        "ruby" | "rb" => Some(RuntimeConfig {
124            command: find_binary(&["ruby"])?,
125            args: vec!["-e".to_string()],
126            needs_temp_file: false,
127            file_extension: "rb".to_string(),
128            env: HashMap::new(),
129        }),
130        "go" | "golang" => Some(RuntimeConfig {
131            command: find_binary(&["go"])?,
132            args: vec!["run".to_string()],
133            needs_temp_file: true,
134            file_extension: "go".to_string(),
135            env: HashMap::new(),
136        }),
137        "rust" | "rs" => Some(RuntimeConfig {
138            command: "rustc_script".to_string(),
139            args: vec![],
140            needs_temp_file: true,
141            file_extension: "rs".to_string(),
142            env: HashMap::new(),
143        }),
144        "php" => Some(RuntimeConfig {
145            command: find_binary(&["php"])?,
146            args: vec!["-r".to_string()],
147            needs_temp_file: false,
148            file_extension: "php".to_string(),
149            env: HashMap::new(),
150        }),
151        "perl" | "pl" => Some(RuntimeConfig {
152            command: find_binary(&["perl"])?,
153            args: vec!["-e".to_string()],
154            needs_temp_file: false,
155            file_extension: "pl".to_string(),
156            env: HashMap::new(),
157        }),
158        "r" => Some(RuntimeConfig {
159            command: find_binary(&["Rscript"])?,
160            args: vec!["-e".to_string()],
161            needs_temp_file: false,
162            file_extension: "R".to_string(),
163            env: HashMap::new(),
164        }),
165        "elixir" | "ex" => Some(RuntimeConfig {
166            command: find_binary(&["elixir"])?,
167            args: vec!["-e".to_string()],
168            needs_temp_file: false,
169            file_extension: "exs".to_string(),
170            env: HashMap::new(),
171        }),
172        _ => None,
173    }
174}
175
176const SANDBOX_ENV_ALLOWLIST: &[&str] = &[
177    "PATH",
178    "HOME",
179    "USER",
180    "LANG",
181    "LC_ALL",
182    "TERM",
183    "TMPDIR",
184    "TMP",
185    "TEMP",
186    "SYSTEMROOT",
187    "WINDIR",
188];
189
190fn apply_sandbox_env(cmd: &mut Command, runtime: &RuntimeConfig) {
191    cmd.env_clear();
192    for key in SANDBOX_ENV_ALLOWLIST {
193        if let Ok(val) = std::env::var(key) {
194            cmd.env(key, val);
195        }
196    }
197    for (k, v) in &runtime.env {
198        cmd.env(k, v);
199    }
200    cmd.env("LEAN_CTX_SANDBOX", "1");
201}
202
203fn execute_with_stdin(
204    runtime: &RuntimeConfig,
205    code: &str,
206    timeout: u64,
207) -> Result<(String, String, i32), String> {
208    let mut cmd = Command::new(&runtime.command);
209    for arg in &runtime.args {
210        cmd.arg(arg);
211    }
212    cmd.arg(code);
213    apply_sandbox_env(&mut cmd, runtime);
214    cmd.stdout(std::process::Stdio::piped());
215    cmd.stderr(std::process::Stdio::piped());
216
217    let child = cmd
218        .spawn()
219        .map_err(|e| format!("Failed to spawn {}: {e}", runtime.command))?;
220
221    let output = wait_with_timeout(child, timeout)?;
222    Ok((
223        crate::shell::decode_output(&output.stdout),
224        crate::shell::decode_output(&output.stderr),
225        output.status.code().unwrap_or(1),
226    ))
227}
228
229fn execute_with_file(
230    runtime: &RuntimeConfig,
231    code: &str,
232    timeout: u64,
233) -> Result<(String, String, i32), String> {
234    let tmp_dir = std::env::temp_dir().join("lean-ctx-sandbox");
235    let _ = std::fs::create_dir_all(&tmp_dir);
236
237    let suffix = format!(".{}", runtime.file_extension);
238    let tmp = tempfile::Builder::new()
239        .prefix("exec_")
240        .suffix(&suffix)
241        .tempfile_in(&tmp_dir)
242        .map_err(|e| format!("Failed to create temp file: {e}"))?;
243    let file_path = tmp.into_temp_path();
244
245    std::fs::write(&file_path, code).map_err(|e| format!("Failed to write temp file: {e}"))?;
246
247    let result = if runtime.command == "rustc_script" {
248        execute_rust(&file_path, timeout)
249    } else {
250        let mut cmd = Command::new(&runtime.command);
251        for arg in &runtime.args {
252            cmd.arg(arg);
253        }
254        cmd.arg(&file_path);
255        apply_sandbox_env(&mut cmd, runtime);
256        cmd.stdout(std::process::Stdio::piped());
257        cmd.stderr(std::process::Stdio::piped());
258
259        let child = cmd
260            .spawn()
261            .map_err(|e| format!("Failed to spawn {}: {e}", runtime.command))?;
262        let output = wait_with_timeout(child, timeout)?;
263        Ok((
264            crate::shell::decode_output(&output.stdout),
265            crate::shell::decode_output(&output.stderr),
266            output.status.code().unwrap_or(1),
267        ))
268    };
269
270    let _ = std::fs::remove_file(&file_path);
271    result
272}
273
274fn execute_rust(
275    source_path: &std::path::Path,
276    timeout: u64,
277) -> Result<(String, String, i32), String> {
278    let binary_path = source_path.with_extension("");
279
280    let mut compile_cmd = Command::new("rustc");
281    compile_cmd.arg(source_path).arg("-o").arg(&binary_path);
282    compile_cmd.env_clear();
283    for key in SANDBOX_ENV_ALLOWLIST {
284        if let Ok(val) = std::env::var(key) {
285            compile_cmd.env(key, val);
286        }
287    }
288    compile_cmd.env("LEAN_CTX_SANDBOX", "1");
289
290    let compile = compile_cmd
291        .output()
292        .map_err(|e| format!("rustc not found: {e}"))?;
293
294    if !compile.status.success() {
295        let stderr = crate::shell::decode_output(&compile.stderr);
296        let _ = std::fs::remove_file(&binary_path);
297        return Ok((String::new(), stderr, compile.status.code().unwrap_or(1)));
298    }
299
300    let mut run_cmd = Command::new(&binary_path);
301    run_cmd.env_clear();
302    for key in SANDBOX_ENV_ALLOWLIST {
303        if let Ok(val) = std::env::var(key) {
304            run_cmd.env(key, val);
305        }
306    }
307    run_cmd.env("LEAN_CTX_SANDBOX", "1");
308    run_cmd.stdout(std::process::Stdio::piped());
309    run_cmd.stderr(std::process::Stdio::piped());
310
311    let child = run_cmd
312        .spawn()
313        .map_err(|e| format!("Failed to run compiled binary: {e}"))?;
314
315    let output = wait_with_timeout(child, timeout)?;
316    let _ = std::fs::remove_file(&binary_path);
317
318    Ok((
319        crate::shell::decode_output(&output.stdout),
320        crate::shell::decode_output(&output.stderr),
321        output.status.code().unwrap_or(1),
322    ))
323}
324
325fn wait_with_timeout(
326    child: std::process::Child,
327    timeout_secs: u64,
328) -> Result<std::process::Output, String> {
329    let mut child = child;
330    let deadline = std::time::Instant::now() + std::time::Duration::from_secs(timeout_secs);
331
332    loop {
333        match child.try_wait() {
334            Ok(Some(_)) => return child.wait_with_output().map_err(|e| e.to_string()),
335            Ok(None) => {
336                if std::time::Instant::now() > deadline {
337                    let _ = child.kill();
338                    return Err(format!("Execution timed out after {timeout_secs}s"));
339                }
340                std::thread::sleep(std::time::Duration::from_millis(50));
341            }
342            Err(e) => return Err(e.to_string()),
343        }
344    }
345}
346
347fn find_binary(candidates: &[&str]) -> Option<String> {
348    for name in candidates {
349        if which_exists(name) {
350            return Some(name.to_string());
351        }
352    }
353    None
354}
355
356fn which_exists(name: &str) -> bool {
357    #[cfg(target_os = "windows")]
358    let check_cmd = Command::new("where")
359        .arg(name)
360        .stdout(std::process::Stdio::null())
361        .stderr(std::process::Stdio::null())
362        .status();
363
364    #[cfg(not(target_os = "windows"))]
365    let check_cmd = Command::new("which")
366        .arg(name)
367        .stdout(std::process::Stdio::null())
368        .stderr(std::process::Stdio::null())
369        .status();
370
371    check_cmd.is_ok_and(|s| s.success())
372}
373
374fn truncate_output(output: &str) -> String {
375    if output.len() <= MAX_OUTPUT_BYTES {
376        return output.to_string();
377    }
378    truncate_smart(output, MAX_OUTPUT_BYTES)
379}
380
381fn truncate_smart(output: &str, max_bytes: usize) -> String {
382    if output.len() <= max_bytes {
383        return output.to_string();
384    }
385
386    let lines: Vec<&str> = output.lines().collect();
387    let total_lines = lines.len();
388
389    let head_count = (total_lines * 60) / 100;
390    let tail_count = total_lines - head_count;
391
392    let head: Vec<&str> = lines.iter().take(head_count).copied().collect();
393    let tail: Vec<&str> = lines
394        .iter()
395        .skip(total_lines - tail_count)
396        .copied()
397        .collect();
398
399    let head_text = head.join("\n");
400    let tail_text = tail.join("\n");
401
402    if head_text.len() + tail_text.len() + 100 > max_bytes {
403        let half = max_bytes / 2;
404        let h = &output[..half.min(output.len())];
405        let t_start = output.len().saturating_sub(half);
406        let t = &output[t_start..];
407        let skipped = output.len() - h.len() - t.len();
408        return format!("{h}\n\n... [{skipped} bytes truncated — showing head + tail] ...\n\n{t}");
409    }
410
411    let skipped_lines = total_lines - head_count - tail_count;
412    let skipped_bytes = output.len() - head_text.len() - tail_text.len();
413    format!(
414        "{head_text}\n\n... [{skipped_lines} lines / {skipped_bytes} bytes truncated — showing first {head_count} + last {tail_count} lines] ...\n\n{tail_text}"
415    )
416}
417
418pub fn supported_languages() -> &'static [&'static str] {
419    &[
420        "javascript",
421        "typescript",
422        "python",
423        "shell",
424        "ruby",
425        "go",
426        "rust",
427        "php",
428        "perl",
429        "r",
430        "elixir",
431    ]
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    fn python_available() -> bool {
439        find_binary(&["python3", "python"]).is_some()
440    }
441
442    #[test]
443    fn execute_python_hello() {
444        if !python_available() {
445            return;
446        }
447        let result = execute("python", "print('hello sandbox')", None);
448        assert_eq!(result.exit_code, 0);
449        assert!(result.stdout.contains("hello sandbox"));
450    }
451
452    #[test]
453    #[cfg(not(target_os = "windows"))]
454    fn execute_shell_echo() {
455        let result = execute("shell", "echo 'test output'", None);
456        assert_eq!(result.exit_code, 0);
457        assert!(result.stdout.contains("test output"));
458    }
459
460    #[test]
461    fn execute_unsupported_language() {
462        let result = execute("brainfuck", "++++", None);
463        assert_eq!(result.exit_code, 1);
464        assert!(result.stderr.contains("Unsupported language"));
465    }
466
467    #[test]
468    fn execute_python_error() {
469        if !python_available() {
470            return;
471        }
472        let result = execute("python", "raise ValueError('test error')", None);
473        assert_ne!(result.exit_code, 0);
474        assert!(result.stderr.contains("ValueError"));
475    }
476
477    #[test]
478    fn execute_with_timeout() {
479        if !python_available() {
480            return;
481        }
482        let result = execute("python", "import time; time.sleep(60)", Some(1));
483        assert_ne!(result.exit_code, 0);
484    }
485
486    #[test]
487    fn truncate_preserves_head_and_tail() {
488        let lines: Vec<String> = (0..100)
489            .map(|i| format!("line {i}: some content here"))
490            .collect();
491        let output = lines.join("\n");
492        let truncated = truncate_smart(&output, 500);
493        assert!(truncated.contains("line 0:"));
494        assert!(truncated.contains("line 99:"));
495        assert!(truncated.contains("truncated"));
496    }
497
498    #[test]
499    fn supported_languages_list() {
500        let langs = supported_languages();
501        assert!(langs.contains(&"python"));
502        assert!(langs.contains(&"javascript"));
503        assert!(langs.contains(&"rust"));
504        assert_eq!(langs.len(), 11);
505    }
506
507    #[test]
508    fn sandbox_env_is_set() {
509        if !python_available() {
510            return;
511        }
512        let result = execute(
513            "python",
514            "import os; print(os.environ.get('LEAN_CTX_SANDBOX', 'missing'))",
515            None,
516        );
517        assert_eq!(result.exit_code, 0);
518        assert!(result.stdout.contains('1'));
519    }
520
521    #[test]
522    #[cfg(not(target_os = "windows"))]
523    fn batch_execute_multiple() {
524        let items = vec![
525            ("python".to_string(), "print(1+1)".to_string()),
526            ("shell".to_string(), "echo hello".to_string()),
527        ];
528        let results = batch_execute(&items);
529        assert_eq!(results.len(), 2);
530        assert!(results[0].stdout.contains('2'));
531        assert!(results[1].stdout.contains("hello"));
532    }
533}