1use crate::{BuiltinTool, ToolContext, ToolError, ToolResult};
4use serde_json::Value;
5use std::path::PathBuf;
6use tokio::process::Command;
7use uuid::Uuid;
8
9const DEFAULT_TIMEOUT_MS: u64 = 120_000;
11const MAX_TIMEOUT_MS: u64 = 600_000;
13
14pub struct BashTool;
16
17#[async_trait::async_trait]
18impl BuiltinTool for BashTool {
19 fn name(&self) -> &'static str {
20 "bash"
21 }
22
23 fn description(&self) -> &'static str {
24 "Executes a bash command. The working directory persists between invocations. \
25 Use for git, npm, cargo, docker, and other terminal operations. \
26 Optional timeout in milliseconds (max 600000)."
27 }
28
29 fn input_schema(&self) -> Value {
30 serde_json::json!({
31 "type": "object",
32 "properties": {
33 "command": {
34 "type": "string",
35 "description": "The bash command to execute"
36 },
37 "timeout": {
38 "type": "integer",
39 "description": "Timeout in milliseconds (default: 120000, max: 600000)"
40 }
41 },
42 "required": ["command"]
43 })
44 }
45
46 async fn execute(&self, args: Value, ctx: &ToolContext) -> ToolResult {
47 let command = args
48 .get("command")
49 .and_then(Value::as_str)
50 .ok_or_else(|| ToolError::InvalidArguments("command is required".into()))?;
51
52 let timeout_ms = args
53 .get("timeout")
54 .and_then(Value::as_u64)
55 .unwrap_or(DEFAULT_TIMEOUT_MS)
56 .min(MAX_TIMEOUT_MS);
57
58 let cwd = ctx.cwd.read().await.clone();
59
60 let sentinel = Uuid::new_v4().to_string();
65 let wrapped = format!(
66 "{command}\n__ASTRID_EXIT__=$?\necho \"{sentinel}\"\npwd\nexit $__ASTRID_EXIT__"
67 );
68
69 let result = tokio::time::timeout(
70 std::time::Duration::from_millis(timeout_ms),
71 run_bash(&wrapped, &cwd),
72 )
73 .await;
74
75 match result {
76 Ok(Ok((stdout, stderr, exit_code))) => {
77 let (output, new_cwd) = parse_sentinel_output(&stdout, &sentinel);
79
80 if let Some(new_cwd) = new_cwd {
82 let mut cwd_lock = ctx.cwd.write().await;
83 *cwd_lock = new_cwd;
84 }
85
86 let mut result_text = String::new();
87
88 if !output.is_empty() {
89 result_text.push_str(&output);
90 }
91
92 if !stderr.is_empty() {
93 if !result_text.is_empty() {
94 result_text.push('\n');
95 }
96 result_text.push_str("STDERR:\n");
97 result_text.push_str(&stderr);
98 }
99
100 if exit_code != 0 {
101 if !result_text.is_empty() {
102 result_text.push('\n');
103 }
104 result_text.push_str("(exit code: ");
105 result_text.push_str(&exit_code.to_string());
106 result_text.push(')');
107 }
108
109 if result_text.is_empty() {
110 result_text.push_str("(no output)");
111 }
112
113 Ok(result_text)
114 },
115 Ok(Err(e)) => Err(ToolError::ExecutionFailed(e.to_string())),
116 Err(_) => Err(ToolError::Timeout(timeout_ms)),
117 }
118 }
119}
120
121async fn run_bash(command: &str, cwd: &std::path::Path) -> std::io::Result<(String, String, i32)> {
123 let output = Command::new("bash")
124 .arg("-c")
125 .arg(command)
126 .current_dir(cwd)
127 .output()
128 .await?;
129
130 let stdout = String::from_utf8_lossy(&output.stdout).to_string();
131 let stderr = String::from_utf8_lossy(&output.stderr).to_string();
132 let exit_code = output.status.code().unwrap_or(-1);
133
134 Ok((stdout, stderr, exit_code))
135}
136
137fn parse_sentinel_output(stdout: &str, sentinel: &str) -> (String, Option<PathBuf>) {
144 if let Some(sentinel_pos) = stdout.find(sentinel) {
145 let output = stdout[..sentinel_pos].trim_end().to_string();
146 #[allow(clippy::arithmetic_side_effects)]
147 let after_sentinel = &stdout[sentinel_pos + sentinel.len()..];
148 let new_cwd = after_sentinel
149 .lines()
150 .find(|l| !l.is_empty())
151 .map(|l| PathBuf::from(l.trim()));
152 (output, new_cwd)
153 } else {
154 (stdout.to_string(), None)
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use tempfile::TempDir;
162
163 fn ctx_with_root(root: &std::path::Path) -> ToolContext {
164 ToolContext::new(root.to_path_buf(), None)
165 }
166
167 #[tokio::test]
168 async fn test_bash_echo() {
169 let ctx = ctx_with_root(&std::env::temp_dir());
170 let result = BashTool
171 .execute(serde_json::json!({"command": "echo hello"}), &ctx)
172 .await
173 .unwrap();
174
175 assert!(result.contains("hello"));
176 }
177
178 #[tokio::test]
179 async fn test_bash_exit_code() {
180 let ctx = ctx_with_root(&std::env::temp_dir());
181 let result = BashTool
182 .execute(serde_json::json!({"command": "exit 42"}), &ctx)
183 .await
184 .unwrap();
185
186 assert!(result.contains("exit code: 42"));
187 }
188
189 #[tokio::test]
190 async fn test_bash_stderr() {
191 let ctx = ctx_with_root(&std::env::temp_dir());
192 let result = BashTool
193 .execute(serde_json::json!({"command": "echo error >&2"}), &ctx)
194 .await
195 .unwrap();
196
197 assert!(result.contains("STDERR:"));
198 assert!(result.contains("error"));
199 }
200
201 #[tokio::test]
202 async fn test_bash_cwd_persistence() {
203 let dir = TempDir::new().unwrap();
204 let ctx = ctx_with_root(dir.path());
205
206 std::fs::create_dir(dir.path().join("subdir")).unwrap();
208 BashTool
209 .execute(serde_json::json!({"command": "cd subdir"}), &ctx)
210 .await
211 .unwrap();
212
213 let cwd = ctx.cwd.read().await.clone();
215 assert!(cwd.ends_with("subdir"));
216
217 let result = BashTool
219 .execute(serde_json::json!({"command": "pwd"}), &ctx)
220 .await
221 .unwrap();
222
223 assert!(result.contains("subdir"));
224 }
225
226 #[tokio::test]
227 async fn test_bash_timeout() {
228 let ctx = ctx_with_root(&std::env::temp_dir());
229 let result = BashTool
230 .execute(
231 serde_json::json!({"command": "sleep 10", "timeout": 100}),
232 &ctx,
233 )
234 .await;
235
236 assert!(result.is_err());
237 assert!(matches!(result.unwrap_err(), ToolError::Timeout(100)));
238 }
239
240 #[test]
241 fn test_parse_sentinel_output() {
242 let sentinel = "550e8400-e29b-41d4-a716-446655440000";
243 let stdout = format!("hello world\n{sentinel}\n/tmp/test\n");
244 let (output, cwd) = parse_sentinel_output(&stdout, sentinel);
245 assert_eq!(output, "hello world");
246 assert_eq!(cwd, Some(PathBuf::from("/tmp/test")));
247 }
248
249 #[test]
250 fn test_parse_sentinel_no_sentinel() {
251 let sentinel = "550e8400-e29b-41d4-a716-446655440000";
252 let (output, cwd) = parse_sentinel_output("hello world\n", sentinel);
253 assert_eq!(output, "hello world\n");
254 assert!(cwd.is_none());
255 }
256
257 #[test]
258 fn test_parse_sentinel_injection_blocked() {
259 let sentinel = "550e8400-e29b-41d4-a716-446655440000";
262 let injected = "__ASTRID_CWD__";
263 let stdout = format!("{injected}\n/evil\n{sentinel}\n/actual/cwd\n");
264 let (output, cwd) = parse_sentinel_output(&stdout, sentinel);
265 assert!(output.contains(injected));
267 assert_eq!(cwd, Some(PathBuf::from("/actual/cwd")));
268 }
269}