use async_trait::async_trait;
use rucora_core::{
error::ToolError,
tool::{Tool, ToolCategory},
};
use serde_json::{Value, json};
use std::collections::HashSet;
use std::path::{Path, PathBuf};
use std::time::Duration;
use tokio::time::timeout;
fn get_shell_description() -> &'static str {
if cfg!(target_os = "windows") {
"执行系统命令。当前平台:Windows。常用命令:dir、cd、type、findstr、copy、move、del、mkdir。请只使用与 Windows 兼容的命令。"
} else if cfg!(target_os = "macos") {
"执行系统命令。当前平台:macOS。常用命令:ls、cd、cat、grep、cp、mv、rm、mkdir。请只使用与 macOS 兼容的命令。"
} else if cfg!(target_os = "linux") {
"执行系统命令。当前平台:Linux。常用命令:ls、cd、cat、grep、cp、mv、rm、mkdir。请只使用与 Linux 兼容的命令。"
} else {
"执行系统命令。请使用适合当前平台的命令。"
}
}
pub const SHELL_TIMEOUT_SECS: u64 = 60;
pub const MAX_OUTPUT_BYTES: usize = 1_048_576;
const FORBIDDEN_COMMANDS: &[&str] = &[
"rm -rf",
"rm -fr",
"del /f/s/q", "format",
"mkfs",
"diskpart", "shutdown",
"reboot",
"halt", "wget",
"curl", ];
const DANGEROUS_OPERATORS: &[&str] = &[
"|", "||", "&&", ";", ">", ">>", "<", "<<<", "`", "$(", "${", "\n", "\r", "\\", ];
pub struct ShellTool {
allowed_commands: HashSet<String>,
forbidden_commands: HashSet<String>,
}
impl ShellTool {
pub fn new() -> Self {
Self {
allowed_commands: HashSet::new(),
forbidden_commands: HashSet::new(),
}
}
pub fn with_allowed_commands(mut self, commands: Vec<String>) -> Self {
self.allowed_commands = commands.into_iter().collect();
self
}
pub fn with_forbidden_commands(mut self, commands: Vec<String>) -> Self {
self.forbidden_commands = commands.into_iter().collect();
self
}
fn validate_command(&self, command: &str) -> Result<(), ToolError> {
let cmd_lower = command.to_lowercase();
for forbidden in FORBIDDEN_COMMANDS {
if cmd_lower.contains(forbidden) {
return Err(ToolError::Message(format!(
"命令包含禁止的操作:{forbidden}"
)));
}
}
for forbidden in &self.forbidden_commands {
if cmd_lower.contains(forbidden) {
return Err(ToolError::Message(format!(
"命令包含禁止的操作:{forbidden}"
)));
}
}
for operator in DANGEROUS_OPERATORS {
if command.contains(operator) {
return Err(ToolError::Message(format!(
"命令包含危险操作符:{operator}"
)));
}
}
if !self.allowed_commands.is_empty() {
let cmd_name = command.split_whitespace().next().unwrap_or(command);
if !self.allowed_commands.contains(cmd_name) {
return Err(ToolError::Message(format!(
"命令 {cmd_name} 不在允许的白名单中"
)));
}
}
if command.contains("..") {
return Err(ToolError::Message(
"命令包含路径遍历(..),这是不安全的".to_string(),
));
}
Ok(())
}
fn validate_args(&self, args: &[String]) -> Result<(), ToolError> {
for arg in args {
for operator in DANGEROUS_OPERATORS {
if arg.contains(operator) {
return Err(ToolError::Message(format!(
"命令参数包含危险操作符:{operator}"
)));
}
}
if arg.contains("..") {
return Err(ToolError::Message(
"命令参数包含路径遍历(..),这是不安全的".to_string(),
));
}
}
Ok(())
}
fn validate_working_dir(&self, dir: &str) -> Result<(), ToolError> {
let path = Path::new(dir);
if dir.contains("..") {
return Err(ToolError::Message(
"工作目录包含路径遍历(..),这是不安全的".to_string(),
));
}
if !path.exists() {
return Err(ToolError::Message(format!("工作目录不存在:{dir}")));
}
if !path.is_dir() {
return Err(ToolError::Message(format!("工作目录路径不是目录:{dir}")));
}
Ok(())
}
}
impl Default for ShellTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for ShellTool {
fn name(&self) -> &str {
"shell"
}
fn description(&self) -> Option<&str> {
Some(get_shell_description())
}
fn categories(&self) -> &'static [ToolCategory] {
&[ToolCategory::System]
}
fn input_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "要执行的命令"
},
"args": {
"type": "array",
"items": {"type": "string"},
"description": "命令参数列表"
},
"timeout": {
"type": "integer",
"description": "超时时间(秒),默认 60 秒",
"default": 60
},
"working_dir": {
"type": "string",
"description": "工作目录(可选)"
}
},
"required": ["command"]
})
}
async fn call(&self, input: Value) -> Result<Value, ToolError> {
let command = input
.get("command")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::Message("缺少必需的 'command' 字段".to_string()))?;
let args: Vec<String> = input
.get("args")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let timeout_secs = input
.get("timeout")
.and_then(|v| v.as_u64())
.unwrap_or(SHELL_TIMEOUT_SECS);
self.validate_command(command)?;
self.validate_args(&args)?;
let working_dir = input.get("working_dir").and_then(|v| v.as_str());
let working_dir = if let Some(dir) = working_dir {
self.validate_working_dir(dir)?;
Some(PathBuf::from(dir))
} else {
None
};
let result =
execute_shell_command(command, &args, timeout_secs, working_dir.as_deref()).await?;
Ok(json!({
"command": command,
"args": args,
"stdout": result.stdout,
"stderr": result.stderr,
"exit_code": result.exit_code,
"success": result.exit_code == 0,
"truncated": result.truncated
}))
}
}
pub struct CommandResult {
pub stdout: String,
pub stderr: String,
pub exit_code: i32,
pub truncated: bool,
}
pub async fn execute_shell_command(
command: &str,
args: &[String],
timeout_secs: u64,
working_dir: Option<&Path>,
) -> Result<CommandResult, ToolError> {
let timeout_duration = Duration::from_secs(timeout_secs);
let mut cmd = tokio::process::Command::new(command);
cmd.args(args);
if let Some(dir) = working_dir {
cmd.current_dir(dir);
}
cmd.env_remove("AWS_SECRET_ACCESS_KEY");
cmd.env_remove("AZURE_CLIENT_SECRET");
cmd.env_remove("GCP_SERVICE_ACCOUNT_KEY");
let output = timeout(timeout_duration, cmd.output())
.await
.map_err(|_| ToolError::Message(format!("命令执行超时({timeout_secs} 秒)")))?
.map_err(|e| ToolError::Message(format!("命令执行失败:{e}")))?;
let exit_code = output.status.code().unwrap_or(-1);
let (stdout, stdout_truncated) = truncate_output(&output.stdout);
let (stderr, stderr_truncated) = truncate_output(&output.stderr);
Ok(CommandResult {
stdout,
stderr,
exit_code,
truncated: stdout_truncated || stderr_truncated,
})
}
pub fn truncate_output(output: &[u8]) -> (String, bool) {
if output.len() > MAX_OUTPUT_BYTES {
let truncated = String::from_utf8_lossy(&output[..MAX_OUTPUT_BYTES]);
(format!("{truncated}... [截断]"), true)
} else {
(String::from_utf8_lossy(output).to_string(), false)
}
}