1use crate::{BuiltinTool, ToolContext, ToolError, ToolResult};
4use serde_json::Value;
5use std::path::PathBuf;
6use tokio::process::Command;
7
8const DEFAULT_TIMEOUT_MS: u64 = 120_000;
10const MAX_TIMEOUT_MS: u64 = 600_000;
12const CWD_SENTINEL: &str = "__ASTRID_CWD__";
14
15pub struct BashTool;
17
18#[async_trait::async_trait]
19impl BuiltinTool for BashTool {
20 fn name(&self) -> &'static str {
21 "bash"
22 }
23
24 fn description(&self) -> &'static str {
25 "Executes a bash command. The working directory persists between invocations. \
26 Use for git, npm, cargo, docker, and other terminal operations. \
27 Optional timeout in milliseconds (max 600000)."
28 }
29
30 fn input_schema(&self) -> Value {
31 serde_json::json!({
32 "type": "object",
33 "properties": {
34 "command": {
35 "type": "string",
36 "description": "The bash command to execute"
37 },
38 "timeout": {
39 "type": "integer",
40 "description": "Timeout in milliseconds (default: 120000, max: 600000)"
41 }
42 },
43 "required": ["command"]
44 })
45 }
46
47 async fn execute(&self, args: Value, ctx: &ToolContext) -> ToolResult {
48 let command = args
49 .get("command")
50 .and_then(Value::as_str)
51 .ok_or_else(|| ToolError::InvalidArguments("command is required".into()))?;
52
53 let timeout_ms = args
54 .get("timeout")
55 .and_then(Value::as_u64)
56 .unwrap_or(DEFAULT_TIMEOUT_MS)
57 .min(MAX_TIMEOUT_MS);
58
59 let cwd = ctx.cwd.read().await.clone();
60
61 let wrapped = format!(
63 "{command}\n__ASTRID_EXIT__=$?\necho \"{CWD_SENTINEL}\"\npwd\nexit $__ASTRID_EXIT__"
64 );
65
66 let result = tokio::time::timeout(
67 std::time::Duration::from_millis(timeout_ms),
68 run_bash(&wrapped, &cwd),
69 )
70 .await;
71
72 match result {
73 Ok(Ok((stdout, stderr, exit_code))) => {
74 let (output, new_cwd) = parse_sentinel_output(&stdout);
76
77 if let Some(new_cwd) = new_cwd {
79 let mut cwd_lock = ctx.cwd.write().await;
80 *cwd_lock = new_cwd;
81 }
82
83 let mut result_text = String::new();
84
85 if !output.is_empty() {
86 result_text.push_str(&output);
87 }
88
89 if !stderr.is_empty() {
90 if !result_text.is_empty() {
91 result_text.push('\n');
92 }
93 result_text.push_str("STDERR:\n");
94 result_text.push_str(&stderr);
95 }
96
97 if exit_code != 0 {
98 if !result_text.is_empty() {
99 result_text.push('\n');
100 }
101 result_text.push_str("(exit code: ");
102 result_text.push_str(&exit_code.to_string());
103 result_text.push(')');
104 }
105
106 if result_text.is_empty() {
107 result_text.push_str("(no output)");
108 }
109
110 Ok(result_text)
111 },
112 Ok(Err(e)) => Err(ToolError::ExecutionFailed(e.to_string())),
113 Err(_) => Err(ToolError::Timeout(timeout_ms)),
114 }
115 }
116}
117
118async fn run_bash(command: &str, cwd: &std::path::Path) -> std::io::Result<(String, String, i32)> {
120 let output = Command::new("bash")
121 .arg("-c")
122 .arg(command)
123 .current_dir(cwd)
124 .output()
125 .await?;
126
127 let stdout = String::from_utf8_lossy(&output.stdout).to_string();
128 let stderr = String::from_utf8_lossy(&output.stderr).to_string();
129 let exit_code = output.status.code().unwrap_or(-1);
130
131 Ok((stdout, stderr, exit_code))
132}
133
134fn parse_sentinel_output(stdout: &str) -> (String, Option<PathBuf>) {
136 if let Some(sentinel_pos) = stdout.find(CWD_SENTINEL) {
137 let output = stdout[..sentinel_pos].trim_end().to_string();
138 #[allow(clippy::arithmetic_side_effects)]
140 let after_sentinel = &stdout[sentinel_pos + CWD_SENTINEL.len()..];
141 let new_cwd = after_sentinel
142 .lines()
143 .find(|l| !l.is_empty())
144 .map(|l| PathBuf::from(l.trim()));
145 (output, new_cwd)
146 } else {
147 (stdout.to_string(), None)
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use tempfile::TempDir;
155
156 fn ctx_with_root(root: &std::path::Path) -> ToolContext {
157 ToolContext::new(root.to_path_buf())
158 }
159
160 #[tokio::test]
161 async fn test_bash_echo() {
162 let ctx = ctx_with_root(&std::env::temp_dir());
163 let result = BashTool
164 .execute(serde_json::json!({"command": "echo hello"}), &ctx)
165 .await
166 .unwrap();
167
168 assert!(result.contains("hello"));
169 }
170
171 #[tokio::test]
172 async fn test_bash_exit_code() {
173 let ctx = ctx_with_root(&std::env::temp_dir());
174 let result = BashTool
175 .execute(serde_json::json!({"command": "exit 42"}), &ctx)
176 .await
177 .unwrap();
178
179 assert!(result.contains("exit code: 42"));
180 }
181
182 #[tokio::test]
183 async fn test_bash_stderr() {
184 let ctx = ctx_with_root(&std::env::temp_dir());
185 let result = BashTool
186 .execute(serde_json::json!({"command": "echo error >&2"}), &ctx)
187 .await
188 .unwrap();
189
190 assert!(result.contains("STDERR:"));
191 assert!(result.contains("error"));
192 }
193
194 #[tokio::test]
195 async fn test_bash_cwd_persistence() {
196 let dir = TempDir::new().unwrap();
197 let ctx = ctx_with_root(dir.path());
198
199 std::fs::create_dir(dir.path().join("subdir")).unwrap();
201 BashTool
202 .execute(serde_json::json!({"command": "cd subdir"}), &ctx)
203 .await
204 .unwrap();
205
206 let cwd = ctx.cwd.read().await.clone();
208 assert!(cwd.ends_with("subdir"));
209
210 let result = BashTool
212 .execute(serde_json::json!({"command": "pwd"}), &ctx)
213 .await
214 .unwrap();
215
216 assert!(result.contains("subdir"));
217 }
218
219 #[tokio::test]
220 async fn test_bash_timeout() {
221 let ctx = ctx_with_root(&std::env::temp_dir());
222 let result = BashTool
223 .execute(
224 serde_json::json!({"command": "sleep 10", "timeout": 100}),
225 &ctx,
226 )
227 .await;
228
229 assert!(result.is_err());
230 assert!(matches!(result.unwrap_err(), ToolError::Timeout(100)));
231 }
232
233 #[test]
234 fn test_parse_sentinel_output() {
235 let stdout = format!("hello world\n{CWD_SENTINEL}\n/tmp/test\n");
236 let (output, cwd) = parse_sentinel_output(&stdout);
237 assert_eq!(output, "hello world");
238 assert_eq!(cwd, Some(PathBuf::from("/tmp/test")));
239 }
240
241 #[test]
242 fn test_parse_sentinel_no_sentinel() {
243 let (output, cwd) = parse_sentinel_output("hello world\n");
244 assert_eq!(output, "hello world\n");
245 assert!(cwd.is_none());
246 }
247}