Skip to main content

aster/tools/
bash.rs

1//! Bash Tool Implementation
2//!
3//! This module implements the `BashTool` for executing shell commands with:
4//! - Cross-platform support (Windows PowerShell/CMD, macOS, Linux)
5//! - Safety checks for dangerous commands
6//! - Warning pattern detection
7//! - Background task execution
8//! - Configurable timeout
9//! - Output truncation
10//!
11//! Requirements: 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9
12
13use std::process::Stdio;
14use std::sync::Arc;
15use std::time::Duration;
16
17use async_trait::async_trait;
18use regex::Regex;
19use serde::{Deserialize, Serialize};
20use tokio::process::Command;
21use tracing::{debug, warn};
22
23use super::base::{PermissionCheckResult, Tool};
24use super::context::{ToolContext, ToolOptions, ToolResult};
25use super::error::ToolError;
26use super::task::TaskManager;
27
28/// Maximum output length before truncation (128KB)
29pub const MAX_OUTPUT_LENGTH: usize = 128 * 1024;
30
31/// Default timeout for command execution (5 minutes)
32pub const DEFAULT_TIMEOUT_SECS: u64 = 300;
33
34/// Maximum timeout allowed (30 minutes)
35pub const MAX_TIMEOUT_SECS: u64 = 1800;
36
37/// Safety check result for command validation
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct SafetyCheckResult {
40    /// Whether the command is safe to execute
41    pub safe: bool,
42    /// Reason for blocking (if not safe)
43    pub reason: Option<String>,
44    /// Warning message (if potentially dangerous but allowed)
45    pub warning: Option<String>,
46}
47
48impl SafetyCheckResult {
49    /// Create a safe result
50    pub fn safe() -> Self {
51        Self {
52            safe: true,
53            reason: None,
54            warning: None,
55        }
56    }
57
58    /// Create a safe result with a warning
59    pub fn safe_with_warning(warning: impl Into<String>) -> Self {
60        Self {
61            safe: true,
62            reason: None,
63            warning: Some(warning.into()),
64        }
65    }
66
67    /// Create an unsafe result with a reason
68    pub fn unsafe_with_reason(reason: impl Into<String>) -> Self {
69        Self {
70            safe: false,
71            reason: Some(reason.into()),
72            warning: None,
73        }
74    }
75}
76
77/// Sandbox configuration for command execution
78#[derive(Debug, Clone, Default)]
79pub struct SandboxConfig {
80    /// Whether sandbox is enabled
81    pub enabled: bool,
82    /// Allowed directories for file access
83    pub allowed_directories: Vec<String>,
84    /// Environment variables to set
85    pub environment: std::collections::HashMap<String, String>,
86}
87
88/// Bash Tool for executing shell commands
89///
90/// Provides secure shell command execution with:
91/// - Dangerous command blacklist
92/// - Warning pattern detection
93/// - Cross-platform support
94/// - Timeout control
95/// - Output truncation
96///
97/// Requirements: 3.1
98#[derive(Debug)]
99pub struct BashTool {
100    /// Dangerous commands that are blocked
101    dangerous_commands: Vec<String>,
102    /// Warning patterns for potentially dangerous commands
103    warning_patterns: Vec<Regex>,
104    /// Task manager for background execution
105    task_manager: Arc<TaskManager>,
106    /// Sandbox configuration
107    sandbox_config: Option<SandboxConfig>,
108}
109
110impl Default for BashTool {
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116impl BashTool {
117    /// Create a new BashTool with default settings
118    pub fn new() -> Self {
119        Self {
120            dangerous_commands: Self::default_dangerous_commands(),
121            warning_patterns: Self::default_warning_patterns(),
122            task_manager: Arc::new(TaskManager::new()),
123            sandbox_config: None,
124        }
125    }
126
127    /// Create a BashTool with custom task manager
128    pub fn with_task_manager(task_manager: Arc<TaskManager>) -> Self {
129        Self {
130            dangerous_commands: Self::default_dangerous_commands(),
131            warning_patterns: Self::default_warning_patterns(),
132            task_manager,
133            sandbox_config: None,
134        }
135    }
136
137    /// Set sandbox configuration
138    pub fn with_sandbox(mut self, config: SandboxConfig) -> Self {
139        self.sandbox_config = Some(config);
140        self
141    }
142
143    /// Set custom dangerous commands
144    pub fn with_dangerous_commands(mut self, commands: Vec<String>) -> Self {
145        self.dangerous_commands = commands;
146        self
147    }
148
149    /// Add additional dangerous commands
150    pub fn add_dangerous_commands(&mut self, commands: Vec<String>) {
151        self.dangerous_commands.extend(commands);
152    }
153
154    /// Set custom warning patterns
155    pub fn with_warning_patterns(mut self, patterns: Vec<Regex>) -> Self {
156        self.warning_patterns = patterns;
157        self
158    }
159
160    /// Get the task manager
161    pub fn task_manager(&self) -> &Arc<TaskManager> {
162        &self.task_manager
163    }
164
165    /// Default list of dangerous commands that should be blocked
166    fn default_dangerous_commands() -> Vec<String> {
167        vec![
168            // Destructive file operations
169            "rm -rf /".to_string(),
170            "rm -rf /*".to_string(),
171            "rm -rf ~".to_string(),
172            "rm -rf ~/*".to_string(),
173            "rm -rf .".to_string(),
174            "rm -rf ..".to_string(),
175            // Format/partition commands
176            "mkfs".to_string(),
177            "fdisk".to_string(),
178            "dd if=/dev/zero".to_string(),
179            "dd if=/dev/random".to_string(),
180            // Fork bombs
181            ":(){ :|:& };:".to_string(),
182            // System shutdown/reboot
183            "shutdown".to_string(),
184            "reboot".to_string(),
185            "halt".to_string(),
186            "poweroff".to_string(),
187            "init 0".to_string(),
188            "init 6".to_string(),
189            // Dangerous redirects
190            "> /dev/sda".to_string(),
191            "> /dev/hda".to_string(),
192            // Network attacks
193            "nc -l".to_string(),
194            // Privilege escalation attempts
195            "chmod 777 /".to_string(),
196            "chown -R".to_string(),
197        ]
198    }
199
200    /// Default warning patterns for potentially dangerous commands
201    fn default_warning_patterns() -> Vec<Regex> {
202        let patterns = [
203            // Recursive delete
204            r"rm\s+(-[a-zA-Z]*r[a-zA-Z]*|-[a-zA-Z]*f[a-zA-Z]*)+",
205            // Sudo commands
206            r"sudo\s+",
207            // Curl/wget piped to shell
208            r"(curl|wget)\s+.*\|\s*(bash|sh|zsh)",
209            // Chmod with dangerous permissions
210            r"chmod\s+[0-7]*7[0-7]*",
211            // Kill all processes
212            r"kill\s+-9\s+-1",
213            r"killall",
214            // Environment variable manipulation
215            r"export\s+PATH=",
216            r"export\s+LD_PRELOAD",
217            // Git force push
218            r"git\s+push\s+.*--force",
219            r"git\s+push\s+-f",
220            // Database drop
221            r"DROP\s+DATABASE",
222            r"DROP\s+TABLE",
223            // Docker dangerous operations
224            r"docker\s+rm\s+-f",
225            r"docker\s+system\s+prune",
226        ];
227
228        patterns.iter().filter_map(|p| Regex::new(p).ok()).collect()
229    }
230}
231
232// =============================================================================
233// Safety Check Implementation (Requirements: 3.2, 3.3)
234// =============================================================================
235
236impl BashTool {
237    /// Check if a command is safe to execute
238    ///
239    /// This method checks the command against:
240    /// 1. Dangerous command blacklist (blocks execution)
241    /// 2. Warning patterns (allows with warning)
242    ///
243    /// Requirements: 3.2, 3.3
244    pub fn check_command_safety(&self, command: &str) -> SafetyCheckResult {
245        let command_lower = command.to_lowercase();
246        let command_trimmed = command.trim();
247
248        // Check against dangerous command blacklist
249        for dangerous in &self.dangerous_commands {
250            let dangerous_lower = dangerous.to_lowercase();
251            if command_lower.contains(&dangerous_lower) {
252                return SafetyCheckResult::unsafe_with_reason(format!(
253                    "Command contains dangerous pattern: '{}'",
254                    dangerous
255                ));
256            }
257        }
258
259        // Check for fork bomb patterns
260        if self.is_fork_bomb(command_trimmed) {
261            return SafetyCheckResult::unsafe_with_reason("Command appears to be a fork bomb");
262        }
263
264        // Check for dangerous redirects to device files
265        if self.has_dangerous_redirect(command_trimmed) {
266            return SafetyCheckResult::unsafe_with_reason(
267                "Command contains dangerous redirect to device file",
268            );
269        }
270
271        // Check against warning patterns
272        let mut warnings = Vec::new();
273        for pattern in &self.warning_patterns {
274            if pattern.is_match(command_trimmed) {
275                warnings.push(format!("Matches warning pattern: {}", pattern.as_str()));
276            }
277        }
278
279        if !warnings.is_empty() {
280            return SafetyCheckResult::safe_with_warning(warnings.join("; "));
281        }
282
283        SafetyCheckResult::safe()
284    }
285
286    /// Check if command appears to be a fork bomb
287    fn is_fork_bomb(&self, command: &str) -> bool {
288        // Common fork bomb patterns
289        let fork_bomb_patterns = [
290            r":\(\)\s*\{\s*:\s*\|\s*:\s*&\s*\}\s*;\s*:", // :(){ :|:& };:
291            r"\$\{:\|:\&\}",                             // ${:|:&}
292            r"\.\/\s*\$0\s*&",                           // ./$0 &
293        ];
294
295        for pattern in fork_bomb_patterns {
296            if let Ok(re) = Regex::new(pattern) {
297                if re.is_match(command) {
298                    return true;
299                }
300            }
301        }
302
303        false
304    }
305
306    /// Check for dangerous redirects to device files
307    fn has_dangerous_redirect(&self, command: &str) -> bool {
308        let dangerous_devices = [
309            "/dev/sda",
310            "/dev/sdb",
311            "/dev/sdc",
312            "/dev/hda",
313            "/dev/hdb",
314            "/dev/nvme",
315            "/dev/mem",
316            "/dev/kmem",
317        ];
318
319        for device in dangerous_devices {
320            if command.contains(&format!("> {}", device))
321                || command.contains(&format!(">{}", device))
322                || command.contains(&format!(">> {}", device))
323                || command.contains(&format!(">>{}", device))
324            {
325                return true;
326            }
327        }
328
329        false
330    }
331
332    /// Check if a command is in the dangerous commands list
333    pub fn is_dangerous_command(&self, command: &str) -> bool {
334        !self.check_command_safety(command).safe
335    }
336
337    /// Check if a command triggers any warning patterns
338    pub fn has_warning(&self, command: &str) -> bool {
339        self.check_command_safety(command).warning.is_some()
340    }
341}
342
343// =============================================================================
344// Foreground Execution Implementation (Requirements: 3.1, 3.5)
345// =============================================================================
346
347impl BashTool {
348    /// Execute a command in the foreground with timeout
349    ///
350    /// Supports cross-platform execution:
351    /// - Windows: Uses PowerShell or CMD
352    /// - macOS/Linux: Uses sh -c
353    ///
354    /// Requirements: 3.1, 3.5
355    pub async fn execute_foreground(
356        &self,
357        command: &str,
358        timeout: Duration,
359        context: &ToolContext,
360    ) -> Result<ToolResult, ToolError> {
361        // Check for cancellation
362        if context.is_cancelled() {
363            return Err(ToolError::Cancelled);
364        }
365
366        // Enforce maximum timeout
367        let effective_timeout = if timeout.as_secs() > MAX_TIMEOUT_SECS {
368            warn!(
369                "Requested timeout {:?} exceeds maximum, using {} seconds",
370                timeout, MAX_TIMEOUT_SECS
371            );
372            Duration::from_secs(MAX_TIMEOUT_SECS)
373        } else {
374            timeout
375        };
376
377        debug!(
378            "Executing command with timeout {:?}: {}",
379            effective_timeout, command
380        );
381
382        // Build the command based on platform
383        let mut cmd = self.build_platform_command(command, context);
384
385        // Execute with timeout
386        let result = tokio::time::timeout(effective_timeout, async {
387            cmd.stdout(Stdio::piped())
388                .stderr(Stdio::piped())
389                .stdin(Stdio::null())
390                .kill_on_drop(true)
391                .output()
392                .await
393        })
394        .await;
395
396        match result {
397            Ok(Ok(output)) => {
398                let stdout = String::from_utf8_lossy(&output.stdout).to_string();
399                let stderr = String::from_utf8_lossy(&output.stderr).to_string();
400                let exit_code = output.status.code().unwrap_or(-1);
401
402                debug!(
403                    "Command completed with exit code {}, stdout: {} bytes, stderr: {} bytes",
404                    exit_code,
405                    stdout.len(),
406                    stderr.len()
407                );
408
409                // Combine and truncate output
410                let combined_output = self.format_output(&stdout, &stderr, exit_code);
411                let truncated_output = self.truncate_output(&combined_output);
412
413                if output.status.success() {
414                    Ok(ToolResult::success(truncated_output)
415                        .with_metadata("exit_code", serde_json::json!(exit_code))
416                        .with_metadata("stdout_length", serde_json::json!(stdout.len()))
417                        .with_metadata("stderr_length", serde_json::json!(stderr.len())))
418                } else {
419                    Ok(ToolResult::error(truncated_output)
420                        .with_metadata("exit_code", serde_json::json!(exit_code))
421                        .with_metadata("stdout_length", serde_json::json!(stdout.len()))
422                        .with_metadata("stderr_length", serde_json::json!(stderr.len())))
423                }
424            }
425            Ok(Err(e)) => {
426                warn!("Command execution failed: {}", e);
427                Err(ToolError::execution_failed(format!(
428                    "Failed to execute command: {}",
429                    e
430                )))
431            }
432            Err(_) => {
433                warn!("Command timed out after {:?}", effective_timeout);
434                Err(ToolError::timeout(effective_timeout))
435            }
436        }
437    }
438
439    /// Build a platform-specific command
440    fn build_platform_command(&self, command: &str, context: &ToolContext) -> Command {
441        let mut cmd = if cfg!(target_os = "windows") {
442            // Try PowerShell first, fall back to CMD
443            let mut cmd = Command::new("powershell");
444            cmd.args(["-NoProfile", "-NonInteractive", "-Command", command]);
445            cmd
446        } else {
447            // Unix-like systems (macOS, Linux)
448            let mut cmd = Command::new("sh");
449            cmd.args(["-c", command]);
450            cmd
451        };
452
453        // Set working directory
454        cmd.current_dir(&context.working_directory);
455
456        // Set environment variables
457        cmd.env("ASTER_TERMINAL", "1");
458        for (key, value) in &context.environment {
459            cmd.env(key, value);
460        }
461
462        // Apply sandbox environment if configured
463        if let Some(ref sandbox) = self.sandbox_config {
464            for (key, value) in &sandbox.environment {
465                cmd.env(key, value);
466            }
467        }
468
469        cmd
470    }
471
472    /// Format command output combining stdout and stderr
473    fn format_output(&self, stdout: &str, stderr: &str, exit_code: i32) -> String {
474        let mut output = String::new();
475
476        if !stdout.is_empty() {
477            output.push_str(stdout);
478        }
479
480        if !stderr.is_empty() {
481            if !output.is_empty() && !output.ends_with('\n') {
482                output.push('\n');
483            }
484            if !stdout.is_empty() {
485                output.push_str("--- stderr ---\n");
486            }
487            output.push_str(stderr);
488        }
489
490        if exit_code != 0 && output.is_empty() {
491            output = format!("Command exited with code {}", exit_code);
492        }
493
494        output
495    }
496}
497
498// =============================================================================
499// Background Execution Implementation (Requirements: 3.4)
500// =============================================================================
501
502impl BashTool {
503    /// Execute a command in the background
504    ///
505    /// Returns a task_id that can be used to query status and output.
506    /// The actual task management is delegated to TaskManager.
507    ///
508    /// Requirements: 3.4
509    pub async fn execute_background(
510        &self,
511        command: &str,
512        context: &ToolContext,
513    ) -> Result<ToolResult, ToolError> {
514        // Check for cancellation
515        if context.is_cancelled() {
516            return Err(ToolError::Cancelled);
517        }
518
519        // Delegate to task manager
520        let task_id = self.task_manager.start(command, context).await?;
521
522        Ok(
523            ToolResult::success(format!("Background task started with ID: {}", task_id))
524                .with_metadata("task_id", serde_json::json!(task_id))
525                .with_metadata("background", serde_json::json!(true)),
526        )
527    }
528}
529
530// =============================================================================
531// Tool Trait Implementation (Requirements: 3.6, 3.7, 3.8)
532// =============================================================================
533
534#[async_trait]
535impl Tool for BashTool {
536    /// Returns the tool name
537    fn name(&self) -> &str {
538        "bash"
539    }
540
541    /// Returns the tool description
542    fn description(&self) -> &str {
543        "Execute shell commands with safety checks and timeout control. \
544         Supports both foreground and background execution. \
545         Use 'background: true' parameter for long-running commands."
546    }
547
548    /// Returns the JSON Schema for input parameters
549    fn input_schema(&self) -> serde_json::Value {
550        serde_json::json!({
551            "type": "object",
552            "properties": {
553                "command": {
554                    "type": "string",
555                    "description": "The shell command to execute"
556                },
557                "timeout": {
558                    "type": "integer",
559                    "description": "Timeout in seconds (default: 300, max: 1800)",
560                    "default": 300,
561                    "minimum": 1,
562                    "maximum": 1800
563                },
564                "background": {
565                    "type": "boolean",
566                    "description": "Run command in background and return task_id",
567                    "default": false
568                }
569            },
570            "required": ["command"]
571        })
572    }
573
574    /// Execute the bash command
575    ///
576    /// Requirements: 3.6, 3.7
577    async fn execute(
578        &self,
579        params: serde_json::Value,
580        context: &ToolContext,
581    ) -> Result<ToolResult, ToolError> {
582        // Extract command parameter
583        let command = params
584            .get("command")
585            .and_then(|v| v.as_str())
586            .ok_or_else(|| ToolError::invalid_params("Missing required parameter: command"))?;
587
588        // Extract timeout parameter (default: 300 seconds)
589        let timeout_secs = params
590            .get("timeout")
591            .and_then(|v| v.as_u64())
592            .unwrap_or(DEFAULT_TIMEOUT_SECS);
593        let timeout = Duration::from_secs(timeout_secs);
594
595        // Extract background parameter (default: false)
596        let background = params
597            .get("background")
598            .and_then(|v| v.as_bool())
599            .unwrap_or(false);
600
601        // Execute based on mode
602        if background {
603            self.execute_background(command, context).await
604        } else {
605            self.execute_foreground(command, timeout, context).await
606        }
607    }
608
609    /// Check permissions before execution
610    ///
611    /// Performs safety check and returns appropriate permission result.
612    ///
613    /// Requirements: 3.8
614    async fn check_permissions(
615        &self,
616        params: &serde_json::Value,
617        _context: &ToolContext,
618    ) -> PermissionCheckResult {
619        // Extract command for safety check
620        let command = match params.get("command").and_then(|v| v.as_str()) {
621            Some(cmd) => cmd,
622            None => return PermissionCheckResult::deny("Missing command parameter"),
623        };
624
625        // Perform safety check
626        let safety_result = self.check_command_safety(command);
627
628        if !safety_result.safe {
629            let reason = safety_result
630                .reason
631                .unwrap_or_else(|| "Command blocked by safety check".to_string());
632            return PermissionCheckResult::deny(reason);
633        }
634
635        // If there's a warning, ask for confirmation
636        if let Some(warning) = safety_result.warning {
637            return PermissionCheckResult::ask(format!(
638                "Command may be dangerous: {}. Do you want to proceed?",
639                warning
640            ));
641        }
642
643        PermissionCheckResult::allow()
644    }
645
646    /// Get tool options
647    fn options(&self) -> ToolOptions {
648        ToolOptions::new()
649            .with_max_retries(0) // Don't retry shell commands by default
650            .with_base_timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
651            .with_dynamic_timeout(false)
652    }
653}
654
655// =============================================================================
656// Output Truncation Implementation (Requirements: 3.9)
657// =============================================================================
658
659impl BashTool {
660    /// Truncate output if it exceeds MAX_OUTPUT_LENGTH
661    ///
662    /// Adds a truncation indicator when output is truncated.
663    ///
664    /// Requirements: 3.9
665    pub fn truncate_output(&self, output: &str) -> String {
666        if output.len() <= MAX_OUTPUT_LENGTH {
667            return output.to_string();
668        }
669
670        // Calculate how much to keep
671        let truncation_message = format!(
672            "\n\n... [Output truncated. Showing first {} of {} bytes]",
673            MAX_OUTPUT_LENGTH,
674            output.len()
675        );
676        let keep_length = MAX_OUTPUT_LENGTH - truncation_message.len();
677
678        // Find a valid UTF-8 char boundary at or before keep_length
679        let mut safe_length = keep_length;
680        while safe_length > 0 && !output.is_char_boundary(safe_length) {
681            safe_length -= 1;
682        }
683
684        // Try to truncate at a line boundary
685        let truncated = output.get(..safe_length).unwrap_or(output);
686        let last_newline = truncated.rfind('\n').unwrap_or(truncated.len());
687
688        format!(
689            "{}{}",
690            output.get(..last_newline).unwrap_or(output),
691            truncation_message
692        )
693    }
694
695    /// Check if output would be truncated
696    pub fn would_truncate(&self, output: &str) -> bool {
697        output.len() > MAX_OUTPUT_LENGTH
698    }
699}
700
701// =============================================================================
702// Unit Tests
703// =============================================================================
704
705#[cfg(test)]
706mod tests {
707    use super::*;
708    use std::path::PathBuf;
709
710    fn create_test_context() -> ToolContext {
711        ToolContext::new(PathBuf::from("/tmp"))
712            .with_session_id("test-session")
713            .with_user("test-user")
714    }
715
716    // Safety Check Tests
717
718    #[test]
719    fn test_safe_command() {
720        let tool = BashTool::new();
721        let result = tool.check_command_safety("echo 'hello world'");
722        assert!(result.safe);
723        assert!(result.reason.is_none());
724        assert!(result.warning.is_none());
725    }
726
727    #[test]
728    fn test_dangerous_rm_rf_root() {
729        let tool = BashTool::new();
730        let result = tool.check_command_safety("rm -rf /");
731        assert!(!result.safe);
732        assert!(result.reason.is_some());
733    }
734
735    #[test]
736    fn test_dangerous_rm_rf_home() {
737        let tool = BashTool::new();
738        let result = tool.check_command_safety("rm -rf ~");
739        assert!(!result.safe);
740        assert!(result.reason.is_some());
741    }
742
743    #[test]
744    fn test_dangerous_fork_bomb() {
745        let tool = BashTool::new();
746        let result = tool.check_command_safety(":(){ :|:& };:");
747        assert!(!result.safe);
748    }
749
750    #[test]
751    fn test_dangerous_device_redirect() {
752        let tool = BashTool::new();
753        let result = tool.check_command_safety("echo 'data' > /dev/sda");
754        assert!(!result.safe);
755    }
756
757    #[test]
758    fn test_warning_sudo() {
759        let tool = BashTool::new();
760        let result = tool.check_command_safety("sudo apt-get update");
761        assert!(result.safe);
762        assert!(result.warning.is_some());
763    }
764
765    #[test]
766    fn test_warning_curl_pipe_bash() {
767        let tool = BashTool::new();
768        let result = tool.check_command_safety("curl https://example.com/script.sh | bash");
769        assert!(result.safe);
770        assert!(result.warning.is_some());
771    }
772
773    #[test]
774    fn test_warning_recursive_rm() {
775        let tool = BashTool::new();
776        // Use rm -r without -f to trigger warning pattern but not blacklist
777        let result = tool.check_command_safety("rm -r ./temp_dir");
778        assert!(result.safe);
779        assert!(result.warning.is_some());
780    }
781
782    #[test]
783    fn test_is_dangerous_command() {
784        let tool = BashTool::new();
785        assert!(tool.is_dangerous_command("rm -rf /"));
786        assert!(!tool.is_dangerous_command("ls -la"));
787    }
788
789    #[test]
790    fn test_has_warning() {
791        let tool = BashTool::new();
792        assert!(tool.has_warning("sudo ls"));
793        assert!(!tool.has_warning("ls -la"));
794    }
795
796    // Output Truncation Tests
797
798    #[test]
799    fn test_truncate_short_output() {
800        let tool = BashTool::new();
801        let output = "Hello, World!";
802        let result = tool.truncate_output(output);
803        assert_eq!(result, output);
804    }
805
806    #[test]
807    fn test_truncate_long_output() {
808        let tool = BashTool::new();
809        let output = "x".repeat(MAX_OUTPUT_LENGTH + 1000);
810        let result = tool.truncate_output(&output);
811        assert!(result.len() <= MAX_OUTPUT_LENGTH + 100); // Allow for truncation message
812        assert!(result.contains("[Output truncated"));
813    }
814
815    #[test]
816    fn test_would_truncate() {
817        let tool = BashTool::new();
818        assert!(!tool.would_truncate("short"));
819        assert!(tool.would_truncate(&"x".repeat(MAX_OUTPUT_LENGTH + 1)));
820    }
821
822    // Tool Trait Tests
823
824    #[test]
825    fn test_tool_name() {
826        let tool = BashTool::new();
827        assert_eq!(tool.name(), "bash");
828    }
829
830    #[test]
831    fn test_tool_description() {
832        let tool = BashTool::new();
833        assert!(!tool.description().is_empty());
834        assert!(tool.description().contains("shell"));
835    }
836
837    #[test]
838    fn test_tool_input_schema() {
839        let tool = BashTool::new();
840        let schema = tool.input_schema();
841        assert_eq!(schema["type"], "object");
842        assert!(schema["properties"]["command"].is_object());
843        assert!(schema["properties"]["timeout"].is_object());
844        assert!(schema["properties"]["background"].is_object());
845    }
846
847    #[test]
848    fn test_tool_options() {
849        let tool = BashTool::new();
850        let options = tool.options();
851        assert_eq!(options.max_retries, 0);
852        assert_eq!(
853            options.base_timeout,
854            Duration::from_secs(DEFAULT_TIMEOUT_SECS)
855        );
856    }
857
858    // Permission Check Tests
859
860    #[tokio::test]
861    async fn test_check_permissions_safe_command() {
862        let tool = BashTool::new();
863        let context = create_test_context();
864        let params = serde_json::json!({"command": "echo 'hello'"});
865
866        let result = tool.check_permissions(&params, &context).await;
867        assert!(result.is_allowed());
868    }
869
870    #[tokio::test]
871    async fn test_check_permissions_dangerous_command() {
872        let tool = BashTool::new();
873        let context = create_test_context();
874        let params = serde_json::json!({"command": "rm -rf /"});
875
876        let result = tool.check_permissions(&params, &context).await;
877        assert!(result.is_denied());
878    }
879
880    #[tokio::test]
881    async fn test_check_permissions_warning_command() {
882        let tool = BashTool::new();
883        let context = create_test_context();
884        let params = serde_json::json!({"command": "sudo ls"});
885
886        let result = tool.check_permissions(&params, &context).await;
887        assert!(result.requires_confirmation());
888    }
889
890    #[tokio::test]
891    async fn test_check_permissions_missing_command() {
892        let tool = BashTool::new();
893        let context = create_test_context();
894        let params = serde_json::json!({});
895
896        let result = tool.check_permissions(&params, &context).await;
897        assert!(result.is_denied());
898    }
899
900    // Execution Tests
901
902    #[tokio::test]
903    async fn test_execute_simple_command() {
904        let tool = BashTool::new();
905        let context = create_test_context();
906        let params = serde_json::json!({
907            "command": "echo 'hello world'"
908        });
909
910        let result = tool.execute(params, &context).await;
911        assert!(result.is_ok());
912        let tool_result = result.unwrap();
913        assert!(tool_result.is_success());
914        assert!(tool_result.output.unwrap().contains("hello world"));
915    }
916
917    #[tokio::test]
918    async fn test_execute_with_exit_code() {
919        let tool = BashTool::new();
920        let context = create_test_context();
921        let params = serde_json::json!({
922            "command": "exit 1"
923        });
924
925        let result = tool.execute(params, &context).await;
926        assert!(result.is_ok());
927        let tool_result = result.unwrap();
928        assert!(tool_result.is_error());
929        assert_eq!(
930            tool_result.metadata.get("exit_code"),
931            Some(&serde_json::json!(1))
932        );
933    }
934
935    #[tokio::test]
936    async fn test_execute_missing_command() {
937        let tool = BashTool::new();
938        let context = create_test_context();
939        let params = serde_json::json!({});
940
941        let result = tool.execute(params, &context).await;
942        assert!(result.is_err());
943        assert!(matches!(result.unwrap_err(), ToolError::InvalidParams(_)));
944    }
945
946    #[tokio::test]
947    async fn test_execute_with_timeout() {
948        let tool = BashTool::new();
949        let context = create_test_context();
950
951        // Use a very short timeout
952        let params = serde_json::json!({
953            "command": if cfg!(target_os = "windows") { "timeout /t 5" } else { "sleep 5" },
954            "timeout": 1
955        });
956
957        let result = tool.execute(params, &context).await;
958        assert!(result.is_err());
959        assert!(matches!(result.unwrap_err(), ToolError::Timeout(_)));
960    }
961
962    #[tokio::test]
963    async fn test_execute_background() {
964        use tempfile::TempDir;
965
966        let temp_dir = TempDir::new().unwrap();
967        let task_manager =
968            Arc::new(TaskManager::new().with_output_directory(temp_dir.path().to_path_buf()));
969        let tool = BashTool::with_task_manager(task_manager.clone());
970        let context = create_test_context();
971        let params = serde_json::json!({
972            "command": "echo 'hello'",
973            "background": true
974        });
975
976        let result = tool.execute(params, &context).await;
977        // Background execution is now implemented
978        assert!(result.is_ok());
979        let tool_result = result.unwrap();
980        assert!(tool_result.is_success());
981        assert!(tool_result.metadata.contains_key("task_id"));
982        assert!(tool_result.metadata.contains_key("background"));
983
984        // Clean up
985        let _ = task_manager.kill_all().await;
986    }
987
988    // Builder Tests
989
990    #[test]
991    fn test_builder_with_task_manager() {
992        let task_manager = Arc::new(TaskManager::new());
993        let tool = BashTool::with_task_manager(task_manager.clone());
994        assert!(Arc::ptr_eq(&tool.task_manager, &task_manager));
995    }
996
997    #[test]
998    fn test_builder_with_sandbox() {
999        let sandbox = SandboxConfig {
1000            enabled: true,
1001            allowed_directories: vec!["/tmp".to_string()],
1002            environment: std::collections::HashMap::new(),
1003        };
1004        let tool = BashTool::new().with_sandbox(sandbox);
1005        assert!(tool.sandbox_config.is_some());
1006        assert!(tool.sandbox_config.unwrap().enabled);
1007    }
1008
1009    #[test]
1010    fn test_builder_with_dangerous_commands() {
1011        let commands = vec!["custom_dangerous".to_string()];
1012        let tool = BashTool::new().with_dangerous_commands(commands);
1013        assert!(tool.is_dangerous_command("custom_dangerous"));
1014    }
1015
1016    #[test]
1017    fn test_add_dangerous_commands() {
1018        let mut tool = BashTool::new();
1019        tool.add_dangerous_commands(vec!["new_dangerous".to_string()]);
1020        assert!(tool.is_dangerous_command("new_dangerous"));
1021    }
1022
1023    // Format Output Tests
1024
1025    #[test]
1026    fn test_format_output_stdout_only() {
1027        let tool = BashTool::new();
1028        let result = tool.format_output("stdout content", "", 0);
1029        assert_eq!(result, "stdout content");
1030    }
1031
1032    #[test]
1033    fn test_format_output_stderr_only() {
1034        let tool = BashTool::new();
1035        let result = tool.format_output("", "stderr content", 1);
1036        assert_eq!(result, "stderr content");
1037    }
1038
1039    #[test]
1040    fn test_format_output_both() {
1041        let tool = BashTool::new();
1042        let result = tool.format_output("stdout", "stderr", 0);
1043        assert!(result.contains("stdout"));
1044        assert!(result.contains("stderr"));
1045    }
1046
1047    #[test]
1048    fn test_format_output_empty_with_error() {
1049        let tool = BashTool::new();
1050        let result = tool.format_output("", "", 1);
1051        assert!(result.contains("exited with code 1"));
1052    }
1053
1054    // Safety Check Result Tests
1055
1056    #[test]
1057    fn test_safety_check_result_safe() {
1058        let result = SafetyCheckResult::safe();
1059        assert!(result.safe);
1060        assert!(result.reason.is_none());
1061        assert!(result.warning.is_none());
1062    }
1063
1064    #[test]
1065    fn test_safety_check_result_safe_with_warning() {
1066        let result = SafetyCheckResult::safe_with_warning("Be careful");
1067        assert!(result.safe);
1068        assert!(result.reason.is_none());
1069        assert_eq!(result.warning, Some("Be careful".to_string()));
1070    }
1071
1072    #[test]
1073    fn test_safety_check_result_unsafe() {
1074        let result = SafetyCheckResult::unsafe_with_reason("Dangerous");
1075        assert!(!result.safe);
1076        assert_eq!(result.reason, Some("Dangerous".to_string()));
1077        assert!(result.warning.is_none());
1078    }
1079}