use std::collections::HashSet;
use std::sync::OnceLock;
use regex::Regex;
use serde::{Deserialize, Serialize};
use chio_kernel::{Guard, GuardContext, KernelError, Verdict};
use crate::action::{extract_action, ToolAction};
pub fn default_dangerous_modules() -> Vec<String> {
vec![
"os".to_string(),
"subprocess".to_string(),
"socket".to_string(),
"sys".to_string(),
"ctypes".to_string(),
"shutil".to_string(),
"pickle".to_string(),
"marshal".to_string(),
"importlib".to_string(),
]
}
fn default_network_modules() -> &'static [&'static str] {
&[
"socket",
"requests",
"urllib",
"urllib2",
"urllib3",
"http",
"httpx",
"aiohttp",
"websockets",
"ftplib",
"smtplib",
"telnetlib",
]
}
#[derive(Debug, thiserror::Error)]
pub enum CodeExecutionError {
#[error("invalid module pattern `{pattern}`: {source}")]
InvalidPattern {
pattern: String,
#[source]
source: regex::Error,
},
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
pub struct CodeExecutionConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default)]
pub language_allowlist: Vec<String>,
#[serde(default = "default_dangerous_modules")]
pub module_denylist: Vec<String>,
#[serde(default = "default_true")]
pub network_access: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_execution_time_ms: Option<u64>,
#[serde(default = "default_max_scan_bytes")]
pub max_scan_bytes: usize,
}
fn default_true() -> bool {
true
}
fn default_max_scan_bytes() -> usize {
64 * 1024
}
impl Default for CodeExecutionConfig {
fn default() -> Self {
Self {
enabled: true,
language_allowlist: vec!["python".to_string()],
module_denylist: default_dangerous_modules(),
network_access: false,
max_execution_time_ms: None,
max_scan_bytes: default_max_scan_bytes(),
}
}
}
pub struct CodeExecutionGuard {
enabled: bool,
language_allowlist: HashSet<String>,
module_patterns: Vec<(String, Regex)>,
network_access: bool,
max_execution_time_ms: Option<u64>,
max_scan_bytes: usize,
}
impl CodeExecutionGuard {
pub fn new() -> Self {
match Self::with_config(CodeExecutionConfig::default()) {
Ok(g) => g,
Err(_) => Self::empty_failclosed(),
}
}
fn empty_failclosed() -> Self {
Self {
enabled: true,
language_allowlist: HashSet::new(),
module_patterns: Vec::new(),
network_access: false,
max_execution_time_ms: Some(0),
max_scan_bytes: default_max_scan_bytes(),
}
}
pub fn with_config(config: CodeExecutionConfig) -> Result<Self, CodeExecutionError> {
let mut module_patterns = Vec::with_capacity(config.module_denylist.len());
for module in &config.module_denylist {
let pattern = module_regex_source(module);
let re = Regex::new(&pattern).map_err(|e| CodeExecutionError::InvalidPattern {
pattern: module.clone(),
source: e,
})?;
module_patterns.push((module.clone(), re));
}
let language_allowlist: HashSet<String> = config
.language_allowlist
.into_iter()
.map(|s| s.to_ascii_lowercase())
.collect();
Ok(Self {
enabled: config.enabled,
language_allowlist,
module_patterns,
network_access: config.network_access,
max_execution_time_ms: config.max_execution_time_ms,
max_scan_bytes: config.max_scan_bytes.max(1),
})
}
fn read_execution_time_ms(arguments: &serde_json::Value) -> Option<u64> {
for key in [
"execution_time_ms",
"executionTimeMs",
"timeout_ms",
"timeoutMs",
"max_execution_time_ms",
"maxExecutionTimeMs",
] {
if let Some(v) = arguments.get(key).and_then(|v| v.as_u64()) {
return Some(v);
}
}
None
}
fn requested_network_access(arguments: &serde_json::Value) -> Option<bool> {
for key in [
"network_access",
"networkAccess",
"allow_network",
"allowNetwork",
] {
if let Some(v) = arguments.get(key).and_then(|v| v.as_bool()) {
return Some(v);
}
}
None
}
fn code_uses_network(code: &str) -> bool {
let net_re = network_module_regex();
net_re.is_match(code)
}
}
impl Default for CodeExecutionGuard {
fn default() -> Self {
Self::new()
}
}
impl Guard for CodeExecutionGuard {
fn name(&self) -> &str {
"code-execution"
}
fn evaluate(&self, ctx: &GuardContext) -> Result<Verdict, KernelError> {
if !self.enabled {
return Ok(Verdict::Allow);
}
let action = extract_action(&ctx.request.tool_name, &ctx.request.arguments);
let (language, code) = match action {
ToolAction::CodeExecution { language, code } => (language, code),
_ => return Ok(Verdict::Allow),
};
if !self.language_allowlist.is_empty() {
let lang = language.to_ascii_lowercase();
if lang == "unknown" || !self.language_allowlist.contains(&lang) {
return Ok(Verdict::Deny);
}
}
let truncated = if code.len() > self.max_scan_bytes {
let mut end = self.max_scan_bytes;
while end > 0 && !code.is_char_boundary(end) {
end -= 1;
}
&code[..end]
} else {
code.as_str()
};
for (name, re) in &self.module_patterns {
if re.is_match(truncated) {
tracing::warn!(
guard = "code-execution",
module = %name,
"denying code execution: dangerous module detected"
);
return Ok(Verdict::Deny);
}
}
if !self.network_access {
let requested = Self::requested_network_access(&ctx.request.arguments).unwrap_or(false);
if requested || Self::code_uses_network(truncated) {
return Ok(Verdict::Deny);
}
}
if let Some(max_ms) = self.max_execution_time_ms {
if let Some(requested) = Self::read_execution_time_ms(&ctx.request.arguments) {
if requested > max_ms {
return Ok(Verdict::Deny);
}
}
}
Ok(Verdict::Allow)
}
}
fn module_regex_source(module: &str) -> String {
let escaped = regex::escape(module);
format!(
r#"(?m)(?:^|[^A-Za-z0-9_])(?:import\s+{m}(?:\s|$|\.|,)|from\s+{m}(?:\s|\.)|require\s*\(\s*['"]{m}['"]\s*\)|{m}\s*\.)"#,
m = escaped
)
}
fn network_module_regex() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
let alternation = default_network_modules()
.iter()
.map(|m| regex::escape(m))
.collect::<Vec<_>>()
.join("|");
match Regex::new(&format!(
r#"(?m)(?:^|[^A-Za-z0-9_])(?:import\s+(?:{a})(?:\s|$|\.|,)|from\s+(?:{a})(?:\s|\.)|require\s*\(\s*['"](?:{a})['"]\s*\)|\bfetch\s*\()"#,
a = alternation
)) {
Ok(re) => re,
Err(err) => {
tracing::error!(error = %err, "code-execution: failed to compile network regex");
#[allow(clippy::expect_used)]
{
Regex::new(r"\A\z").expect("empty-string regex compiles")
}
}
}
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn module_regex_matches_import_forms() {
let re = Regex::new(&module_regex_source("subprocess")).unwrap();
assert!(re.is_match("import subprocess\n"));
assert!(re.is_match("from subprocess import call"));
assert!(re.is_match("require('subprocess')"));
assert!(re.is_match("subprocess.run(['ls'])"));
assert!(!re.is_match("import subprocesses\n"));
assert!(!re.is_match("# subprocess comment with no code"));
}
#[test]
fn network_module_regex_detects_requests() {
let re = network_module_regex();
assert!(re.is_match("import requests\n"));
assert!(re.is_match("from urllib import parse"));
assert!(re.is_match("fetch('https://x')"));
assert!(!re.is_match("import math"));
}
}