use super::{Tool, ToolParameters, ToolResult};
use crate::error::{Result, ToolError};
use crate::sandbox::{SandboxCommand, SandboxExecutor};
use futures::future::BoxFuture;
use serde_json::Value;
use shlex::split as shlex_split;
use std::collections::HashSet;
use std::sync::{Arc, LazyLock};
use tokio::process::Command;
static ALLOWED_COMMANDS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
HashSet::from([
"ls", "cat", "head", "tail", "less", "more", "file", "stat", "wc",
"pwd", "tree", "find", "du", "git", "cargo", "rustc", "clippy", "rustfmt", "grep", "rg", "ag", "fd", "echo", "printf", "cut", "sort", "uniq", "diff",
"which", "whereis", "env", "date", "uname",
])
});
static REQUIRE_APPROVAL_COMMANDS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
HashSet::from([
"rm", "rmdir", "mv", "cp", "curl", "wget", "nc", "kill", "killall", "pkill", "apt", "apt-get", "yum", "dnf", "brew", "pip", "pip3", "npm", "yarn", "pnpm",
"bash", "sh", "zsh", "fish", "python", "python3", "node", "perl", "ruby", "php",
"sed", "awk",
])
});
static DANGEROUS_COMMANDS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
HashSet::from([
"dd", "shred", "mkfs", "fdisk", "sudo", "su", "chmod", "chown", "chgrp", "reboot", "shutdown", "halt", "poweroff", "init",
"nmap",
])
});
static GIT_SAFE_SUBCOMMANDS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
HashSet::from([
"status", "log", "show", "diff", "branch", "tag", "ls-files", "ls-tree", "remote", "config",
"add", "commit", "checkout", "switch", "stash",
])
});
static CARGO_SAFE_SUBCOMMANDS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
HashSet::from([
"check", "build", "test", "clippy", "fmt", "tree", "search", "metadata",
"clean", "update",
])
});
const SHELL_METACHARACTERS: &[char] = &[
'|', ';', '&', '$', '`', '>', '<', '(', ')', '\n', ];
#[derive(Debug, Clone, PartialEq)]
pub enum CommandSafety {
Safe,
RequiresApproval(String),
Dangerous(String),
}
pub struct ShellTool {
strict_mode: bool,
sandbox: Option<Arc<dyn SandboxExecutor>>,
}
impl Default for ShellTool {
fn default() -> Self {
Self::new()
}
}
impl ShellTool {
pub fn new() -> Self {
Self {
strict_mode: true,
sandbox: None,
}
}
pub fn new_permissive() -> Self {
Self {
strict_mode: false,
sandbox: None,
}
}
pub fn with_sandbox(mut self, sandbox: Arc<dyn SandboxExecutor>) -> Self {
self.sandbox = Some(sandbox);
self
}
fn has_shell_metacharacters(&self, cmd: &str) -> bool {
cmd.contains(SHELL_METACHARACTERS)
}
pub fn check_command_safety(&self, command: &str) -> CommandSafety {
if self.has_shell_metacharacters(command) {
return CommandSafety::Dangerous(format!(
"命令包含 shell 元字符(| ; & $ ` > < () 等),已拒绝执行。\
\n本工具仅支持简单命令(程序名 + 参数),不支持管道、重定向、命令替换等 shell 语法。\
\n命令: {}",
command
));
}
let parts = match shlex_split(command) {
Some(parts) => parts,
None => {
return CommandSafety::Dangerous(format!(
"命令解析失败,可能包含未闭合引号或非法参数格式: {}",
command
));
}
};
if parts.is_empty() {
return CommandSafety::Dangerous("空命令".to_string());
}
let base_cmd = parts[0].as_str();
if DANGEROUS_COMMANDS.contains(base_cmd) {
return CommandSafety::Dangerous(format!(
"命令 '{}' 在危险命令黑名单中,已拒绝执行",
base_cmd
));
}
if REQUIRE_APPROVAL_COMMANDS.contains(base_cmd) {
return CommandSafety::RequiresApproval(format!(
"命令 '{}' 可能造成系统变更,需要人工确认",
base_cmd
));
}
if self.strict_mode && !ALLOWED_COMMANDS.contains(base_cmd) {
return CommandSafety::Dangerous(format!(
"命令 '{}' 不在安全白名单中,已拒绝执行",
base_cmd
));
}
match base_cmd {
"git" => self.check_git_command(&parts),
"cargo" => self.check_cargo_command(&parts),
_ => CommandSafety::Safe,
}
}
fn check_git_command(&self, parts: &[String]) -> CommandSafety {
if parts.len() < 2 {
return CommandSafety::Safe;
}
let subcommand = parts[1].as_str();
match subcommand {
"push" | "pull" | "fetch" | "clone" => CommandSafety::RequiresApproval(format!(
"git {} 涉及网络操作,需要确认",
subcommand
)),
"reset" => {
if parts.iter().any(|part| part == "--hard") {
CommandSafety::Dangerous(
"git reset --hard 会丢失数据,已拒绝。如需执行请手动操作".to_string(),
)
} else {
CommandSafety::RequiresApproval(
"git reset 会修改 Git 状态,需要确认".to_string(),
)
}
}
"clean" => {
CommandSafety::RequiresApproval("git clean 会删除未跟踪文件,需要确认".to_string())
}
cmd if GIT_SAFE_SUBCOMMANDS.contains(cmd) => {
if cmd == "commit" || cmd == "add" || cmd == "checkout" {
CommandSafety::RequiresApproval(format!("git {} 会修改仓库,需要确认", cmd))
} else {
CommandSafety::Safe
}
}
_ => CommandSafety::RequiresApproval(format!(
"git {} 不在已知安全列表中,需要确认",
subcommand
)),
}
}
fn check_cargo_command(&self, parts: &[String]) -> CommandSafety {
if parts.len() < 2 {
return CommandSafety::Safe;
}
let subcommand = parts[1].as_str();
match subcommand {
"install" | "uninstall" | "publish" => CommandSafety::RequiresApproval(format!(
"cargo {} 涉及包安装/发布,需要确认",
subcommand
)),
"run" => CommandSafety::RequiresApproval("cargo run 会执行程序,需要确认".to_string()),
cmd if CARGO_SAFE_SUBCOMMANDS.contains(cmd) => {
if cmd == "clean" || cmd == "update" {
CommandSafety::RequiresApproval(format!("cargo {} 会修改项目,需要确认", cmd))
} else {
CommandSafety::Safe
}
}
_ => CommandSafety::RequiresApproval(format!(
"cargo {} 不在已知安全列表中,需要确认",
subcommand
)),
}
}
}
impl Tool for ShellTool {
fn name(&self) -> &str {
"shell"
}
fn description(&self) -> &str {
"执行受限的 shell 命令(仅允许安全的只读操作和代码相关命令)。参数:command - 要执行的命令。注意:仅支持简单命令(程序名 + 参数),不支持管道、重定向、命令替换等 shell 语法。"
}
fn parameters(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "要执行的命令(仅限白名单中的安全命令,不支持管道/重定向/命令替换等 shell 语法)"
}
},
"required": ["command"]
})
}
fn execute(&self, parameters: ToolParameters) -> BoxFuture<'_, Result<ToolResult>> {
Box::pin(async move {
let command = parameters
.get("command")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::MissingParameter("command".to_string()))?;
match self.check_command_safety(command) {
CommandSafety::Safe => {}
CommandSafety::RequiresApproval(reason) => {
return Ok(ToolResult::error(format!(
"⚠️ 需要人工确认:{}\n命令:{}\n\n请使用 human_loop 模块进行确认后再执行。",
reason, command
)));
}
CommandSafety::Dangerous(reason) => {
return Ok(ToolResult::error(format!(
"🚫 安全拒绝:{}\n命令:{}\n\n如需执行此类操作,请手动在终端中执行。",
reason, command
)));
}
}
let parts = shlex_split(command).ok_or_else(|| ToolError::ExecutionFailed {
tool: self.name().to_string(),
message: "命令解析失败,可能包含未闭合引号或非法参数格式".to_string(),
})?;
let program = parts[0].as_str();
let args = &parts[1..];
if let Some(sandbox) = &self.sandbox {
let sandbox_cmd = SandboxCommand::program(program, args.to_vec());
match sandbox.execute(sandbox_cmd).await {
Ok(result) => {
if result.success() {
Ok(ToolResult::success(result.stdout))
} else {
Ok(ToolResult::error(format!(
"命令执行失败,退出码: {}\n标准输出: {}\n错误输出: {}",
result.exit_code, result.stdout, result.stderr
)))
}
}
Err(e) => Ok(ToolResult::error(format!("沙箱执行失败: {}", e))),
}
} else {
match Command::new(program).args(args).output().await {
Ok(output) => {
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
if output.status.success() {
Ok(ToolResult::success(stdout))
} else {
Ok(ToolResult::error(format!(
"命令执行失败,退出码: {:?}\n标准输出: {}\n错误输出: {}",
output.status.code(),
stdout,
stderr
)))
}
}
Err(e) => Ok(ToolResult::error(format!("无法执行命令: {}", e))),
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_safe_commands() {
let tool = ShellTool::new();
assert_eq!(tool.check_command_safety("ls -la"), CommandSafety::Safe);
assert_eq!(tool.check_command_safety("pwd"), CommandSafety::Safe);
assert_eq!(
tool.check_command_safety("cat README.md"),
CommandSafety::Safe
);
assert_eq!(tool.check_command_safety("git status"), CommandSafety::Safe);
assert_eq!(
tool.check_command_safety("cargo check"),
CommandSafety::Safe
);
}
#[test]
fn test_shell_injection_rejected() {
let tool = ShellTool::new();
match tool.check_command_safety("ls | rm -rf /tmp") {
CommandSafety::Dangerous(_) => {}
other => panic!("管道注入应被拒绝,但得到: {:?}", other),
}
match tool.check_command_safety("echo $(id)") {
CommandSafety::Dangerous(_) => {}
other => panic!("命令替换注入应被拒绝,但得到: {:?}", other),
}
match tool.check_command_safety("echo `id`") {
CommandSafety::Dangerous(_) => {}
other => panic!("反引号注入应被拒绝,但得到: {:?}", other),
}
match tool.check_command_safety("ls; rm -rf /tmp/x") {
CommandSafety::Dangerous(_) => {}
other => panic!("分号注入应被拒绝,但得到: {:?}", other),
}
match tool.check_command_safety("cat file > /etc/passwd") {
CommandSafety::Dangerous(_) => {}
other => panic!("重定向注入应被拒绝,但得到: {:?}", other),
}
match tool.check_command_safety("echo hello && rm -rf /") {
CommandSafety::Dangerous(_) => {}
other => panic!("条件执行注入应被拒绝,但得到: {:?}", other),
}
match tool.check_command_safety("$(dangerous)") {
CommandSafety::Dangerous(_) => {}
other => panic!("子 shell 注入应被拒绝,但得到: {:?}", other),
}
}
#[test]
fn test_require_approval_commands() {
let tool = ShellTool::new();
match tool.check_command_safety("rm -rf /tmp/test") {
CommandSafety::RequiresApproval(_) => {}
_ => panic!("rm 命令应该需要确认"),
}
match tool.check_command_safety("curl http://example.com") {
CommandSafety::RequiresApproval(_) => {}
_ => panic!("curl 命令应该需要确认"),
}
match tool.check_command_safety("npm install package") {
CommandSafety::RequiresApproval(_) => {}
_ => panic!("npm 命令应该需要确认"),
}
match tool.check_command_safety("python script.py") {
CommandSafety::RequiresApproval(_) => {}
_ => panic!("python 命令应该需要确认"),
}
}
#[test]
fn test_dangerous_commands() {
let tool = ShellTool::new();
match tool.check_command_safety("dd if=/dev/zero of=/dev/sda") {
CommandSafety::Dangerous(_) => {}
_ => panic!("dd 命令应该被拒绝"),
}
match tool.check_command_safety("sudo apt install") {
CommandSafety::Dangerous(_) => {}
_ => panic!("sudo 命令应该被拒绝"),
}
match tool.check_command_safety("chmod 777 /etc/passwd") {
CommandSafety::Dangerous(_) => {}
_ => panic!("chmod 命令应该被拒绝"),
}
match tool.check_command_safety("reboot") {
CommandSafety::Dangerous(_) => {}
_ => panic!("reboot 命令应该被拒绝"),
}
}
#[test]
fn test_git_commands() {
let tool = ShellTool::new();
assert_eq!(tool.check_command_safety("git log"), CommandSafety::Safe);
assert_eq!(tool.check_command_safety("git diff"), CommandSafety::Safe);
assert_eq!(tool.check_command_safety("git status"), CommandSafety::Safe);
match tool.check_command_safety("git commit -m 'test'") {
CommandSafety::RequiresApproval(_) => {}
_ => panic!("git commit 应该需要确认"),
}
match tool.check_command_safety("git push origin main") {
CommandSafety::RequiresApproval(_) => {}
_ => panic!("git push 应该需要确认"),
}
match tool.check_command_safety("git add .") {
CommandSafety::RequiresApproval(_) => {}
_ => panic!("git add 应该需要确认"),
}
match tool.check_command_safety("git clean -fd") {
CommandSafety::RequiresApproval(_) => {}
_ => panic!("git clean 应该需要确认"),
}
match tool.check_command_safety("git reset --hard HEAD~1") {
CommandSafety::Dangerous(_) => {}
_ => panic!("git reset --hard 应该被拒绝"),
}
}
#[test]
fn test_cargo_commands() {
let tool = ShellTool::new();
assert_eq!(
tool.check_command_safety("cargo check"),
CommandSafety::Safe
);
assert_eq!(tool.check_command_safety("cargo test"), CommandSafety::Safe);
assert_eq!(
tool.check_command_safety("cargo clippy"),
CommandSafety::Safe
);
assert_eq!(
tool.check_command_safety("cargo build"),
CommandSafety::Safe
);
match tool.check_command_safety("cargo run") {
CommandSafety::RequiresApproval(_) => {}
_ => panic!("cargo run 应该需要确认"),
}
match tool.check_command_safety("cargo install some-package") {
CommandSafety::RequiresApproval(_) => {}
_ => panic!("cargo install 应该需要确认"),
}
match tool.check_command_safety("cargo clean") {
CommandSafety::RequiresApproval(_) => {}
_ => panic!("cargo clean 应该需要确认"),
}
}
#[test]
fn test_unknown_command_in_strict_mode() {
let tool = ShellTool::new();
match tool.check_command_safety("unknown_command") {
CommandSafety::Dangerous(_) => {}
_ => panic!("严格模式下应该拒绝未知命令"),
}
}
#[tokio::test]
async fn test_shell_tool_execution() {
let tool = ShellTool::new();
let mut params = HashMap::new();
params.insert("command".to_string(), serde_json::json!("echo hello"));
let result = tool.execute(params).await.unwrap();
assert!(result.success);
assert!(result.output.contains("hello"));
let mut params = HashMap::new();
params.insert("command".to_string(), serde_json::json!("rm test.txt"));
let result = tool.execute(params).await.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("确认"));
let mut params = HashMap::new();
params.insert("command".to_string(), serde_json::json!("sudo reboot"));
let result = tool.execute(params).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("拒绝"));
}
#[tokio::test]
async fn test_shell_injection_rejected_in_execution() {
let tool = ShellTool::new();
let mut params = HashMap::new();
params.insert("command".to_string(), serde_json::json!("ls | rm -rf /tmp"));
let result = tool.execute(params).await.unwrap();
assert!(!result.success, "管道注入应被拒绝");
assert!(result.error.as_ref().unwrap().contains("shell 元字符"));
let mut params = HashMap::new();
params.insert("command".to_string(), serde_json::json!("echo $(id)"));
let result = tool.execute(params).await.unwrap();
assert!(!result.success, "命令替换应被拒绝");
assert!(result.error.as_ref().unwrap().contains("shell 元字符"));
let mut params = HashMap::new();
params.insert("command".to_string(), serde_json::json!("ls; echo pwned"));
let result = tool.execute(params).await.unwrap();
assert!(!result.success, "分号注入应被拒绝");
assert!(result.error.as_ref().unwrap().contains("shell 元字符"));
}
#[tokio::test]
async fn test_shell_tool_with_sandbox() {
use crate::sandbox::{LocalConfig, LocalSandbox};
let config = LocalConfig {
enable_os_sandbox: false,
..Default::default()
};
let sandbox = Arc::new(LocalSandbox::new(config));
let tool = ShellTool::new().with_sandbox(sandbox);
let mut params = HashMap::new();
params.insert(
"command".to_string(),
serde_json::json!("echo sandbox_test"),
);
let result = tool.execute(params).await.unwrap();
assert!(result.success, "Tool failed: {:?}", result.error);
assert!(result.output.contains("sandbox_test"));
}
}