Skip to main content

agent_sdk/primitive_tools/
bash.rs

1use crate::reminders::{append_reminder, builtin};
2use crate::{Environment, PrimitiveToolName, Tool, ToolContext, ToolResult, ToolTier};
3use anyhow::{Context, Result};
4use serde::Deserialize;
5use serde_json::{Value, json};
6use std::fmt::Write;
7use std::sync::Arc;
8
9use super::PrimitiveToolContext;
10
11/// Tool for executing shell commands
12pub struct BashTool<E: Environment> {
13    ctx: PrimitiveToolContext<E>,
14}
15
16impl<E: Environment> BashTool<E> {
17    #[must_use]
18    pub const fn new(environment: Arc<E>, capabilities: crate::AgentCapabilities) -> Self {
19        Self {
20            ctx: PrimitiveToolContext::new(environment, capabilities),
21        }
22    }
23}
24
25#[derive(Debug, Deserialize)]
26struct BashInput {
27    /// Command to execute
28    command: String,
29    /// Timeout in milliseconds (default: 120000 = 2 minutes).
30    /// Accepts either an integer or a numeric string such as "5000".
31    /// Uses `Option` so that explicit `null` from the model is handled
32    /// gracefully (falls back to the default rather than failing
33    /// deserialization).
34    #[serde(
35        default,
36        deserialize_with = "super::deserialize_optional_u64_from_string_or_int"
37    )]
38    timeout_ms: Option<u64>,
39}
40
41const DEFAULT_TIMEOUT_MS: u64 = 120_000; // 2 minutes
42
43impl<E: Environment + 'static> Tool<()> for BashTool<E> {
44    type Name = PrimitiveToolName;
45
46    fn name(&self) -> PrimitiveToolName {
47        PrimitiveToolName::Bash
48    }
49
50    fn display_name(&self) -> &'static str {
51        "Run Command"
52    }
53
54    fn description(&self) -> &'static str {
55        "Execute a shell command. Use for git, npm, cargo, and other CLI tools. Returns stdout, stderr, and exit code."
56    }
57
58    fn tier(&self) -> ToolTier {
59        ToolTier::Confirm
60    }
61
62    fn input_schema(&self) -> Value {
63        json!({
64            "type": "object",
65            "properties": {
66                "command": {
67                    "type": "string",
68                    "description": "The shell command to execute"
69                },
70                "timeout_ms": {
71                    "anyOf": [
72                        {"type": "integer"},
73                        {"type": "string", "pattern": "^[0-9]+$"}
74                    ],
75                    "description": "Timeout in milliseconds. Accepts either an integer or a numeric string. Default: 120000 (2 minutes)"
76                }
77            },
78            "required": ["command"]
79        })
80    }
81
82    async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
83        let input: BashInput = serde_json::from_value(input.clone())
84            .with_context(|| format!("Invalid input for bash tool: {input}"))?;
85
86        // Check exec capability
87        if !self.ctx.capabilities.exec {
88            return Ok(ToolResult::error(
89                "Permission denied: command execution is disabled",
90            ));
91        }
92
93        // Check if command is allowed
94        if !self.ctx.capabilities.can_exec(&input.command) {
95            return Ok(ToolResult::error(format!(
96                "Permission denied: command '{}' is not allowed",
97                truncate_command(&input.command, 100)
98            )));
99        }
100
101        // Validate timeout
102        let timeout_ms = input.timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS).min(600_000); // Max 10 minutes
103
104        // Execute command
105        let result = self
106            .ctx
107            .environment
108            .exec(&input.command, Some(timeout_ms))
109            .await
110            .context("Failed to execute command")?;
111
112        // Format output
113        let mut output = String::new();
114
115        if !result.stdout.is_empty() {
116            output.push_str(&result.stdout);
117        }
118
119        if !result.stderr.is_empty() {
120            if !output.is_empty() {
121                output.push_str("\n\n--- stderr ---\n");
122            }
123            output.push_str(&result.stderr);
124        }
125
126        if output.is_empty() {
127            output = "(no output)".to_string();
128        }
129
130        // Truncate if too long
131        let max_output_len = 30_000;
132        if output.len() > max_output_len {
133            output = format!(
134                "{}...\n\n(output truncated, {} total characters)",
135                &output[..max_output_len],
136                output.len()
137            );
138        }
139
140        // Include exit code in output
141        let _ = write!(output, "\n\nExit code: {}", result.exit_code);
142
143        let mut tool_result = if result.success() {
144            ToolResult::success(output)
145        } else {
146            ToolResult::error(output)
147        };
148
149        // Add verification reminder
150        append_reminder(&mut tool_result, builtin::BASH_VERIFICATION_REMINDER);
151
152        Ok(tool_result)
153    }
154}
155
156fn truncate_command(s: &str, max_len: usize) -> String {
157    if s.len() <= max_len {
158        s.to_string()
159    } else {
160        format!("{}...", &s[..max_len])
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::AgentCapabilities;
168    use crate::environment::ExecResult;
169    use async_trait::async_trait;
170    use std::collections::HashMap;
171    use std::sync::RwLock;
172
173    // Mock environment for testing bash execution
174    struct MockBashEnvironment {
175        root: String,
176        // Map of command to (stdout, stderr, exit_code)
177        commands: RwLock<HashMap<String, (String, String, i32)>>,
178    }
179
180    impl MockBashEnvironment {
181        fn new() -> Self {
182            Self {
183                root: "/workspace".to_string(),
184                commands: RwLock::new(HashMap::new()),
185            }
186        }
187
188        fn add_command(&self, cmd: &str, stdout: &str, stderr: &str, exit_code: i32) {
189            self.commands.write().unwrap().insert(
190                cmd.to_string(),
191                (stdout.to_string(), stderr.to_string(), exit_code),
192            );
193        }
194    }
195
196    #[async_trait]
197    impl crate::Environment for MockBashEnvironment {
198        async fn read_file(&self, _path: &str) -> Result<String> {
199            Ok(String::new())
200        }
201
202        async fn read_file_bytes(&self, _path: &str) -> Result<Vec<u8>> {
203            Ok(vec![])
204        }
205
206        async fn write_file(&self, _path: &str, _content: &str) -> Result<()> {
207            Ok(())
208        }
209
210        async fn write_file_bytes(&self, _path: &str, _content: &[u8]) -> Result<()> {
211            Ok(())
212        }
213
214        async fn list_dir(&self, _path: &str) -> Result<Vec<crate::environment::FileEntry>> {
215            Ok(vec![])
216        }
217
218        async fn exists(&self, _path: &str) -> Result<bool> {
219            Ok(false)
220        }
221
222        async fn is_dir(&self, _path: &str) -> Result<bool> {
223            Ok(false)
224        }
225
226        async fn is_file(&self, _path: &str) -> Result<bool> {
227            Ok(false)
228        }
229
230        async fn create_dir(&self, _path: &str) -> Result<()> {
231            Ok(())
232        }
233
234        async fn delete_file(&self, _path: &str) -> Result<()> {
235            Ok(())
236        }
237
238        async fn delete_dir(&self, _path: &str, _recursive: bool) -> Result<()> {
239            Ok(())
240        }
241
242        async fn grep(
243            &self,
244            _pattern: &str,
245            _path: &str,
246            _recursive: bool,
247        ) -> Result<Vec<crate::environment::GrepMatch>> {
248            Ok(vec![])
249        }
250
251        async fn glob(&self, _pattern: &str) -> Result<Vec<String>> {
252            Ok(vec![])
253        }
254
255        async fn exec(&self, command: &str, _timeout_ms: Option<u64>) -> Result<ExecResult> {
256            let commands = self.commands.read().unwrap();
257            if let Some((stdout, stderr, exit_code)) = commands.get(command) {
258                Ok(ExecResult {
259                    stdout: stdout.clone(),
260                    stderr: stderr.clone(),
261                    exit_code: *exit_code,
262                })
263            } else {
264                // Default: command not found
265                Ok(ExecResult {
266                    stdout: String::new(),
267                    stderr: format!("command not found: {command}"),
268                    exit_code: 127,
269                })
270            }
271        }
272
273        fn root(&self) -> &str {
274            &self.root
275        }
276    }
277
278    fn create_test_tool(
279        env: Arc<MockBashEnvironment>,
280        capabilities: AgentCapabilities,
281    ) -> BashTool<MockBashEnvironment> {
282        BashTool::new(env, capabilities)
283    }
284
285    fn tool_ctx() -> ToolContext<()> {
286        ToolContext::new(())
287    }
288
289    // ===================
290    // Unit Tests
291    // ===================
292
293    #[tokio::test]
294    async fn test_bash_simple_command() -> anyhow::Result<()> {
295        let env = Arc::new(MockBashEnvironment::new());
296        env.add_command("echo hello", "hello\n", "", 0);
297
298        let tool = create_test_tool(env, AgentCapabilities::full_access());
299        let result = tool
300            .execute(&tool_ctx(), json!({"command": "echo hello"}))
301            .await?;
302
303        assert!(result.success);
304        assert!(result.output.contains("hello"));
305        assert!(result.output.contains("Exit code: 0"));
306        Ok(())
307    }
308
309    #[tokio::test]
310    async fn test_bash_command_with_stderr() -> anyhow::Result<()> {
311        let env = Arc::new(MockBashEnvironment::new());
312        env.add_command("cmd", "stdout output", "stderr output", 0);
313
314        let tool = create_test_tool(env, AgentCapabilities::full_access());
315        let result = tool.execute(&tool_ctx(), json!({"command": "cmd"})).await?;
316
317        assert!(result.success);
318        assert!(result.output.contains("stdout output"));
319        assert!(result.output.contains("stderr output"));
320        Ok(())
321    }
322
323    #[tokio::test]
324    async fn test_bash_command_nonzero_exit() -> anyhow::Result<()> {
325        let env = Arc::new(MockBashEnvironment::new());
326        env.add_command("failing_cmd", "", "error occurred", 1);
327
328        let tool = create_test_tool(env, AgentCapabilities::full_access());
329        let result = tool
330            .execute(&tool_ctx(), json!({"command": "failing_cmd"}))
331            .await?;
332
333        assert!(!result.success);
334        assert!(result.output.contains("Exit code: 1"));
335        Ok(())
336    }
337
338    #[tokio::test]
339    async fn test_bash_command_not_found() -> anyhow::Result<()> {
340        let env = Arc::new(MockBashEnvironment::new());
341
342        let tool = create_test_tool(env, AgentCapabilities::full_access());
343        let result = tool
344            .execute(&tool_ctx(), json!({"command": "nonexistent_cmd"}))
345            .await?;
346
347        assert!(!result.success);
348        assert!(result.output.contains("Exit code: 127"));
349        Ok(())
350    }
351
352    // ===================
353    // Integration Tests
354    // ===================
355
356    #[tokio::test]
357    async fn test_bash_exec_disabled() -> anyhow::Result<()> {
358        let env = Arc::new(MockBashEnvironment::new());
359
360        // Read-only capabilities (exec disabled)
361        let caps = AgentCapabilities::read_only();
362
363        let tool = create_test_tool(env, caps);
364        let result = tool.execute(&tool_ctx(), json!({"command": "ls"})).await?;
365
366        assert!(!result.success);
367        assert!(result.output.contains("Permission denied"));
368        assert!(result.output.contains("execution is disabled"));
369        Ok(())
370    }
371
372    #[tokio::test]
373    async fn test_bash_denied_commands() -> anyhow::Result<()> {
374        let env = Arc::new(MockBashEnvironment::new());
375
376        // Client configures denied commands
377        let caps = AgentCapabilities::full_access()
378            .with_denied_commands(vec![r"rm\s+-rf\s+/".into(), r"^sudo\s".into()]);
379
380        let tool = create_test_tool(Arc::clone(&env), caps.clone());
381        let result = tool
382            .execute(&tool_ctx(), json!({"command": "rm -rf /"}))
383            .await?;
384        assert!(!result.success);
385        assert!(result.output.contains("Permission denied"));
386        assert!(result.output.contains("not allowed"));
387
388        let tool = create_test_tool(env, caps);
389        let result = tool
390            .execute(&tool_ctx(), json!({"command": "sudo apt-get install foo"}))
391            .await?;
392        assert!(!result.success);
393        assert!(result.output.contains("Permission denied"));
394        Ok(())
395    }
396
397    #[tokio::test]
398    async fn test_bash_allowed_commands_restriction() -> anyhow::Result<()> {
399        let env = Arc::new(MockBashEnvironment::new());
400        env.add_command("cargo build", "Compiling...", "", 0);
401
402        // Only allow cargo and git commands
403        let caps = AgentCapabilities::full_access()
404            .with_allowed_commands(vec![r"^cargo ".into(), r"^git ".into()]);
405
406        let tool = create_test_tool(Arc::clone(&env), caps.clone());
407
408        // cargo should be allowed
409        let result = tool
410            .execute(&tool_ctx(), json!({"command": "cargo build"}))
411            .await?;
412        assert!(result.success);
413
414        // ls should be denied
415        let tool = create_test_tool(env, caps);
416        let result = tool
417            .execute(&tool_ctx(), json!({"command": "ls -la"}))
418            .await?;
419        assert!(!result.success);
420        assert!(result.output.contains("not allowed"));
421        Ok(())
422    }
423
424    // ===================
425    // Edge Cases
426    // ===================
427
428    #[tokio::test]
429    async fn test_bash_empty_output() -> anyhow::Result<()> {
430        let env = Arc::new(MockBashEnvironment::new());
431        env.add_command("true", "", "", 0);
432
433        let tool = create_test_tool(env, AgentCapabilities::full_access());
434        let result = tool
435            .execute(&tool_ctx(), json!({"command": "true"}))
436            .await?;
437
438        assert!(result.success);
439        assert!(result.output.contains("(no output)"));
440        Ok(())
441    }
442
443    #[tokio::test]
444    async fn test_bash_custom_timeout() -> anyhow::Result<()> {
445        let env = Arc::new(MockBashEnvironment::new());
446        env.add_command("slow_cmd", "done", "", 0);
447
448        let tool = create_test_tool(env, AgentCapabilities::full_access());
449        let result = tool
450            .execute(
451                &tool_ctx(),
452                json!({"command": "slow_cmd", "timeout_ms": 5000}),
453            )
454            .await?;
455
456        assert!(result.success);
457        Ok(())
458    }
459
460    #[tokio::test]
461    async fn test_bash_tool_metadata() {
462        let env = Arc::new(MockBashEnvironment::new());
463        let tool = create_test_tool(env, AgentCapabilities::full_access());
464
465        assert_eq!(tool.name(), PrimitiveToolName::Bash);
466        assert_eq!(tool.tier(), ToolTier::Confirm);
467        assert!(tool.description().contains("Execute"));
468
469        let schema = tool.input_schema();
470        assert!(schema.get("properties").is_some());
471        assert!(schema["properties"].get("command").is_some());
472        assert!(schema["properties"].get("timeout_ms").is_some());
473    }
474
475    #[tokio::test]
476    async fn test_bash_invalid_input() -> anyhow::Result<()> {
477        let env = Arc::new(MockBashEnvironment::new());
478        let tool = create_test_tool(env, AgentCapabilities::full_access());
479
480        // Missing required command field
481        let result = tool.execute(&tool_ctx(), json!({})).await;
482        assert!(result.is_err());
483        Ok(())
484    }
485
486    #[tokio::test]
487    async fn test_bash_null_timeout_ms() -> anyhow::Result<()> {
488        let env = Arc::new(MockBashEnvironment::new());
489        env.add_command("echo hello", "hello", "", 0);
490        let tool = create_test_tool(env, AgentCapabilities::full_access());
491
492        // Model may send explicit null for optional fields — must not fail
493        let result = tool
494            .execute(
495                &tool_ctx(),
496                json!({"command": "echo hello", "timeout_ms": null}),
497            )
498            .await?;
499
500        assert!(result.success);
501        Ok(())
502    }
503
504    #[tokio::test]
505    async fn test_bash_missing_timeout_uses_default() -> anyhow::Result<()> {
506        let env = Arc::new(MockBashEnvironment::new());
507        env.add_command("echo hi", "hi", "", 0);
508        let tool = create_test_tool(env, AgentCapabilities::full_access());
509
510        // Omitted timeout_ms should use the default
511        let result = tool
512            .execute(&tool_ctx(), json!({"command": "echo hi"}))
513            .await?;
514
515        assert!(result.success);
516        Ok(())
517    }
518
519    #[tokio::test]
520    async fn test_bash_string_timeout_ms() -> anyhow::Result<()> {
521        let env = Arc::new(MockBashEnvironment::new());
522        env.add_command("echo timeout", "ok", "", 0);
523        let tool = create_test_tool(env, AgentCapabilities::full_access());
524
525        let result = tool
526            .execute(
527                &tool_ctx(),
528                json!({"command": "echo timeout", "timeout_ms": "5000"}),
529            )
530            .await?;
531
532        assert!(result.success);
533        Ok(())
534    }
535
536    #[tokio::test]
537    async fn test_bash_long_output_truncated() -> anyhow::Result<()> {
538        let env = Arc::new(MockBashEnvironment::new());
539        let long_output = "x".repeat(40_000);
540        env.add_command("long_output_cmd", &long_output, "", 0);
541
542        let tool = create_test_tool(env, AgentCapabilities::full_access());
543        let result = tool
544            .execute(&tool_ctx(), json!({"command": "long_output_cmd"}))
545            .await?;
546
547        assert!(result.success);
548        assert!(result.output.contains("output truncated"));
549        assert!(result.output.len() < 35_000); // Should be truncated
550        Ok(())
551    }
552
553    #[tokio::test]
554    async fn test_truncate_command_function() {
555        assert_eq!(truncate_command("short", 10), "short");
556        assert_eq!(
557            truncate_command("this is a longer command", 10),
558            "this is a ..."
559        );
560    }
561}