Skip to main content

rucora_tools/system/
shell.rs

1//! Shell 工具模块
2//!
3//! 提供系统命令执行功能,支持超时和安全限制
4
5use async_trait::async_trait;
6use rucora_core::{
7    error::ToolError,
8    tool::{Tool, ToolCategory},
9};
10use serde_json::{Value, json};
11use std::collections::HashSet;
12use std::path::{Path, PathBuf};
13use std::time::Duration;
14use tokio::time::timeout;
15
16/// 获取 Shell 工具描述
17fn get_shell_description() -> &'static str {
18    if cfg!(target_os = "windows") {
19        "执行系统命令。当前平台:Windows。常用命令:dir、cd、type、findstr、copy、move、del、mkdir。请只使用与 Windows 兼容的命令。"
20    } else if cfg!(target_os = "macos") {
21        "执行系统命令。当前平台:macOS。常用命令:ls、cd、cat、grep、cp、mv、rm、mkdir。请只使用与 macOS 兼容的命令。"
22    } else if cfg!(target_os = "linux") {
23        "执行系统命令。当前平台:Linux。常用命令:ls、cd、cat、grep、cp、mv、rm、mkdir。请只使用与 Linux 兼容的命令。"
24    } else {
25        "执行系统命令。请使用适合当前平台的命令。"
26    }
27}
28
29/// Shell 命令执行的超时时间(秒)
30pub const SHELL_TIMEOUT_SECS: u64 = 60;
31/// 最大输出大小(1MB),防止内存溢出
32pub const MAX_OUTPUT_BYTES: usize = 1_048_576;
33
34/// 默认禁止的命令列表
35const FORBIDDEN_COMMANDS: &[&str] = &[
36    "rm -rf",
37    "rm -fr",
38    "del /f/s/q", // 强制删除
39    "format",
40    "mkfs",
41    "diskpart", // 磁盘操作
42    "shutdown",
43    "reboot",
44    "halt", // 系统操作
45    "wget",
46    "curl", // 网络下载(可用受限版本替代)
47];
48
49/// 默认禁止的危险操作符
50const DANGEROUS_OPERATORS: &[&str] = &[
51    "|", "||", "&&", ";", ">", ">>", "<", "<<<", // 管道和重定向
52    "`", "$(", "${", // 命令替换
53    "\n", "\r", // 多行命令
54    "\\", // 续行符
55];
56
57/// Shell 工具:执行系统命令。
58///
59/// 支持超时和输出限制,防止命令执行时间过长或输出过大。
60///
61/// # 安全机制
62///
63/// - 命令黑名单检查
64/// - 危险操作符检测
65/// - 路径遍历防护
66/// - 安全的环境变量(清除敏感变量)
67/// - 超时和输出大小限制
68///
69/// 适用场景:
70/// - 执行系统命令
71/// - 运行脚本
72///
73/// 输入格式:
74/// ```json
75/// {
76///   "command": "要执行的命令",
77///   "args": ["命令参数"],
78///   "timeout": 60,  // 可选,超时时间(秒)
79///   "working_dir": "/path/to/dir"  // 可选,工作目录
80/// }
81/// ```
82pub struct ShellTool {
83    /// 允许的命令白名单(如果为空,则只检查黑名单)
84    allowed_commands: HashSet<String>,
85    /// 额外的禁止命令列表
86    forbidden_commands: HashSet<String>,
87}
88
89impl ShellTool {
90    /// 创建一个新的 ShellTool 实例(使用默认安全配置)。
91    pub fn new() -> Self {
92        Self {
93            allowed_commands: HashSet::new(),
94            forbidden_commands: HashSet::new(),
95        }
96    }
97
98    /// 设置允许的命令白名单。
99    ///
100    /// 如果设置了白名单,只有白名单中的命令可以执行。
101    pub fn with_allowed_commands(mut self, commands: Vec<String>) -> Self {
102        self.allowed_commands = commands.into_iter().collect();
103        self
104    }
105
106    /// 添加额外的禁止命令。
107    pub fn with_forbidden_commands(mut self, commands: Vec<String>) -> Self {
108        self.forbidden_commands = commands.into_iter().collect();
109        self
110    }
111
112    /// 检查命令是否安全
113    fn validate_command(&self, command: &str) -> Result<(), ToolError> {
114        let cmd_lower = command.to_lowercase();
115
116        // 检查是否在禁止列表中
117        for forbidden in FORBIDDEN_COMMANDS {
118            if cmd_lower.contains(forbidden) {
119                return Err(ToolError::Message(format!(
120                    "命令包含禁止的操作:{forbidden}"
121                )));
122            }
123        }
124
125        // 检查额外的禁止命令
126        for forbidden in &self.forbidden_commands {
127            if cmd_lower.contains(forbidden) {
128                return Err(ToolError::Message(format!(
129                    "命令包含禁止的操作:{forbidden}"
130                )));
131            }
132        }
133
134        // 检查危险操作符
135        for operator in DANGEROUS_OPERATORS {
136            if command.contains(operator) {
137                return Err(ToolError::Message(format!(
138                    "命令包含危险操作符:{operator}"
139                )));
140            }
141        }
142
143        // 如果设置了白名单,检查命令是否在白名单中
144        if !self.allowed_commands.is_empty() {
145            let cmd_name = command.split_whitespace().next().unwrap_or(command);
146            if !self.allowed_commands.contains(cmd_name) {
147                return Err(ToolError::Message(format!(
148                    "命令 {cmd_name} 不在允许的白名单中"
149                )));
150            }
151        }
152
153        // 检查路径遍历
154        if command.contains("..") {
155            return Err(ToolError::Message(
156                "命令包含路径遍历(..),这是不安全的".to_string(),
157            ));
158        }
159
160        Ok(())
161    }
162
163    /// 检查命令参数是否安全
164    fn validate_args(&self, args: &[String]) -> Result<(), ToolError> {
165        for arg in args {
166            for operator in DANGEROUS_OPERATORS {
167                if arg.contains(operator) {
168                    return Err(ToolError::Message(format!(
169                        "命令参数包含危险操作符:{operator}"
170                    )));
171                }
172            }
173
174            if arg.contains("..") {
175                return Err(ToolError::Message(
176                    "命令参数包含路径遍历(..),这是不安全的".to_string(),
177                ));
178            }
179        }
180        Ok(())
181    }
182
183    /// 检查工作目录是否安全
184    fn validate_working_dir(&self, dir: &str) -> Result<(), ToolError> {
185        let path = Path::new(dir);
186
187        // 检查路径遍历
188        if dir.contains("..") {
189            return Err(ToolError::Message(
190                "工作目录包含路径遍历(..),这是不安全的".to_string(),
191            ));
192        }
193
194        // 检查目录是否存在
195        if !path.exists() {
196            return Err(ToolError::Message(format!("工作目录不存在:{dir}")));
197        }
198
199        if !path.is_dir() {
200            return Err(ToolError::Message(format!("工作目录路径不是目录:{dir}")));
201        }
202
203        Ok(())
204    }
205}
206
207impl Default for ShellTool {
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213#[async_trait]
214impl Tool for ShellTool {
215    fn name(&self) -> &str {
216        "shell"
217    }
218
219    fn description(&self) -> Option<&str> {
220        Some(get_shell_description())
221    }
222
223    fn categories(&self) -> &'static [ToolCategory] {
224        &[ToolCategory::System]
225    }
226
227    fn input_schema(&self) -> Value {
228        json!({
229            "type": "object",
230            "properties": {
231                "command": {
232                    "type": "string",
233                    "description": "要执行的命令"
234                },
235                "args": {
236                    "type": "array",
237                    "items": {"type": "string"},
238                    "description": "命令参数列表"
239                },
240                "timeout": {
241                    "type": "integer",
242                    "description": "超时时间(秒),默认 60 秒",
243                    "default": 60
244                },
245                "working_dir": {
246                    "type": "string",
247                    "description": "工作目录(可选)"
248                }
249            },
250            "required": ["command"]
251        })
252    }
253
254    async fn call(&self, input: Value) -> Result<Value, ToolError> {
255        let command = input
256            .get("command")
257            .and_then(|v| v.as_str())
258            .ok_or_else(|| ToolError::Message("缺少必需的 'command' 字段".to_string()))?;
259
260        let args: Vec<String> = input
261            .get("args")
262            .and_then(|v| v.as_array())
263            .map(|arr| {
264                arr.iter()
265                    .filter_map(|v| v.as_str().map(String::from))
266                    .collect()
267            })
268            .unwrap_or_default();
269
270        let timeout_secs = input
271            .get("timeout")
272            .and_then(|v| v.as_u64())
273            .unwrap_or(SHELL_TIMEOUT_SECS);
274
275        // 验证命令安全性
276        self.validate_command(command)?;
277        self.validate_args(&args)?;
278
279        // 处理工作目录
280        let working_dir = input.get("working_dir").and_then(|v| v.as_str());
281
282        let working_dir = if let Some(dir) = working_dir {
283            self.validate_working_dir(dir)?;
284            Some(PathBuf::from(dir))
285        } else {
286            None
287        };
288
289        // 执行命令
290        let result =
291            execute_shell_command(command, &args, timeout_secs, working_dir.as_deref()).await?;
292
293        Ok(json!({
294            "command": command,
295            "args": args,
296            "stdout": result.stdout,
297            "stderr": result.stderr,
298            "exit_code": result.exit_code,
299            "success": result.exit_code == 0,
300            "truncated": result.truncated
301        }))
302    }
303}
304
305/// 命令执行结果
306pub struct CommandResult {
307    pub stdout: String,
308    pub stderr: String,
309    pub exit_code: i32,
310    pub truncated: bool,
311}
312
313/// 执行 shell 命令
314pub async fn execute_shell_command(
315    command: &str,
316    args: &[String],
317    timeout_secs: u64,
318    working_dir: Option<&Path>,
319) -> Result<CommandResult, ToolError> {
320    let timeout_duration = Duration::from_secs(timeout_secs);
321
322    let mut cmd = tokio::process::Command::new(command);
323    cmd.args(args);
324
325    // 设置工作目录
326    if let Some(dir) = working_dir {
327        cmd.current_dir(dir);
328    }
329
330    // 清除敏感环境变量
331    cmd.env_remove("AWS_SECRET_ACCESS_KEY");
332    cmd.env_remove("AZURE_CLIENT_SECRET");
333    cmd.env_remove("GCP_SERVICE_ACCOUNT_KEY");
334
335    // 执行命令(带超时)
336    let output = timeout(timeout_duration, cmd.output())
337        .await
338        .map_err(|_| ToolError::Message(format!("命令执行超时({timeout_secs} 秒)")))?
339        .map_err(|e| ToolError::Message(format!("命令执行失败:{e}")))?;
340
341    let exit_code = output.status.code().unwrap_or(-1);
342
343    // 处理输出(截断过长的输出)
344    let (stdout, stdout_truncated) = truncate_output(&output.stdout);
345    let (stderr, stderr_truncated) = truncate_output(&output.stderr);
346
347    Ok(CommandResult {
348        stdout,
349        stderr,
350        exit_code,
351        truncated: stdout_truncated || stderr_truncated,
352    })
353}
354
355/// 截断输出
356pub fn truncate_output(output: &[u8]) -> (String, bool) {
357    if output.len() > MAX_OUTPUT_BYTES {
358        let truncated = String::from_utf8_lossy(&output[..MAX_OUTPUT_BYTES]);
359        (format!("{truncated}... [截断]"), true)
360    } else {
361        (String::from_utf8_lossy(output).to_string(), false)
362    }
363}