Skip to main content

lean_ctx/core/
sandbox.rs

1use std::collections::HashMap;
2use std::process::Command;
3
4#[derive(Debug, Clone)]
5pub struct SandboxResult {
6    pub stdout: String,
7    pub stderr: String,
8    pub exit_code: i32,
9    pub language: String,
10    pub duration_ms: u64,
11}
12
13const TIMEOUT_SECS: u64 = 30;
14const MAX_OUTPUT_BYTES: usize = 32_768;
15
16pub fn execute(language: &str, code: &str, timeout_secs: Option<u64>) -> SandboxResult {
17    let timeout = timeout_secs.unwrap_or(TIMEOUT_SECS);
18    let start = std::time::Instant::now();
19
20    let runtime = match resolve_runtime(language) {
21        Some(r) => r,
22        None => {
23            return SandboxResult {
24                stdout: String::new(),
25                stderr: format!("Unsupported language: {language}. Supported: javascript, typescript, python, shell, ruby, go, rust, php, perl, r, elixir"),
26                exit_code: 1,
27                language: language.to_string(),
28                duration_ms: 0,
29            };
30        }
31    };
32
33    let result = if runtime.needs_temp_file {
34        execute_with_file(&runtime, code, timeout)
35    } else {
36        execute_with_stdin(&runtime, code, timeout)
37    };
38
39    let duration_ms = start.elapsed().as_millis() as u64;
40
41    match result {
42        Ok((stdout, stderr, code)) => SandboxResult {
43            stdout: truncate_output(&stdout),
44            stderr: truncate_smart(&stderr, 2048),
45            exit_code: code,
46            language: language.to_string(),
47            duration_ms,
48        },
49        Err(e) => SandboxResult {
50            stdout: String::new(),
51            stderr: format!("Execution error: {e}"),
52            exit_code: 1,
53            language: language.to_string(),
54            duration_ms,
55        },
56    }
57}
58
59pub fn batch_execute(items: &[(String, String)]) -> Vec<SandboxResult> {
60    items
61        .iter()
62        .map(|(lang, code)| execute(lang, code, None))
63        .collect()
64}
65
66struct RuntimeConfig {
67    command: String,
68    args: Vec<String>,
69    needs_temp_file: bool,
70    file_extension: String,
71    env: HashMap<String, String>,
72}
73
74fn resolve_runtime(language: &str) -> Option<RuntimeConfig> {
75    let lang = language.to_lowercase();
76    let lang = lang.as_str();
77
78    match lang {
79        "javascript" | "js" | "node" => Some(RuntimeConfig {
80            command: find_binary(&["bun", "node"])?,
81            args: vec!["-e".to_string()],
82            needs_temp_file: false,
83            file_extension: "js".to_string(),
84            env: HashMap::new(),
85        }),
86        "typescript" | "ts" => Some(RuntimeConfig {
87            command: find_binary(&["bun", "npx"])?,
88            args: if which_exists("bun") {
89                vec!["-e".to_string()]
90            } else {
91                vec!["tsx".to_string(), "-e".to_string()]
92            },
93            needs_temp_file: false,
94            file_extension: "ts".to_string(),
95            env: HashMap::new(),
96        }),
97        "python" | "py" => Some(RuntimeConfig {
98            command: find_binary(&["python3", "python"])?,
99            args: vec!["-c".to_string()],
100            needs_temp_file: false,
101            file_extension: "py".to_string(),
102            env: HashMap::from([("PYTHONDONTWRITEBYTECODE".into(), "1".into())]),
103        }),
104        "shell" | "bash" | "sh" => {
105            #[cfg(target_os = "windows")]
106            {
107                Some(RuntimeConfig {
108                    command: "cmd".to_string(),
109                    args: vec!["/C".to_string()],
110                    needs_temp_file: false,
111                    file_extension: "bat".to_string(),
112                    env: HashMap::new(),
113                })
114            }
115            #[cfg(not(target_os = "windows"))]
116            {
117                Some(RuntimeConfig {
118                    command: find_binary(&["bash", "sh"])?,
119                    args: vec!["-c".to_string()],
120                    needs_temp_file: false,
121                    file_extension: "sh".to_string(),
122                    env: HashMap::new(),
123                })
124            }
125        }
126        "ruby" | "rb" => Some(RuntimeConfig {
127            command: find_binary(&["ruby"])?,
128            args: vec!["-e".to_string()],
129            needs_temp_file: false,
130            file_extension: "rb".to_string(),
131            env: HashMap::new(),
132        }),
133        "go" | "golang" => Some(RuntimeConfig {
134            command: find_binary(&["go"])?,
135            args: vec!["run".to_string()],
136            needs_temp_file: true,
137            file_extension: "go".to_string(),
138            env: HashMap::new(),
139        }),
140        "rust" | "rs" => Some(RuntimeConfig {
141            command: "rustc_script".to_string(),
142            args: vec![],
143            needs_temp_file: true,
144            file_extension: "rs".to_string(),
145            env: HashMap::new(),
146        }),
147        "php" => Some(RuntimeConfig {
148            command: find_binary(&["php"])?,
149            args: vec!["-r".to_string()],
150            needs_temp_file: false,
151            file_extension: "php".to_string(),
152            env: HashMap::new(),
153        }),
154        "perl" | "pl" => Some(RuntimeConfig {
155            command: find_binary(&["perl"])?,
156            args: vec!["-e".to_string()],
157            needs_temp_file: false,
158            file_extension: "pl".to_string(),
159            env: HashMap::new(),
160        }),
161        "r" => Some(RuntimeConfig {
162            command: find_binary(&["Rscript"])?,
163            args: vec!["-e".to_string()],
164            needs_temp_file: false,
165            file_extension: "R".to_string(),
166            env: HashMap::new(),
167        }),
168        "elixir" | "ex" => Some(RuntimeConfig {
169            command: find_binary(&["elixir"])?,
170            args: vec!["-e".to_string()],
171            needs_temp_file: false,
172            file_extension: "exs".to_string(),
173            env: HashMap::new(),
174        }),
175        _ => None,
176    }
177}
178
179fn execute_with_stdin(
180    runtime: &RuntimeConfig,
181    code: &str,
182    timeout: u64,
183) -> Result<(String, String, i32), String> {
184    let mut cmd = Command::new(&runtime.command);
185    for arg in &runtime.args {
186        cmd.arg(arg);
187    }
188    cmd.arg(code);
189
190    for (k, v) in &runtime.env {
191        cmd.env(k, v);
192    }
193
194    cmd.env("LEAN_CTX_SANDBOX", "1");
195    cmd.stdout(std::process::Stdio::piped());
196    cmd.stderr(std::process::Stdio::piped());
197
198    let child = cmd
199        .spawn()
200        .map_err(|e| format!("Failed to spawn {}: {e}", runtime.command))?;
201
202    let output = wait_with_timeout(child, timeout)?;
203    Ok((
204        String::from_utf8_lossy(&output.stdout).to_string(),
205        String::from_utf8_lossy(&output.stderr).to_string(),
206        output.status.code().unwrap_or(1),
207    ))
208}
209
210fn execute_with_file(
211    runtime: &RuntimeConfig,
212    code: &str,
213    timeout: u64,
214) -> Result<(String, String, i32), String> {
215    let tmp_dir = std::env::temp_dir().join("lean-ctx-sandbox");
216    let _ = std::fs::create_dir_all(&tmp_dir);
217
218    let file_name = format!(
219        "exec_{}.{}",
220        std::time::SystemTime::now()
221            .duration_since(std::time::UNIX_EPOCH)
222            .unwrap_or_default()
223            .as_nanos(),
224        runtime.file_extension
225    );
226    let file_path = tmp_dir.join(&file_name);
227
228    std::fs::write(&file_path, code).map_err(|e| format!("Failed to write temp file: {e}"))?;
229
230    let result = if runtime.command == "rustc_script" {
231        execute_rust(&file_path, timeout)
232    } else {
233        let mut cmd = Command::new(&runtime.command);
234        for arg in &runtime.args {
235            cmd.arg(arg);
236        }
237        cmd.arg(&file_path);
238        for (k, v) in &runtime.env {
239            cmd.env(k, v);
240        }
241        cmd.env("LEAN_CTX_SANDBOX", "1");
242        cmd.stdout(std::process::Stdio::piped());
243        cmd.stderr(std::process::Stdio::piped());
244
245        let child = cmd
246            .spawn()
247            .map_err(|e| format!("Failed to spawn {}: {e}", runtime.command))?;
248        let output = wait_with_timeout(child, timeout)?;
249        Ok((
250            String::from_utf8_lossy(&output.stdout).to_string(),
251            String::from_utf8_lossy(&output.stderr).to_string(),
252            output.status.code().unwrap_or(1),
253        ))
254    };
255
256    let _ = std::fs::remove_file(&file_path);
257    result
258}
259
260fn execute_rust(
261    source_path: &std::path::Path,
262    timeout: u64,
263) -> Result<(String, String, i32), String> {
264    let binary_path = source_path.with_extension("");
265
266    let compile = Command::new("rustc")
267        .arg(source_path)
268        .arg("-o")
269        .arg(&binary_path)
270        .output()
271        .map_err(|e| format!("rustc not found: {e}"))?;
272
273    if !compile.status.success() {
274        let stderr = String::from_utf8_lossy(&compile.stderr).to_string();
275        let _ = std::fs::remove_file(&binary_path);
276        return Ok((String::new(), stderr, compile.status.code().unwrap_or(1)));
277    }
278
279    let child = Command::new(&binary_path)
280        .env("LEAN_CTX_SANDBOX", "1")
281        .stdout(std::process::Stdio::piped())
282        .stderr(std::process::Stdio::piped())
283        .spawn()
284        .map_err(|e| format!("Failed to run compiled binary: {e}"))?;
285
286    let output = wait_with_timeout(child, timeout)?;
287    let _ = std::fs::remove_file(&binary_path);
288
289    Ok((
290        String::from_utf8_lossy(&output.stdout).to_string(),
291        String::from_utf8_lossy(&output.stderr).to_string(),
292        output.status.code().unwrap_or(1),
293    ))
294}
295
296fn wait_with_timeout(
297    child: std::process::Child,
298    timeout_secs: u64,
299) -> Result<std::process::Output, String> {
300    let mut child = child;
301    let deadline = std::time::Instant::now() + std::time::Duration::from_secs(timeout_secs);
302
303    loop {
304        match child.try_wait() {
305            Ok(Some(_)) => return child.wait_with_output().map_err(|e| e.to_string()),
306            Ok(None) => {
307                if std::time::Instant::now() > deadline {
308                    let _ = child.kill();
309                    return Err(format!("Execution timed out after {timeout_secs}s"));
310                }
311                std::thread::sleep(std::time::Duration::from_millis(50));
312            }
313            Err(e) => return Err(e.to_string()),
314        }
315    }
316}
317
318fn find_binary(candidates: &[&str]) -> Option<String> {
319    for name in candidates {
320        if which_exists(name) {
321            return Some(name.to_string());
322        }
323    }
324    None
325}
326
327fn which_exists(name: &str) -> bool {
328    #[cfg(target_os = "windows")]
329    let check_cmd = Command::new("where")
330        .arg(name)
331        .stdout(std::process::Stdio::null())
332        .stderr(std::process::Stdio::null())
333        .status();
334
335    #[cfg(not(target_os = "windows"))]
336    let check_cmd = Command::new("which")
337        .arg(name)
338        .stdout(std::process::Stdio::null())
339        .stderr(std::process::Stdio::null())
340        .status();
341
342    check_cmd.map(|s| s.success()).unwrap_or(false)
343}
344
345fn truncate_output(output: &str) -> String {
346    if output.len() <= MAX_OUTPUT_BYTES {
347        return output.to_string();
348    }
349    truncate_smart(output, MAX_OUTPUT_BYTES)
350}
351
352fn truncate_smart(output: &str, max_bytes: usize) -> String {
353    if output.len() <= max_bytes {
354        return output.to_string();
355    }
356
357    let lines: Vec<&str> = output.lines().collect();
358    let total_lines = lines.len();
359
360    let head_count = (total_lines * 60) / 100;
361    let tail_count = total_lines - head_count;
362
363    let head: Vec<&str> = lines.iter().take(head_count).copied().collect();
364    let tail: Vec<&str> = lines
365        .iter()
366        .skip(total_lines - tail_count)
367        .copied()
368        .collect();
369
370    let head_text = head.join("\n");
371    let tail_text = tail.join("\n");
372
373    if head_text.len() + tail_text.len() + 100 > max_bytes {
374        let half = max_bytes / 2;
375        let h = &output[..half.min(output.len())];
376        let t_start = output.len().saturating_sub(half);
377        let t = &output[t_start..];
378        let skipped = output.len() - h.len() - t.len();
379        return format!("{h}\n\n... [{skipped} bytes truncated — showing head + tail] ...\n\n{t}");
380    }
381
382    let skipped_lines = total_lines - head_count - tail_count;
383    let skipped_bytes = output.len() - head_text.len() - tail_text.len();
384    format!(
385        "{head_text}\n\n... [{skipped_lines} lines / {skipped_bytes} bytes truncated — showing first {head_count} + last {tail_count} lines] ...\n\n{tail_text}"
386    )
387}
388
389pub fn supported_languages() -> &'static [&'static str] {
390    &[
391        "javascript",
392        "typescript",
393        "python",
394        "shell",
395        "ruby",
396        "go",
397        "rust",
398        "php",
399        "perl",
400        "r",
401        "elixir",
402    ]
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn execute_python_hello() {
411        let result = execute("python", "print('hello sandbox')", None);
412        assert_eq!(result.exit_code, 0);
413        assert!(result.stdout.contains("hello sandbox"));
414    }
415
416    #[test]
417    #[cfg(not(target_os = "windows"))]
418    fn execute_shell_echo() {
419        let result = execute("shell", "echo 'test output'", None);
420        assert_eq!(result.exit_code, 0);
421        assert!(result.stdout.contains("test output"));
422    }
423
424    #[test]
425    fn execute_unsupported_language() {
426        let result = execute("brainfuck", "++++", None);
427        assert_eq!(result.exit_code, 1);
428        assert!(result.stderr.contains("Unsupported language"));
429    }
430
431    #[test]
432    fn execute_python_error() {
433        let result = execute("python", "raise ValueError('test error')", None);
434        assert_ne!(result.exit_code, 0);
435        assert!(result.stderr.contains("ValueError"));
436    }
437
438    #[test]
439    fn execute_with_timeout() {
440        let result = execute("python", "import time; time.sleep(60)", Some(1));
441        assert_ne!(result.exit_code, 0);
442    }
443
444    #[test]
445    fn truncate_preserves_head_and_tail() {
446        let lines: Vec<String> = (0..100)
447            .map(|i| format!("line {i}: some content here"))
448            .collect();
449        let output = lines.join("\n");
450        let truncated = truncate_smart(&output, 500);
451        assert!(truncated.contains("line 0:"));
452        assert!(truncated.contains("line 99:"));
453        assert!(truncated.contains("truncated"));
454    }
455
456    #[test]
457    fn supported_languages_list() {
458        let langs = supported_languages();
459        assert!(langs.contains(&"python"));
460        assert!(langs.contains(&"javascript"));
461        assert!(langs.contains(&"rust"));
462        assert_eq!(langs.len(), 11);
463    }
464
465    #[test]
466    fn sandbox_env_is_set() {
467        let result = execute(
468            "python",
469            "import os; print(os.environ.get('LEAN_CTX_SANDBOX', 'missing'))",
470            None,
471        );
472        assert_eq!(result.exit_code, 0);
473        assert!(result.stdout.contains("1"));
474    }
475
476    #[test]
477    #[cfg(not(target_os = "windows"))]
478    fn batch_execute_multiple() {
479        let items = vec![
480            ("python".to_string(), "print(1+1)".to_string()),
481            ("shell".to_string(), "echo hello".to_string()),
482        ];
483        let results = batch_execute(&items);
484        assert_eq!(results.len(), 2);
485        assert!(results[0].stdout.contains("2"));
486        assert!(results[1].stdout.contains("hello"));
487    }
488}