Skip to main content

rs_adk/tools/
bash_tool.rs

1//! Bash execution tool — allows agents to execute shell commands.
2//!
3//! Mirrors ADK-Python's `ExecuteBashTool`. Provides policy-based
4//! command validation and requires user confirmation before execution.
5
6use std::path::PathBuf;
7use std::process::Stdio;
8
9use async_trait::async_trait;
10
11use crate::error::ToolError;
12use crate::tool::ToolFunction;
13
14/// Policy for allowed bash commands based on prefix matching.
15#[derive(Debug, Clone)]
16pub struct BashToolPolicy {
17    /// Allowed command prefixes. Use `["*"]` to allow all commands.
18    pub allowed_command_prefixes: Vec<String>,
19}
20
21impl Default for BashToolPolicy {
22    fn default() -> Self {
23        Self {
24            allowed_command_prefixes: vec!["*".into()],
25        }
26    }
27}
28
29impl BashToolPolicy {
30    /// Check whether a command is allowed by this policy.
31    pub fn validate(&self, command: &str) -> Result<(), String> {
32        let stripped = command.trim();
33        if stripped.is_empty() {
34            return Err("Command is required.".into());
35        }
36
37        if self.allowed_command_prefixes.iter().any(|p| p == "*") {
38            return Ok(());
39        }
40
41        for prefix in &self.allowed_command_prefixes {
42            if stripped.starts_with(prefix.as_str()) {
43                return Ok(());
44            }
45        }
46
47        Err(format!(
48            "Command blocked. Permitted prefixes are: {}",
49            self.allowed_command_prefixes.join(", ")
50        ))
51    }
52}
53
54/// Tool that executes bash commands with policy-based validation.
55///
56/// Commands are validated against the configured policy before execution.
57/// In a real deployment, this tool should also require user confirmation
58/// via the tool confirmation mechanism.
59#[derive(Debug, Clone)]
60pub struct ExecuteBashTool {
61    /// Working directory for command execution.
62    workspace: PathBuf,
63    /// Command validation policy.
64    policy: BashToolPolicy,
65    /// Command execution timeout in seconds.
66    timeout_secs: u64,
67}
68
69impl ExecuteBashTool {
70    /// Create a new bash execution tool with the given workspace directory.
71    pub fn new(workspace: PathBuf) -> Self {
72        Self {
73            workspace,
74            policy: BashToolPolicy::default(),
75            timeout_secs: 30,
76        }
77    }
78
79    /// Set the command validation policy.
80    pub fn with_policy(mut self, policy: BashToolPolicy) -> Self {
81        self.policy = policy;
82        self
83    }
84
85    /// Set the execution timeout in seconds.
86    pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
87        self.timeout_secs = timeout_secs;
88        self
89    }
90}
91
92#[async_trait]
93impl ToolFunction for ExecuteBashTool {
94    fn name(&self) -> &str {
95        "execute_bash"
96    }
97
98    fn description(&self) -> &str {
99        "Executes a bash command with the working directory set to the workspace. \
100         All commands require validation against the configured policy."
101    }
102
103    fn parameters(&self) -> Option<serde_json::Value> {
104        Some(serde_json::json!({
105            "type": "object",
106            "properties": {
107                "command": {
108                    "type": "string",
109                    "description": "The bash command to execute."
110                }
111            },
112            "required": ["command"]
113        }))
114    }
115
116    async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
117        let command = args
118            .get("command")
119            .and_then(|v| v.as_str())
120            .ok_or_else(|| ToolError::InvalidArgs("Missing command".into()))?;
121
122        // Validate command against policy
123        if let Err(e) = self.policy.validate(command) {
124            return Ok(serde_json::json!({"error": e}));
125        }
126
127        // Execute command
128        let output = tokio::process::Command::new("sh")
129            .arg("-c")
130            .arg(command)
131            .current_dir(&self.workspace)
132            .stdout(Stdio::piped())
133            .stderr(Stdio::piped())
134            .output()
135            .await
136            .map_err(|e| ToolError::ExecutionFailed(format!("Failed to execute command: {e}")))?;
137
138        Ok(serde_json::json!({
139            "stdout": String::from_utf8_lossy(&output.stdout),
140            "stderr": String::from_utf8_lossy(&output.stderr),
141            "returncode": output.status.code().unwrap_or(-1)
142        }))
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[test]
151    fn policy_allows_all_by_default() {
152        let policy = BashToolPolicy::default();
153        assert!(policy.validate("ls -la").is_ok());
154        assert!(policy.validate("echo hello").is_ok());
155    }
156
157    #[test]
158    fn policy_blocks_unmatched_prefix() {
159        let policy = BashToolPolicy {
160            allowed_command_prefixes: vec!["ls".into(), "echo".into()],
161        };
162        assert!(policy.validate("ls -la").is_ok());
163        assert!(policy.validate("echo hello").is_ok());
164        assert!(policy.validate("rm -rf /").is_err());
165    }
166
167    #[test]
168    fn policy_rejects_empty_command() {
169        let policy = BashToolPolicy::default();
170        assert!(policy.validate("").is_err());
171        assert!(policy.validate("  ").is_err());
172    }
173
174    #[test]
175    fn tool_metadata() {
176        let tool = ExecuteBashTool::new(PathBuf::from("/tmp"));
177        assert_eq!(tool.name(), "execute_bash");
178        assert!(tool.parameters().is_some());
179    }
180
181    #[tokio::test]
182    async fn execute_simple_command() {
183        let tool = ExecuteBashTool::new(PathBuf::from("/tmp"));
184        let result = tool
185            .call(serde_json::json!({"command": "echo hello"}))
186            .await
187            .unwrap();
188        assert_eq!(result["stdout"].as_str().unwrap().trim(), "hello");
189        assert_eq!(result["returncode"], 0);
190    }
191
192    #[tokio::test]
193    async fn blocked_command_returns_error() {
194        let tool = ExecuteBashTool::new(PathBuf::from("/tmp")).with_policy(BashToolPolicy {
195            allowed_command_prefixes: vec!["ls".into()],
196        });
197        let result = tool
198            .call(serde_json::json!({"command": "rm -rf /"}))
199            .await
200            .unwrap();
201        assert!(result["error"].as_str().unwrap().contains("blocked"));
202    }
203}