1use async_trait::async_trait;
6use rucora_core::{
7 error::ToolError,
8 tool::{Tool, ToolCategory},
9};
10use serde_json::{Value, json};
11use std::collections::HashSet;
12use std::path::{Path, PathBuf};
13use std::time::Duration;
14use tokio::time::timeout;
15
16fn get_shell_description() -> &'static str {
18 if cfg!(target_os = "windows") {
19 "执行系统命令。当前平台:Windows。常用命令:dir、cd、type、findstr、copy、move、del、mkdir。请只使用与 Windows 兼容的命令。"
20 } else if cfg!(target_os = "macos") {
21 "执行系统命令。当前平台:macOS。常用命令:ls、cd、cat、grep、cp、mv、rm、mkdir。请只使用与 macOS 兼容的命令。"
22 } else if cfg!(target_os = "linux") {
23 "执行系统命令。当前平台:Linux。常用命令:ls、cd、cat、grep、cp、mv、rm、mkdir。请只使用与 Linux 兼容的命令。"
24 } else {
25 "执行系统命令。请使用适合当前平台的命令。"
26 }
27}
28
29pub const SHELL_TIMEOUT_SECS: u64 = 60;
31pub const MAX_OUTPUT_BYTES: usize = 1_048_576;
33
34const FORBIDDEN_COMMANDS: &[&str] = &[
36 "rm -rf",
37 "rm -fr",
38 "del /f/s/q", "format",
40 "mkfs",
41 "diskpart", "shutdown",
43 "reboot",
44 "halt", "wget",
46 "curl", ];
48
49const DANGEROUS_OPERATORS: &[&str] = &[
51 "|", "||", "&&", ";", ">", ">>", "<", "<<<", "`", "$(", "${", "\n", "\r", "\\", ];
56
57pub struct ShellTool {
83 allowed_commands: HashSet<String>,
85 forbidden_commands: HashSet<String>,
87}
88
89impl ShellTool {
90 pub fn new() -> Self {
92 Self {
93 allowed_commands: HashSet::new(),
94 forbidden_commands: HashSet::new(),
95 }
96 }
97
98 pub fn with_allowed_commands(mut self, commands: Vec<String>) -> Self {
102 self.allowed_commands = commands.into_iter().collect();
103 self
104 }
105
106 pub fn with_forbidden_commands(mut self, commands: Vec<String>) -> Self {
108 self.forbidden_commands = commands.into_iter().collect();
109 self
110 }
111
112 fn validate_command(&self, command: &str) -> Result<(), ToolError> {
114 let cmd_lower = command.to_lowercase();
115
116 for forbidden in FORBIDDEN_COMMANDS {
118 if cmd_lower.contains(forbidden) {
119 return Err(ToolError::Message(format!(
120 "命令包含禁止的操作:{forbidden}"
121 )));
122 }
123 }
124
125 for forbidden in &self.forbidden_commands {
127 if cmd_lower.contains(forbidden) {
128 return Err(ToolError::Message(format!(
129 "命令包含禁止的操作:{forbidden}"
130 )));
131 }
132 }
133
134 for operator in DANGEROUS_OPERATORS {
136 if command.contains(operator) {
137 return Err(ToolError::Message(format!(
138 "命令包含危险操作符:{operator}"
139 )));
140 }
141 }
142
143 if !self.allowed_commands.is_empty() {
145 let cmd_name = command.split_whitespace().next().unwrap_or(command);
146 if !self.allowed_commands.contains(cmd_name) {
147 return Err(ToolError::Message(format!(
148 "命令 {cmd_name} 不在允许的白名单中"
149 )));
150 }
151 }
152
153 if command.contains("..") {
155 return Err(ToolError::Message(
156 "命令包含路径遍历(..),这是不安全的".to_string(),
157 ));
158 }
159
160 Ok(())
161 }
162
163 fn validate_args(&self, args: &[String]) -> Result<(), ToolError> {
165 for arg in args {
166 for operator in DANGEROUS_OPERATORS {
167 if arg.contains(operator) {
168 return Err(ToolError::Message(format!(
169 "命令参数包含危险操作符:{operator}"
170 )));
171 }
172 }
173
174 if arg.contains("..") {
175 return Err(ToolError::Message(
176 "命令参数包含路径遍历(..),这是不安全的".to_string(),
177 ));
178 }
179 }
180 Ok(())
181 }
182
183 fn validate_working_dir(&self, dir: &str) -> Result<(), ToolError> {
185 let path = Path::new(dir);
186
187 if dir.contains("..") {
189 return Err(ToolError::Message(
190 "工作目录包含路径遍历(..),这是不安全的".to_string(),
191 ));
192 }
193
194 if !path.exists() {
196 return Err(ToolError::Message(format!("工作目录不存在:{dir}")));
197 }
198
199 if !path.is_dir() {
200 return Err(ToolError::Message(format!("工作目录路径不是目录:{dir}")));
201 }
202
203 Ok(())
204 }
205}
206
207impl Default for ShellTool {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213#[async_trait]
214impl Tool for ShellTool {
215 fn name(&self) -> &str {
216 "shell"
217 }
218
219 fn description(&self) -> Option<&str> {
220 Some(get_shell_description())
221 }
222
223 fn categories(&self) -> &'static [ToolCategory] {
224 &[ToolCategory::System]
225 }
226
227 fn input_schema(&self) -> Value {
228 json!({
229 "type": "object",
230 "properties": {
231 "command": {
232 "type": "string",
233 "description": "要执行的命令"
234 },
235 "args": {
236 "type": "array",
237 "items": {"type": "string"},
238 "description": "命令参数列表"
239 },
240 "timeout": {
241 "type": "integer",
242 "description": "超时时间(秒),默认 60 秒",
243 "default": 60
244 },
245 "working_dir": {
246 "type": "string",
247 "description": "工作目录(可选)"
248 }
249 },
250 "required": ["command"]
251 })
252 }
253
254 async fn call(&self, input: Value) -> Result<Value, ToolError> {
255 let command = input
256 .get("command")
257 .and_then(|v| v.as_str())
258 .ok_or_else(|| ToolError::Message("缺少必需的 'command' 字段".to_string()))?;
259
260 let args: Vec<String> = input
261 .get("args")
262 .and_then(|v| v.as_array())
263 .map(|arr| {
264 arr.iter()
265 .filter_map(|v| v.as_str().map(String::from))
266 .collect()
267 })
268 .unwrap_or_default();
269
270 let timeout_secs = input
271 .get("timeout")
272 .and_then(|v| v.as_u64())
273 .unwrap_or(SHELL_TIMEOUT_SECS);
274
275 self.validate_command(command)?;
277 self.validate_args(&args)?;
278
279 let working_dir = input.get("working_dir").and_then(|v| v.as_str());
281
282 let working_dir = if let Some(dir) = working_dir {
283 self.validate_working_dir(dir)?;
284 Some(PathBuf::from(dir))
285 } else {
286 None
287 };
288
289 let result =
291 execute_shell_command(command, &args, timeout_secs, working_dir.as_deref()).await?;
292
293 Ok(json!({
294 "command": command,
295 "args": args,
296 "stdout": result.stdout,
297 "stderr": result.stderr,
298 "exit_code": result.exit_code,
299 "success": result.exit_code == 0,
300 "truncated": result.truncated
301 }))
302 }
303}
304
305pub struct CommandResult {
307 pub stdout: String,
308 pub stderr: String,
309 pub exit_code: i32,
310 pub truncated: bool,
311}
312
313pub async fn execute_shell_command(
315 command: &str,
316 args: &[String],
317 timeout_secs: u64,
318 working_dir: Option<&Path>,
319) -> Result<CommandResult, ToolError> {
320 let timeout_duration = Duration::from_secs(timeout_secs);
321
322 let mut cmd = tokio::process::Command::new(command);
323 cmd.args(args);
324
325 if let Some(dir) = working_dir {
327 cmd.current_dir(dir);
328 }
329
330 cmd.env_remove("AWS_SECRET_ACCESS_KEY");
332 cmd.env_remove("AZURE_CLIENT_SECRET");
333 cmd.env_remove("GCP_SERVICE_ACCOUNT_KEY");
334
335 let output = timeout(timeout_duration, cmd.output())
337 .await
338 .map_err(|_| ToolError::Message(format!("命令执行超时({timeout_secs} 秒)")))?
339 .map_err(|e| ToolError::Message(format!("命令执行失败:{e}")))?;
340
341 let exit_code = output.status.code().unwrap_or(-1);
342
343 let (stdout, stdout_truncated) = truncate_output(&output.stdout);
345 let (stderr, stderr_truncated) = truncate_output(&output.stderr);
346
347 Ok(CommandResult {
348 stdout,
349 stderr,
350 exit_code,
351 truncated: stdout_truncated || stderr_truncated,
352 })
353}
354
355pub fn truncate_output(output: &[u8]) -> (String, bool) {
357 if output.len() > MAX_OUTPUT_BYTES {
358 let truncated = String::from_utf8_lossy(&output[..MAX_OUTPUT_BYTES]);
359 (format!("{truncated}... [截断]"), true)
360 } else {
361 (String::from_utf8_lossy(output).to_string(), false)
362 }
363}