use async_trait::async_trait;
use rucora_core::{
error::ToolError,
tool::{Tool, ToolCategory},
};
use serde_json::{Value, json};
use super::shell::{SHELL_TIMEOUT_SECS, execute_shell_command};
pub struct CmdExecTool {
pub allowed_prefixes: &'static [&'static str],
}
impl CmdExecTool {
pub fn new() -> Self {
Self {
allowed_prefixes: &["curl", "curl.exe"],
}
}
fn validate_command(&self, cmd: &str) -> Result<(), ToolError> {
let t = cmd.trim();
let prefix_ok = self
.allowed_prefixes
.iter()
.any(|p| t == *p || t.starts_with(&format!("{p} ")));
if !prefix_ok {
return Err(ToolError::Message(
"出于安全考虑,cmd_exec 目前仅允许执行 curl 命令".to_string(),
));
}
let forbidden = ["|", "&&", ";", ">", "<", "`", "$ (", "$(", "\n", "\r"];
if forbidden.iter().any(|x| t.contains(x)) {
return Err(ToolError::Message(
"出于安全考虑,cmd_exec 禁止管道/重定向/链式/多行命令".to_string(),
));
}
Ok(())
}
}
impl Default for CmdExecTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for CmdExecTool {
fn name(&self) -> &str {
"cmd_exec"
}
fn description(&self) -> Option<&str> {
Some("执行受限的命令行(当前仅允许 curl)")
}
fn categories(&self) -> &'static [ToolCategory] {
&[ToolCategory::System]
}
fn input_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "要执行的命令行(仅允许以 curl 开头)"
},
"timeout": {
"type": "integer",
"description": "超时时间(秒),默认 60 秒",
"default": 60
}
},
"required": ["command"]
})
}
async fn call(&self, input: Value) -> Result<Value, ToolError> {
let command = input
.get("command")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::Message("缺少必需的 'command' 字段".to_string()))?;
let timeout_secs = input
.get("timeout")
.and_then(|v| v.as_u64())
.unwrap_or(SHELL_TIMEOUT_SECS);
self.validate_command(command)?;
let mut parts = command.split_whitespace();
let executable = parts
.next()
.ok_or_else(|| ToolError::Message("命令不能为空".to_string()))?;
let args: Vec<String> = parts.map(String::from).collect();
let result = execute_shell_command(executable, &args, timeout_secs, None).await?;
Ok(json!({
"command": command,
"stdout": result.stdout,
"stderr": result.stderr,
"exit_code": result.exit_code,
"success": result.exit_code == 0,
"truncated": result.truncated
}))
}
}