Skip to main content

lean_ctx/core/
sandbox.rs

1use std::collections::HashMap;
2use std::process::Command;
3
4#[derive(Debug, Clone)]
5pub struct SandboxResult {
6    pub stdout: String,
7    pub stderr: String,
8    pub exit_code: i32,
9    pub language: String,
10    pub duration_ms: u64,
11}
12
13const TIMEOUT_SECS: u64 = 30;
14const MAX_OUTPUT_BYTES: usize = 32_768;
15
16pub fn execute(language: &str, code: &str, timeout_secs: Option<u64>) -> SandboxResult {
17    let timeout = timeout_secs.unwrap_or(TIMEOUT_SECS);
18    let start = std::time::Instant::now();
19
20    let Some(runtime) = resolve_runtime(language) else {
21        return SandboxResult {
22                stdout: String::new(),
23                stderr: format!("Unsupported language: {language}. Supported: javascript, typescript, python, shell, ruby, go, rust, php, perl, r, elixir"),
24                exit_code: 1,
25                language: language.to_string(),
26                duration_ms: 0,
27            };
28    };
29
30    let result = if runtime.needs_temp_file {
31        execute_with_file(&runtime, code, timeout)
32    } else {
33        execute_with_stdin(&runtime, code, timeout)
34    };
35
36    let duration_ms = start.elapsed().as_millis() as u64;
37
38    match result {
39        Ok((stdout, stderr, code)) => SandboxResult {
40            stdout: truncate_output(&stdout),
41            stderr: truncate_smart(&stderr, 2048),
42            exit_code: code,
43            language: language.to_string(),
44            duration_ms,
45        },
46        Err(e) => SandboxResult {
47            stdout: String::new(),
48            stderr: format!("Execution error: {e}"),
49            exit_code: 1,
50            language: language.to_string(),
51            duration_ms,
52        },
53    }
54}
55
56pub fn batch_execute(items: &[(String, String)]) -> Vec<SandboxResult> {
57    items
58        .iter()
59        .map(|(lang, code)| execute(lang, code, None))
60        .collect()
61}
62
63struct RuntimeConfig {
64    command: String,
65    args: Vec<String>,
66    needs_temp_file: bool,
67    file_extension: String,
68    env: HashMap<String, String>,
69}
70
71fn resolve_runtime(language: &str) -> Option<RuntimeConfig> {
72    let lang = language.to_lowercase();
73    let lang = lang.as_str();
74
75    match lang {
76        "javascript" | "js" | "node" => Some(RuntimeConfig {
77            command: find_binary(&["bun", "node"])?,
78            args: vec!["-e".to_string()],
79            needs_temp_file: false,
80            file_extension: "js".to_string(),
81            env: HashMap::new(),
82        }),
83        "typescript" | "ts" => Some(RuntimeConfig {
84            command: find_binary(&["bun", "npx"])?,
85            args: if which_exists("bun") {
86                vec!["-e".to_string()]
87            } else {
88                vec!["tsx".to_string(), "-e".to_string()]
89            },
90            needs_temp_file: false,
91            file_extension: "ts".to_string(),
92            env: HashMap::new(),
93        }),
94        "python" | "py" => Some(RuntimeConfig {
95            command: find_binary(&["python3", "python"])?,
96            args: vec!["-c".to_string()],
97            needs_temp_file: false,
98            file_extension: "py".to_string(),
99            env: HashMap::from([("PYTHONDONTWRITEBYTECODE".into(), "1".into())]),
100        }),
101        "shell" | "bash" | "sh" => {
102            #[cfg(target_os = "windows")]
103            {
104                Some(RuntimeConfig {
105                    command: "cmd".to_string(),
106                    args: vec!["/C".to_string()],
107                    needs_temp_file: false,
108                    file_extension: "bat".to_string(),
109                    env: HashMap::new(),
110                })
111            }
112            #[cfg(not(target_os = "windows"))]
113            {
114                Some(RuntimeConfig {
115                    command: find_binary(&["bash", "sh"])?,
116                    args: vec!["-c".to_string()],
117                    needs_temp_file: false,
118                    file_extension: "sh".to_string(),
119                    env: HashMap::new(),
120                })
121            }
122        }
123        "ruby" | "rb" => Some(RuntimeConfig {
124            command: find_binary(&["ruby"])?,
125            args: vec!["-e".to_string()],
126            needs_temp_file: false,
127            file_extension: "rb".to_string(),
128            env: HashMap::new(),
129        }),
130        "go" | "golang" => Some(RuntimeConfig {
131            command: find_binary(&["go"])?,
132            args: vec!["run".to_string()],
133            needs_temp_file: true,
134            file_extension: "go".to_string(),
135            env: HashMap::new(),
136        }),
137        "rust" | "rs" => Some(RuntimeConfig {
138            command: "rustc_script".to_string(),
139            args: vec![],
140            needs_temp_file: true,
141            file_extension: "rs".to_string(),
142            env: HashMap::new(),
143        }),
144        "php" => Some(RuntimeConfig {
145            command: find_binary(&["php"])?,
146            args: vec!["-r".to_string()],
147            needs_temp_file: false,
148            file_extension: "php".to_string(),
149            env: HashMap::new(),
150        }),
151        "perl" | "pl" => Some(RuntimeConfig {
152            command: find_binary(&["perl"])?,
153            args: vec!["-e".to_string()],
154            needs_temp_file: false,
155            file_extension: "pl".to_string(),
156            env: HashMap::new(),
157        }),
158        "r" => Some(RuntimeConfig {
159            command: find_binary(&["Rscript"])?,
160            args: vec!["-e".to_string()],
161            needs_temp_file: false,
162            file_extension: "R".to_string(),
163            env: HashMap::new(),
164        }),
165        "elixir" | "ex" => Some(RuntimeConfig {
166            command: find_binary(&["elixir"])?,
167            args: vec!["-e".to_string()],
168            needs_temp_file: false,
169            file_extension: "exs".to_string(),
170            env: HashMap::new(),
171        }),
172        _ => None,
173    }
174}
175
176fn execute_with_stdin(
177    runtime: &RuntimeConfig,
178    code: &str,
179    timeout: u64,
180) -> Result<(String, String, i32), String> {
181    let mut cmd = Command::new(&runtime.command);
182    for arg in &runtime.args {
183        cmd.arg(arg);
184    }
185    cmd.arg(code);
186
187    for (k, v) in &runtime.env {
188        cmd.env(k, v);
189    }
190
191    cmd.env("LEAN_CTX_SANDBOX", "1");
192    cmd.stdout(std::process::Stdio::piped());
193    cmd.stderr(std::process::Stdio::piped());
194
195    let child = cmd
196        .spawn()
197        .map_err(|e| format!("Failed to spawn {}: {e}", runtime.command))?;
198
199    let output = wait_with_timeout(child, timeout)?;
200    Ok((
201        crate::shell::decode_output(&output.stdout),
202        crate::shell::decode_output(&output.stderr),
203        output.status.code().unwrap_or(1),
204    ))
205}
206
207fn execute_with_file(
208    runtime: &RuntimeConfig,
209    code: &str,
210    timeout: u64,
211) -> Result<(String, String, i32), String> {
212    let tmp_dir = std::env::temp_dir().join("lean-ctx-sandbox");
213    let _ = std::fs::create_dir_all(&tmp_dir);
214
215    let file_name = format!(
216        "exec_{}.{}",
217        std::time::SystemTime::now()
218            .duration_since(std::time::UNIX_EPOCH)
219            .unwrap_or_default()
220            .as_nanos(),
221        runtime.file_extension
222    );
223    let file_path = tmp_dir.join(&file_name);
224
225    std::fs::write(&file_path, code).map_err(|e| format!("Failed to write temp file: {e}"))?;
226
227    let result = if runtime.command == "rustc_script" {
228        execute_rust(&file_path, timeout)
229    } else {
230        let mut cmd = Command::new(&runtime.command);
231        for arg in &runtime.args {
232            cmd.arg(arg);
233        }
234        cmd.arg(&file_path);
235        for (k, v) in &runtime.env {
236            cmd.env(k, v);
237        }
238        cmd.env("LEAN_CTX_SANDBOX", "1");
239        cmd.stdout(std::process::Stdio::piped());
240        cmd.stderr(std::process::Stdio::piped());
241
242        let child = cmd
243            .spawn()
244            .map_err(|e| format!("Failed to spawn {}: {e}", runtime.command))?;
245        let output = wait_with_timeout(child, timeout)?;
246        Ok((
247            crate::shell::decode_output(&output.stdout),
248            crate::shell::decode_output(&output.stderr),
249            output.status.code().unwrap_or(1),
250        ))
251    };
252
253    let _ = std::fs::remove_file(&file_path);
254    result
255}
256
257fn execute_rust(
258    source_path: &std::path::Path,
259    timeout: u64,
260) -> Result<(String, String, i32), String> {
261    let binary_path = source_path.with_extension("");
262
263    let compile = Command::new("rustc")
264        .arg(source_path)
265        .arg("-o")
266        .arg(&binary_path)
267        .output()
268        .map_err(|e| format!("rustc not found: {e}"))?;
269
270    if !compile.status.success() {
271        let stderr = crate::shell::decode_output(&compile.stderr);
272        let _ = std::fs::remove_file(&binary_path);
273        return Ok((String::new(), stderr, compile.status.code().unwrap_or(1)));
274    }
275
276    let child = Command::new(&binary_path)
277        .env("LEAN_CTX_SANDBOX", "1")
278        .stdout(std::process::Stdio::piped())
279        .stderr(std::process::Stdio::piped())
280        .spawn()
281        .map_err(|e| format!("Failed to run compiled binary: {e}"))?;
282
283    let output = wait_with_timeout(child, timeout)?;
284    let _ = std::fs::remove_file(&binary_path);
285
286    Ok((
287        crate::shell::decode_output(&output.stdout),
288        crate::shell::decode_output(&output.stderr),
289        output.status.code().unwrap_or(1),
290    ))
291}
292
293fn wait_with_timeout(
294    child: std::process::Child,
295    timeout_secs: u64,
296) -> Result<std::process::Output, String> {
297    let mut child = child;
298    let deadline = std::time::Instant::now() + std::time::Duration::from_secs(timeout_secs);
299
300    loop {
301        match child.try_wait() {
302            Ok(Some(_)) => return child.wait_with_output().map_err(|e| e.to_string()),
303            Ok(None) => {
304                if std::time::Instant::now() > deadline {
305                    let _ = child.kill();
306                    return Err(format!("Execution timed out after {timeout_secs}s"));
307                }
308                std::thread::sleep(std::time::Duration::from_millis(50));
309            }
310            Err(e) => return Err(e.to_string()),
311        }
312    }
313}
314
315fn find_binary(candidates: &[&str]) -> Option<String> {
316    for name in candidates {
317        if which_exists(name) {
318            return Some(name.to_string());
319        }
320    }
321    None
322}
323
324fn which_exists(name: &str) -> bool {
325    #[cfg(target_os = "windows")]
326    let check_cmd = Command::new("where")
327        .arg(name)
328        .stdout(std::process::Stdio::null())
329        .stderr(std::process::Stdio::null())
330        .status();
331
332    #[cfg(not(target_os = "windows"))]
333    let check_cmd = Command::new("which")
334        .arg(name)
335        .stdout(std::process::Stdio::null())
336        .stderr(std::process::Stdio::null())
337        .status();
338
339    check_cmd.is_ok_and(|s| s.success())
340}
341
342fn truncate_output(output: &str) -> String {
343    if output.len() <= MAX_OUTPUT_BYTES {
344        return output.to_string();
345    }
346    truncate_smart(output, MAX_OUTPUT_BYTES)
347}
348
349fn truncate_smart(output: &str, max_bytes: usize) -> String {
350    if output.len() <= max_bytes {
351        return output.to_string();
352    }
353
354    let lines: Vec<&str> = output.lines().collect();
355    let total_lines = lines.len();
356
357    let head_count = (total_lines * 60) / 100;
358    let tail_count = total_lines - head_count;
359
360    let head: Vec<&str> = lines.iter().take(head_count).copied().collect();
361    let tail: Vec<&str> = lines
362        .iter()
363        .skip(total_lines - tail_count)
364        .copied()
365        .collect();
366
367    let head_text = head.join("\n");
368    let tail_text = tail.join("\n");
369
370    if head_text.len() + tail_text.len() + 100 > max_bytes {
371        let half = max_bytes / 2;
372        let h = &output[..half.min(output.len())];
373        let t_start = output.len().saturating_sub(half);
374        let t = &output[t_start..];
375        let skipped = output.len() - h.len() - t.len();
376        return format!("{h}\n\n... [{skipped} bytes truncated — showing head + tail] ...\n\n{t}");
377    }
378
379    let skipped_lines = total_lines - head_count - tail_count;
380    let skipped_bytes = output.len() - head_text.len() - tail_text.len();
381    format!(
382        "{head_text}\n\n... [{skipped_lines} lines / {skipped_bytes} bytes truncated — showing first {head_count} + last {tail_count} lines] ...\n\n{tail_text}"
383    )
384}
385
386pub fn supported_languages() -> &'static [&'static str] {
387    &[
388        "javascript",
389        "typescript",
390        "python",
391        "shell",
392        "ruby",
393        "go",
394        "rust",
395        "php",
396        "perl",
397        "r",
398        "elixir",
399    ]
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405
406    fn python_available() -> bool {
407        find_binary(&["python3", "python"]).is_some()
408    }
409
410    #[test]
411    fn execute_python_hello() {
412        if !python_available() {
413            return;
414        }
415        let result = execute("python", "print('hello sandbox')", None);
416        assert_eq!(result.exit_code, 0);
417        assert!(result.stdout.contains("hello sandbox"));
418    }
419
420    #[test]
421    #[cfg(not(target_os = "windows"))]
422    fn execute_shell_echo() {
423        let result = execute("shell", "echo 'test output'", None);
424        assert_eq!(result.exit_code, 0);
425        assert!(result.stdout.contains("test output"));
426    }
427
428    #[test]
429    fn execute_unsupported_language() {
430        let result = execute("brainfuck", "++++", None);
431        assert_eq!(result.exit_code, 1);
432        assert!(result.stderr.contains("Unsupported language"));
433    }
434
435    #[test]
436    fn execute_python_error() {
437        if !python_available() {
438            return;
439        }
440        let result = execute("python", "raise ValueError('test error')", None);
441        assert_ne!(result.exit_code, 0);
442        assert!(result.stderr.contains("ValueError"));
443    }
444
445    #[test]
446    fn execute_with_timeout() {
447        if !python_available() {
448            return;
449        }
450        let result = execute("python", "import time; time.sleep(60)", Some(1));
451        assert_ne!(result.exit_code, 0);
452    }
453
454    #[test]
455    fn truncate_preserves_head_and_tail() {
456        let lines: Vec<String> = (0..100)
457            .map(|i| format!("line {i}: some content here"))
458            .collect();
459        let output = lines.join("\n");
460        let truncated = truncate_smart(&output, 500);
461        assert!(truncated.contains("line 0:"));
462        assert!(truncated.contains("line 99:"));
463        assert!(truncated.contains("truncated"));
464    }
465
466    #[test]
467    fn supported_languages_list() {
468        let langs = supported_languages();
469        assert!(langs.contains(&"python"));
470        assert!(langs.contains(&"javascript"));
471        assert!(langs.contains(&"rust"));
472        assert_eq!(langs.len(), 11);
473    }
474
475    #[test]
476    fn sandbox_env_is_set() {
477        if !python_available() {
478            return;
479        }
480        let result = execute(
481            "python",
482            "import os; print(os.environ.get('LEAN_CTX_SANDBOX', 'missing'))",
483            None,
484        );
485        assert_eq!(result.exit_code, 0);
486        assert!(result.stdout.contains('1'));
487    }
488
489    #[test]
490    #[cfg(not(target_os = "windows"))]
491    fn batch_execute_multiple() {
492        let items = vec![
493            ("python".to_string(), "print(1+1)".to_string()),
494            ("shell".to_string(), "echo hello".to_string()),
495        ];
496        let results = batch_execute(&items);
497        assert_eq!(results.len(), 2);
498        assert!(results[0].stdout.contains('2'));
499        assert!(results[1].stdout.contains("hello"));
500    }
501}