pub mod fs;
pub mod isolation;
pub mod network;
#[cfg(test)]
pub mod tests;
pub use fs::{FileSystemRestrictions, SecurityError as FsSecurityError};
pub use isolation::{
IsolationConfig, apply_isolation, current_gid, current_uid, is_running_as_root,
recommend_safe_uid,
};
pub use network::{NetworkRestrictions, NetworkSecurityError};
use crate::types::{
ErrorType, ExecutionContext, ExecutionResult, SandboxConfig, SecurityViolationType,
};
use anyhow::{Context, Result};
use std::process::Stdio;
use std::time::{Duration, Instant};
use tokio::process::Command;
use tracing::{debug, warn};
#[derive(Debug)]
pub struct CodeSandbox {
config: SandboxConfig,
}
impl CodeSandbox {
pub fn new(config: SandboxConfig) -> Result<Self> {
if config.max_execution_time_ms == 0 {
anyhow::bail!("max_execution_time_ms must be greater than 0");
}
if config.max_memory_mb == 0 {
anyhow::bail!("max_memory_mb must be greater than 0");
}
Ok(Self { config })
}
pub async fn execute(&self, code: &str, context: ExecutionContext) -> Result<ExecutionResult> {
let start = Instant::now();
if let Some(violation) = self.detect_security_violations(code) {
warn!("Security violation detected: {:?}", violation);
return Ok(ExecutionResult::SecurityViolation {
reason: format!("Security violation: {:?}", violation),
violation_type: violation,
});
}
if code.len() > 100_000 {
return Ok(ExecutionResult::SecurityViolation {
reason: "Code exceeds maximum length (100KB)".to_string(),
violation_type: SecurityViolationType::MaliciousCode,
});
}
let wrapper = self.create_secure_wrapper(code, &context)?;
let result = self.execute_isolated(wrapper, start).await?;
Ok(result)
}
fn detect_security_violations(&self, code: &str) -> Option<SecurityViolationType> {
if !self.config.allow_filesystem {
let fs_patterns = [
"require('fs')",
"require(\"fs\")",
"require(`fs`)",
"import fs from",
"import * as fs",
"readFile",
"writeFile",
"mkdir",
"rmdir",
"unlink",
"__dirname",
"__filename",
];
for pattern in &fs_patterns {
if code.contains(pattern) {
return Some(SecurityViolationType::FileSystemAccess);
}
}
}
if !self.config.allow_network {
let network_patterns = [
"require('http')",
"require('https')",
"require('net')",
"fetch(",
"XMLHttpRequest",
"WebSocket",
"import('http')",
"import('https')",
];
for pattern in &network_patterns {
if code.contains(pattern) {
return Some(SecurityViolationType::NetworkAccess);
}
}
}
if !self.config.allow_subprocesses {
let process_patterns = [
"require('child_process')",
"exec(",
"execSync(",
"spawn(",
"spawnSync(",
"fork(",
"execFile(",
"process.exit",
];
for pattern in &process_patterns {
if code.contains(pattern) {
return Some(SecurityViolationType::ProcessExecution);
}
}
}
let loop_count = code.matches("while(true)").count()
+ code.matches("for(;;)").count()
+ code.matches("while (true)").count()
+ code.matches("for (;;)").count();
if loop_count > 0 {
return Some(SecurityViolationType::InfiniteLoop);
}
if code.contains("eval(") || code.contains("Function(") {
return Some(SecurityViolationType::MaliciousCode);
}
None
}
fn create_secure_wrapper(&self, user_code: &str, context: &ExecutionContext) -> Result<String> {
let context_json =
serde_json::to_string(context).context("Failed to serialize execution context")?;
let escaped_code = user_code
.replace('\\', "\\\\") .replace('`', "\\`") .replace("${", "\\${") .replace("\x00", "\\x00") .replace("\x0b", "\\x0b") .replace("\x0c", "\\x0c");
let wrapper = format!(
r#"
'use strict';
// Disable dangerous globals
delete global.process;
delete global.require;
delete global.module;
delete global.__dirname;
delete global.__filename;
// Set up restricted console
const outputs = [];
const errors = [];
const safeConsole = {{
log: (...args) => outputs.push(args.map(String).join(' ')),
error: (...args) => errors.push(args.map(String).join(' ')),
warn: (...args) => errors.push('WARN: ' + args.map(String).join(' ')),
info: (...args) => outputs.push('INFO: ' + args.map(String).join(' ')),
}};
// Execution context
const context = {};
// Main execution wrapper
(async () => {{
try {{
// Set timeout to prevent infinite loops
const timeout = setTimeout(() => {{
throw new Error('TIMEOUT_EXCEEDED');
}}, {});
// User code execution
const userFn = async () => {{
const console = safeConsole;
{};
}};
const result = await userFn();
clearTimeout(timeout);
// Output results
console.log(JSON.stringify({{
success: true,
result: result,
stdout: outputs.join('\n'),
stderr: errors.join('\n'),
}}));
}} catch (error) {{
console.error(JSON.stringify({{
success: false,
error: error.message,
stack: error.stack,
stdout: outputs.join('\n'),
stderr: errors.join('\n'),
}}));
process.exit(1);
}}
}})();
"#,
context_json, self.config.max_execution_time_ms, escaped_code
);
Ok(wrapper)
}
async fn execute_isolated(
&self,
wrapper_code: String,
start_time: Instant,
) -> Result<ExecutionResult> {
let child = Command::new("node")
.arg("--no-warnings")
.arg("-e")
.arg(&wrapper_code)
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true) .spawn()
.context("Failed to spawn Node.js process")?;
let timeout = Duration::from_millis(self.config.max_execution_time_ms);
let output = match tokio::time::timeout(timeout, child.wait_with_output()).await {
Ok(Ok(output)) => output,
Ok(Err(e)) => {
warn!("Process execution failed: {}", e);
return Ok(ExecutionResult::Error {
message: format!("Process execution failed: {}", e),
error_type: ErrorType::Runtime,
stdout: String::new(),
stderr: String::new(),
});
}
Err(_) => {
return Ok(ExecutionResult::Timeout {
elapsed_ms: start_time.elapsed().as_millis() as u64,
partial_output: None,
});
}
};
let elapsed_ms = start_time.elapsed().as_millis() as u64;
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
debug!(
"Execution completed in {}ms, status: {}",
elapsed_ms,
output.status.code().unwrap_or(-1)
);
if output.status.success() {
if let Some(result_line) = stdout.lines().last() {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(result_line) {
if let Some(true) = parsed.get("success").and_then(|v| v.as_bool()) {
return Ok(ExecutionResult::Success {
output: parsed
.get("result")
.map(|v| v.to_string())
.unwrap_or_default(),
stdout: parsed
.get("stdout")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
stderr: parsed
.get("stderr")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
execution_time_ms: elapsed_ms,
});
}
}
}
Ok(ExecutionResult::Success {
output: stdout.clone(),
stdout,
stderr,
execution_time_ms: elapsed_ms,
})
} else {
let error_type = if stderr.contains("SyntaxError") {
ErrorType::Syntax
} else if stderr.contains("TIMEOUT_EXCEEDED") {
return Ok(ExecutionResult::Timeout {
elapsed_ms,
partial_output: Some(stdout),
});
} else if stderr.contains("EACCES") || stderr.contains("EPERM") {
ErrorType::Permission
} else {
ErrorType::Runtime
};
if let Some(error_line) = stderr.lines().last() {
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(error_line) {
if let Some(error_msg) = parsed.get("error").and_then(|v| v.as_str()) {
return Ok(ExecutionResult::Error {
message: error_msg.to_string(),
error_type,
stdout: parsed
.get("stdout")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
stderr: parsed
.get("stderr")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
});
}
}
}
Ok(ExecutionResult::Error {
message: stderr.clone(),
error_type,
stdout,
stderr,
})
}
}
}