use rucora_core::{
error::ToolError,
tool::{Tool, ToolCategory},
};
use async_trait::async_trait;
use serde_json::{Value, json};
use std::collections::HashSet;
use std::path::Path;
use std::time::Duration;
use tokio::time::timeout;
fn get_shell_description() -> &'static str {
if cfg!(target_os = "windows") {
"执行系统命令。当前平台:Windows。请使用 Windows 命令:dir (列表), cd (切换目录), type (查看文件), findstr (搜索), copy (复制), move (移动), del (删除), mkdir (创建目录)。命令必须与当前操作系统兼容。"
} else if cfg!(target_os = "macos") {
"执行系统命令。当前平台:macOS。请使用 macOS 命令:ls (列表), cd (切换目录), cat (查看文件), grep (搜索), cp (复制), mv (移动), rm (删除), mkdir (创建目录)。命令必须与当前操作系统兼容。"
} else if cfg!(target_os = "linux") {
"执行系统命令。当前平台:Linux。请使用 Linux 命令:ls (列表), cd (切换目录), cat (查看文件), grep (搜索), cp (复制), mv (移动), rm (删除), mkdir (创建目录)。命令必须与当前操作系统兼容。"
} else {
"执行系统命令。请使用适合当前平台的命令。"
}
}
pub const SHELL_TIMEOUT_SECS: u64 = 60;
pub const MAX_OUTPUT_BYTES: usize = 1_048_576;
const FORBIDDEN_COMMANDS: &[&str] = &[
"rm -rf",
"rm -fr",
"del /f/s/q", "format",
"mkfs",
"diskpart", "shutdown",
"reboot",
"halt", "wget",
"curl", ];
const DANGEROUS_OPERATORS: &[&str] = &[
"|", "||", "&&", ";", ">", ">>", "<", "<<<", "`", "$(", "${", "\n", "\r", "\\", ];
pub struct ShellTool {
allowed_commands: HashSet<String>,
forbidden_commands: HashSet<String>,
}
impl ShellTool {
pub fn new() -> Self {
Self {
allowed_commands: HashSet::new(),
forbidden_commands: FORBIDDEN_COMMANDS.iter().map(|s| s.to_string()).collect(),
}
}
pub fn with_allowed_commands(mut self, commands: Vec<String>) -> Self {
self.allowed_commands = commands.into_iter().collect();
self
}
pub fn with_forbidden_commands(mut self, commands: Vec<String>) -> Self {
self.forbidden_commands.extend(commands);
self
}
fn validate_command(&self, command: &str, args: &[String]) -> Result<(), ToolError> {
let full_command = if args.is_empty() {
command.to_string()
} else {
format!("{} {}", command, args.join(" "))
};
if !self.allowed_commands.is_empty() {
let cmd_name = command.split_whitespace().next().unwrap_or(command);
if !self.allowed_commands.contains(cmd_name) {
return Err(ToolError::PolicyDenied {
rule_id: "shell.whitelist".to_string(),
reason: format!("命令 '{cmd_name}' 不在白名单中"),
});
}
}
for forbidden in &self.forbidden_commands {
if full_command.contains(forbidden) {
return Err(ToolError::PolicyDenied {
rule_id: "shell.blacklist".to_string(),
reason: format!("命令包含禁止的操作: {forbidden}"),
});
}
}
for op in DANGEROUS_OPERATORS {
if command.contains(op) || args.iter().any(|a| a.contains(op)) {
return Err(ToolError::PolicyDenied {
rule_id: "shell.dangerous_operators".to_string(),
reason: format!("命令包含危险操作符: {op}"),
});
}
}
if command.contains("..") || args.iter().any(|a| a.contains("..")) {
return Err(ToolError::PolicyDenied {
rule_id: "shell.path_traversal".to_string(),
reason: "命令可能包含路径遍历攻击".to_string(),
});
}
let env_patterns = ["PASSWORD", "SECRET", "TOKEN", "API_KEY", "CREDENTIAL"];
for pattern in &env_patterns {
if command.contains(pattern) || args.iter().any(|a| a.contains(pattern)) {
return Err(ToolError::PolicyDenied {
rule_id: "shell.env_leak".to_string(),
reason: "命令可能泄露敏感环境变量".to_string(),
});
}
}
Ok(())
}
}
impl Default for ShellTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for ShellTool {
fn name(&self) -> &str {
"shell"
}
fn description(&self) -> Option<&str> {
Some(get_shell_description())
}
fn categories(&self) -> &'static [ToolCategory] {
&[ToolCategory::System]
}
fn input_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "要执行的命令(注意:必须与当前操作系统兼容。Windows 使用 dir/findstr/type 等命令,Linux/Mac 使用 ls/grep/cat 等命令)"
},
"args": {
"type": "array",
"items": {
"type": "string"
},
"description": "命令参数"
},
"timeout": {
"type": "integer",
"description": "超时时间(秒),默认 60 秒"
},
"working_dir": {
"type": "string",
"description": "工作目录(可选)"
}
},
"required": ["command"]
})
}
async fn call(&self, input: Value) -> Result<Value, ToolError> {
let command = input
.get("command")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::Message("缺少必需的 'command' 字段".to_string()))?;
let args: Vec<String> = input
.get("args")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|x| x.as_str())
.map(|s| s.to_string())
.collect::<Vec<_>>()
})
.unwrap_or_default();
self.validate_command(command, &args)?;
let timeout_secs = input
.get("timeout")
.and_then(|v| v.as_u64())
.unwrap_or(SHELL_TIMEOUT_SECS);
let working_dir = input
.get("working_dir")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
if let Some(ref dir) = working_dir {
let path = Path::new(dir);
if !path.exists() || !path.is_dir() {
return Err(ToolError::Message(format!(
"工作目录不存在或不是目录: {dir}"
)));
}
if dir.contains("..") {
return Err(ToolError::PolicyDenied {
rule_id: "shell.path_traversal".to_string(),
reason: "工作目录可能包含路径遍历".to_string(),
});
}
}
let result = timeout(
Duration::from_secs(timeout_secs),
execute_shell_command(command, &args, working_dir.as_deref()),
)
.await;
match result {
Ok(Ok(output)) => {
let stdout = truncate_output(String::from_utf8_lossy(&output.stdout).to_string());
let stderr = truncate_output(String::from_utf8_lossy(&output.stderr).to_string());
Ok(json!({
"stdout": stdout,
"stderr": stderr,
"exit_code": output.status.code().unwrap_or(-1),
"success": output.status.success()
}))
}
Ok(Err(e)) => Err(ToolError::Message(format!("命令执行失败: {e}"))),
Err(_) => Err(ToolError::Message(format!(
"命令执行超时(超过 {timeout_secs} 秒)"
))),
}
}
}
pub async fn execute_shell_command(
command: &str,
args: &[String],
working_dir: Option<&str>,
) -> Result<std::process::Output, std::io::Error> {
let mut cmd = if cfg!(target_os = "windows") {
let mut c = std::process::Command::new("cmd");
c.arg("/C");
if args.is_empty() {
c.arg(command);
} else {
let full_cmd = format!("{} {}", command, args.join(" "));
c.arg(full_cmd);
}
c
} else if cfg!(any(target_os = "linux", target_os = "macos")) {
let mut c = std::process::Command::new("sh");
c.arg("-c");
if args.is_empty() {
c.arg(command);
} else {
let full_cmd = format!("{} {}", command, args.join(" "));
c.arg(full_cmd);
}
c
} else {
let mut c = std::process::Command::new(command);
if !args.is_empty() {
c.args(args);
}
c
};
if let Some(dir) = working_dir {
cmd.current_dir(dir);
}
cmd.env_clear();
let safe_env_vars = [
"PATH",
"HOME",
"USER",
"SHELL",
"TMPDIR",
"TEMP",
"TMP",
"SystemRoot",
"USERPROFILE",
];
for var in &safe_env_vars {
if let Ok(val) = std::env::var(var) {
cmd.env(var, val);
}
}
tokio::task::spawn_blocking(move || cmd.output())
.await
.map_err(|e| std::io::Error::other(format!("任务执行失败: {e}")))?
}
pub fn truncate_output(mut output: String) -> String {
if output.len() > MAX_OUTPUT_BYTES {
let mut boundary = MAX_OUTPUT_BYTES;
while boundary > 0 && !output.is_char_boundary(boundary) {
boundary -= 1;
}
output.truncate(boundary);
output.push_str("\n... [输出已截断,超过 1MB 限制]");
}
output
}