Skip to main content

astrid_tools/
bash.rs

1//! Bash tool — executes shell commands with persistent working directory.
2
3use crate::{BuiltinTool, ToolContext, ToolError, ToolResult};
4use serde_json::Value;
5use std::path::PathBuf;
6use tokio::process::Command;
7use uuid::Uuid;
8
9/// Default timeout in milliseconds (2 minutes).
10const DEFAULT_TIMEOUT_MS: u64 = 120_000;
11/// Maximum timeout in milliseconds (10 minutes).
12const MAX_TIMEOUT_MS: u64 = 600_000;
13
14/// Built-in tool for executing bash commands.
15pub struct BashTool;
16
17#[async_trait::async_trait]
18impl BuiltinTool for BashTool {
19    fn name(&self) -> &'static str {
20        "bash"
21    }
22
23    fn description(&self) -> &'static str {
24        "Executes a bash command. The working directory persists between invocations. \
25         Use for git, npm, cargo, docker, and other terminal operations. \
26         Optional timeout in milliseconds (max 600000)."
27    }
28
29    fn input_schema(&self) -> Value {
30        serde_json::json!({
31            "type": "object",
32            "properties": {
33                "command": {
34                    "type": "string",
35                    "description": "The bash command to execute"
36                },
37                "timeout": {
38                    "type": "integer",
39                    "description": "Timeout in milliseconds (default: 120000, max: 600000)"
40                }
41            },
42            "required": ["command"]
43        })
44    }
45
46    async fn execute(&self, args: Value, ctx: &ToolContext) -> ToolResult {
47        let command = args
48            .get("command")
49            .and_then(Value::as_str)
50            .ok_or_else(|| ToolError::InvalidArguments("command is required".into()))?;
51
52        let timeout_ms = args
53            .get("timeout")
54            .and_then(Value::as_u64)
55            .unwrap_or(DEFAULT_TIMEOUT_MS)
56            .min(MAX_TIMEOUT_MS);
57
58        let cwd = ctx.cwd.read().await.clone();
59
60        // Use a per-invocation UUID as the CWD sentinel. A fixed sentinel string
61        // could be injected by command output (e.g. `echo __ASTRID_CWD__; echo /evil`),
62        // poisoning the shared `cwd`. A UUID is unguessable, so command output cannot
63        // replicate it and claim a false working directory.
64        let sentinel = Uuid::new_v4().to_string();
65        let wrapped = format!(
66            "{command}\n__ASTRID_EXIT__=$?\necho \"{sentinel}\"\npwd\nexit $__ASTRID_EXIT__"
67        );
68
69        let result = tokio::time::timeout(
70            std::time::Duration::from_millis(timeout_ms),
71            run_bash(&wrapped, &cwd),
72        )
73        .await;
74
75        match result {
76            Ok(Ok((stdout, stderr, exit_code))) => {
77                // Parse stdout: split on sentinel to get output and new cwd
78                let (output, new_cwd) = parse_sentinel_output(&stdout, &sentinel);
79
80                // Update persistent cwd
81                if let Some(new_cwd) = new_cwd {
82                    let mut cwd_lock = ctx.cwd.write().await;
83                    *cwd_lock = new_cwd;
84                }
85
86                let mut result_text = String::new();
87
88                if !output.is_empty() {
89                    result_text.push_str(&output);
90                }
91
92                if !stderr.is_empty() {
93                    if !result_text.is_empty() {
94                        result_text.push('\n');
95                    }
96                    result_text.push_str("STDERR:\n");
97                    result_text.push_str(&stderr);
98                }
99
100                if exit_code != 0 {
101                    if !result_text.is_empty() {
102                        result_text.push('\n');
103                    }
104                    result_text.push_str("(exit code: ");
105                    result_text.push_str(&exit_code.to_string());
106                    result_text.push(')');
107                }
108
109                if result_text.is_empty() {
110                    result_text.push_str("(no output)");
111                }
112
113                Ok(result_text)
114            },
115            Ok(Err(e)) => Err(ToolError::ExecutionFailed(e.to_string())),
116            Err(_) => Err(ToolError::Timeout(timeout_ms)),
117        }
118    }
119}
120
121/// Run a bash command and capture stdout, stderr, and exit code.
122async fn run_bash(command: &str, cwd: &std::path::Path) -> std::io::Result<(String, String, i32)> {
123    let output = Command::new("bash")
124        .arg("-c")
125        .arg(command)
126        .current_dir(cwd)
127        .output()
128        .await?;
129
130    let stdout = String::from_utf8_lossy(&output.stdout).to_string();
131    let stderr = String::from_utf8_lossy(&output.stderr).to_string();
132    let exit_code = output.status.code().unwrap_or(-1);
133
134    Ok((stdout, stderr, exit_code))
135}
136
137/// Parse the sentinel from stdout to extract command output and new cwd.
138///
139/// `sentinel` must be the per-invocation UUID generated before the command ran.
140/// Because the sentinel is unguessable, only the line appended by our shell wrapper
141/// matches; any matching line produced by command output would require the process
142/// to have known the UUID in advance.
143fn parse_sentinel_output(stdout: &str, sentinel: &str) -> (String, Option<PathBuf>) {
144    if let Some(sentinel_pos) = stdout.find(sentinel) {
145        let output = stdout[..sentinel_pos].trim_end().to_string();
146        #[allow(clippy::arithmetic_side_effects)]
147        let after_sentinel = &stdout[sentinel_pos + sentinel.len()..];
148        let new_cwd = after_sentinel
149            .lines()
150            .find(|l| !l.is_empty())
151            .map(|l| PathBuf::from(l.trim()));
152        (output, new_cwd)
153    } else {
154        (stdout.to_string(), None)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use tempfile::TempDir;
162
163    fn ctx_with_root(root: &std::path::Path) -> ToolContext {
164        ToolContext::new(root.to_path_buf(), None)
165    }
166
167    #[tokio::test]
168    async fn test_bash_echo() {
169        let ctx = ctx_with_root(&std::env::temp_dir());
170        let result = BashTool
171            .execute(serde_json::json!({"command": "echo hello"}), &ctx)
172            .await
173            .unwrap();
174
175        assert!(result.contains("hello"));
176    }
177
178    #[tokio::test]
179    async fn test_bash_exit_code() {
180        let ctx = ctx_with_root(&std::env::temp_dir());
181        let result = BashTool
182            .execute(serde_json::json!({"command": "exit 42"}), &ctx)
183            .await
184            .unwrap();
185
186        assert!(result.contains("exit code: 42"));
187    }
188
189    #[tokio::test]
190    async fn test_bash_stderr() {
191        let ctx = ctx_with_root(&std::env::temp_dir());
192        let result = BashTool
193            .execute(serde_json::json!({"command": "echo error >&2"}), &ctx)
194            .await
195            .unwrap();
196
197        assert!(result.contains("STDERR:"));
198        assert!(result.contains("error"));
199    }
200
201    #[tokio::test]
202    async fn test_bash_cwd_persistence() {
203        let dir = TempDir::new().unwrap();
204        let ctx = ctx_with_root(dir.path());
205
206        // Create a subdirectory and cd into it
207        std::fs::create_dir(dir.path().join("subdir")).unwrap();
208        BashTool
209            .execute(serde_json::json!({"command": "cd subdir"}), &ctx)
210            .await
211            .unwrap();
212
213        // Verify cwd was updated
214        let cwd = ctx.cwd.read().await.clone();
215        assert!(cwd.ends_with("subdir"));
216
217        // Next command should run in the new cwd
218        let result = BashTool
219            .execute(serde_json::json!({"command": "pwd"}), &ctx)
220            .await
221            .unwrap();
222
223        assert!(result.contains("subdir"));
224    }
225
226    #[tokio::test]
227    async fn test_bash_timeout() {
228        let ctx = ctx_with_root(&std::env::temp_dir());
229        let result = BashTool
230            .execute(
231                serde_json::json!({"command": "sleep 10", "timeout": 100}),
232                &ctx,
233            )
234            .await;
235
236        assert!(result.is_err());
237        assert!(matches!(result.unwrap_err(), ToolError::Timeout(100)));
238    }
239
240    #[test]
241    fn test_parse_sentinel_output() {
242        let sentinel = "550e8400-e29b-41d4-a716-446655440000";
243        let stdout = format!("hello world\n{sentinel}\n/tmp/test\n");
244        let (output, cwd) = parse_sentinel_output(&stdout, sentinel);
245        assert_eq!(output, "hello world");
246        assert_eq!(cwd, Some(PathBuf::from("/tmp/test")));
247    }
248
249    #[test]
250    fn test_parse_sentinel_no_sentinel() {
251        let sentinel = "550e8400-e29b-41d4-a716-446655440000";
252        let (output, cwd) = parse_sentinel_output("hello world\n", sentinel);
253        assert_eq!(output, "hello world\n");
254        assert!(cwd.is_none());
255    }
256
257    #[test]
258    fn test_parse_sentinel_injection_blocked() {
259        // Simulates a command that outputs the old static sentinel followed by a crafted path.
260        // With a per-invocation UUID, the injected string cannot match the real sentinel.
261        let sentinel = "550e8400-e29b-41d4-a716-446655440000";
262        let injected = "__ASTRID_CWD__";
263        let stdout = format!("{injected}\n/evil\n{sentinel}\n/actual/cwd\n");
264        let (output, cwd) = parse_sentinel_output(&stdout, sentinel);
265        // The injected sentinel appears in output (it's just text), not as the CWD.
266        assert!(output.contains(injected));
267        assert_eq!(cwd, Some(PathBuf::from("/actual/cwd")));
268    }
269}