Skip to main content

agent_sdk/primitive_tools/
bash.rs

1use crate::{Environment, PrimitiveToolName, Tool, ToolContext, ToolResult, ToolTier};
2use anyhow::{Context, Result};
3use serde::Deserialize;
4use serde_json::{Value, json};
5use std::fmt::Write;
6use std::sync::Arc;
7
8use super::PrimitiveToolContext;
9
10/// Tool for executing shell commands
11pub struct BashTool<E: Environment> {
12    ctx: PrimitiveToolContext<E>,
13}
14
15impl<E: Environment> BashTool<E> {
16    #[must_use]
17    pub const fn new(environment: Arc<E>, capabilities: crate::AgentCapabilities) -> Self {
18        Self {
19            ctx: PrimitiveToolContext::new(environment, capabilities),
20        }
21    }
22}
23
24#[derive(Debug, Deserialize)]
25struct BashInput {
26    /// Command to execute
27    command: String,
28    /// Timeout in milliseconds (default: 120000 = 2 minutes).
29    /// Accepts either an integer or a numeric string such as "5000".
30    /// Uses `Option` so that explicit `null` from the model is handled
31    /// gracefully (falls back to the default rather than failing
32    /// deserialization).
33    #[serde(
34        default,
35        deserialize_with = "super::deserialize_optional_u64_from_string_or_int"
36    )]
37    timeout_ms: Option<u64>,
38}
39
40const DEFAULT_TIMEOUT_MS: u64 = 120_000; // 2 minutes
41
42impl<E: Environment + 'static> Tool<()> for BashTool<E> {
43    type Name = PrimitiveToolName;
44
45    fn name(&self) -> PrimitiveToolName {
46        PrimitiveToolName::Bash
47    }
48
49    fn display_name(&self) -> &'static str {
50        "Run Command"
51    }
52
53    fn description(&self) -> &'static str {
54        "Execute a shell command. Use for git, npm, cargo, and other CLI tools. Returns stdout, stderr, and exit code."
55    }
56
57    fn tier(&self) -> ToolTier {
58        ToolTier::Confirm
59    }
60
61    fn input_schema(&self) -> Value {
62        json!({
63            "type": "object",
64            "properties": {
65                "command": {
66                    "type": "string",
67                    "description": "The shell command to execute"
68                },
69                "timeout_ms": {
70                    "anyOf": [
71                        {"type": "integer"},
72                        {"type": "string", "pattern": "^[0-9]+$"}
73                    ],
74                    "description": "Timeout in milliseconds. Accepts either an integer or a numeric string. Default: 120000 (2 minutes)"
75                }
76            },
77            "required": ["command"]
78        })
79    }
80
81    async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
82        let input: BashInput = serde_json::from_value(input.clone())
83            .with_context(|| format!("Invalid input for bash tool: {input}"))?;
84
85        // Check exec capability and command allow/deny rules
86        if let Err(reason) = self.ctx.capabilities.check_exec(&input.command) {
87            return Ok(ToolResult::error(format!(
88                "Permission denied: cannot execute '{}': {reason}",
89                truncate_command(&input.command, 100)
90            )));
91        }
92
93        // Validate timeout
94        let timeout_ms = input.timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS).min(600_000); // Max 10 minutes
95
96        // Execute command
97        let result = self
98            .ctx
99            .environment
100            .exec(&input.command, Some(timeout_ms))
101            .await
102            .context("Failed to execute command")?;
103
104        // Format output
105        let mut output = String::new();
106
107        if !result.stdout.is_empty() {
108            output.push_str(&result.stdout);
109        }
110
111        if !result.stderr.is_empty() {
112            if !output.is_empty() {
113                output.push_str("\n\n--- stderr ---\n");
114            }
115            output.push_str(&result.stderr);
116        }
117
118        if output.is_empty() {
119            output = "(no output)".to_string();
120        }
121
122        // Truncate if too long
123        let max_output_len = 30_000;
124        if output.len() > max_output_len {
125            output = format!(
126                "{}...\n\n(output truncated, {} total characters)",
127                &output[..max_output_len],
128                output.len()
129            );
130        }
131
132        // Include exit code in output
133        let _ = write!(output, "\n\nExit code: {}", result.exit_code);
134
135        let tool_result = if result.success() {
136            ToolResult::success(output)
137        } else {
138            ToolResult::error(output)
139        };
140
141        Ok(tool_result)
142    }
143}
144
145fn truncate_command(s: &str, max_len: usize) -> String {
146    if s.len() <= max_len {
147        s.to_string()
148    } else {
149        format!("{}...", &s[..max_len])
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use crate::AgentCapabilities;
157    use crate::environment::ExecResult;
158    use async_trait::async_trait;
159    use std::collections::HashMap;
160    use std::sync::RwLock;
161
162    // Mock environment for testing bash execution
163    struct MockBashEnvironment {
164        root: String,
165        // Map of command to (stdout, stderr, exit_code)
166        commands: RwLock<HashMap<String, (String, String, i32)>>,
167    }
168
169    impl MockBashEnvironment {
170        fn new() -> Self {
171            Self {
172                root: "/workspace".to_string(),
173                commands: RwLock::new(HashMap::new()),
174            }
175        }
176
177        fn add_command(&self, cmd: &str, stdout: &str, stderr: &str, exit_code: i32) {
178            self.commands.write().unwrap().insert(
179                cmd.to_string(),
180                (stdout.to_string(), stderr.to_string(), exit_code),
181            );
182        }
183    }
184
185    #[async_trait]
186    impl crate::Environment for MockBashEnvironment {
187        async fn read_file(&self, _path: &str) -> Result<String> {
188            Ok(String::new())
189        }
190
191        async fn read_file_bytes(&self, _path: &str) -> Result<Vec<u8>> {
192            Ok(vec![])
193        }
194
195        async fn write_file(&self, _path: &str, _content: &str) -> Result<()> {
196            Ok(())
197        }
198
199        async fn write_file_bytes(&self, _path: &str, _content: &[u8]) -> Result<()> {
200            Ok(())
201        }
202
203        async fn list_dir(&self, _path: &str) -> Result<Vec<crate::environment::FileEntry>> {
204            Ok(vec![])
205        }
206
207        async fn exists(&self, _path: &str) -> Result<bool> {
208            Ok(false)
209        }
210
211        async fn is_dir(&self, _path: &str) -> Result<bool> {
212            Ok(false)
213        }
214
215        async fn is_file(&self, _path: &str) -> Result<bool> {
216            Ok(false)
217        }
218
219        async fn create_dir(&self, _path: &str) -> Result<()> {
220            Ok(())
221        }
222
223        async fn delete_file(&self, _path: &str) -> Result<()> {
224            Ok(())
225        }
226
227        async fn delete_dir(&self, _path: &str, _recursive: bool) -> Result<()> {
228            Ok(())
229        }
230
231        async fn grep(
232            &self,
233            _pattern: &str,
234            _path: &str,
235            _recursive: bool,
236        ) -> Result<Vec<crate::environment::GrepMatch>> {
237            Ok(vec![])
238        }
239
240        async fn glob(&self, _pattern: &str) -> Result<Vec<String>> {
241            Ok(vec![])
242        }
243
244        async fn exec(&self, command: &str, _timeout_ms: Option<u64>) -> Result<ExecResult> {
245            let commands = self.commands.read().unwrap();
246            if let Some((stdout, stderr, exit_code)) = commands.get(command) {
247                Ok(ExecResult {
248                    stdout: stdout.clone(),
249                    stderr: stderr.clone(),
250                    exit_code: *exit_code,
251                })
252            } else {
253                // Default: command not found
254                Ok(ExecResult {
255                    stdout: String::new(),
256                    stderr: format!("command not found: {command}"),
257                    exit_code: 127,
258                })
259            }
260        }
261
262        fn root(&self) -> &str {
263            &self.root
264        }
265    }
266
267    fn create_test_tool(
268        env: Arc<MockBashEnvironment>,
269        capabilities: AgentCapabilities,
270    ) -> BashTool<MockBashEnvironment> {
271        BashTool::new(env, capabilities)
272    }
273
274    fn tool_ctx() -> ToolContext<()> {
275        ToolContext::new(())
276    }
277
278    // ===================
279    // Unit Tests
280    // ===================
281
282    #[tokio::test]
283    async fn test_bash_simple_command() -> anyhow::Result<()> {
284        let env = Arc::new(MockBashEnvironment::new());
285        env.add_command("echo hello", "hello\n", "", 0);
286
287        let tool = create_test_tool(env, AgentCapabilities::full_access());
288        let result = tool
289            .execute(&tool_ctx(), json!({"command": "echo hello"}))
290            .await?;
291
292        assert!(result.success);
293        assert!(result.output.contains("hello"));
294        assert!(result.output.contains("Exit code: 0"));
295        Ok(())
296    }
297
298    #[tokio::test]
299    async fn test_bash_command_with_stderr() -> anyhow::Result<()> {
300        let env = Arc::new(MockBashEnvironment::new());
301        env.add_command("cmd", "stdout output", "stderr output", 0);
302
303        let tool = create_test_tool(env, AgentCapabilities::full_access());
304        let result = tool.execute(&tool_ctx(), json!({"command": "cmd"})).await?;
305
306        assert!(result.success);
307        assert!(result.output.contains("stdout output"));
308        assert!(result.output.contains("stderr output"));
309        Ok(())
310    }
311
312    #[tokio::test]
313    async fn test_bash_command_nonzero_exit() -> anyhow::Result<()> {
314        let env = Arc::new(MockBashEnvironment::new());
315        env.add_command("failing_cmd", "", "error occurred", 1);
316
317        let tool = create_test_tool(env, AgentCapabilities::full_access());
318        let result = tool
319            .execute(&tool_ctx(), json!({"command": "failing_cmd"}))
320            .await?;
321
322        assert!(!result.success);
323        assert!(result.output.contains("Exit code: 1"));
324        Ok(())
325    }
326
327    #[tokio::test]
328    async fn test_bash_command_not_found() -> anyhow::Result<()> {
329        let env = Arc::new(MockBashEnvironment::new());
330
331        let tool = create_test_tool(env, AgentCapabilities::full_access());
332        let result = tool
333            .execute(&tool_ctx(), json!({"command": "nonexistent_cmd"}))
334            .await?;
335
336        assert!(!result.success);
337        assert!(result.output.contains("Exit code: 127"));
338        Ok(())
339    }
340
341    // ===================
342    // Integration Tests
343    // ===================
344
345    #[tokio::test]
346    async fn test_bash_exec_disabled() -> anyhow::Result<()> {
347        let env = Arc::new(MockBashEnvironment::new());
348
349        // Read-only capabilities (exec disabled)
350        let caps = AgentCapabilities::read_only();
351
352        let tool = create_test_tool(env, caps);
353        let result = tool.execute(&tool_ctx(), json!({"command": "ls"})).await?;
354
355        assert!(!result.success);
356        assert!(result.output.contains("Permission denied"));
357        assert!(result.output.contains("execution is disabled"));
358        Ok(())
359    }
360
361    #[tokio::test]
362    async fn test_bash_denied_commands() -> anyhow::Result<()> {
363        let env = Arc::new(MockBashEnvironment::new());
364
365        // Client configures denied commands
366        let caps = AgentCapabilities::full_access()
367            .with_denied_commands(vec![r"rm\s+-rf\s+/".into(), r"^sudo\s".into()]);
368
369        let tool = create_test_tool(Arc::clone(&env), caps.clone());
370        let result = tool
371            .execute(&tool_ctx(), json!({"command": "rm -rf /"}))
372            .await?;
373        assert!(!result.success);
374        assert!(result.output.contains("Permission denied"));
375        assert!(result.output.contains("denied pattern"));
376
377        let tool = create_test_tool(env, caps);
378        let result = tool
379            .execute(&tool_ctx(), json!({"command": "sudo apt-get install foo"}))
380            .await?;
381        assert!(!result.success);
382        assert!(result.output.contains("Permission denied"));
383        Ok(())
384    }
385
386    #[tokio::test]
387    async fn test_bash_allowed_commands_restriction() -> anyhow::Result<()> {
388        let env = Arc::new(MockBashEnvironment::new());
389        env.add_command("cargo build", "Compiling...", "", 0);
390
391        // Only allow cargo and git commands
392        let caps = AgentCapabilities::full_access()
393            .with_allowed_commands(vec![r"^cargo ".into(), r"^git ".into()]);
394
395        let tool = create_test_tool(Arc::clone(&env), caps.clone());
396
397        // cargo should be allowed
398        let result = tool
399            .execute(&tool_ctx(), json!({"command": "cargo build"}))
400            .await?;
401        assert!(result.success);
402
403        // ls should be denied
404        let tool = create_test_tool(env, caps);
405        let result = tool
406            .execute(&tool_ctx(), json!({"command": "ls -la"}))
407            .await?;
408        assert!(!result.success);
409        assert!(result.output.contains("not in allowed list"));
410        Ok(())
411    }
412
413    // ===================
414    // Edge Cases
415    // ===================
416
417    #[tokio::test]
418    async fn test_bash_empty_output() -> anyhow::Result<()> {
419        let env = Arc::new(MockBashEnvironment::new());
420        env.add_command("true", "", "", 0);
421
422        let tool = create_test_tool(env, AgentCapabilities::full_access());
423        let result = tool
424            .execute(&tool_ctx(), json!({"command": "true"}))
425            .await?;
426
427        assert!(result.success);
428        assert!(result.output.contains("(no output)"));
429        Ok(())
430    }
431
432    #[tokio::test]
433    async fn test_bash_custom_timeout() -> anyhow::Result<()> {
434        let env = Arc::new(MockBashEnvironment::new());
435        env.add_command("slow_cmd", "done", "", 0);
436
437        let tool = create_test_tool(env, AgentCapabilities::full_access());
438        let result = tool
439            .execute(
440                &tool_ctx(),
441                json!({"command": "slow_cmd", "timeout_ms": 5000}),
442            )
443            .await?;
444
445        assert!(result.success);
446        Ok(())
447    }
448
449    #[tokio::test]
450    async fn test_bash_tool_metadata() {
451        let env = Arc::new(MockBashEnvironment::new());
452        let tool = create_test_tool(env, AgentCapabilities::full_access());
453
454        assert_eq!(tool.name(), PrimitiveToolName::Bash);
455        assert_eq!(tool.tier(), ToolTier::Confirm);
456        assert!(tool.description().contains("Execute"));
457
458        let schema = tool.input_schema();
459        assert!(schema.get("properties").is_some());
460        assert!(schema["properties"].get("command").is_some());
461        assert!(schema["properties"].get("timeout_ms").is_some());
462    }
463
464    #[tokio::test]
465    async fn test_bash_invalid_input() -> anyhow::Result<()> {
466        let env = Arc::new(MockBashEnvironment::new());
467        let tool = create_test_tool(env, AgentCapabilities::full_access());
468
469        // Missing required command field
470        let result = tool.execute(&tool_ctx(), json!({})).await;
471        assert!(result.is_err());
472        Ok(())
473    }
474
475    #[tokio::test]
476    async fn test_bash_null_timeout_ms() -> anyhow::Result<()> {
477        let env = Arc::new(MockBashEnvironment::new());
478        env.add_command("echo hello", "hello", "", 0);
479        let tool = create_test_tool(env, AgentCapabilities::full_access());
480
481        // Model may send explicit null for optional fields — must not fail
482        let result = tool
483            .execute(
484                &tool_ctx(),
485                json!({"command": "echo hello", "timeout_ms": null}),
486            )
487            .await?;
488
489        assert!(result.success);
490        Ok(())
491    }
492
493    #[tokio::test]
494    async fn test_bash_missing_timeout_uses_default() -> anyhow::Result<()> {
495        let env = Arc::new(MockBashEnvironment::new());
496        env.add_command("echo hi", "hi", "", 0);
497        let tool = create_test_tool(env, AgentCapabilities::full_access());
498
499        // Omitted timeout_ms should use the default
500        let result = tool
501            .execute(&tool_ctx(), json!({"command": "echo hi"}))
502            .await?;
503
504        assert!(result.success);
505        Ok(())
506    }
507
508    #[tokio::test]
509    async fn test_bash_string_timeout_ms() -> anyhow::Result<()> {
510        let env = Arc::new(MockBashEnvironment::new());
511        env.add_command("echo timeout", "ok", "", 0);
512        let tool = create_test_tool(env, AgentCapabilities::full_access());
513
514        let result = tool
515            .execute(
516                &tool_ctx(),
517                json!({"command": "echo timeout", "timeout_ms": "5000"}),
518            )
519            .await?;
520
521        assert!(result.success);
522        Ok(())
523    }
524
525    #[tokio::test]
526    async fn test_bash_long_output_truncated() -> anyhow::Result<()> {
527        let env = Arc::new(MockBashEnvironment::new());
528        let long_output = "x".repeat(40_000);
529        env.add_command("long_output_cmd", &long_output, "", 0);
530
531        let tool = create_test_tool(env, AgentCapabilities::full_access());
532        let result = tool
533            .execute(&tool_ctx(), json!({"command": "long_output_cmd"}))
534            .await?;
535
536        assert!(result.success);
537        assert!(result.output.contains("output truncated"));
538        assert!(result.output.len() < 35_000); // Should be truncated
539        Ok(())
540    }
541
542    #[tokio::test]
543    async fn test_truncate_command_function() {
544        assert_eq!(truncate_command("short", 10), "short");
545        assert_eq!(
546            truncate_command("this is a longer command", 10),
547            "this is a ..."
548        );
549    }
550}