agent_sdk/primitive_tools/
bash.rs

1use crate::reminders::{append_reminder, builtin};
2use crate::{Environment, Tool, ToolContext, ToolResult, ToolTier};
3use anyhow::{Context, Result};
4use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::{Value, json};
7use std::fmt::Write;
8use std::sync::Arc;
9
10use super::PrimitiveToolContext;
11
12/// Tool for executing shell commands
13pub struct BashTool<E: Environment> {
14    ctx: PrimitiveToolContext<E>,
15}
16
17impl<E: Environment> BashTool<E> {
18    #[must_use]
19    pub const fn new(environment: Arc<E>, capabilities: crate::AgentCapabilities) -> Self {
20        Self {
21            ctx: PrimitiveToolContext::new(environment, capabilities),
22        }
23    }
24}
25
26#[derive(Debug, Deserialize)]
27struct BashInput {
28    /// Command to execute
29    command: String,
30    /// Timeout in milliseconds (default: 120000 = 2 minutes)
31    #[serde(default = "default_timeout")]
32    timeout_ms: u64,
33}
34
35const fn default_timeout() -> u64 {
36    120_000 // 2 minutes
37}
38
39#[async_trait]
40impl<E: Environment + 'static> Tool<()> for BashTool<E> {
41    fn name(&self) -> &'static str {
42        "bash"
43    }
44
45    fn description(&self) -> &'static str {
46        "Execute a shell command. Use for git, npm, cargo, and other CLI tools. Returns stdout, stderr, and exit code."
47    }
48
49    fn tier(&self) -> ToolTier {
50        ToolTier::Confirm
51    }
52
53    fn input_schema(&self) -> Value {
54        json!({
55            "type": "object",
56            "properties": {
57                "command": {
58                    "type": "string",
59                    "description": "The shell command to execute"
60                },
61                "timeout_ms": {
62                    "type": "integer",
63                    "description": "Timeout in milliseconds. Default: 120000 (2 minutes)"
64                }
65            },
66            "required": ["command"]
67        })
68    }
69
70    async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
71        let input: BashInput =
72            serde_json::from_value(input).context("Invalid input for bash tool")?;
73
74        // Check exec capability
75        if !self.ctx.capabilities.exec {
76            return Ok(ToolResult::error(
77                "Permission denied: command execution is disabled",
78            ));
79        }
80
81        // Check if command is allowed
82        if !self.ctx.capabilities.can_exec(&input.command) {
83            return Ok(ToolResult::error(format!(
84                "Permission denied: command '{}' is not allowed",
85                truncate_command(&input.command, 100)
86            )));
87        }
88
89        // Validate timeout
90        let timeout_ms = input.timeout_ms.min(600_000); // Max 10 minutes
91
92        // Execute command
93        let result = self
94            .ctx
95            .environment
96            .exec(&input.command, Some(timeout_ms))
97            .await
98            .context("Failed to execute command")?;
99
100        // Format output
101        let mut output = String::new();
102
103        if !result.stdout.is_empty() {
104            output.push_str(&result.stdout);
105        }
106
107        if !result.stderr.is_empty() {
108            if !output.is_empty() {
109                output.push_str("\n\n--- stderr ---\n");
110            }
111            output.push_str(&result.stderr);
112        }
113
114        if output.is_empty() {
115            output = "(no output)".to_string();
116        }
117
118        // Truncate if too long
119        let max_output_len = 30_000;
120        if output.len() > max_output_len {
121            output = format!(
122                "{}...\n\n(output truncated, {} total characters)",
123                &output[..max_output_len],
124                output.len()
125            );
126        }
127
128        // Include exit code in output
129        let _ = write!(output, "\n\nExit code: {}", result.exit_code);
130
131        let mut tool_result = if result.success() {
132            ToolResult::success(output)
133        } else {
134            ToolResult::error(output)
135        };
136
137        // Add verification reminder
138        append_reminder(&mut tool_result, builtin::BASH_VERIFICATION_REMINDER);
139
140        Ok(tool_result)
141    }
142}
143
144fn truncate_command(s: &str, max_len: usize) -> String {
145    if s.len() <= max_len {
146        s.to_string()
147    } else {
148        format!("{}...", &s[..max_len])
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use crate::AgentCapabilities;
156    use crate::environment::ExecResult;
157    use std::collections::HashMap;
158    use std::sync::RwLock;
159
160    // Mock environment for testing bash execution
161    struct MockBashEnvironment {
162        root: String,
163        // Map of command to (stdout, stderr, exit_code)
164        commands: RwLock<HashMap<String, (String, String, i32)>>,
165    }
166
167    impl MockBashEnvironment {
168        fn new() -> Self {
169            Self {
170                root: "/workspace".to_string(),
171                commands: RwLock::new(HashMap::new()),
172            }
173        }
174
175        fn add_command(&self, cmd: &str, stdout: &str, stderr: &str, exit_code: i32) {
176            self.commands.write().unwrap().insert(
177                cmd.to_string(),
178                (stdout.to_string(), stderr.to_string(), exit_code),
179            );
180        }
181    }
182
183    #[async_trait]
184    impl crate::Environment for MockBashEnvironment {
185        async fn read_file(&self, _path: &str) -> Result<String> {
186            Ok(String::new())
187        }
188
189        async fn read_file_bytes(&self, _path: &str) -> Result<Vec<u8>> {
190            Ok(vec![])
191        }
192
193        async fn write_file(&self, _path: &str, _content: &str) -> Result<()> {
194            Ok(())
195        }
196
197        async fn write_file_bytes(&self, _path: &str, _content: &[u8]) -> Result<()> {
198            Ok(())
199        }
200
201        async fn list_dir(&self, _path: &str) -> Result<Vec<crate::environment::FileEntry>> {
202            Ok(vec![])
203        }
204
205        async fn exists(&self, _path: &str) -> Result<bool> {
206            Ok(false)
207        }
208
209        async fn is_dir(&self, _path: &str) -> Result<bool> {
210            Ok(false)
211        }
212
213        async fn is_file(&self, _path: &str) -> Result<bool> {
214            Ok(false)
215        }
216
217        async fn create_dir(&self, _path: &str) -> Result<()> {
218            Ok(())
219        }
220
221        async fn delete_file(&self, _path: &str) -> Result<()> {
222            Ok(())
223        }
224
225        async fn delete_dir(&self, _path: &str, _recursive: bool) -> Result<()> {
226            Ok(())
227        }
228
229        async fn grep(
230            &self,
231            _pattern: &str,
232            _path: &str,
233            _recursive: bool,
234        ) -> Result<Vec<crate::environment::GrepMatch>> {
235            Ok(vec![])
236        }
237
238        async fn glob(&self, _pattern: &str) -> Result<Vec<String>> {
239            Ok(vec![])
240        }
241
242        async fn exec(&self, command: &str, _timeout_ms: Option<u64>) -> Result<ExecResult> {
243            let commands = self.commands.read().unwrap();
244            if let Some((stdout, stderr, exit_code)) = commands.get(command) {
245                Ok(ExecResult {
246                    stdout: stdout.clone(),
247                    stderr: stderr.clone(),
248                    exit_code: *exit_code,
249                })
250            } else {
251                // Default: command not found
252                Ok(ExecResult {
253                    stdout: String::new(),
254                    stderr: format!("command not found: {command}"),
255                    exit_code: 127,
256                })
257            }
258        }
259
260        fn root(&self) -> &str {
261            &self.root
262        }
263    }
264
265    fn create_test_tool(
266        env: Arc<MockBashEnvironment>,
267        capabilities: AgentCapabilities,
268    ) -> BashTool<MockBashEnvironment> {
269        BashTool::new(env, capabilities)
270    }
271
272    fn tool_ctx() -> ToolContext<()> {
273        ToolContext::new(())
274    }
275
276    // ===================
277    // Unit Tests
278    // ===================
279
280    #[tokio::test]
281    async fn test_bash_simple_command() -> anyhow::Result<()> {
282        let env = Arc::new(MockBashEnvironment::new());
283        env.add_command("echo hello", "hello\n", "", 0);
284
285        let tool = create_test_tool(env, AgentCapabilities::full_access());
286        let result = tool
287            .execute(&tool_ctx(), json!({"command": "echo hello"}))
288            .await?;
289
290        assert!(result.success);
291        assert!(result.output.contains("hello"));
292        assert!(result.output.contains("Exit code: 0"));
293        Ok(())
294    }
295
296    #[tokio::test]
297    async fn test_bash_command_with_stderr() -> anyhow::Result<()> {
298        let env = Arc::new(MockBashEnvironment::new());
299        env.add_command("cmd", "stdout output", "stderr output", 0);
300
301        let tool = create_test_tool(env, AgentCapabilities::full_access());
302        let result = tool.execute(&tool_ctx(), json!({"command": "cmd"})).await?;
303
304        assert!(result.success);
305        assert!(result.output.contains("stdout output"));
306        assert!(result.output.contains("stderr output"));
307        Ok(())
308    }
309
310    #[tokio::test]
311    async fn test_bash_command_nonzero_exit() -> anyhow::Result<()> {
312        let env = Arc::new(MockBashEnvironment::new());
313        env.add_command("failing_cmd", "", "error occurred", 1);
314
315        let tool = create_test_tool(env, AgentCapabilities::full_access());
316        let result = tool
317            .execute(&tool_ctx(), json!({"command": "failing_cmd"}))
318            .await?;
319
320        assert!(!result.success);
321        assert!(result.output.contains("Exit code: 1"));
322        Ok(())
323    }
324
325    #[tokio::test]
326    async fn test_bash_command_not_found() -> anyhow::Result<()> {
327        let env = Arc::new(MockBashEnvironment::new());
328
329        let tool = create_test_tool(env, AgentCapabilities::full_access());
330        let result = tool
331            .execute(&tool_ctx(), json!({"command": "nonexistent_cmd"}))
332            .await?;
333
334        assert!(!result.success);
335        assert!(result.output.contains("Exit code: 127"));
336        Ok(())
337    }
338
339    // ===================
340    // Integration Tests
341    // ===================
342
343    #[tokio::test]
344    async fn test_bash_exec_disabled() -> anyhow::Result<()> {
345        let env = Arc::new(MockBashEnvironment::new());
346
347        // Read-only capabilities (exec disabled)
348        let caps = AgentCapabilities::read_only();
349
350        let tool = create_test_tool(env, caps);
351        let result = tool.execute(&tool_ctx(), json!({"command": "ls"})).await?;
352
353        assert!(!result.success);
354        assert!(result.output.contains("Permission denied"));
355        assert!(result.output.contains("execution is disabled"));
356        Ok(())
357    }
358
359    #[tokio::test]
360    async fn test_bash_dangerous_command_denied() -> anyhow::Result<()> {
361        let env = Arc::new(MockBashEnvironment::new());
362
363        // Full access but default denied commands
364        let caps = AgentCapabilities::full_access();
365
366        let tool = create_test_tool(env, caps);
367        let result = tool
368            .execute(&tool_ctx(), json!({"command": "rm -rf /"}))
369            .await?;
370
371        assert!(!result.success);
372        assert!(result.output.contains("Permission denied"));
373        assert!(result.output.contains("not allowed"));
374        Ok(())
375    }
376
377    #[tokio::test]
378    async fn test_bash_sudo_command_denied() -> anyhow::Result<()> {
379        let env = Arc::new(MockBashEnvironment::new());
380        let caps = AgentCapabilities::full_access();
381
382        let tool = create_test_tool(env, caps);
383        let result = tool
384            .execute(&tool_ctx(), json!({"command": "sudo apt-get install foo"}))
385            .await?;
386
387        assert!(!result.success);
388        assert!(result.output.contains("Permission denied"));
389        Ok(())
390    }
391
392    #[tokio::test]
393    async fn test_bash_allowed_commands_restriction() -> anyhow::Result<()> {
394        let env = Arc::new(MockBashEnvironment::new());
395        env.add_command("cargo build", "Compiling...", "", 0);
396
397        // Only allow cargo and git commands
398        let caps = AgentCapabilities::full_access()
399            .with_denied_commands(vec![])
400            .with_allowed_commands(vec![r"^cargo ".into(), r"^git ".into()]);
401
402        let tool = create_test_tool(Arc::clone(&env), caps.clone());
403
404        // cargo should be allowed
405        let result = tool
406            .execute(&tool_ctx(), json!({"command": "cargo build"}))
407            .await?;
408        assert!(result.success);
409
410        // ls should be denied
411        let tool = create_test_tool(env, caps);
412        let result = tool
413            .execute(&tool_ctx(), json!({"command": "ls -la"}))
414            .await?;
415        assert!(!result.success);
416        assert!(result.output.contains("not allowed"));
417        Ok(())
418    }
419
420    // ===================
421    // Edge Cases
422    // ===================
423
424    #[tokio::test]
425    async fn test_bash_empty_output() -> anyhow::Result<()> {
426        let env = Arc::new(MockBashEnvironment::new());
427        env.add_command("true", "", "", 0);
428
429        let tool = create_test_tool(env, AgentCapabilities::full_access());
430        let result = tool
431            .execute(&tool_ctx(), json!({"command": "true"}))
432            .await?;
433
434        assert!(result.success);
435        assert!(result.output.contains("(no output)"));
436        Ok(())
437    }
438
439    #[tokio::test]
440    async fn test_bash_custom_timeout() -> anyhow::Result<()> {
441        let env = Arc::new(MockBashEnvironment::new());
442        env.add_command("slow_cmd", "done", "", 0);
443
444        let tool = create_test_tool(env, AgentCapabilities::full_access());
445        let result = tool
446            .execute(
447                &tool_ctx(),
448                json!({"command": "slow_cmd", "timeout_ms": 5000}),
449            )
450            .await?;
451
452        assert!(result.success);
453        Ok(())
454    }
455
456    #[tokio::test]
457    async fn test_bash_tool_metadata() {
458        let env = Arc::new(MockBashEnvironment::new());
459        let tool = create_test_tool(env, AgentCapabilities::full_access());
460
461        assert_eq!(tool.name(), "bash");
462        assert_eq!(tool.tier(), ToolTier::Confirm);
463        assert!(tool.description().contains("Execute"));
464
465        let schema = tool.input_schema();
466        assert!(schema.get("properties").is_some());
467        assert!(schema["properties"].get("command").is_some());
468        assert!(schema["properties"].get("timeout_ms").is_some());
469    }
470
471    #[tokio::test]
472    async fn test_bash_invalid_input() -> anyhow::Result<()> {
473        let env = Arc::new(MockBashEnvironment::new());
474        let tool = create_test_tool(env, AgentCapabilities::full_access());
475
476        // Missing required command field
477        let result = tool.execute(&tool_ctx(), json!({})).await;
478        assert!(result.is_err());
479        Ok(())
480    }
481
482    #[tokio::test]
483    async fn test_bash_long_output_truncated() -> anyhow::Result<()> {
484        let env = Arc::new(MockBashEnvironment::new());
485        let long_output = "x".repeat(40_000);
486        env.add_command("long_output_cmd", &long_output, "", 0);
487
488        let tool = create_test_tool(env, AgentCapabilities::full_access());
489        let result = tool
490            .execute(&tool_ctx(), json!({"command": "long_output_cmd"}))
491            .await?;
492
493        assert!(result.success);
494        assert!(result.output.contains("output truncated"));
495        assert!(result.output.len() < 35_000); // Should be truncated
496        Ok(())
497    }
498
499    #[tokio::test]
500    async fn test_truncate_command_function() {
501        assert_eq!(truncate_command("short", 10), "short");
502        assert_eq!(
503            truncate_command("this is a longer command", 10),
504            "this is a ..."
505        );
506    }
507}