use async_trait::async_trait;
use rucora_core::{
error::ToolError,
tool::{Tool, ToolCategory},
};
use serde_json::{Value, json};
use std::collections::HashSet;
use std::path::Path;
use std::time::Duration;
use tokio::time::timeout;
fn get_shell_description() -> &'static str {
if cfg!(target_os = "windows") {
"执行系统命令。当前平台:Windows。请使用 Windows 命令:dir (列表), cd (切换目录), type (查看文件), findstr (搜索), copy (复制), move (移动), del (删除), mkdir (创建目录)。命令必须与当前操作系统兼容。"
} else if cfg!(target_os = "macos") {
"执行系统命令。当前平台:macOS。请使用 macOS 命令:ls (列表), cd (切换目录), cat (查看文件), grep (搜索), cp (复制), mv (移动), rm (删除), mkdir (创建目录)。命令必须与当前操作系统兼容。"
} else if cfg!(target_os = "linux") {
"执行系统命令。当前平台:Linux。请使用 Linux 命令:ls (列表), cd (切换目录), cat (查看文件), grep (搜索), cp (复制), mv (移动), rm (删除), mkdir (创建目录)。命令必须与当前操作系统兼容。"
} else {
"执行系统命令。请使用适合当前平台的命令。"
}
}
pub const SHELL_TIMEOUT_SECS: u64 = 60;
pub const MAX_OUTPUT_BYTES: usize = 1_048_576;
const FORBIDDEN_COMMANDS: &[&str] = &[
"rm -rf",
"rm -fr",
"del /f/s/q", "format",
"mkfs",
"diskpart", "shutdown",
"reboot",
"halt", "wget",
"curl", ];
const DANGEROUS_OPERATORS: &[&str] = &[
"|", "||", "&&", ";", ">", ">>", "<", "<<<", "`", "$(", "${", "\n", "\r", "\\", ];
pub struct ShellTool {
allowed_commands: HashSet<String>,
forbidden_commands: HashSet<String>,
}
impl ShellTool {
pub fn new() -> Self {
Self {
allowed_commands: HashSet::new(),
forbidden_commands: HashSet::new(),
}
}
pub fn with_allowed_commands(mut self, commands: Vec<String>) -> Self {
self.allowed_commands = commands.into_iter().collect();
self
}
pub fn with_forbidden_commands(mut self, commands: Vec<String>) -> Self {
self.forbidden_commands = commands.into_iter().collect();
self
}
fn validate_command(&self, command: &str) -> Result<(), ToolError> {
let cmd_lower = command.to_lowercase();
for forbidden in FORBIDDEN_COMMANDS {
if cmd_lower.contains(forbidden) {
return Err(ToolError::Message(format!(
"命令包含禁止的操作:{forbidden}"
)));
}
}
for forbidden in &self.forbidden_commands {
if cmd_lower.contains(forbidden) {
return Err(ToolError::Message(format!(
"命令包含禁止的操作:{forbidden}"
)));
}
}
for operator in DANGEROUS_OPERATORS {
if command.contains(operator) {
return Err(ToolError::Message(format!(
"命令包含危险操作符:{operator}"
)));
}
}
if !self.allowed_commands.is_empty() {
let cmd_name = command.split_whitespace().next().unwrap_or(command);
if !self.allowed_commands.contains(cmd_name) {
return Err(ToolError::Message(format!(
"命令 {cmd_name} 不在允许的白名单中"
)));
}
}
if command.contains("..") {
return Err(ToolError::Message(
"命令包含路径遍历(..),这是不安全的".to_string(),
));
}
Ok(())
}
fn validate_working_dir(&self, dir: &str) -> Result<(), ToolError> {
let path = Path::new(dir);
if dir.contains("..") {
return Err(ToolError::Message(
"工作目录包含路径遍历(..),这是不安全的".to_string(),
));
}
if !path.exists() {
return Err(ToolError::Message(format!("工作目录不存在:{dir}")));
}
if !path.is_dir() {
return Err(ToolError::Message(format!("工作目录路径不是目录:{dir}")));
}
Ok(())
}
}
impl Default for ShellTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for ShellTool {
fn name(&self) -> &str {
"shell"
}
fn description(&self) -> Option<&str> {
Some(get_shell_description())
}
fn categories(&self) -> &'static [ToolCategory] {
&[ToolCategory::System]
}
fn input_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "要执行的命令"
},
"args": {
"type": "array",
"items": {"type": "string"},
"description": "命令参数列表"
},
"timeout": {
"type": "integer",
"description": "超时时间(秒),默认 60 秒",
"default": 60
},
"working_dir": {
"type": "string",
"description": "工作目录(可选)"
}
},
"required": ["command"]
})
}
async fn call(&self, input: Value) -> Result<Value, ToolError> {
let command = input
.get("command")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::Message("缺少必需的 'command' 字段".to_string()))?;
let args: Vec<String> = input
.get("args")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let timeout_secs = input
.get("timeout")
.and_then(|v| v.as_u64())
.unwrap_or(SHELL_TIMEOUT_SECS);
self.validate_command(command)?;
let working_dir = input.get("working_dir").and_then(|v| v.as_str());
if let Some(dir) = working_dir {
self.validate_working_dir(dir)?;
}
let full_command = if args.is_empty() {
command.to_string()
} else {
format!("{} {}", command, args.join(" "))
};
let result = execute_shell_command(&full_command, timeout_secs, working_dir).await?;
Ok(json!({
"command": full_command,
"stdout": result.stdout,
"stderr": result.stderr,
"exit_code": result.exit_code,
"success": result.exit_code == 0,
"truncated": result.truncated
}))
}
}
pub struct CommandResult {
pub stdout: String,
pub stderr: String,
pub exit_code: i32,
pub truncated: bool,
}
pub async fn execute_shell_command(
command: &str,
timeout_secs: u64,
working_dir: Option<&str>,
) -> Result<CommandResult, ToolError> {
let timeout_duration = Duration::from_secs(timeout_secs);
let (shell, shell_arg) = if cfg!(target_os = "windows") {
("cmd", "/C")
} else {
("sh", "-c")
};
let mut cmd = tokio::process::Command::new(shell);
cmd.arg(shell_arg).arg(command);
if let Some(dir) = working_dir {
cmd.current_dir(dir);
}
cmd.env_remove("AWS_SECRET_ACCESS_KEY");
cmd.env_remove("AZURE_CLIENT_SECRET");
cmd.env_remove("GCP_SERVICE_ACCOUNT_KEY");
let output = timeout(timeout_duration, cmd.output())
.await
.map_err(|_| ToolError::Message(format!("命令执行超时({timeout_secs} 秒)")))?
.map_err(|e| ToolError::Message(format!("命令执行失败:{e}")))?;
let exit_code = output.status.code().unwrap_or(-1);
let (stdout, stdout_truncated) = truncate_output(&output.stdout);
let (stderr, stderr_truncated) = truncate_output(&output.stderr);
Ok(CommandResult {
stdout,
stderr,
exit_code,
truncated: stdout_truncated || stderr_truncated,
})
}
pub fn truncate_output(output: &[u8]) -> (String, bool) {
if output.len() > MAX_OUTPUT_BYTES {
let truncated = String::from_utf8_lossy(&output[..MAX_OUTPUT_BYTES]);
(format!("{truncated}... [截断]"), true)
} else {
(String::from_utf8_lossy(output).to_string(), false)
}
}