agent_sdk/primitive_tools/
bash.rs

1use crate::{Environment, Tool, ToolContext, ToolResult, ToolTier};
2use anyhow::{Context, Result};
3use async_trait::async_trait;
4use serde::Deserialize;
5use serde_json::{Value, json};
6use std::fmt::Write;
7use std::sync::Arc;
8
9use super::PrimitiveToolContext;
10
11/// Tool for executing shell commands
12pub struct BashTool<E: Environment> {
13    ctx: PrimitiveToolContext<E>,
14}
15
16impl<E: Environment> BashTool<E> {
17    #[must_use]
18    pub const fn new(environment: Arc<E>, capabilities: crate::AgentCapabilities) -> Self {
19        Self {
20            ctx: PrimitiveToolContext::new(environment, capabilities),
21        }
22    }
23}
24
25#[derive(Debug, Deserialize)]
26struct BashInput {
27    /// Command to execute
28    command: String,
29    /// Timeout in milliseconds (default: 120000 = 2 minutes)
30    #[serde(default = "default_timeout")]
31    timeout_ms: u64,
32}
33
34const fn default_timeout() -> u64 {
35    120_000 // 2 minutes
36}
37
38#[async_trait]
39impl<E: Environment + 'static> Tool<()> for BashTool<E> {
40    fn name(&self) -> &'static str {
41        "bash"
42    }
43
44    fn description(&self) -> &'static str {
45        "Execute a shell command. Use for git, npm, cargo, and other CLI tools. Returns stdout, stderr, and exit code."
46    }
47
48    fn tier(&self) -> ToolTier {
49        ToolTier::Confirm
50    }
51
52    fn input_schema(&self) -> Value {
53        json!({
54            "type": "object",
55            "properties": {
56                "command": {
57                    "type": "string",
58                    "description": "The shell command to execute"
59                },
60                "timeout_ms": {
61                    "type": "integer",
62                    "description": "Timeout in milliseconds. Default: 120000 (2 minutes)"
63                }
64            },
65            "required": ["command"]
66        })
67    }
68
69    async fn execute(&self, _ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
70        let input: BashInput =
71            serde_json::from_value(input).context("Invalid input for bash tool")?;
72
73        // Check exec capability
74        if !self.ctx.capabilities.exec {
75            return Ok(ToolResult::error(
76                "Permission denied: command execution is disabled",
77            ));
78        }
79
80        // Check if command is allowed
81        if !self.ctx.capabilities.can_exec(&input.command) {
82            return Ok(ToolResult::error(format!(
83                "Permission denied: command '{}' is not allowed",
84                truncate_command(&input.command, 100)
85            )));
86        }
87
88        // Validate timeout
89        let timeout_ms = input.timeout_ms.min(600_000); // Max 10 minutes
90
91        // Execute command
92        let result = self
93            .ctx
94            .environment
95            .exec(&input.command, Some(timeout_ms))
96            .await
97            .context("Failed to execute command")?;
98
99        // Format output
100        let mut output = String::new();
101
102        if !result.stdout.is_empty() {
103            output.push_str(&result.stdout);
104        }
105
106        if !result.stderr.is_empty() {
107            if !output.is_empty() {
108                output.push_str("\n\n--- stderr ---\n");
109            }
110            output.push_str(&result.stderr);
111        }
112
113        if output.is_empty() {
114            output = "(no output)".to_string();
115        }
116
117        // Truncate if too long
118        let max_output_len = 30_000;
119        if output.len() > max_output_len {
120            output = format!(
121                "{}...\n\n(output truncated, {} total characters)",
122                &output[..max_output_len],
123                output.len()
124            );
125        }
126
127        // Include exit code in output
128        let _ = write!(output, "\n\nExit code: {}", result.exit_code);
129
130        if result.success() {
131            Ok(ToolResult::success(output))
132        } else {
133            Ok(ToolResult::error(output))
134        }
135    }
136}
137
138fn truncate_command(s: &str, max_len: usize) -> String {
139    if s.len() <= max_len {
140        s.to_string()
141    } else {
142        format!("{}...", &s[..max_len])
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use crate::AgentCapabilities;
150    use crate::environment::ExecResult;
151    use std::collections::HashMap;
152    use std::sync::RwLock;
153
154    // Mock environment for testing bash execution
155    struct MockBashEnvironment {
156        root: String,
157        // Map of command to (stdout, stderr, exit_code)
158        commands: RwLock<HashMap<String, (String, String, i32)>>,
159    }
160
161    impl MockBashEnvironment {
162        fn new() -> Self {
163            Self {
164                root: "/workspace".to_string(),
165                commands: RwLock::new(HashMap::new()),
166            }
167        }
168
169        fn add_command(&self, cmd: &str, stdout: &str, stderr: &str, exit_code: i32) {
170            self.commands.write().unwrap().insert(
171                cmd.to_string(),
172                (stdout.to_string(), stderr.to_string(), exit_code),
173            );
174        }
175    }
176
177    #[async_trait]
178    impl crate::Environment for MockBashEnvironment {
179        async fn read_file(&self, _path: &str) -> Result<String> {
180            Ok(String::new())
181        }
182
183        async fn read_file_bytes(&self, _path: &str) -> Result<Vec<u8>> {
184            Ok(vec![])
185        }
186
187        async fn write_file(&self, _path: &str, _content: &str) -> Result<()> {
188            Ok(())
189        }
190
191        async fn write_file_bytes(&self, _path: &str, _content: &[u8]) -> Result<()> {
192            Ok(())
193        }
194
195        async fn list_dir(&self, _path: &str) -> Result<Vec<crate::environment::FileEntry>> {
196            Ok(vec![])
197        }
198
199        async fn exists(&self, _path: &str) -> Result<bool> {
200            Ok(false)
201        }
202
203        async fn is_dir(&self, _path: &str) -> Result<bool> {
204            Ok(false)
205        }
206
207        async fn is_file(&self, _path: &str) -> Result<bool> {
208            Ok(false)
209        }
210
211        async fn create_dir(&self, _path: &str) -> Result<()> {
212            Ok(())
213        }
214
215        async fn delete_file(&self, _path: &str) -> Result<()> {
216            Ok(())
217        }
218
219        async fn delete_dir(&self, _path: &str, _recursive: bool) -> Result<()> {
220            Ok(())
221        }
222
223        async fn grep(
224            &self,
225            _pattern: &str,
226            _path: &str,
227            _recursive: bool,
228        ) -> Result<Vec<crate::environment::GrepMatch>> {
229            Ok(vec![])
230        }
231
232        async fn glob(&self, _pattern: &str) -> Result<Vec<String>> {
233            Ok(vec![])
234        }
235
236        async fn exec(&self, command: &str, _timeout_ms: Option<u64>) -> Result<ExecResult> {
237            let commands = self.commands.read().unwrap();
238            if let Some((stdout, stderr, exit_code)) = commands.get(command) {
239                Ok(ExecResult {
240                    stdout: stdout.clone(),
241                    stderr: stderr.clone(),
242                    exit_code: *exit_code,
243                })
244            } else {
245                // Default: command not found
246                Ok(ExecResult {
247                    stdout: String::new(),
248                    stderr: format!("command not found: {command}"),
249                    exit_code: 127,
250                })
251            }
252        }
253
254        fn root(&self) -> &str {
255            &self.root
256        }
257    }
258
259    fn create_test_tool(
260        env: Arc<MockBashEnvironment>,
261        capabilities: AgentCapabilities,
262    ) -> BashTool<MockBashEnvironment> {
263        BashTool::new(env, capabilities)
264    }
265
266    fn tool_ctx() -> ToolContext<()> {
267        ToolContext::new(())
268    }
269
270    // ===================
271    // Unit Tests
272    // ===================
273
274    #[tokio::test]
275    async fn test_bash_simple_command() -> anyhow::Result<()> {
276        let env = Arc::new(MockBashEnvironment::new());
277        env.add_command("echo hello", "hello\n", "", 0);
278
279        let tool = create_test_tool(env, AgentCapabilities::full_access());
280        let result = tool
281            .execute(&tool_ctx(), json!({"command": "echo hello"}))
282            .await?;
283
284        assert!(result.success);
285        assert!(result.output.contains("hello"));
286        assert!(result.output.contains("Exit code: 0"));
287        Ok(())
288    }
289
290    #[tokio::test]
291    async fn test_bash_command_with_stderr() -> anyhow::Result<()> {
292        let env = Arc::new(MockBashEnvironment::new());
293        env.add_command("cmd", "stdout output", "stderr output", 0);
294
295        let tool = create_test_tool(env, AgentCapabilities::full_access());
296        let result = tool.execute(&tool_ctx(), json!({"command": "cmd"})).await?;
297
298        assert!(result.success);
299        assert!(result.output.contains("stdout output"));
300        assert!(result.output.contains("stderr output"));
301        Ok(())
302    }
303
304    #[tokio::test]
305    async fn test_bash_command_nonzero_exit() -> anyhow::Result<()> {
306        let env = Arc::new(MockBashEnvironment::new());
307        env.add_command("failing_cmd", "", "error occurred", 1);
308
309        let tool = create_test_tool(env, AgentCapabilities::full_access());
310        let result = tool
311            .execute(&tool_ctx(), json!({"command": "failing_cmd"}))
312            .await?;
313
314        assert!(!result.success);
315        assert!(result.output.contains("Exit code: 1"));
316        Ok(())
317    }
318
319    #[tokio::test]
320    async fn test_bash_command_not_found() -> anyhow::Result<()> {
321        let env = Arc::new(MockBashEnvironment::new());
322
323        let tool = create_test_tool(env, AgentCapabilities::full_access());
324        let result = tool
325            .execute(&tool_ctx(), json!({"command": "nonexistent_cmd"}))
326            .await?;
327
328        assert!(!result.success);
329        assert!(result.output.contains("Exit code: 127"));
330        Ok(())
331    }
332
333    // ===================
334    // Integration Tests
335    // ===================
336
337    #[tokio::test]
338    async fn test_bash_exec_disabled() -> anyhow::Result<()> {
339        let env = Arc::new(MockBashEnvironment::new());
340
341        // Read-only capabilities (exec disabled)
342        let caps = AgentCapabilities::read_only();
343
344        let tool = create_test_tool(env, caps);
345        let result = tool.execute(&tool_ctx(), json!({"command": "ls"})).await?;
346
347        assert!(!result.success);
348        assert!(result.output.contains("Permission denied"));
349        assert!(result.output.contains("execution is disabled"));
350        Ok(())
351    }
352
353    #[tokio::test]
354    async fn test_bash_dangerous_command_denied() -> anyhow::Result<()> {
355        let env = Arc::new(MockBashEnvironment::new());
356
357        // Full access but default denied commands
358        let caps = AgentCapabilities::full_access();
359
360        let tool = create_test_tool(env, caps);
361        let result = tool
362            .execute(&tool_ctx(), json!({"command": "rm -rf /"}))
363            .await?;
364
365        assert!(!result.success);
366        assert!(result.output.contains("Permission denied"));
367        assert!(result.output.contains("not allowed"));
368        Ok(())
369    }
370
371    #[tokio::test]
372    async fn test_bash_sudo_command_denied() -> anyhow::Result<()> {
373        let env = Arc::new(MockBashEnvironment::new());
374        let caps = AgentCapabilities::full_access();
375
376        let tool = create_test_tool(env, caps);
377        let result = tool
378            .execute(&tool_ctx(), json!({"command": "sudo apt-get install foo"}))
379            .await?;
380
381        assert!(!result.success);
382        assert!(result.output.contains("Permission denied"));
383        Ok(())
384    }
385
386    #[tokio::test]
387    async fn test_bash_allowed_commands_restriction() -> anyhow::Result<()> {
388        let env = Arc::new(MockBashEnvironment::new());
389        env.add_command("cargo build", "Compiling...", "", 0);
390
391        // Only allow cargo and git commands
392        let caps = AgentCapabilities::full_access()
393            .with_denied_commands(vec![])
394            .with_allowed_commands(vec![r"^cargo ".into(), r"^git ".into()]);
395
396        let tool = create_test_tool(Arc::clone(&env), caps.clone());
397
398        // cargo should be allowed
399        let result = tool
400            .execute(&tool_ctx(), json!({"command": "cargo build"}))
401            .await?;
402        assert!(result.success);
403
404        // ls should be denied
405        let tool = create_test_tool(env, caps);
406        let result = tool
407            .execute(&tool_ctx(), json!({"command": "ls -la"}))
408            .await?;
409        assert!(!result.success);
410        assert!(result.output.contains("not allowed"));
411        Ok(())
412    }
413
414    // ===================
415    // Edge Cases
416    // ===================
417
418    #[tokio::test]
419    async fn test_bash_empty_output() -> anyhow::Result<()> {
420        let env = Arc::new(MockBashEnvironment::new());
421        env.add_command("true", "", "", 0);
422
423        let tool = create_test_tool(env, AgentCapabilities::full_access());
424        let result = tool
425            .execute(&tool_ctx(), json!({"command": "true"}))
426            .await?;
427
428        assert!(result.success);
429        assert!(result.output.contains("(no output)"));
430        Ok(())
431    }
432
433    #[tokio::test]
434    async fn test_bash_custom_timeout() -> anyhow::Result<()> {
435        let env = Arc::new(MockBashEnvironment::new());
436        env.add_command("slow_cmd", "done", "", 0);
437
438        let tool = create_test_tool(env, AgentCapabilities::full_access());
439        let result = tool
440            .execute(
441                &tool_ctx(),
442                json!({"command": "slow_cmd", "timeout_ms": 5000}),
443            )
444            .await?;
445
446        assert!(result.success);
447        Ok(())
448    }
449
450    #[tokio::test]
451    async fn test_bash_tool_metadata() {
452        let env = Arc::new(MockBashEnvironment::new());
453        let tool = create_test_tool(env, AgentCapabilities::full_access());
454
455        assert_eq!(tool.name(), "bash");
456        assert_eq!(tool.tier(), ToolTier::Confirm);
457        assert!(tool.description().contains("Execute"));
458
459        let schema = tool.input_schema();
460        assert!(schema.get("properties").is_some());
461        assert!(schema["properties"].get("command").is_some());
462        assert!(schema["properties"].get("timeout_ms").is_some());
463    }
464
465    #[tokio::test]
466    async fn test_bash_invalid_input() -> anyhow::Result<()> {
467        let env = Arc::new(MockBashEnvironment::new());
468        let tool = create_test_tool(env, AgentCapabilities::full_access());
469
470        // Missing required command field
471        let result = tool.execute(&tool_ctx(), json!({})).await;
472        assert!(result.is_err());
473        Ok(())
474    }
475
476    #[tokio::test]
477    async fn test_bash_long_output_truncated() -> anyhow::Result<()> {
478        let env = Arc::new(MockBashEnvironment::new());
479        let long_output = "x".repeat(40_000);
480        env.add_command("long_output_cmd", &long_output, "", 0);
481
482        let tool = create_test_tool(env, AgentCapabilities::full_access());
483        let result = tool
484            .execute(&tool_ctx(), json!({"command": "long_output_cmd"}))
485            .await?;
486
487        assert!(result.success);
488        assert!(result.output.contains("output truncated"));
489        assert!(result.output.len() < 35_000); // Should be truncated
490        Ok(())
491    }
492
493    #[tokio::test]
494    async fn test_truncate_command_function() {
495        assert_eq!(truncate_command("short", 10), "short");
496        assert_eq!(
497            truncate_command("this is a longer command", 10),
498            "this is a ..."
499        );
500    }
501}