Skip to main content

chio_guards/
code_execution.rs

1//! CodeExecutionGuard -- language allowlist, dangerous-module detection,
2//! network gating, and execution-time bounds for sandboxed interpreter
3//! actions.
4//!
5//! Roadmap phase 8.1.  The guard applies to
6//! [`ToolAction::CodeExecution`] derived from tool calls like `python`,
7//! `eval`, `run_code`, `jupyter`, etc.  See [`crate::action::extract_action`]
8//! for the full list of tool names that map to code execution.
9//!
10//! # Enforcement surface
11//!
12//! | Policy                    | Behavior                                                   |
13//! |---------------------------|------------------------------------------------------------|
14//! | `language_allowlist`      | Languages outside the set are denied                       |
15//! | `dangerous_modules`       | Imports/uses of named modules (e.g. `subprocess`) are denied |
16//! | `network_access`          | When `false`, calls requesting network are denied          |
17//! | `max_execution_time_ms`   | When the arguments exceed this bound, the call is denied   |
18//!
19//! Network access is considered requested when either:
20//!
21//! - the arguments carry `network_access = true` / `allow_network = true`;
22//! - or the code contains an obvious network module import
23//!   (`socket`, `requests`, `urllib`, `http`, `httpx`, `aiohttp`, `fetch(`).
24//!
25//! The module-detection regexes target Python, JavaScript, and the common
26//! shell-style `import X` / `require('X')` / `from X import` forms.  The
27//! detection is intentionally conservative: regex matches are *denial
28//! signals*, never permit signals.
29//!
30//! # Fail-closed behavior
31//!
32//! - [`ToolAction::CodeExecution`] with no `language` value is denied when
33//!   a [`CodeExecutionConfig::language_allowlist`] is set;
34//! - malformed configuration (invalid regex patterns in
35//!   [`CodeExecutionConfig::module_denylist`]) causes
36//!   [`CodeExecutionGuard::with_config`] to return
37//!   [`CodeExecutionError::InvalidPattern`];
38//! - non-code-execution actions pass through with [`Verdict::Allow`].
39
40use std::collections::HashSet;
41use std::sync::OnceLock;
42
43use regex::Regex;
44use serde::{Deserialize, Serialize};
45
46use chio_kernel::{Guard, GuardContext, KernelError, Verdict};
47
48use crate::action::{extract_action, ToolAction};
49
50/// Default dangerous module names (Python-focused; matches are case
51/// sensitive and use word boundaries).
52pub fn default_dangerous_modules() -> Vec<String> {
53    vec![
54        "os".to_string(),
55        "subprocess".to_string(),
56        "socket".to_string(),
57        "sys".to_string(),
58        "ctypes".to_string(),
59        "shutil".to_string(),
60        "pickle".to_string(),
61        "marshal".to_string(),
62        "importlib".to_string(),
63    ]
64}
65
66/// Default network-module names that signal a code body wants network
67/// access.  Used by the `network_access` gate when arguments do not carry
68/// an explicit flag.
69fn default_network_modules() -> &'static [&'static str] {
70    &[
71        "socket",
72        "requests",
73        "urllib",
74        "urllib2",
75        "urllib3",
76        "http",
77        "httpx",
78        "aiohttp",
79        "websockets",
80        "ftplib",
81        "smtplib",
82        "telnetlib",
83    ]
84}
85
86/// Errors produced when building a [`CodeExecutionGuard`] or parsing its
87/// configuration.
88#[derive(Debug, thiserror::Error)]
89pub enum CodeExecutionError {
90    /// A denylist entry was not a valid regex literal.
91    #[error("invalid module pattern `{pattern}`: {source}")]
92    InvalidPattern {
93        pattern: String,
94        #[source]
95        source: regex::Error,
96    },
97}
98
99/// Configuration for [`CodeExecutionGuard`].
100#[derive(Clone, Debug, Deserialize, Serialize)]
101#[serde(deny_unknown_fields)]
102pub struct CodeExecutionConfig {
103    /// Enable/disable the guard entirely.
104    #[serde(default = "default_true")]
105    pub enabled: bool,
106    /// Allowed interpreter languages.  Empty means "any language".
107    #[serde(default)]
108    pub language_allowlist: Vec<String>,
109    /// Dangerous module names (used as word-boundary literal matches
110    /// against the code body).  Defaults to
111    /// [`default_dangerous_modules`].
112    #[serde(default = "default_dangerous_modules")]
113    pub module_denylist: Vec<String>,
114    /// When `false`, deny code-execution calls that request network
115    /// access (either via argument flag or a detectable network import).
116    #[serde(default = "default_true")]
117    pub network_access: bool,
118    /// Maximum execution time in milliseconds.  When set, any call with
119    /// an `execution_time_ms` / `timeout_ms` argument above this value
120    /// is denied.  `None` disables the check.
121    #[serde(default, skip_serializing_if = "Option::is_none")]
122    pub max_execution_time_ms: Option<u64>,
123    /// Maximum bytes of code to scan for module detection.  Longer code
124    /// bodies are truncated at a UTF-8 boundary before scanning.
125    #[serde(default = "default_max_scan_bytes")]
126    pub max_scan_bytes: usize,
127}
128
129fn default_true() -> bool {
130    true
131}
132
133fn default_max_scan_bytes() -> usize {
134    64 * 1024
135}
136
137impl Default for CodeExecutionConfig {
138    fn default() -> Self {
139        Self {
140            enabled: true,
141            language_allowlist: vec!["python".to_string()],
142            module_denylist: default_dangerous_modules(),
143            network_access: false,
144            max_execution_time_ms: None,
145            max_scan_bytes: default_max_scan_bytes(),
146        }
147    }
148}
149
150/// Guard that enforces [`CodeExecutionConfig`] policies against
151/// [`ToolAction::CodeExecution`] calls.
152pub struct CodeExecutionGuard {
153    enabled: bool,
154    language_allowlist: HashSet<String>,
155    module_patterns: Vec<(String, Regex)>,
156    network_access: bool,
157    max_execution_time_ms: Option<u64>,
158    max_scan_bytes: usize,
159}
160
161impl CodeExecutionGuard {
162    /// Build a guard with default configuration.  Never fails because the
163    /// default patterns are known-valid regex fragments.
164    pub fn new() -> Self {
165        match Self::with_config(CodeExecutionConfig::default()) {
166            Ok(g) => g,
167            Err(_) => Self::empty_failclosed(),
168        }
169    }
170
171    /// Build an empty guard that denies every code-execution call.  Used
172    /// as a fallback when the default configuration somehow fails to
173    /// compile (defensive programming; should never trigger).
174    fn empty_failclosed() -> Self {
175        Self {
176            enabled: true,
177            language_allowlist: HashSet::new(),
178            module_patterns: Vec::new(),
179            network_access: false,
180            max_execution_time_ms: Some(0),
181            max_scan_bytes: default_max_scan_bytes(),
182        }
183    }
184
185    /// Build a guard with explicit configuration.  Returns an error when
186    /// any entry in `module_denylist` is not a valid literal identifier
187    /// (we build word-boundary regexes from the literal).
188    pub fn with_config(config: CodeExecutionConfig) -> Result<Self, CodeExecutionError> {
189        let mut module_patterns = Vec::with_capacity(config.module_denylist.len());
190        for module in &config.module_denylist {
191            let pattern = module_regex_source(module);
192            let re = Regex::new(&pattern).map_err(|e| CodeExecutionError::InvalidPattern {
193                pattern: module.clone(),
194                source: e,
195            })?;
196            module_patterns.push((module.clone(), re));
197        }
198        let language_allowlist: HashSet<String> = config
199            .language_allowlist
200            .into_iter()
201            .map(|s| s.to_ascii_lowercase())
202            .collect();
203        Ok(Self {
204            enabled: config.enabled,
205            language_allowlist,
206            module_patterns,
207            network_access: config.network_access,
208            max_execution_time_ms: config.max_execution_time_ms,
209            max_scan_bytes: config.max_scan_bytes.max(1),
210        })
211    }
212
213    /// Read the execution-time ceiling from the arguments.  Accepts
214    /// `execution_time_ms`, `timeout_ms`, `max_execution_time_ms`.
215    fn read_execution_time_ms(arguments: &serde_json::Value) -> Option<u64> {
216        for key in [
217            "execution_time_ms",
218            "executionTimeMs",
219            "timeout_ms",
220            "timeoutMs",
221            "max_execution_time_ms",
222            "maxExecutionTimeMs",
223        ] {
224            if let Some(v) = arguments.get(key).and_then(|v| v.as_u64()) {
225                return Some(v);
226            }
227        }
228        None
229    }
230
231    /// Read an explicit network-access flag from the arguments, if present.
232    fn requested_network_access(arguments: &serde_json::Value) -> Option<bool> {
233        for key in [
234            "network_access",
235            "networkAccess",
236            "allow_network",
237            "allowNetwork",
238        ] {
239            if let Some(v) = arguments.get(key).and_then(|v| v.as_bool()) {
240                return Some(v);
241            }
242        }
243        None
244    }
245
246    /// Return `true` if `code` appears to import or call into a
247    /// network-capable module.
248    fn code_uses_network(code: &str) -> bool {
249        let net_re = network_module_regex();
250        net_re.is_match(code)
251    }
252}
253
254impl Default for CodeExecutionGuard {
255    fn default() -> Self {
256        Self::new()
257    }
258}
259
260impl Guard for CodeExecutionGuard {
261    fn name(&self) -> &str {
262        "code-execution"
263    }
264
265    fn evaluate(&self, ctx: &GuardContext) -> Result<Verdict, KernelError> {
266        if !self.enabled {
267            return Ok(Verdict::Allow);
268        }
269
270        let action = extract_action(&ctx.request.tool_name, &ctx.request.arguments);
271        let (language, code) = match action {
272            ToolAction::CodeExecution { language, code } => (language, code),
273            _ => return Ok(Verdict::Allow),
274        };
275
276        // 1. Language allowlist.
277        if !self.language_allowlist.is_empty() {
278            let lang = language.to_ascii_lowercase();
279            if lang == "unknown" || !self.language_allowlist.contains(&lang) {
280                return Ok(Verdict::Deny);
281            }
282        }
283
284        // Bound scan size for module detection.
285        let truncated = if code.len() > self.max_scan_bytes {
286            let mut end = self.max_scan_bytes;
287            while end > 0 && !code.is_char_boundary(end) {
288                end -= 1;
289            }
290            &code[..end]
291        } else {
292            code.as_str()
293        };
294
295        // 2. Dangerous-module detection.
296        for (name, re) in &self.module_patterns {
297            if re.is_match(truncated) {
298                tracing::warn!(
299                    guard = "code-execution",
300                    module = %name,
301                    "denying code execution: dangerous module detected"
302                );
303                return Ok(Verdict::Deny);
304            }
305        }
306
307        // 3. Network access gate.
308        if !self.network_access {
309            let requested = Self::requested_network_access(&ctx.request.arguments).unwrap_or(false);
310            if requested || Self::code_uses_network(truncated) {
311                return Ok(Verdict::Deny);
312            }
313        }
314
315        // 4. Execution-time bound.
316        if let Some(max_ms) = self.max_execution_time_ms {
317            if let Some(requested) = Self::read_execution_time_ms(&ctx.request.arguments) {
318                if requested > max_ms {
319                    return Ok(Verdict::Deny);
320                }
321            }
322        }
323
324        Ok(Verdict::Allow)
325    }
326}
327
328/// Build a regex that matches `import <module>`, `from <module> import`,
329/// `require('<module>')`, or a bare `<module>.something` reference in
330/// code.  The source is escaped so dotted module names are treated as
331/// literals.
332fn module_regex_source(module: &str) -> String {
333    let escaped = regex::escape(module);
334    // Word-boundary anchors handle the `import subprocess`,
335    // `from subprocess`, and `subprocess.call` forms; a trailing alternation
336    // picks up `require("subprocess")` and `require('subprocess')`.
337    format!(
338        r#"(?m)(?:^|[^A-Za-z0-9_])(?:import\s+{m}(?:\s|$|\.|,)|from\s+{m}(?:\s|\.)|require\s*\(\s*['"]{m}['"]\s*\)|{m}\s*\.)"#,
339        m = escaped
340    )
341}
342
343/// Compiled once per process: detects calls/imports of the well-known
344/// network modules listed in [`default_network_modules`].
345fn network_module_regex() -> &'static Regex {
346    static RE: OnceLock<Regex> = OnceLock::new();
347    RE.get_or_init(|| {
348        let alternation = default_network_modules()
349            .iter()
350            .map(|m| regex::escape(m))
351            .collect::<Vec<_>>()
352            .join("|");
353        // Fall back to a never-matching regex rather than panicking.
354        match Regex::new(&format!(
355            r#"(?m)(?:^|[^A-Za-z0-9_])(?:import\s+(?:{a})(?:\s|$|\.|,)|from\s+(?:{a})(?:\s|\.)|require\s*\(\s*['"](?:{a})['"]\s*\)|\bfetch\s*\()"#,
356            a = alternation
357        )) {
358            Ok(re) => re,
359            Err(err) => {
360                tracing::error!(error = %err, "code-execution: failed to compile network regex");
361                // Safe fallback: regex that never matches anything.
362                #[allow(clippy::expect_used)]
363                {
364                    Regex::new(r"\A\z").expect("empty-string regex compiles")
365                }
366            }
367        }
368    })
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn module_regex_matches_import_forms() {
377        let re = Regex::new(&module_regex_source("subprocess")).unwrap();
378        assert!(re.is_match("import subprocess\n"));
379        assert!(re.is_match("from subprocess import call"));
380        assert!(re.is_match("require('subprocess')"));
381        assert!(re.is_match("subprocess.run(['ls'])"));
382        assert!(!re.is_match("import subprocesses\n"));
383        assert!(!re.is_match("# subprocess comment with no code"));
384    }
385
386    #[test]
387    fn network_module_regex_detects_requests() {
388        let re = network_module_regex();
389        assert!(re.is_match("import requests\n"));
390        assert!(re.is_match("from urllib import parse"));
391        assert!(re.is_match("fetch('https://x')"));
392        assert!(!re.is_match("import math"));
393    }
394}