Skip to main content

agent_sdk/primitive_tools/
bash.rs

1use crate::reminders::{append_reminder, builtin};
2use crate::{Environment, PrimitiveToolName, Tool, ToolContext, ToolResult, ToolTier};
3use anyhow::{Context, Result};
4use serde::Deserialize;
5use serde_json::{Value, json};
6use std::fmt::Write;
7use std::sync::Arc;
8
9use super::PrimitiveToolContext;
10
11/// Tool for executing shell commands
12pub struct BashTool<E: Environment> {
13    ctx: PrimitiveToolContext<E>,
14}
15
16impl<E: Environment> BashTool<E> {
17    #[must_use]
18    pub const fn new(environment: Arc<E>, capabilities: crate::AgentCapabilities) -> Self {
19        Self {
20            ctx: PrimitiveToolContext::new(environment, capabilities),
21        }
22    }
23}
24
25#[derive(Debug, Deserialize)]
26struct BashInput {
27    /// Command to execute
28    command: String,
29    /// Timeout in milliseconds (default: 120000 = 2 minutes).
30    /// Accepts either an integer or a numeric string such as "5000".
31    /// Uses `Option` so that explicit `null` from the model is handled
32    /// gracefully (falls back to the default rather than failing
33    /// deserialization).
34    #[serde(
35        default,
36        deserialize_with = "super::deserialize_optional_u64_from_string_or_int"
37    )]
38    timeout_ms: Option<u64>,
39}
40
41const DEFAULT_TIMEOUT_MS: u64 = 120_000; // 2 minutes
42
43impl<E: Environment + 'static> Tool<()> for BashTool<E> {
44    type Name = PrimitiveToolName;
45
46    fn name(&self) -> PrimitiveToolName {
47        PrimitiveToolName::Bash
48    }
49
50    fn display_name(&self) -> &'static str {
51        "Run Command"
52    }
53
54    fn description(&self) -> &'static str {
55        "Execute a shell command. Use for git, npm, cargo, and other CLI tools. Returns stdout, stderr, and exit code."
56    }
57
58    fn tier(&self) -> ToolTier {
59        ToolTier::Confirm
60    }
61
62    fn input_schema(&self) -> Value {
63        json!({
64            "type": "object",
65            "properties": {
66                "command": {
67                    "type": "string",
68                    "description": "The shell command to execute"
69                },
70                "timeout_ms": {
71                    "anyOf": [
72                        {"type": "integer"},
73                        {"type": "string", "pattern": "^[0-9]+$"}
74                    ],
75                    "description": "Timeout in milliseconds. Accepts either an integer or a numeric string. Default: 120000 (2 minutes)"
76                }
77            },
78            "required": ["command"]
79        })
80    }
81
82    async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
83        let input: BashInput = serde_json::from_value(input.clone())
84            .with_context(|| format!("Invalid input for bash tool: {input}"))?;
85
86        // Check exec capability
87        if !self.ctx.capabilities.exec {
88            return Ok(ToolResult::error(
89                "Permission denied: command execution is disabled",
90            ));
91        }
92
93        // Check if command is allowed
94        if !self.ctx.capabilities.can_exec(&input.command) {
95            return Ok(ToolResult::error(format!(
96                "Permission denied: command '{}' is not allowed",
97                truncate_command(&input.command, 100)
98            )));
99        }
100
101        // Validate timeout
102        let timeout_ms = input.timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS).min(600_000); // Max 10 minutes
103
104        // Execute command
105        let result = self
106            .ctx
107            .environment
108            .exec(&input.command, Some(timeout_ms))
109            .await
110            .context("Failed to execute command")?;
111
112        // Format output
113        let mut output = String::new();
114
115        if !result.stdout.is_empty() {
116            output.push_str(&result.stdout);
117        }
118
119        if !result.stderr.is_empty() {
120            if !output.is_empty() {
121                output.push_str("\n\n--- stderr ---\n");
122            }
123            output.push_str(&result.stderr);
124        }
125
126        if output.is_empty() {
127            output = "(no output)".to_string();
128        }
129
130        // Truncate if too long
131        let max_output_len = 30_000;
132        if output.len() > max_output_len {
133            output = format!(
134                "{}...\n\n(output truncated, {} total characters)",
135                &output[..max_output_len],
136                output.len()
137            );
138        }
139
140        // Include exit code in output
141        let _ = write!(output, "\n\nExit code: {}", result.exit_code);
142
143        let mut tool_result = if result.success() {
144            ToolResult::success(output)
145        } else {
146            ToolResult::error(output)
147        };
148
149        // Add verification reminder
150        append_reminder(&mut tool_result, builtin::BASH_VERIFICATION_REMINDER);
151
152        Ok(tool_result)
153    }
154}
155
156fn truncate_command(s: &str, max_len: usize) -> String {
157    if s.len() <= max_len {
158        s.to_string()
159    } else {
160        format!("{}...", &s[..max_len])
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::AgentCapabilities;
168    use crate::environment::ExecResult;
169    use async_trait::async_trait;
170    use std::collections::HashMap;
171    use std::sync::RwLock;
172
173    // Mock environment for testing bash execution
174    struct MockBashEnvironment {
175        root: String,
176        // Map of command to (stdout, stderr, exit_code)
177        commands: RwLock<HashMap<String, (String, String, i32)>>,
178    }
179
180    impl MockBashEnvironment {
181        fn new() -> Self {
182            Self {
183                root: "/workspace".to_string(),
184                commands: RwLock::new(HashMap::new()),
185            }
186        }
187
188        fn add_command(&self, cmd: &str, stdout: &str, stderr: &str, exit_code: i32) {
189            self.commands.write().unwrap().insert(
190                cmd.to_string(),
191                (stdout.to_string(), stderr.to_string(), exit_code),
192            );
193        }
194    }
195
196    #[async_trait]
197    impl crate::Environment for MockBashEnvironment {
198        async fn read_file(&self, _path: &str) -> Result<String> {
199            Ok(String::new())
200        }
201
202        async fn read_file_bytes(&self, _path: &str) -> Result<Vec<u8>> {
203            Ok(vec![])
204        }
205
206        async fn write_file(&self, _path: &str, _content: &str) -> Result<()> {
207            Ok(())
208        }
209
210        async fn write_file_bytes(&self, _path: &str, _content: &[u8]) -> Result<()> {
211            Ok(())
212        }
213
214        async fn list_dir(&self, _path: &str) -> Result<Vec<crate::environment::FileEntry>> {
215            Ok(vec![])
216        }
217
218        async fn exists(&self, _path: &str) -> Result<bool> {
219            Ok(false)
220        }
221
222        async fn is_dir(&self, _path: &str) -> Result<bool> {
223            Ok(false)
224        }
225
226        async fn is_file(&self, _path: &str) -> Result<bool> {
227            Ok(false)
228        }
229
230        async fn create_dir(&self, _path: &str) -> Result<()> {
231            Ok(())
232        }
233
234        async fn delete_file(&self, _path: &str) -> Result<()> {
235            Ok(())
236        }
237
238        async fn delete_dir(&self, _path: &str, _recursive: bool) -> Result<()> {
239            Ok(())
240        }
241
242        async fn grep(
243            &self,
244            _pattern: &str,
245            _path: &str,
246            _recursive: bool,
247        ) -> Result<Vec<crate::environment::GrepMatch>> {
248            Ok(vec![])
249        }
250
251        async fn glob(&self, _pattern: &str) -> Result<Vec<String>> {
252            Ok(vec![])
253        }
254
255        async fn exec(&self, command: &str, _timeout_ms: Option<u64>) -> Result<ExecResult> {
256            let commands = self.commands.read().unwrap();
257            if let Some((stdout, stderr, exit_code)) = commands.get(command) {
258                Ok(ExecResult {
259                    stdout: stdout.clone(),
260                    stderr: stderr.clone(),
261                    exit_code: *exit_code,
262                })
263            } else {
264                // Default: command not found
265                Ok(ExecResult {
266                    stdout: String::new(),
267                    stderr: format!("command not found: {command}"),
268                    exit_code: 127,
269                })
270            }
271        }
272
273        fn root(&self) -> &str {
274            &self.root
275        }
276    }
277
278    fn create_test_tool(
279        env: Arc<MockBashEnvironment>,
280        capabilities: AgentCapabilities,
281    ) -> BashTool<MockBashEnvironment> {
282        BashTool::new(env, capabilities)
283    }
284
285    fn tool_ctx() -> ToolContext<()> {
286        ToolContext::new(())
287    }
288
289    // ===================
290    // Unit Tests
291    // ===================
292
293    #[tokio::test]
294    async fn test_bash_simple_command() -> anyhow::Result<()> {
295        let env = Arc::new(MockBashEnvironment::new());
296        env.add_command("echo hello", "hello\n", "", 0);
297
298        let tool = create_test_tool(env, AgentCapabilities::full_access());
299        let result = tool
300            .execute(&tool_ctx(), json!({"command": "echo hello"}))
301            .await?;
302
303        assert!(result.success);
304        assert!(result.output.contains("hello"));
305        assert!(result.output.contains("Exit code: 0"));
306        Ok(())
307    }
308
309    #[tokio::test]
310    async fn test_bash_command_with_stderr() -> anyhow::Result<()> {
311        let env = Arc::new(MockBashEnvironment::new());
312        env.add_command("cmd", "stdout output", "stderr output", 0);
313
314        let tool = create_test_tool(env, AgentCapabilities::full_access());
315        let result = tool.execute(&tool_ctx(), json!({"command": "cmd"})).await?;
316
317        assert!(result.success);
318        assert!(result.output.contains("stdout output"));
319        assert!(result.output.contains("stderr output"));
320        Ok(())
321    }
322
323    #[tokio::test]
324    async fn test_bash_command_nonzero_exit() -> anyhow::Result<()> {
325        let env = Arc::new(MockBashEnvironment::new());
326        env.add_command("failing_cmd", "", "error occurred", 1);
327
328        let tool = create_test_tool(env, AgentCapabilities::full_access());
329        let result = tool
330            .execute(&tool_ctx(), json!({"command": "failing_cmd"}))
331            .await?;
332
333        assert!(!result.success);
334        assert!(result.output.contains("Exit code: 1"));
335        Ok(())
336    }
337
338    #[tokio::test]
339    async fn test_bash_command_not_found() -> anyhow::Result<()> {
340        let env = Arc::new(MockBashEnvironment::new());
341
342        let tool = create_test_tool(env, AgentCapabilities::full_access());
343        let result = tool
344            .execute(&tool_ctx(), json!({"command": "nonexistent_cmd"}))
345            .await?;
346
347        assert!(!result.success);
348        assert!(result.output.contains("Exit code: 127"));
349        Ok(())
350    }
351
352    // ===================
353    // Integration Tests
354    // ===================
355
356    #[tokio::test]
357    async fn test_bash_exec_disabled() -> anyhow::Result<()> {
358        let env = Arc::new(MockBashEnvironment::new());
359
360        // Read-only capabilities (exec disabled)
361        let caps = AgentCapabilities::read_only();
362
363        let tool = create_test_tool(env, caps);
364        let result = tool.execute(&tool_ctx(), json!({"command": "ls"})).await?;
365
366        assert!(!result.success);
367        assert!(result.output.contains("Permission denied"));
368        assert!(result.output.contains("execution is disabled"));
369        Ok(())
370    }
371
372    #[tokio::test]
373    async fn test_bash_dangerous_command_denied() -> anyhow::Result<()> {
374        let env = Arc::new(MockBashEnvironment::new());
375
376        // Full access but default denied commands
377        let caps = AgentCapabilities::full_access();
378
379        let tool = create_test_tool(env, caps);
380        let result = tool
381            .execute(&tool_ctx(), json!({"command": "rm -rf /"}))
382            .await?;
383
384        assert!(!result.success);
385        assert!(result.output.contains("Permission denied"));
386        assert!(result.output.contains("not allowed"));
387        Ok(())
388    }
389
390    #[tokio::test]
391    async fn test_bash_sudo_command_denied() -> anyhow::Result<()> {
392        let env = Arc::new(MockBashEnvironment::new());
393        let caps = AgentCapabilities::full_access();
394
395        let tool = create_test_tool(env, caps);
396        let result = tool
397            .execute(&tool_ctx(), json!({"command": "sudo apt-get install foo"}))
398            .await?;
399
400        assert!(!result.success);
401        assert!(result.output.contains("Permission denied"));
402        Ok(())
403    }
404
405    #[tokio::test]
406    async fn test_bash_allowed_commands_restriction() -> anyhow::Result<()> {
407        let env = Arc::new(MockBashEnvironment::new());
408        env.add_command("cargo build", "Compiling...", "", 0);
409
410        // Only allow cargo and git commands
411        let caps = AgentCapabilities::full_access()
412            .with_denied_commands(vec![])
413            .with_allowed_commands(vec![r"^cargo ".into(), r"^git ".into()]);
414
415        let tool = create_test_tool(Arc::clone(&env), caps.clone());
416
417        // cargo should be allowed
418        let result = tool
419            .execute(&tool_ctx(), json!({"command": "cargo build"}))
420            .await?;
421        assert!(result.success);
422
423        // ls should be denied
424        let tool = create_test_tool(env, caps);
425        let result = tool
426            .execute(&tool_ctx(), json!({"command": "ls -la"}))
427            .await?;
428        assert!(!result.success);
429        assert!(result.output.contains("not allowed"));
430        Ok(())
431    }
432
433    // ===================
434    // Edge Cases
435    // ===================
436
437    #[tokio::test]
438    async fn test_bash_empty_output() -> anyhow::Result<()> {
439        let env = Arc::new(MockBashEnvironment::new());
440        env.add_command("true", "", "", 0);
441
442        let tool = create_test_tool(env, AgentCapabilities::full_access());
443        let result = tool
444            .execute(&tool_ctx(), json!({"command": "true"}))
445            .await?;
446
447        assert!(result.success);
448        assert!(result.output.contains("(no output)"));
449        Ok(())
450    }
451
452    #[tokio::test]
453    async fn test_bash_custom_timeout() -> anyhow::Result<()> {
454        let env = Arc::new(MockBashEnvironment::new());
455        env.add_command("slow_cmd", "done", "", 0);
456
457        let tool = create_test_tool(env, AgentCapabilities::full_access());
458        let result = tool
459            .execute(
460                &tool_ctx(),
461                json!({"command": "slow_cmd", "timeout_ms": 5000}),
462            )
463            .await?;
464
465        assert!(result.success);
466        Ok(())
467    }
468
469    #[tokio::test]
470    async fn test_bash_tool_metadata() {
471        let env = Arc::new(MockBashEnvironment::new());
472        let tool = create_test_tool(env, AgentCapabilities::full_access());
473
474        assert_eq!(tool.name(), PrimitiveToolName::Bash);
475        assert_eq!(tool.tier(), ToolTier::Confirm);
476        assert!(tool.description().contains("Execute"));
477
478        let schema = tool.input_schema();
479        assert!(schema.get("properties").is_some());
480        assert!(schema["properties"].get("command").is_some());
481        assert!(schema["properties"].get("timeout_ms").is_some());
482    }
483
484    #[tokio::test]
485    async fn test_bash_invalid_input() -> anyhow::Result<()> {
486        let env = Arc::new(MockBashEnvironment::new());
487        let tool = create_test_tool(env, AgentCapabilities::full_access());
488
489        // Missing required command field
490        let result = tool.execute(&tool_ctx(), json!({})).await;
491        assert!(result.is_err());
492        Ok(())
493    }
494
495    #[tokio::test]
496    async fn test_bash_null_timeout_ms() -> anyhow::Result<()> {
497        let env = Arc::new(MockBashEnvironment::new());
498        env.add_command("echo hello", "hello", "", 0);
499        let tool = create_test_tool(env, AgentCapabilities::full_access());
500
501        // Model may send explicit null for optional fields — must not fail
502        let result = tool
503            .execute(
504                &tool_ctx(),
505                json!({"command": "echo hello", "timeout_ms": null}),
506            )
507            .await?;
508
509        assert!(result.success);
510        Ok(())
511    }
512
513    #[tokio::test]
514    async fn test_bash_missing_timeout_uses_default() -> anyhow::Result<()> {
515        let env = Arc::new(MockBashEnvironment::new());
516        env.add_command("echo hi", "hi", "", 0);
517        let tool = create_test_tool(env, AgentCapabilities::full_access());
518
519        // Omitted timeout_ms should use the default
520        let result = tool
521            .execute(&tool_ctx(), json!({"command": "echo hi"}))
522            .await?;
523
524        assert!(result.success);
525        Ok(())
526    }
527
528    #[tokio::test]
529    async fn test_bash_string_timeout_ms() -> anyhow::Result<()> {
530        let env = Arc::new(MockBashEnvironment::new());
531        env.add_command("echo timeout", "ok", "", 0);
532        let tool = create_test_tool(env, AgentCapabilities::full_access());
533
534        let result = tool
535            .execute(
536                &tool_ctx(),
537                json!({"command": "echo timeout", "timeout_ms": "5000"}),
538            )
539            .await?;
540
541        assert!(result.success);
542        Ok(())
543    }
544
545    #[tokio::test]
546    async fn test_bash_long_output_truncated() -> anyhow::Result<()> {
547        let env = Arc::new(MockBashEnvironment::new());
548        let long_output = "x".repeat(40_000);
549        env.add_command("long_output_cmd", &long_output, "", 0);
550
551        let tool = create_test_tool(env, AgentCapabilities::full_access());
552        let result = tool
553            .execute(&tool_ctx(), json!({"command": "long_output_cmd"}))
554            .await?;
555
556        assert!(result.success);
557        assert!(result.output.contains("output truncated"));
558        assert!(result.output.len() < 35_000); // Should be truncated
559        Ok(())
560    }
561
562    #[tokio::test]
563    async fn test_truncate_command_function() {
564        assert_eq!(truncate_command("short", 10), "short");
565        assert_eq!(
566            truncate_command("this is a longer command", 10),
567            "this is a ..."
568        );
569    }
570}