1use std::collections::HashSet;
41use std::sync::OnceLock;
42
43use regex::Regex;
44use serde::{Deserialize, Serialize};
45
46use chio_kernel::{Guard, GuardContext, KernelError, Verdict};
47
48use crate::action::{extract_action, ToolAction};
49
50pub fn default_dangerous_modules() -> Vec<String> {
53 vec![
54 "os".to_string(),
55 "subprocess".to_string(),
56 "socket".to_string(),
57 "sys".to_string(),
58 "ctypes".to_string(),
59 "shutil".to_string(),
60 "pickle".to_string(),
61 "marshal".to_string(),
62 "importlib".to_string(),
63 ]
64}
65
66fn default_network_modules() -> &'static [&'static str] {
70 &[
71 "socket",
72 "requests",
73 "urllib",
74 "urllib2",
75 "urllib3",
76 "http",
77 "httpx",
78 "aiohttp",
79 "websockets",
80 "ftplib",
81 "smtplib",
82 "telnetlib",
83 ]
84}
85
86#[derive(Debug, thiserror::Error)]
89pub enum CodeExecutionError {
90 #[error("invalid module pattern `{pattern}`: {source}")]
92 InvalidPattern {
93 pattern: String,
94 #[source]
95 source: regex::Error,
96 },
97}
98
99#[derive(Clone, Debug, Deserialize, Serialize)]
101#[serde(deny_unknown_fields)]
102pub struct CodeExecutionConfig {
103 #[serde(default = "default_true")]
105 pub enabled: bool,
106 #[serde(default)]
108 pub language_allowlist: Vec<String>,
109 #[serde(default = "default_dangerous_modules")]
113 pub module_denylist: Vec<String>,
114 #[serde(default = "default_true")]
117 pub network_access: bool,
118 #[serde(default, skip_serializing_if = "Option::is_none")]
122 pub max_execution_time_ms: Option<u64>,
123 #[serde(default = "default_max_scan_bytes")]
126 pub max_scan_bytes: usize,
127}
128
129fn default_true() -> bool {
130 true
131}
132
133fn default_max_scan_bytes() -> usize {
134 64 * 1024
135}
136
137impl Default for CodeExecutionConfig {
138 fn default() -> Self {
139 Self {
140 enabled: true,
141 language_allowlist: vec!["python".to_string()],
142 module_denylist: default_dangerous_modules(),
143 network_access: false,
144 max_execution_time_ms: None,
145 max_scan_bytes: default_max_scan_bytes(),
146 }
147 }
148}
149
150pub struct CodeExecutionGuard {
153 enabled: bool,
154 language_allowlist: HashSet<String>,
155 module_patterns: Vec<(String, Regex)>,
156 network_access: bool,
157 max_execution_time_ms: Option<u64>,
158 max_scan_bytes: usize,
159}
160
161impl CodeExecutionGuard {
162 pub fn new() -> Self {
165 match Self::with_config(CodeExecutionConfig::default()) {
166 Ok(g) => g,
167 Err(_) => Self::empty_failclosed(),
168 }
169 }
170
171 fn empty_failclosed() -> Self {
175 Self {
176 enabled: true,
177 language_allowlist: HashSet::new(),
178 module_patterns: Vec::new(),
179 network_access: false,
180 max_execution_time_ms: Some(0),
181 max_scan_bytes: default_max_scan_bytes(),
182 }
183 }
184
185 pub fn with_config(config: CodeExecutionConfig) -> Result<Self, CodeExecutionError> {
189 let mut module_patterns = Vec::with_capacity(config.module_denylist.len());
190 for module in &config.module_denylist {
191 let pattern = module_regex_source(module);
192 let re = Regex::new(&pattern).map_err(|e| CodeExecutionError::InvalidPattern {
193 pattern: module.clone(),
194 source: e,
195 })?;
196 module_patterns.push((module.clone(), re));
197 }
198 let language_allowlist: HashSet<String> = config
199 .language_allowlist
200 .into_iter()
201 .map(|s| s.to_ascii_lowercase())
202 .collect();
203 Ok(Self {
204 enabled: config.enabled,
205 language_allowlist,
206 module_patterns,
207 network_access: config.network_access,
208 max_execution_time_ms: config.max_execution_time_ms,
209 max_scan_bytes: config.max_scan_bytes.max(1),
210 })
211 }
212
213 fn read_execution_time_ms(arguments: &serde_json::Value) -> Option<u64> {
216 for key in [
217 "execution_time_ms",
218 "executionTimeMs",
219 "timeout_ms",
220 "timeoutMs",
221 "max_execution_time_ms",
222 "maxExecutionTimeMs",
223 ] {
224 if let Some(v) = arguments.get(key).and_then(|v| v.as_u64()) {
225 return Some(v);
226 }
227 }
228 None
229 }
230
231 fn requested_network_access(arguments: &serde_json::Value) -> Option<bool> {
233 for key in [
234 "network_access",
235 "networkAccess",
236 "allow_network",
237 "allowNetwork",
238 ] {
239 if let Some(v) = arguments.get(key).and_then(|v| v.as_bool()) {
240 return Some(v);
241 }
242 }
243 None
244 }
245
246 fn code_uses_network(code: &str) -> bool {
249 let net_re = network_module_regex();
250 net_re.is_match(code)
251 }
252}
253
254impl Default for CodeExecutionGuard {
255 fn default() -> Self {
256 Self::new()
257 }
258}
259
260impl Guard for CodeExecutionGuard {
261 fn name(&self) -> &str {
262 "code-execution"
263 }
264
265 fn evaluate(&self, ctx: &GuardContext) -> Result<Verdict, KernelError> {
266 if !self.enabled {
267 return Ok(Verdict::Allow);
268 }
269
270 let action = extract_action(&ctx.request.tool_name, &ctx.request.arguments);
271 let (language, code) = match action {
272 ToolAction::CodeExecution { language, code } => (language, code),
273 _ => return Ok(Verdict::Allow),
274 };
275
276 if !self.language_allowlist.is_empty() {
278 let lang = language.to_ascii_lowercase();
279 if lang == "unknown" || !self.language_allowlist.contains(&lang) {
280 return Ok(Verdict::Deny);
281 }
282 }
283
284 let truncated = if code.len() > self.max_scan_bytes {
286 let mut end = self.max_scan_bytes;
287 while end > 0 && !code.is_char_boundary(end) {
288 end -= 1;
289 }
290 &code[..end]
291 } else {
292 code.as_str()
293 };
294
295 for (name, re) in &self.module_patterns {
297 if re.is_match(truncated) {
298 tracing::warn!(
299 guard = "code-execution",
300 module = %name,
301 "denying code execution: dangerous module detected"
302 );
303 return Ok(Verdict::Deny);
304 }
305 }
306
307 if !self.network_access {
309 let requested = Self::requested_network_access(&ctx.request.arguments).unwrap_or(false);
310 if requested || Self::code_uses_network(truncated) {
311 return Ok(Verdict::Deny);
312 }
313 }
314
315 if let Some(max_ms) = self.max_execution_time_ms {
317 if let Some(requested) = Self::read_execution_time_ms(&ctx.request.arguments) {
318 if requested > max_ms {
319 return Ok(Verdict::Deny);
320 }
321 }
322 }
323
324 Ok(Verdict::Allow)
325 }
326}
327
328fn module_regex_source(module: &str) -> String {
333 let escaped = regex::escape(module);
334 format!(
338 r#"(?m)(?:^|[^A-Za-z0-9_])(?:import\s+{m}(?:\s|$|\.|,)|from\s+{m}(?:\s|\.)|require\s*\(\s*['"]{m}['"]\s*\)|{m}\s*\.)"#,
339 m = escaped
340 )
341}
342
343fn network_module_regex() -> &'static Regex {
346 static RE: OnceLock<Regex> = OnceLock::new();
347 RE.get_or_init(|| {
348 let alternation = default_network_modules()
349 .iter()
350 .map(|m| regex::escape(m))
351 .collect::<Vec<_>>()
352 .join("|");
353 match Regex::new(&format!(
355 r#"(?m)(?:^|[^A-Za-z0-9_])(?:import\s+(?:{a})(?:\s|$|\.|,)|from\s+(?:{a})(?:\s|\.)|require\s*\(\s*['"](?:{a})['"]\s*\)|\bfetch\s*\()"#,
356 a = alternation
357 )) {
358 Ok(re) => re,
359 Err(err) => {
360 tracing::error!(error = %err, "code-execution: failed to compile network regex");
361 #[allow(clippy::expect_used)]
363 {
364 Regex::new(r"\A\z").expect("empty-string regex compiles")
365 }
366 }
367 }
368 })
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn module_regex_matches_import_forms() {
377 let re = Regex::new(&module_regex_source("subprocess")).unwrap();
378 assert!(re.is_match("import subprocess\n"));
379 assert!(re.is_match("from subprocess import call"));
380 assert!(re.is_match("require('subprocess')"));
381 assert!(re.is_match("subprocess.run(['ls'])"));
382 assert!(!re.is_match("import subprocesses\n"));
383 assert!(!re.is_match("# subprocess comment with no code"));
384 }
385
386 #[test]
387 fn network_module_regex_detects_requests() {
388 let re = network_module_regex();
389 assert!(re.is_match("import requests\n"));
390 assert!(re.is_match("from urllib import parse"));
391 assert!(re.is_match("fetch('https://x')"));
392 assert!(!re.is_match("import math"));
393 }
394}