Skip to main content

agent_sdk/primitive_tools/
bash.rs

1use crate::reminders::{append_reminder, builtin};
2use crate::{Environment, PrimitiveToolName, Tool, ToolContext, ToolResult, ToolTier};
3use anyhow::{Context, Result};
4use serde::Deserialize;
5use serde_json::{Value, json};
6use std::fmt::Write;
7use std::sync::Arc;
8
9use super::PrimitiveToolContext;
10
11/// Tool for executing shell commands
12pub struct BashTool<E: Environment> {
13    ctx: PrimitiveToolContext<E>,
14}
15
16impl<E: Environment> BashTool<E> {
17    #[must_use]
18    pub const fn new(environment: Arc<E>, capabilities: crate::AgentCapabilities) -> Self {
19        Self {
20            ctx: PrimitiveToolContext::new(environment, capabilities),
21        }
22    }
23}
24
25#[derive(Debug, Deserialize)]
26struct BashInput {
27    /// Command to execute
28    command: String,
29    /// Timeout in milliseconds (default: 120000 = 2 minutes)
30    #[serde(default = "default_timeout")]
31    timeout_ms: u64,
32}
33
34const fn default_timeout() -> u64 {
35    120_000 // 2 minutes
36}
37
38impl<E: Environment + 'static> Tool<()> for BashTool<E> {
39    type Name = PrimitiveToolName;
40
41    fn name(&self) -> PrimitiveToolName {
42        PrimitiveToolName::Bash
43    }
44
45    fn display_name(&self) -> &'static str {
46        "Run Command"
47    }
48
49    fn description(&self) -> &'static str {
50        "Execute a shell command. Use for git, npm, cargo, and other CLI tools. Returns stdout, stderr, and exit code."
51    }
52
53    fn tier(&self) -> ToolTier {
54        ToolTier::Confirm
55    }
56
57    fn input_schema(&self) -> Value {
58        json!({
59            "type": "object",
60            "properties": {
61                "command": {
62                    "type": "string",
63                    "description": "The shell command to execute"
64                },
65                "timeout_ms": {
66                    "type": "integer",
67                    "description": "Timeout in milliseconds. Default: 120000 (2 minutes)"
68                }
69            },
70            "required": ["command"]
71        })
72    }
73
74    async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
75        let input: BashInput =
76            serde_json::from_value(input).context("Invalid input for bash tool")?;
77
78        // Check exec capability
79        if !self.ctx.capabilities.exec {
80            return Ok(ToolResult::error(
81                "Permission denied: command execution is disabled",
82            ));
83        }
84
85        // Check if command is allowed
86        if !self.ctx.capabilities.can_exec(&input.command) {
87            return Ok(ToolResult::error(format!(
88                "Permission denied: command '{}' is not allowed",
89                truncate_command(&input.command, 100)
90            )));
91        }
92
93        // Validate timeout
94        let timeout_ms = input.timeout_ms.min(600_000); // Max 10 minutes
95
96        // Execute command
97        let result = self
98            .ctx
99            .environment
100            .exec(&input.command, Some(timeout_ms))
101            .await
102            .context("Failed to execute command")?;
103
104        // Format output
105        let mut output = String::new();
106
107        if !result.stdout.is_empty() {
108            output.push_str(&result.stdout);
109        }
110
111        if !result.stderr.is_empty() {
112            if !output.is_empty() {
113                output.push_str("\n\n--- stderr ---\n");
114            }
115            output.push_str(&result.stderr);
116        }
117
118        if output.is_empty() {
119            output = "(no output)".to_string();
120        }
121
122        // Truncate if too long
123        let max_output_len = 30_000;
124        if output.len() > max_output_len {
125            output = format!(
126                "{}...\n\n(output truncated, {} total characters)",
127                &output[..max_output_len],
128                output.len()
129            );
130        }
131
132        // Include exit code in output
133        let _ = write!(output, "\n\nExit code: {}", result.exit_code);
134
135        let mut tool_result = if result.success() {
136            ToolResult::success(output)
137        } else {
138            ToolResult::error(output)
139        };
140
141        // Add verification reminder
142        append_reminder(&mut tool_result, builtin::BASH_VERIFICATION_REMINDER);
143
144        Ok(tool_result)
145    }
146}
147
148fn truncate_command(s: &str, max_len: usize) -> String {
149    if s.len() <= max_len {
150        s.to_string()
151    } else {
152        format!("{}...", &s[..max_len])
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::AgentCapabilities;
160    use crate::environment::ExecResult;
161    use async_trait::async_trait;
162    use std::collections::HashMap;
163    use std::sync::RwLock;
164
165    // Mock environment for testing bash execution
166    struct MockBashEnvironment {
167        root: String,
168        // Map of command to (stdout, stderr, exit_code)
169        commands: RwLock<HashMap<String, (String, String, i32)>>,
170    }
171
172    impl MockBashEnvironment {
173        fn new() -> Self {
174            Self {
175                root: "/workspace".to_string(),
176                commands: RwLock::new(HashMap::new()),
177            }
178        }
179
180        fn add_command(&self, cmd: &str, stdout: &str, stderr: &str, exit_code: i32) {
181            self.commands.write().unwrap().insert(
182                cmd.to_string(),
183                (stdout.to_string(), stderr.to_string(), exit_code),
184            );
185        }
186    }
187
188    #[async_trait]
189    impl crate::Environment for MockBashEnvironment {
190        async fn read_file(&self, _path: &str) -> Result<String> {
191            Ok(String::new())
192        }
193
194        async fn read_file_bytes(&self, _path: &str) -> Result<Vec<u8>> {
195            Ok(vec![])
196        }
197
198        async fn write_file(&self, _path: &str, _content: &str) -> Result<()> {
199            Ok(())
200        }
201
202        async fn write_file_bytes(&self, _path: &str, _content: &[u8]) -> Result<()> {
203            Ok(())
204        }
205
206        async fn list_dir(&self, _path: &str) -> Result<Vec<crate::environment::FileEntry>> {
207            Ok(vec![])
208        }
209
210        async fn exists(&self, _path: &str) -> Result<bool> {
211            Ok(false)
212        }
213
214        async fn is_dir(&self, _path: &str) -> Result<bool> {
215            Ok(false)
216        }
217
218        async fn is_file(&self, _path: &str) -> Result<bool> {
219            Ok(false)
220        }
221
222        async fn create_dir(&self, _path: &str) -> Result<()> {
223            Ok(())
224        }
225
226        async fn delete_file(&self, _path: &str) -> Result<()> {
227            Ok(())
228        }
229
230        async fn delete_dir(&self, _path: &str, _recursive: bool) -> Result<()> {
231            Ok(())
232        }
233
234        async fn grep(
235            &self,
236            _pattern: &str,
237            _path: &str,
238            _recursive: bool,
239        ) -> Result<Vec<crate::environment::GrepMatch>> {
240            Ok(vec![])
241        }
242
243        async fn glob(&self, _pattern: &str) -> Result<Vec<String>> {
244            Ok(vec![])
245        }
246
247        async fn exec(&self, command: &str, _timeout_ms: Option<u64>) -> Result<ExecResult> {
248            let commands = self.commands.read().unwrap();
249            if let Some((stdout, stderr, exit_code)) = commands.get(command) {
250                Ok(ExecResult {
251                    stdout: stdout.clone(),
252                    stderr: stderr.clone(),
253                    exit_code: *exit_code,
254                })
255            } else {
256                // Default: command not found
257                Ok(ExecResult {
258                    stdout: String::new(),
259                    stderr: format!("command not found: {command}"),
260                    exit_code: 127,
261                })
262            }
263        }
264
265        fn root(&self) -> &str {
266            &self.root
267        }
268    }
269
270    fn create_test_tool(
271        env: Arc<MockBashEnvironment>,
272        capabilities: AgentCapabilities,
273    ) -> BashTool<MockBashEnvironment> {
274        BashTool::new(env, capabilities)
275    }
276
277    fn tool_ctx() -> ToolContext<()> {
278        ToolContext::new(())
279    }
280
281    // ===================
282    // Unit Tests
283    // ===================
284
285    #[tokio::test]
286    async fn test_bash_simple_command() -> anyhow::Result<()> {
287        let env = Arc::new(MockBashEnvironment::new());
288        env.add_command("echo hello", "hello\n", "", 0);
289
290        let tool = create_test_tool(env, AgentCapabilities::full_access());
291        let result = tool
292            .execute(&tool_ctx(), json!({"command": "echo hello"}))
293            .await?;
294
295        assert!(result.success);
296        assert!(result.output.contains("hello"));
297        assert!(result.output.contains("Exit code: 0"));
298        Ok(())
299    }
300
301    #[tokio::test]
302    async fn test_bash_command_with_stderr() -> anyhow::Result<()> {
303        let env = Arc::new(MockBashEnvironment::new());
304        env.add_command("cmd", "stdout output", "stderr output", 0);
305
306        let tool = create_test_tool(env, AgentCapabilities::full_access());
307        let result = tool.execute(&tool_ctx(), json!({"command": "cmd"})).await?;
308
309        assert!(result.success);
310        assert!(result.output.contains("stdout output"));
311        assert!(result.output.contains("stderr output"));
312        Ok(())
313    }
314
315    #[tokio::test]
316    async fn test_bash_command_nonzero_exit() -> anyhow::Result<()> {
317        let env = Arc::new(MockBashEnvironment::new());
318        env.add_command("failing_cmd", "", "error occurred", 1);
319
320        let tool = create_test_tool(env, AgentCapabilities::full_access());
321        let result = tool
322            .execute(&tool_ctx(), json!({"command": "failing_cmd"}))
323            .await?;
324
325        assert!(!result.success);
326        assert!(result.output.contains("Exit code: 1"));
327        Ok(())
328    }
329
330    #[tokio::test]
331    async fn test_bash_command_not_found() -> anyhow::Result<()> {
332        let env = Arc::new(MockBashEnvironment::new());
333
334        let tool = create_test_tool(env, AgentCapabilities::full_access());
335        let result = tool
336            .execute(&tool_ctx(), json!({"command": "nonexistent_cmd"}))
337            .await?;
338
339        assert!(!result.success);
340        assert!(result.output.contains("Exit code: 127"));
341        Ok(())
342    }
343
344    // ===================
345    // Integration Tests
346    // ===================
347
348    #[tokio::test]
349    async fn test_bash_exec_disabled() -> anyhow::Result<()> {
350        let env = Arc::new(MockBashEnvironment::new());
351
352        // Read-only capabilities (exec disabled)
353        let caps = AgentCapabilities::read_only();
354
355        let tool = create_test_tool(env, caps);
356        let result = tool.execute(&tool_ctx(), json!({"command": "ls"})).await?;
357
358        assert!(!result.success);
359        assert!(result.output.contains("Permission denied"));
360        assert!(result.output.contains("execution is disabled"));
361        Ok(())
362    }
363
364    #[tokio::test]
365    async fn test_bash_dangerous_command_denied() -> anyhow::Result<()> {
366        let env = Arc::new(MockBashEnvironment::new());
367
368        // Full access but default denied commands
369        let caps = AgentCapabilities::full_access();
370
371        let tool = create_test_tool(env, caps);
372        let result = tool
373            .execute(&tool_ctx(), json!({"command": "rm -rf /"}))
374            .await?;
375
376        assert!(!result.success);
377        assert!(result.output.contains("Permission denied"));
378        assert!(result.output.contains("not allowed"));
379        Ok(())
380    }
381
382    #[tokio::test]
383    async fn test_bash_sudo_command_denied() -> anyhow::Result<()> {
384        let env = Arc::new(MockBashEnvironment::new());
385        let caps = AgentCapabilities::full_access();
386
387        let tool = create_test_tool(env, caps);
388        let result = tool
389            .execute(&tool_ctx(), json!({"command": "sudo apt-get install foo"}))
390            .await?;
391
392        assert!(!result.success);
393        assert!(result.output.contains("Permission denied"));
394        Ok(())
395    }
396
397    #[tokio::test]
398    async fn test_bash_allowed_commands_restriction() -> anyhow::Result<()> {
399        let env = Arc::new(MockBashEnvironment::new());
400        env.add_command("cargo build", "Compiling...", "", 0);
401
402        // Only allow cargo and git commands
403        let caps = AgentCapabilities::full_access()
404            .with_denied_commands(vec![])
405            .with_allowed_commands(vec![r"^cargo ".into(), r"^git ".into()]);
406
407        let tool = create_test_tool(Arc::clone(&env), caps.clone());
408
409        // cargo should be allowed
410        let result = tool
411            .execute(&tool_ctx(), json!({"command": "cargo build"}))
412            .await?;
413        assert!(result.success);
414
415        // ls should be denied
416        let tool = create_test_tool(env, caps);
417        let result = tool
418            .execute(&tool_ctx(), json!({"command": "ls -la"}))
419            .await?;
420        assert!(!result.success);
421        assert!(result.output.contains("not allowed"));
422        Ok(())
423    }
424
425    // ===================
426    // Edge Cases
427    // ===================
428
429    #[tokio::test]
430    async fn test_bash_empty_output() -> anyhow::Result<()> {
431        let env = Arc::new(MockBashEnvironment::new());
432        env.add_command("true", "", "", 0);
433
434        let tool = create_test_tool(env, AgentCapabilities::full_access());
435        let result = tool
436            .execute(&tool_ctx(), json!({"command": "true"}))
437            .await?;
438
439        assert!(result.success);
440        assert!(result.output.contains("(no output)"));
441        Ok(())
442    }
443
444    #[tokio::test]
445    async fn test_bash_custom_timeout() -> anyhow::Result<()> {
446        let env = Arc::new(MockBashEnvironment::new());
447        env.add_command("slow_cmd", "done", "", 0);
448
449        let tool = create_test_tool(env, AgentCapabilities::full_access());
450        let result = tool
451            .execute(
452                &tool_ctx(),
453                json!({"command": "slow_cmd", "timeout_ms": 5000}),
454            )
455            .await?;
456
457        assert!(result.success);
458        Ok(())
459    }
460
461    #[tokio::test]
462    async fn test_bash_tool_metadata() {
463        let env = Arc::new(MockBashEnvironment::new());
464        let tool = create_test_tool(env, AgentCapabilities::full_access());
465
466        assert_eq!(tool.name(), PrimitiveToolName::Bash);
467        assert_eq!(tool.tier(), ToolTier::Confirm);
468        assert!(tool.description().contains("Execute"));
469
470        let schema = tool.input_schema();
471        assert!(schema.get("properties").is_some());
472        assert!(schema["properties"].get("command").is_some());
473        assert!(schema["properties"].get("timeout_ms").is_some());
474    }
475
476    #[tokio::test]
477    async fn test_bash_invalid_input() -> anyhow::Result<()> {
478        let env = Arc::new(MockBashEnvironment::new());
479        let tool = create_test_tool(env, AgentCapabilities::full_access());
480
481        // Missing required command field
482        let result = tool.execute(&tool_ctx(), json!({})).await;
483        assert!(result.is_err());
484        Ok(())
485    }
486
487    #[tokio::test]
488    async fn test_bash_long_output_truncated() -> anyhow::Result<()> {
489        let env = Arc::new(MockBashEnvironment::new());
490        let long_output = "x".repeat(40_000);
491        env.add_command("long_output_cmd", &long_output, "", 0);
492
493        let tool = create_test_tool(env, AgentCapabilities::full_access());
494        let result = tool
495            .execute(&tool_ctx(), json!({"command": "long_output_cmd"}))
496            .await?;
497
498        assert!(result.success);
499        assert!(result.output.contains("output truncated"));
500        assert!(result.output.len() < 35_000); // Should be truncated
501        Ok(())
502    }
503
504    #[tokio::test]
505    async fn test_truncate_command_function() {
506        assert_eq!(truncate_command("short", 10), "short");
507        assert_eq!(
508            truncate_command("this is a longer command", 10),
509            "this is a ..."
510        );
511    }
512}