Skip to main content

agent_sdk/primitive_tools/
bash.rs

1use crate::{Environment, PrimitiveToolName, Tool, ToolContext, ToolResult, ToolTier};
2use anyhow::{Context, Result};
3use serde::Deserialize;
4use serde_json::{Value, json};
5use std::fmt::Write;
6use std::sync::Arc;
7
8use super::PrimitiveToolContext;
9
10/// Tool for executing shell commands
11pub struct BashTool<E: Environment> {
12    ctx: PrimitiveToolContext<E>,
13}
14
15impl<E: Environment> BashTool<E> {
16    #[must_use]
17    pub const fn new(environment: Arc<E>, capabilities: crate::AgentCapabilities) -> Self {
18        Self {
19            ctx: PrimitiveToolContext::new(environment, capabilities),
20        }
21    }
22}
23
24#[derive(Debug, Deserialize)]
25struct BashInput {
26    /// Command to execute
27    command: String,
28    /// Timeout in milliseconds (default: 120000 = 2 minutes).
29    /// Accepts either an integer or a numeric string such as "5000".
30    /// Uses `Option` so that explicit `null` from the model is handled
31    /// gracefully (falls back to the default rather than failing
32    /// deserialization).
33    #[serde(
34        default,
35        deserialize_with = "super::deserialize_optional_u64_from_string_or_int"
36    )]
37    timeout_ms: Option<u64>,
38}
39
40const DEFAULT_TIMEOUT_MS: u64 = 120_000; // 2 minutes
41
42/// Hard upper bound on the command timeout. Larger requests are clamped to this
43/// value (and the clamp is surfaced in the tool result).
44const MAX_TIMEOUT_MS: u64 = 600_000; // 10 minutes
45
46/// Maximum size (in bytes) of the combined stdout/stderr output before it is
47/// truncated. The truncation notice and exit-code line are accounted for so the
48/// final result stays within this bound.
49const MAX_OUTPUT_BYTES: usize = 30_000;
50
51impl<E: Environment + 'static, Ctx: Send + Sync + 'static> Tool<Ctx> for BashTool<E> {
52    type Name = PrimitiveToolName;
53
54    fn name(&self) -> PrimitiveToolName {
55        PrimitiveToolName::Bash
56    }
57
58    fn display_name(&self) -> &'static str {
59        "Run Command"
60    }
61
62    fn description(&self) -> &'static str {
63        "Execute a shell command. Use for git, npm, cargo, and other CLI tools. Returns stdout, stderr, and exit code."
64    }
65
66    fn tier(&self) -> ToolTier {
67        ToolTier::Confirm
68    }
69
70    fn input_schema(&self) -> Value {
71        json!({
72            "type": "object",
73            "properties": {
74                "command": {
75                    "type": "string",
76                    "description": "The shell command to execute"
77                },
78                "timeout_ms": {
79                    "anyOf": [
80                        {"type": "integer"},
81                        {"type": "string", "pattern": "^[0-9]+$"}
82                    ],
83                    "description": "Timeout in milliseconds. Accepts either an integer or a numeric string. Default: 120000 (2 minutes). Maximum: 600000 (10 minutes); larger values are clamped."
84                }
85            },
86            "required": ["command"]
87        })
88    }
89
90    async fn execute(&self, _ctx: &ToolContext<Ctx>, input: Value) -> Result<ToolResult> {
91        // Headroom reserved for the truncation notice appended below, so the
92        // final result stays within `MAX_OUTPUT_BYTES`.
93        const TRUNCATION_NOTICE_RESERVE: usize = 80;
94
95        let input: BashInput = BashInput::deserialize(&input)
96            .with_context(|| format!("Invalid input for bash tool: {input}"))?;
97
98        // Check exec capability and command allow/deny rules
99        if let Err(reason) = self.ctx.capabilities.check_exec(&input.command) {
100            return Ok(ToolResult::error(format!(
101                "Permission denied: cannot execute '{}': {reason}",
102                truncate_command(&input.command, 100)
103            )));
104        }
105
106        // Validate timeout. Requests above the maximum are clamped; the clamp is
107        // surfaced in the result so the model knows its value was reduced.
108        let requested_timeout_ms = input.timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS);
109        let timeout_ms = requested_timeout_ms.min(MAX_TIMEOUT_MS);
110
111        // Execute command
112        let result = self
113            .ctx
114            .environment
115            .exec(&input.command, Some(timeout_ms))
116            .await
117            .context("Failed to execute command")?;
118
119        // Format output
120        let mut output = String::new();
121
122        if !result.stdout.is_empty() {
123            output.push_str(&result.stdout);
124        }
125
126        if !result.stderr.is_empty() {
127            if !output.is_empty() {
128                output.push_str("\n\n--- stderr ---\n");
129            }
130            output.push_str(&result.stderr);
131        }
132
133        if output.is_empty() {
134            output = "(no output)".to_string();
135        }
136
137        // Truncate if too long (UTF-8 safe). Reserve room for the exit-code line
138        // and the truncation notice so the final result stays within the cap.
139        let exit_suffix = format!("\n\nExit code: {}", result.exit_code);
140        let content_budget =
141            MAX_OUTPUT_BYTES.saturating_sub(exit_suffix.len() + TRUNCATION_NOTICE_RESERVE);
142        if output.len() > content_budget {
143            let original_len = output.len();
144            let truncated = super::truncate_str(&output, content_budget);
145            output = format!("{truncated}...\n\n(output truncated, {original_len} total bytes)");
146        }
147
148        // Include exit code in output
149        output.push_str(&exit_suffix);
150
151        // Tell the model when its requested timeout was reduced to the maximum.
152        if requested_timeout_ms > MAX_TIMEOUT_MS {
153            let _ = write!(
154                output,
155                "\n\n(requested timeout {requested_timeout_ms}ms exceeds the maximum of {MAX_TIMEOUT_MS}ms; clamped to {MAX_TIMEOUT_MS}ms)"
156            );
157        }
158
159        let tool_result = if result.success() {
160            ToolResult::success(output)
161        } else {
162            ToolResult::error(output)
163        };
164
165        Ok(tool_result)
166    }
167}
168
169fn truncate_command(s: &str, max_len: usize) -> String {
170    if s.len() <= max_len {
171        s.to_string()
172    } else {
173        format!("{}...", super::truncate_str(s, max_len))
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use crate::AgentCapabilities;
181    use crate::environment::ExecResult;
182    use async_trait::async_trait;
183    use std::collections::HashMap;
184    use std::sync::RwLock;
185
186    // Mock environment for testing bash execution
187    struct MockBashEnvironment {
188        root: String,
189        // Map of command to (stdout, stderr, exit_code)
190        commands: RwLock<HashMap<String, (String, String, i32)>>,
191        // Records the timeout forwarded to the most recent `exec` call so tests
192        // can assert default/parsed/clamped values are passed through.
193        last_timeout_ms: RwLock<Option<u64>>,
194    }
195
196    impl MockBashEnvironment {
197        fn new() -> Self {
198            Self {
199                root: "/workspace".to_string(),
200                commands: RwLock::new(HashMap::new()),
201                last_timeout_ms: RwLock::new(None),
202            }
203        }
204
205        fn add_command(&self, cmd: &str, stdout: &str, stderr: &str, exit_code: i32) -> Result<()> {
206            self.commands.write().ok().context("lock poisoned")?.insert(
207                cmd.to_string(),
208                (stdout.to_string(), stderr.to_string(), exit_code),
209            );
210            Ok(())
211        }
212
213        fn recorded_timeout(&self) -> Result<Option<u64>> {
214            Ok(*self.last_timeout_ms.read().ok().context("lock poisoned")?)
215        }
216    }
217
218    #[async_trait]
219    impl crate::Environment for MockBashEnvironment {
220        async fn read_file(&self, _path: &str) -> Result<String> {
221            Ok(String::new())
222        }
223
224        async fn read_file_bytes(&self, _path: &str) -> Result<Vec<u8>> {
225            Ok(vec![])
226        }
227
228        async fn write_file(&self, _path: &str, _content: &str) -> Result<()> {
229            Ok(())
230        }
231
232        async fn write_file_bytes(&self, _path: &str, _content: &[u8]) -> Result<()> {
233            Ok(())
234        }
235
236        async fn list_dir(&self, _path: &str) -> Result<Vec<crate::environment::FileEntry>> {
237            Ok(vec![])
238        }
239
240        async fn exists(&self, _path: &str) -> Result<bool> {
241            Ok(false)
242        }
243
244        async fn is_dir(&self, _path: &str) -> Result<bool> {
245            Ok(false)
246        }
247
248        async fn is_file(&self, _path: &str) -> Result<bool> {
249            Ok(false)
250        }
251
252        async fn create_dir(&self, _path: &str) -> Result<()> {
253            Ok(())
254        }
255
256        async fn delete_file(&self, _path: &str) -> Result<()> {
257            Ok(())
258        }
259
260        async fn delete_dir(&self, _path: &str, _recursive: bool) -> Result<()> {
261            Ok(())
262        }
263
264        async fn grep(
265            &self,
266            _pattern: &str,
267            _path: &str,
268            _recursive: bool,
269        ) -> Result<Vec<crate::environment::GrepMatch>> {
270            Ok(vec![])
271        }
272
273        async fn glob(&self, _pattern: &str) -> Result<Vec<String>> {
274            Ok(vec![])
275        }
276
277        async fn exec(&self, command: &str, timeout_ms: Option<u64>) -> Result<ExecResult> {
278            *self.last_timeout_ms.write().ok().context("lock poisoned")? = timeout_ms;
279            let commands = self.commands.read().ok().context("lock poisoned")?;
280            if let Some((stdout, stderr, exit_code)) = commands.get(command) {
281                Ok(ExecResult {
282                    stdout: stdout.clone(),
283                    stderr: stderr.clone(),
284                    exit_code: *exit_code,
285                })
286            } else {
287                // Default: command not found
288                Ok(ExecResult {
289                    stdout: String::new(),
290                    stderr: format!("command not found: {command}"),
291                    exit_code: 127,
292                })
293            }
294        }
295
296        fn root(&self) -> &str {
297            &self.root
298        }
299    }
300
301    fn create_test_tool(
302        env: Arc<MockBashEnvironment>,
303        capabilities: AgentCapabilities,
304    ) -> BashTool<MockBashEnvironment> {
305        BashTool::new(env, capabilities)
306    }
307
308    fn tool_ctx() -> ToolContext<()> {
309        ToolContext::new(())
310    }
311
312    // ===================
313    // Unit Tests
314    // ===================
315
316    #[tokio::test]
317    async fn test_bash_simple_command() -> anyhow::Result<()> {
318        let env = Arc::new(MockBashEnvironment::new());
319        env.add_command("echo hello", "hello\n", "", 0)?;
320
321        let tool = create_test_tool(env, AgentCapabilities::full_access());
322        let result = tool
323            .execute(&tool_ctx(), json!({"command": "echo hello"}))
324            .await?;
325
326        assert!(result.success);
327        assert!(result.output.contains("hello"));
328        assert!(result.output.contains("Exit code: 0"));
329        Ok(())
330    }
331
332    #[tokio::test]
333    async fn test_bash_command_with_stderr() -> anyhow::Result<()> {
334        let env = Arc::new(MockBashEnvironment::new());
335        env.add_command("cmd", "stdout output", "stderr output", 0)?;
336
337        let tool = create_test_tool(env, AgentCapabilities::full_access());
338        let result = tool.execute(&tool_ctx(), json!({"command": "cmd"})).await?;
339
340        assert!(result.success);
341        assert!(result.output.contains("stdout output"));
342        assert!(result.output.contains("stderr output"));
343        Ok(())
344    }
345
346    #[tokio::test]
347    async fn test_bash_command_nonzero_exit() -> anyhow::Result<()> {
348        let env = Arc::new(MockBashEnvironment::new());
349        env.add_command("failing_cmd", "", "error occurred", 1)?;
350
351        let tool = create_test_tool(env, AgentCapabilities::full_access());
352        let result = tool
353            .execute(&tool_ctx(), json!({"command": "failing_cmd"}))
354            .await?;
355
356        assert!(!result.success);
357        assert!(result.output.contains("Exit code: 1"));
358        Ok(())
359    }
360
361    #[tokio::test]
362    async fn test_bash_command_not_found() -> anyhow::Result<()> {
363        let env = Arc::new(MockBashEnvironment::new());
364
365        let tool = create_test_tool(env, AgentCapabilities::full_access());
366        let result = tool
367            .execute(&tool_ctx(), json!({"command": "nonexistent_cmd"}))
368            .await?;
369
370        assert!(!result.success);
371        assert!(result.output.contains("Exit code: 127"));
372        Ok(())
373    }
374
375    // ===================
376    // Integration Tests
377    // ===================
378
379    #[tokio::test]
380    async fn test_bash_exec_disabled() -> anyhow::Result<()> {
381        let env = Arc::new(MockBashEnvironment::new());
382
383        // Read-only capabilities (exec disabled)
384        let caps = AgentCapabilities::read_only();
385
386        let tool = create_test_tool(env, caps);
387        let result = tool.execute(&tool_ctx(), json!({"command": "ls"})).await?;
388
389        assert!(!result.success);
390        assert!(result.output.contains("Permission denied"));
391        assert!(result.output.contains("execution is disabled"));
392        Ok(())
393    }
394
395    #[tokio::test]
396    async fn test_bash_denied_commands() -> anyhow::Result<()> {
397        let env = Arc::new(MockBashEnvironment::new());
398
399        // Client configures denied commands
400        let caps = AgentCapabilities::full_access()
401            .with_denied_commands(vec![r"rm\s+-rf\s+/".into(), r"^sudo\s".into()]);
402
403        let tool = create_test_tool(Arc::clone(&env), caps.clone());
404        let result = tool
405            .execute(&tool_ctx(), json!({"command": "rm -rf /"}))
406            .await?;
407        assert!(!result.success);
408        assert!(result.output.contains("Permission denied"));
409        assert!(result.output.contains("denied pattern"));
410
411        let tool = create_test_tool(env, caps);
412        let result = tool
413            .execute(&tool_ctx(), json!({"command": "sudo apt-get install foo"}))
414            .await?;
415        assert!(!result.success);
416        assert!(result.output.contains("Permission denied"));
417        Ok(())
418    }
419
420    #[tokio::test]
421    async fn test_bash_allowed_commands_restriction() -> anyhow::Result<()> {
422        let env = Arc::new(MockBashEnvironment::new());
423        env.add_command("cargo build", "Compiling...", "", 0)?;
424
425        // Only allow cargo and git commands
426        let caps = AgentCapabilities::full_access()
427            .with_allowed_commands(vec![r"^cargo ".into(), r"^git ".into()]);
428
429        let tool = create_test_tool(Arc::clone(&env), caps.clone());
430
431        // cargo should be allowed
432        let result = tool
433            .execute(&tool_ctx(), json!({"command": "cargo build"}))
434            .await?;
435        assert!(result.success);
436
437        // ls should be denied
438        let tool = create_test_tool(env, caps);
439        let result = tool
440            .execute(&tool_ctx(), json!({"command": "ls -la"}))
441            .await?;
442        assert!(!result.success);
443        assert!(result.output.contains("not in allowed list"));
444        Ok(())
445    }
446
447    // ===================
448    // Edge Cases
449    // ===================
450
451    #[tokio::test]
452    async fn test_bash_empty_output() -> anyhow::Result<()> {
453        let env = Arc::new(MockBashEnvironment::new());
454        env.add_command("true", "", "", 0)?;
455
456        let tool = create_test_tool(env, AgentCapabilities::full_access());
457        let result = tool
458            .execute(&tool_ctx(), json!({"command": "true"}))
459            .await?;
460
461        assert!(result.success);
462        assert!(result.output.contains("(no output)"));
463        Ok(())
464    }
465
466    #[tokio::test]
467    async fn test_bash_custom_timeout() -> anyhow::Result<()> {
468        let env = Arc::new(MockBashEnvironment::new());
469        env.add_command("slow_cmd", "done", "", 0)?;
470
471        let tool = create_test_tool(env, AgentCapabilities::full_access());
472        let result = tool
473            .execute(
474                &tool_ctx(),
475                json!({"command": "slow_cmd", "timeout_ms": 5000}),
476            )
477            .await?;
478
479        assert!(result.success);
480        Ok(())
481    }
482
483    #[tokio::test]
484    async fn test_bash_tool_metadata() {
485        let env = Arc::new(MockBashEnvironment::new());
486        let tool = create_test_tool(env, AgentCapabilities::full_access());
487
488        assert_eq!(Tool::<()>::name(&tool), PrimitiveToolName::Bash);
489        assert_eq!(Tool::<()>::tier(&tool), ToolTier::Confirm);
490        assert!(Tool::<()>::description(&tool).contains("Execute"));
491
492        let schema = Tool::<()>::input_schema(&tool);
493        assert!(schema.get("properties").is_some());
494        assert!(schema["properties"].get("command").is_some());
495        assert!(schema["properties"].get("timeout_ms").is_some());
496    }
497
498    #[tokio::test]
499    async fn test_bash_invalid_input() -> anyhow::Result<()> {
500        let env = Arc::new(MockBashEnvironment::new());
501        let tool = create_test_tool(env, AgentCapabilities::full_access());
502
503        // Missing required command field
504        let result = tool.execute(&tool_ctx(), json!({})).await;
505        assert!(result.is_err());
506        Ok(())
507    }
508
509    #[tokio::test]
510    async fn test_bash_null_timeout_ms() -> anyhow::Result<()> {
511        let env = Arc::new(MockBashEnvironment::new());
512        env.add_command("echo hello", "hello", "", 0)?;
513        let tool = create_test_tool(env, AgentCapabilities::full_access());
514
515        // Model may send explicit null for optional fields — must not fail
516        let result = tool
517            .execute(
518                &tool_ctx(),
519                json!({"command": "echo hello", "timeout_ms": null}),
520            )
521            .await?;
522
523        assert!(result.success);
524        Ok(())
525    }
526
527    #[tokio::test]
528    async fn test_bash_missing_timeout_uses_default() -> anyhow::Result<()> {
529        let env = Arc::new(MockBashEnvironment::new());
530        env.add_command("echo hi", "hi", "", 0)?;
531        let tool = create_test_tool(env, AgentCapabilities::full_access());
532
533        // Omitted timeout_ms should use the default
534        let result = tool
535            .execute(&tool_ctx(), json!({"command": "echo hi"}))
536            .await?;
537
538        assert!(result.success);
539        Ok(())
540    }
541
542    #[tokio::test]
543    async fn test_bash_string_timeout_ms() -> anyhow::Result<()> {
544        let env = Arc::new(MockBashEnvironment::new());
545        env.add_command("echo timeout", "ok", "", 0)?;
546        let tool = create_test_tool(env, AgentCapabilities::full_access());
547
548        let result = tool
549            .execute(
550                &tool_ctx(),
551                json!({"command": "echo timeout", "timeout_ms": "5000"}),
552            )
553            .await?;
554
555        assert!(result.success);
556        Ok(())
557    }
558
559    #[tokio::test]
560    async fn test_bash_long_output_truncated() -> anyhow::Result<()> {
561        let env = Arc::new(MockBashEnvironment::new());
562        let long_output = "x".repeat(40_000);
563        env.add_command("long_output_cmd", &long_output, "", 0)?;
564
565        let tool = create_test_tool(env, AgentCapabilities::full_access());
566        let result = tool
567            .execute(&tool_ctx(), json!({"command": "long_output_cmd"}))
568            .await?;
569
570        assert!(result.success);
571        assert!(result.output.contains("output truncated"));
572        assert!(result.output.contains("total bytes"));
573        // The truncation notice and exit-code line are accounted for, so the
574        // final result stays within the cap.
575        assert!(result.output.len() <= MAX_OUTPUT_BYTES);
576        Ok(())
577    }
578
579    #[tokio::test]
580    async fn test_bash_default_timeout_forwarded() -> anyhow::Result<()> {
581        let env = Arc::new(MockBashEnvironment::new());
582        env.add_command("echo hi", "hi", "", 0)?;
583
584        let tool = create_test_tool(Arc::clone(&env), AgentCapabilities::full_access());
585        tool.execute(&tool_ctx(), json!({"command": "echo hi"}))
586            .await?;
587
588        // Omitted timeout forwards the default.
589        assert_eq!(env.recorded_timeout()?, Some(DEFAULT_TIMEOUT_MS));
590        Ok(())
591    }
592
593    #[tokio::test]
594    async fn test_bash_explicit_timeout_forwarded() -> anyhow::Result<()> {
595        let env = Arc::new(MockBashEnvironment::new());
596        env.add_command("echo hi", "hi", "", 0)?;
597
598        let tool = create_test_tool(Arc::clone(&env), AgentCapabilities::full_access());
599        tool.execute(
600            &tool_ctx(),
601            json!({"command": "echo hi", "timeout_ms": 5000}),
602        )
603        .await?;
604
605        assert_eq!(env.recorded_timeout()?, Some(5000));
606        Ok(())
607    }
608
609    #[tokio::test]
610    async fn test_bash_string_timeout_forwarded() -> anyhow::Result<()> {
611        let env = Arc::new(MockBashEnvironment::new());
612        env.add_command("echo hi", "hi", "", 0)?;
613
614        let tool = create_test_tool(Arc::clone(&env), AgentCapabilities::full_access());
615        tool.execute(
616            &tool_ctx(),
617            json!({"command": "echo hi", "timeout_ms": "5000"}),
618        )
619        .await?;
620
621        // Numeric strings are parsed before being forwarded.
622        assert_eq!(env.recorded_timeout()?, Some(5000));
623        Ok(())
624    }
625
626    #[tokio::test]
627    async fn test_bash_timeout_clamped_to_max() -> anyhow::Result<()> {
628        let env = Arc::new(MockBashEnvironment::new());
629        env.add_command("echo hi", "hi", "", 0)?;
630
631        let tool = create_test_tool(Arc::clone(&env), AgentCapabilities::full_access());
632        let result = tool
633            .execute(
634                &tool_ctx(),
635                json!({"command": "echo hi", "timeout_ms": 999_999_999_u64}),
636            )
637            .await?;
638
639        // Oversized requests are clamped to the maximum and the clamp is surfaced.
640        assert_eq!(env.recorded_timeout()?, Some(MAX_TIMEOUT_MS));
641        assert!(result.output.contains("clamped"));
642        Ok(())
643    }
644
645    #[tokio::test]
646    async fn test_truncate_command_function() {
647        assert_eq!(truncate_command("short", 10), "short");
648        assert_eq!(
649            truncate_command("this is a longer command", 10),
650            "this is a ..."
651        );
652    }
653}