rs_adk/tools/
bash_tool.rs1use std::path::PathBuf;
7use std::process::Stdio;
8
9use async_trait::async_trait;
10
11use crate::error::ToolError;
12use crate::tool::ToolFunction;
13
14#[derive(Debug, Clone)]
16pub struct BashToolPolicy {
17 pub allowed_command_prefixes: Vec<String>,
19}
20
21impl Default for BashToolPolicy {
22 fn default() -> Self {
23 Self {
24 allowed_command_prefixes: vec!["*".into()],
25 }
26 }
27}
28
29impl BashToolPolicy {
30 pub fn validate(&self, command: &str) -> Result<(), String> {
32 let stripped = command.trim();
33 if stripped.is_empty() {
34 return Err("Command is required.".into());
35 }
36
37 if self.allowed_command_prefixes.iter().any(|p| p == "*") {
38 return Ok(());
39 }
40
41 for prefix in &self.allowed_command_prefixes {
42 if stripped.starts_with(prefix.as_str()) {
43 return Ok(());
44 }
45 }
46
47 Err(format!(
48 "Command blocked. Permitted prefixes are: {}",
49 self.allowed_command_prefixes.join(", ")
50 ))
51 }
52}
53
54#[derive(Debug, Clone)]
60pub struct ExecuteBashTool {
61 workspace: PathBuf,
63 policy: BashToolPolicy,
65 timeout_secs: u64,
67}
68
69impl ExecuteBashTool {
70 pub fn new(workspace: PathBuf) -> Self {
72 Self {
73 workspace,
74 policy: BashToolPolicy::default(),
75 timeout_secs: 30,
76 }
77 }
78
79 pub fn with_policy(mut self, policy: BashToolPolicy) -> Self {
81 self.policy = policy;
82 self
83 }
84
85 pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
87 self.timeout_secs = timeout_secs;
88 self
89 }
90}
91
92#[async_trait]
93impl ToolFunction for ExecuteBashTool {
94 fn name(&self) -> &str {
95 "execute_bash"
96 }
97
98 fn description(&self) -> &str {
99 "Executes a bash command with the working directory set to the workspace. \
100 All commands require validation against the configured policy."
101 }
102
103 fn parameters(&self) -> Option<serde_json::Value> {
104 Some(serde_json::json!({
105 "type": "object",
106 "properties": {
107 "command": {
108 "type": "string",
109 "description": "The bash command to execute."
110 }
111 },
112 "required": ["command"]
113 }))
114 }
115
116 async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
117 let command = args
118 .get("command")
119 .and_then(|v| v.as_str())
120 .ok_or_else(|| ToolError::InvalidArgs("Missing command".into()))?;
121
122 if let Err(e) = self.policy.validate(command) {
124 return Ok(serde_json::json!({"error": e}));
125 }
126
127 let output = tokio::process::Command::new("sh")
129 .arg("-c")
130 .arg(command)
131 .current_dir(&self.workspace)
132 .stdout(Stdio::piped())
133 .stderr(Stdio::piped())
134 .output()
135 .await
136 .map_err(|e| ToolError::ExecutionFailed(format!("Failed to execute command: {e}")))?;
137
138 Ok(serde_json::json!({
139 "stdout": String::from_utf8_lossy(&output.stdout),
140 "stderr": String::from_utf8_lossy(&output.stderr),
141 "returncode": output.status.code().unwrap_or(-1)
142 }))
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn policy_allows_all_by_default() {
152 let policy = BashToolPolicy::default();
153 assert!(policy.validate("ls -la").is_ok());
154 assert!(policy.validate("echo hello").is_ok());
155 }
156
157 #[test]
158 fn policy_blocks_unmatched_prefix() {
159 let policy = BashToolPolicy {
160 allowed_command_prefixes: vec!["ls".into(), "echo".into()],
161 };
162 assert!(policy.validate("ls -la").is_ok());
163 assert!(policy.validate("echo hello").is_ok());
164 assert!(policy.validate("rm -rf /").is_err());
165 }
166
167 #[test]
168 fn policy_rejects_empty_command() {
169 let policy = BashToolPolicy::default();
170 assert!(policy.validate("").is_err());
171 assert!(policy.validate(" ").is_err());
172 }
173
174 #[test]
175 fn tool_metadata() {
176 let tool = ExecuteBashTool::new(PathBuf::from("/tmp"));
177 assert_eq!(tool.name(), "execute_bash");
178 assert!(tool.parameters().is_some());
179 }
180
181 #[tokio::test]
182 async fn execute_simple_command() {
183 let tool = ExecuteBashTool::new(PathBuf::from("/tmp"));
184 let result = tool
185 .call(serde_json::json!({"command": "echo hello"}))
186 .await
187 .unwrap();
188 assert_eq!(result["stdout"].as_str().unwrap().trim(), "hello");
189 assert_eq!(result["returncode"], 0);
190 }
191
192 #[tokio::test]
193 async fn blocked_command_returns_error() {
194 let tool = ExecuteBashTool::new(PathBuf::from("/tmp")).with_policy(BashToolPolicy {
195 allowed_command_prefixes: vec!["ls".into()],
196 });
197 let result = tool
198 .call(serde_json::json!({"command": "rm -rf /"}))
199 .await
200 .unwrap();
201 assert!(result["error"].as_str().unwrap().contains("blocked"));
202 }
203}