Skip to main content

astrid_tools/
bash.rs

1//! Bash tool — executes shell commands with persistent working directory.
2
3use crate::{BuiltinTool, ToolContext, ToolError, ToolResult};
4use serde_json::Value;
5use std::path::PathBuf;
6use tokio::process::Command;
7
8/// Default timeout in milliseconds (2 minutes).
9const DEFAULT_TIMEOUT_MS: u64 = 120_000;
10/// Maximum timeout in milliseconds (10 minutes).
11const MAX_TIMEOUT_MS: u64 = 600_000;
12/// Sentinel used to extract the post-command working directory.
13const CWD_SENTINEL: &str = "__ASTRID_CWD__";
14
15/// Built-in tool for executing bash commands.
16pub struct BashTool;
17
18#[async_trait::async_trait]
19impl BuiltinTool for BashTool {
20    fn name(&self) -> &'static str {
21        "bash"
22    }
23
24    fn description(&self) -> &'static str {
25        "Executes a bash command. The working directory persists between invocations. \
26         Use for git, npm, cargo, docker, and other terminal operations. \
27         Optional timeout in milliseconds (max 600000)."
28    }
29
30    fn input_schema(&self) -> Value {
31        serde_json::json!({
32            "type": "object",
33            "properties": {
34                "command": {
35                    "type": "string",
36                    "description": "The bash command to execute"
37                },
38                "timeout": {
39                    "type": "integer",
40                    "description": "Timeout in milliseconds (default: 120000, max: 600000)"
41                }
42            },
43            "required": ["command"]
44        })
45    }
46
47    async fn execute(&self, args: Value, ctx: &ToolContext) -> ToolResult {
48        let command = args
49            .get("command")
50            .and_then(Value::as_str)
51            .ok_or_else(|| ToolError::InvalidArguments("command is required".into()))?;
52
53        let timeout_ms = args
54            .get("timeout")
55            .and_then(Value::as_u64)
56            .unwrap_or(DEFAULT_TIMEOUT_MS)
57            .min(MAX_TIMEOUT_MS);
58
59        let cwd = ctx.cwd.read().await.clone();
60
61        // Wrap command with sentinel-based cwd tracking
62        let wrapped = format!(
63            "{command}\n__ASTRID_EXIT__=$?\necho \"{CWD_SENTINEL}\"\npwd\nexit $__ASTRID_EXIT__"
64        );
65
66        let result = tokio::time::timeout(
67            std::time::Duration::from_millis(timeout_ms),
68            run_bash(&wrapped, &cwd),
69        )
70        .await;
71
72        match result {
73            Ok(Ok((stdout, stderr, exit_code))) => {
74                // Parse stdout: split on sentinel to get output and new cwd
75                let (output, new_cwd) = parse_sentinel_output(&stdout);
76
77                // Update persistent cwd
78                if let Some(new_cwd) = new_cwd {
79                    let mut cwd_lock = ctx.cwd.write().await;
80                    *cwd_lock = new_cwd;
81                }
82
83                let mut result_text = String::new();
84
85                if !output.is_empty() {
86                    result_text.push_str(&output);
87                }
88
89                if !stderr.is_empty() {
90                    if !result_text.is_empty() {
91                        result_text.push('\n');
92                    }
93                    result_text.push_str("STDERR:\n");
94                    result_text.push_str(&stderr);
95                }
96
97                if exit_code != 0 {
98                    if !result_text.is_empty() {
99                        result_text.push('\n');
100                    }
101                    result_text.push_str("(exit code: ");
102                    result_text.push_str(&exit_code.to_string());
103                    result_text.push(')');
104                }
105
106                if result_text.is_empty() {
107                    result_text.push_str("(no output)");
108                }
109
110                Ok(result_text)
111            },
112            Ok(Err(e)) => Err(ToolError::ExecutionFailed(e.to_string())),
113            Err(_) => Err(ToolError::Timeout(timeout_ms)),
114        }
115    }
116}
117
118/// Run a bash command and capture stdout, stderr, and exit code.
119async fn run_bash(command: &str, cwd: &std::path::Path) -> std::io::Result<(String, String, i32)> {
120    let output = Command::new("bash")
121        .arg("-c")
122        .arg(command)
123        .current_dir(cwd)
124        .output()
125        .await?;
126
127    let stdout = String::from_utf8_lossy(&output.stdout).to_string();
128    let stderr = String::from_utf8_lossy(&output.stderr).to_string();
129    let exit_code = output.status.code().unwrap_or(-1);
130
131    Ok((stdout, stderr, exit_code))
132}
133
134/// Parse the sentinel from stdout to extract command output and new cwd.
135fn parse_sentinel_output(stdout: &str) -> (String, Option<PathBuf>) {
136    if let Some(sentinel_pos) = stdout.find(CWD_SENTINEL) {
137        let output = stdout[..sentinel_pos].trim_end().to_string();
138        // Safety: sentinel_pos comes from find() and CWD_SENTINEL.len() is within bounds
139        #[allow(clippy::arithmetic_side_effects)]
140        let after_sentinel = &stdout[sentinel_pos + CWD_SENTINEL.len()..];
141        let new_cwd = after_sentinel
142            .lines()
143            .find(|l| !l.is_empty())
144            .map(|l| PathBuf::from(l.trim()));
145        (output, new_cwd)
146    } else {
147        (stdout.to_string(), None)
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use tempfile::TempDir;
155
156    fn ctx_with_root(root: &std::path::Path) -> ToolContext {
157        ToolContext::new(root.to_path_buf())
158    }
159
160    #[tokio::test]
161    async fn test_bash_echo() {
162        let ctx = ctx_with_root(&std::env::temp_dir());
163        let result = BashTool
164            .execute(serde_json::json!({"command": "echo hello"}), &ctx)
165            .await
166            .unwrap();
167
168        assert!(result.contains("hello"));
169    }
170
171    #[tokio::test]
172    async fn test_bash_exit_code() {
173        let ctx = ctx_with_root(&std::env::temp_dir());
174        let result = BashTool
175            .execute(serde_json::json!({"command": "exit 42"}), &ctx)
176            .await
177            .unwrap();
178
179        assert!(result.contains("exit code: 42"));
180    }
181
182    #[tokio::test]
183    async fn test_bash_stderr() {
184        let ctx = ctx_with_root(&std::env::temp_dir());
185        let result = BashTool
186            .execute(serde_json::json!({"command": "echo error >&2"}), &ctx)
187            .await
188            .unwrap();
189
190        assert!(result.contains("STDERR:"));
191        assert!(result.contains("error"));
192    }
193
194    #[tokio::test]
195    async fn test_bash_cwd_persistence() {
196        let dir = TempDir::new().unwrap();
197        let ctx = ctx_with_root(dir.path());
198
199        // Create a subdirectory and cd into it
200        std::fs::create_dir(dir.path().join("subdir")).unwrap();
201        BashTool
202            .execute(serde_json::json!({"command": "cd subdir"}), &ctx)
203            .await
204            .unwrap();
205
206        // Verify cwd was updated
207        let cwd = ctx.cwd.read().await.clone();
208        assert!(cwd.ends_with("subdir"));
209
210        // Next command should run in the new cwd
211        let result = BashTool
212            .execute(serde_json::json!({"command": "pwd"}), &ctx)
213            .await
214            .unwrap();
215
216        assert!(result.contains("subdir"));
217    }
218
219    #[tokio::test]
220    async fn test_bash_timeout() {
221        let ctx = ctx_with_root(&std::env::temp_dir());
222        let result = BashTool
223            .execute(
224                serde_json::json!({"command": "sleep 10", "timeout": 100}),
225                &ctx,
226            )
227            .await;
228
229        assert!(result.is_err());
230        assert!(matches!(result.unwrap_err(), ToolError::Timeout(100)));
231    }
232
233    #[test]
234    fn test_parse_sentinel_output() {
235        let stdout = format!("hello world\n{CWD_SENTINEL}\n/tmp/test\n");
236        let (output, cwd) = parse_sentinel_output(&stdout);
237        assert_eq!(output, "hello world");
238        assert_eq!(cwd, Some(PathBuf::from("/tmp/test")));
239    }
240
241    #[test]
242    fn test_parse_sentinel_no_sentinel() {
243        let (output, cwd) = parse_sentinel_output("hello world\n");
244        assert_eq!(output, "hello world\n");
245        assert!(cwd.is_none());
246    }
247}