Skip to main content

do_memory_mcp/
sandbox.rs

1//! Secure code execution sandbox
2//!
3//! This module provides a secure sandbox for executing TypeScript/JavaScript code
4//! with multiple layers of security:
5//!
6//! 1. Input validation and sanitization
7//! 2. Timeout enforcement
8//! 3. Resource limits (memory, CPU)
9//! 4. Process isolation
10//! 5. Network access controls (deny by default)
11//! 6. File system restrictions (whitelist approach)
12//! 7. Subprocess execution prevention
13//! 8. Malicious code pattern detection
14//!
15//! ## Security Architecture
16//!
17//! The sandbox uses a defense-in-depth approach with multiple security layers:
18//!
19//! - **Input Validation**: All code is scanned for malicious patterns before execution
20//! - **Process Isolation**: Code runs in a separate Node.js process with restricted permissions
21//! - **Resource Limits**: CPU and memory usage are constrained
22//! - **Timeout Enforcement**: Long-running code is terminated
23//! - **Access Controls**: Network and filesystem access are denied by default
24//!
25//! ## Example
26//!
27//! ```no_run
28//! use do_memory_mcp::sandbox::CodeSandbox;
29//! use do_memory_mcp::types::{SandboxConfig, ExecutionContext};
30//!
31//! #[tokio::main]
32//! async fn main() -> anyhow::Result<()> {
33//!     let sandbox = CodeSandbox::new(SandboxConfig::restrictive())?;
34//!     let code = "const result = 1 + 1; console.log(result);";
35//!     let context = ExecutionContext::new("test".to_string(), serde_json::json!({}));
36//!
37//!     let result = sandbox.execute(code, context).await?;
38//!     println!("Result: {:?}", result);
39//!     Ok(())
40//! }
41//! ```
42
43// Security submodules
44pub mod fs;
45pub mod isolation;
46pub mod network;
47
48#[cfg(test)]
49pub mod tests;
50
51pub use fs::{FileSystemRestrictions, SecurityError as FsSecurityError};
52pub use isolation::{
53    IsolationConfig, apply_isolation, current_gid, current_uid, is_running_as_root,
54    recommend_safe_uid,
55};
56pub use network::{NetworkRestrictions, NetworkSecurityError};
57
58use crate::types::{
59    ErrorType, ExecutionContext, ExecutionResult, SandboxConfig, SecurityViolationType,
60};
61use anyhow::{Context, Result};
62use std::process::Stdio;
63use std::time::{Duration, Instant};
64use tokio::process::Command;
65use tracing::{debug, warn};
66
67/// Secure code execution sandbox
68#[derive(Debug)]
69pub struct CodeSandbox {
70    config: SandboxConfig,
71}
72
73impl CodeSandbox {
74    /// Create a new sandbox with the given configuration
75    pub fn new(config: SandboxConfig) -> Result<Self> {
76        // Validate configuration
77        if config.max_execution_time_ms == 0 {
78            anyhow::bail!("max_execution_time_ms must be greater than 0");
79        }
80        if config.max_memory_mb == 0 {
81            anyhow::bail!("max_memory_mb must be greater than 0");
82        }
83
84        Ok(Self { config })
85    }
86
87    /// Execute code in the sandbox
88    ///
89    /// # Security
90    ///
91    /// This method performs multiple security checks:
92    /// 1. Validates and sanitizes input code
93    /// 2. Detects malicious patterns
94    /// 3. Enforces timeout limits
95    /// 4. Restricts resource usage
96    /// 5. Isolates execution in separate process
97    ///
98    /// # Arguments
99    ///
100    /// * `code` - TypeScript/JavaScript code to execute
101    /// * `context` - Execution context with input data
102    ///
103    /// # Returns
104    ///
105    /// Returns `ExecutionResult` containing output or error information
106    pub async fn execute(&self, code: &str, context: ExecutionContext) -> Result<ExecutionResult> {
107        let start = Instant::now();
108
109        // Security check: validate input
110        if let Some(violation) = self.detect_security_violations(code) {
111            warn!("Security violation detected: {:?}", violation);
112            return Ok(ExecutionResult::SecurityViolation {
113                reason: format!("Security violation: {:?}", violation),
114                violation_type: violation,
115            });
116        }
117
118        // Security check: validate code length (prevent DoS)
119        if code.len() > 100_000 {
120            return Ok(ExecutionResult::SecurityViolation {
121                reason: "Code exceeds maximum length (100KB)".to_string(),
122                violation_type: SecurityViolationType::MaliciousCode,
123            });
124        }
125
126        // Create wrapper code with security restrictions
127        let wrapper = self.create_secure_wrapper(code, &context)?;
128
129        // Execute with timeout and resource limits
130        let result = self.execute_isolated(wrapper, start).await?;
131
132        Ok(result)
133    }
134
135    /// Detect potential security violations in code
136    fn detect_security_violations(&self, code: &str) -> Option<SecurityViolationType> {
137        // Check for file system access attempts
138        if !self.config.allow_filesystem {
139            let fs_patterns = [
140                "require('fs')",
141                "require(\"fs\")",
142                "require(`fs`)",
143                "import fs from",
144                "import * as fs",
145                "readFile",
146                "writeFile",
147                "mkdir",
148                "rmdir",
149                "unlink",
150                "__dirname",
151                "__filename",
152            ];
153
154            for pattern in &fs_patterns {
155                if code.contains(pattern) {
156                    return Some(SecurityViolationType::FileSystemAccess);
157                }
158            }
159        }
160
161        // Check for network access attempts
162        if !self.config.allow_network {
163            let network_patterns = [
164                "require('http')",
165                "require('https')",
166                "require('net')",
167                "fetch(",
168                "XMLHttpRequest",
169                "WebSocket",
170                "import('http')",
171                "import('https')",
172            ];
173
174            for pattern in &network_patterns {
175                if code.contains(pattern) {
176                    return Some(SecurityViolationType::NetworkAccess);
177                }
178            }
179        }
180
181        // Check for subprocess execution attempts
182        if !self.config.allow_subprocesses {
183            let process_patterns = [
184                "require('child_process')",
185                "exec(",
186                "execSync(",
187                "spawn(",
188                "spawnSync(",
189                "fork(",
190                "execFile(",
191                "process.exit",
192            ];
193
194            for pattern in &process_patterns {
195                if code.contains(pattern) {
196                    return Some(SecurityViolationType::ProcessExecution);
197                }
198            }
199        }
200
201        // Check for potential infinite loops (basic heuristic)
202        let loop_count = code.matches("while(true)").count()
203            + code.matches("for(;;)").count()
204            + code.matches("while (true)").count()
205            + code.matches("for (;;)").count();
206
207        if loop_count > 0 {
208            return Some(SecurityViolationType::InfiniteLoop);
209        }
210
211        // Check for eval and Function constructor (code injection risks)
212        if code.contains("eval(") || code.contains("Function(") {
213            return Some(SecurityViolationType::MaliciousCode);
214        }
215
216        None
217    }
218
219    /// Create a secure wrapper around user code
220    fn create_secure_wrapper(&self, user_code: &str, context: &ExecutionContext) -> Result<String> {
221        let context_json =
222            serde_json::to_string(context).context("Failed to serialize execution context")?;
223
224        // Escape user code for safe inclusion in template
225        // This prevents command injection and script termination attacks
226        let escaped_code = user_code
227            .replace('\\', "\\\\") // Escape backslashes first
228            .replace('`', "\\`") // Escape template literal backticks
229            .replace("${", "\\${") // Escape template literal expressions
230            // Note: Newlines, carriage returns, and tabs are NOT escaped
231            // They work correctly in JavaScript template literals
232            .replace("\x00", "\\x00") // Escape null bytes
233            .replace("\x0b", "\\x0b") // Escape vertical tabs
234            .replace("\x0c", "\\x0c"); // Escape form feeds
235
236        // Create wrapper that:
237        // 1. Sets up restricted environment
238        // 2. Provides context to user code
239        // 3. Captures output and errors
240        // 4. Enforces timeout
241        let wrapper = format!(
242            r#"
243'use strict';
244
245// Disable dangerous globals
246delete global.process;
247delete global.require;
248delete global.module;
249delete global.__dirname;
250delete global.__filename;
251
252// Set up restricted console
253const outputs = [];
254const errors = [];
255
256const safeConsole = {{
257    log: (...args) => outputs.push(args.map(String).join(' ')),
258    error: (...args) => errors.push(args.map(String).join(' ')),
259    warn: (...args) => errors.push('WARN: ' + args.map(String).join(' ')),
260    info: (...args) => outputs.push('INFO: ' + args.map(String).join(' ')),
261}};
262
263// Execution context
264const context = {};
265
266// Main execution wrapper
267(async () => {{
268    try {{
269        // Set timeout to prevent infinite loops
270        const timeout = setTimeout(() => {{
271            throw new Error('TIMEOUT_EXCEEDED');
272        }}, {});
273
274        // User code execution
275        const userFn = async () => {{
276            const console = safeConsole;
277            {};
278        }};
279
280        const result = await userFn();
281        clearTimeout(timeout);
282
283        // Output results
284        console.log(JSON.stringify({{
285            success: true,
286            result: result,
287            stdout: outputs.join('\n'),
288            stderr: errors.join('\n'),
289        }}));
290    }} catch (error) {{
291        console.error(JSON.stringify({{
292            success: false,
293            error: error.message,
294            stack: error.stack,
295            stdout: outputs.join('\n'),
296            stderr: errors.join('\n'),
297        }}));
298        process.exit(1);
299    }}
300}})();
301"#,
302            context_json, self.config.max_execution_time_ms, escaped_code
303        );
304
305        Ok(wrapper)
306    }
307
308    /// Execute code in an isolated Node.js process
309    async fn execute_isolated(
310        &self,
311        wrapper_code: String,
312        start_time: Instant,
313    ) -> Result<ExecutionResult> {
314        // Spawn Node.js process with restricted permissions
315        let child = Command::new("node")
316            .arg("--no-warnings")
317            .arg("-e")
318            .arg(&wrapper_code)
319            .stdin(Stdio::null())
320            .stdout(Stdio::piped())
321            .stderr(Stdio::piped())
322            .kill_on_drop(true) // Ensure cleanup
323            .spawn()
324            .context("Failed to spawn Node.js process")?;
325
326        // Wait for completion with timeout
327        let timeout = Duration::from_millis(self.config.max_execution_time_ms);
328        let output = match tokio::time::timeout(timeout, child.wait_with_output()).await {
329            Ok(Ok(output)) => output,
330            Ok(Err(e)) => {
331                warn!("Process execution failed: {}", e);
332                return Ok(ExecutionResult::Error {
333                    message: format!("Process execution failed: {}", e),
334                    error_type: ErrorType::Runtime,
335                    stdout: String::new(),
336                    stderr: String::new(),
337                });
338            }
339            Err(_) => {
340                // Timeout occurred - process will be killed by kill_on_drop
341                return Ok(ExecutionResult::Timeout {
342                    elapsed_ms: start_time.elapsed().as_millis() as u64,
343                    partial_output: None,
344                });
345            }
346        };
347
348        let elapsed_ms = start_time.elapsed().as_millis() as u64;
349
350        // Parse output
351        let stdout = String::from_utf8_lossy(&output.stdout).to_string();
352        let stderr = String::from_utf8_lossy(&output.stderr).to_string();
353
354        debug!(
355            "Execution completed in {}ms, status: {}",
356            elapsed_ms,
357            output.status.code().unwrap_or(-1)
358        );
359
360        // Check if execution was successful
361        if output.status.success() {
362            // Try to parse structured output
363            if let Some(result_line) = stdout.lines().last() {
364                if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(result_line) {
365                    if let Some(true) = parsed.get("success").and_then(|v| v.as_bool()) {
366                        return Ok(ExecutionResult::Success {
367                            output: parsed
368                                .get("result")
369                                .map(|v| v.to_string())
370                                .unwrap_or_default(),
371                            stdout: parsed
372                                .get("stdout")
373                                .and_then(|v| v.as_str())
374                                .unwrap_or("")
375                                .to_string(),
376                            stderr: parsed
377                                .get("stderr")
378                                .and_then(|v| v.as_str())
379                                .unwrap_or("")
380                                .to_string(),
381                            execution_time_ms: elapsed_ms,
382                        });
383                    }
384                }
385            }
386
387            // Fallback to raw stdout
388            Ok(ExecutionResult::Success {
389                output: stdout.clone(),
390                stdout,
391                stderr,
392                execution_time_ms: elapsed_ms,
393            })
394        } else {
395            // Execution failed - parse error
396            let error_type = if stderr.contains("SyntaxError") {
397                ErrorType::Syntax
398            } else if stderr.contains("TIMEOUT_EXCEEDED") {
399                return Ok(ExecutionResult::Timeout {
400                    elapsed_ms,
401                    partial_output: Some(stdout),
402                });
403            } else if stderr.contains("EACCES") || stderr.contains("EPERM") {
404                ErrorType::Permission
405            } else {
406                ErrorType::Runtime
407            };
408
409            // Try to parse structured error
410            if let Some(error_line) = stderr.lines().last() {
411                if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(error_line) {
412                    if let Some(error_msg) = parsed.get("error").and_then(|v| v.as_str()) {
413                        return Ok(ExecutionResult::Error {
414                            message: error_msg.to_string(),
415                            error_type,
416                            stdout: parsed
417                                .get("stdout")
418                                .and_then(|v| v.as_str())
419                                .unwrap_or("")
420                                .to_string(),
421                            stderr: parsed
422                                .get("stderr")
423                                .and_then(|v| v.as_str())
424                                .unwrap_or("")
425                                .to_string(),
426                        });
427                    }
428                }
429            }
430
431            Ok(ExecutionResult::Error {
432                message: stderr.clone(),
433                error_type,
434                stdout,
435                stderr,
436            })
437        }
438    }
439}