Skip to main content

codetether_agent/tool/
bash.rs

1//! Bash tool: execute shell commands
2
3use super::sandbox::{SandboxPolicy, execute_sandboxed};
4use super::{Tool, ToolResult};
5use crate::audit::{AuditCategory, AuditOutcome, try_audit_log};
6use anyhow::Result;
7use async_trait::async_trait;
8use serde_json::{Value, json};
9use std::process::Stdio;
10use std::time::Instant;
11use tokio::process::Command;
12use tokio::time::{Duration, timeout};
13
14use crate::telemetry::{TOOL_EXECUTIONS, ToolExecution, record_persistent};
15
16/// Execute shell commands
17pub struct BashTool {
18    timeout_secs: u64,
19    /// When true, execute commands through the sandbox with restricted env.
20    sandboxed: bool,
21}
22
23impl BashTool {
24    pub fn new() -> Self {
25        let sandboxed = std::env::var("CODETETHER_SANDBOX_BASH")
26            .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
27            .unwrap_or(false);
28        Self {
29            timeout_secs: 120,
30            sandboxed,
31        }
32    }
33
34    /// Create a new BashTool with a custom timeout
35    #[allow(dead_code)]
36    pub fn with_timeout(timeout_secs: u64) -> Self {
37        Self {
38            timeout_secs,
39            sandboxed: false,
40        }
41    }
42
43    /// Create a sandboxed BashTool
44    #[allow(dead_code)]
45    pub fn sandboxed() -> Self {
46        Self {
47            timeout_secs: 120,
48            sandboxed: true,
49        }
50    }
51}
52
53fn interactive_auth_risk_reason(command: &str) -> Option<&'static str> {
54    let lower = command.to_ascii_lowercase();
55
56    let has_sudo = lower.starts_with("sudo ")
57        || lower.contains(";sudo ")
58        || lower.contains("&& sudo ")
59        || lower.contains("|| sudo ")
60        || lower.contains("| sudo ");
61    let sudo_non_interactive =
62        lower.contains("sudo -n") || lower.contains("sudo --non-interactive");
63    if has_sudo && !sudo_non_interactive {
64        return Some("Command uses sudo without non-interactive mode (-n).");
65    }
66
67    let has_ssh_family = lower.starts_with("ssh ")
68        || lower.contains(";ssh ")
69        || lower.starts_with("scp ")
70        || lower.contains(";scp ")
71        || lower.starts_with("sftp ")
72        || lower.contains(";sftp ")
73        || lower.contains(" rsync ");
74    if has_ssh_family && !lower.contains("batchmode=yes") {
75        return Some(
76            "SSH-family command may prompt for password/passphrase (missing -o BatchMode=yes).",
77        );
78    }
79
80    if lower.starts_with("su ")
81        || lower.contains(";su ")
82        || lower.contains(" passwd ")
83        || lower.starts_with("passwd")
84        || lower.contains("ssh-add")
85    {
86        return Some("Command is interactive and may require a password prompt.");
87    }
88
89    None
90}
91
92fn looks_like_auth_prompt(output: &str) -> bool {
93    let lower = output.to_ascii_lowercase();
94    [
95        "[sudo] password for",
96        "password:",
97        "passphrase",
98        "no tty present and no askpass program specified",
99        "a terminal is required to read the password",
100        "could not read password",
101        "permission denied (publickey,password",
102    ]
103    .iter()
104    .any(|needle| lower.contains(needle))
105}
106
107#[async_trait]
108impl Tool for BashTool {
109    fn id(&self) -> &str {
110        "bash"
111    }
112
113    fn name(&self) -> &str {
114        "Bash"
115    }
116
117    fn description(&self) -> &str {
118        "bash(command: string, cwd?: string, timeout?: int) - Execute a shell command. Commands run in a bash shell with the current working directory."
119    }
120
121    fn parameters(&self) -> Value {
122        json!({
123            "type": "object",
124            "properties": {
125                "command": {
126                    "type": "string",
127                    "description": "The shell command to execute"
128                },
129                "cwd": {
130                    "type": "string",
131                    "description": "Working directory for the command (optional)"
132                },
133                "timeout": {
134                    "type": "integer",
135                    "description": "Timeout in seconds (default: 120)"
136                }
137            },
138            "required": ["command"],
139            "example": {
140                "command": "ls -la src/",
141                "cwd": "/path/to/project"
142            }
143        })
144    }
145
146    async fn execute(&self, args: Value) -> Result<ToolResult> {
147        let exec_start = Instant::now();
148
149        let command = match args["command"].as_str() {
150            Some(c) => c,
151            None => {
152                return Ok(ToolResult::structured_error(
153                    "INVALID_ARGUMENT",
154                    "bash",
155                    "command is required",
156                    Some(vec!["command"]),
157                    Some(json!({"command": "ls -la", "cwd": "."})),
158                ));
159            }
160        };
161        let cwd = args["cwd"].as_str();
162        let timeout_secs = args["timeout"].as_u64().unwrap_or(self.timeout_secs);
163
164        if let Some(reason) = interactive_auth_risk_reason(command) {
165            // Log warning but don't block anymore per user request
166            tracing::warn!("Interactive auth risk detected: {}", reason);
167        }
168
169        // Sandboxed execution path: restricted env, resource limits, audit logged
170        if self.sandboxed {
171            let policy = SandboxPolicy {
172                allowed_paths: cwd
173                    .map(|d| vec![std::path::PathBuf::from(d)])
174                    .unwrap_or_default(),
175                allow_network: false,
176                allow_exec: true,
177                timeout_secs,
178                ..SandboxPolicy::default()
179            };
180            let work_dir = cwd.map(std::path::Path::new);
181            let sandbox_result = execute_sandboxed(
182                "bash",
183                &["-c".to_string(), command.to_string()],
184                &policy,
185                work_dir,
186            )
187            .await;
188
189            // Audit log the sandboxed execution
190            if let Some(audit) = try_audit_log() {
191                let (outcome, detail) = match &sandbox_result {
192                    Ok(r) => (
193                        if r.success {
194                            AuditOutcome::Success
195                        } else {
196                            AuditOutcome::Failure
197                        },
198                        json!({
199                            "sandboxed": true,
200                            "exit_code": r.exit_code,
201                            "duration_ms": r.duration_ms,
202                            "violations": r.sandbox_violations,
203                        }),
204                    ),
205                    Err(e) => (
206                        AuditOutcome::Failure,
207                        json!({ "sandboxed": true, "error": e.to_string() }),
208                    ),
209                };
210                audit
211                    .log(
212                        AuditCategory::Sandbox,
213                        format!("bash:{}", &command[..command.len().min(80)]),
214                        outcome,
215                        None,
216                        Some(detail),
217                    )
218                    .await;
219            }
220
221            return match sandbox_result {
222                Ok(r) => {
223                    let duration = exec_start.elapsed();
224                    let exec = ToolExecution::start(
225                        "bash",
226                        json!({ "command": command, "sandboxed": true }),
227                    );
228                    let exec = if r.success {
229                        exec.complete_success(format!("exit_code={:?}", r.exit_code), duration)
230                    } else {
231                        exec.complete_error(format!("exit_code={:?}", r.exit_code), duration)
232                    };
233                    TOOL_EXECUTIONS.record(exec.success);
234                    let data = serde_json::json!({
235                        "tool": "bash",
236                        "command": command,
237                        "success": r.success,
238                        "exit_code": r.exit_code,
239                    });
240                    let _ = record_persistent("tool_execution", &data);
241
242                    Ok(ToolResult {
243                        output: r.output,
244                        success: r.success,
245                        metadata: [
246                            ("exit_code".to_string(), json!(r.exit_code)),
247                            ("sandboxed".to_string(), json!(true)),
248                            (
249                                "sandbox_violations".to_string(),
250                                json!(r.sandbox_violations),
251                            ),
252                        ]
253                        .into_iter()
254                        .collect(),
255                    })
256                }
257                Err(e) => {
258                    let duration = exec_start.elapsed();
259                    let exec = ToolExecution::start(
260                        "bash",
261                        json!({ "command": command, "sandboxed": true }),
262                    )
263                    .complete_error(e.to_string(), duration);
264                    TOOL_EXECUTIONS.record(exec.success);
265                    let data = serde_json::json!({
266                        "tool": "bash",
267                        "command": command,
268                        "success": false,
269                        "error": e.to_string(),
270                    });
271                    let _ = record_persistent("tool_execution", &data);
272                    Ok(ToolResult::error(format!("Sandbox error: {}", e)))
273                }
274            };
275        }
276
277        let mut cmd = Command::new("bash");
278        cmd.arg("-c")
279            .arg(command)
280            .stdin(Stdio::null())
281            .stdout(Stdio::piped())
282            .stderr(Stdio::piped())
283            .env("GIT_TERMINAL_PROMPT", "0")
284            .env("GCM_INTERACTIVE", "never")
285            .env("DEBIAN_FRONTEND", "noninteractive")
286            .env("SUDO_ASKPASS", "/bin/false")
287            .env("SSH_ASKPASS", "/bin/false");
288
289        if let Some(dir) = cwd {
290            cmd.current_dir(dir);
291        }
292
293        let result = timeout(Duration::from_secs(timeout_secs), cmd.output()).await;
294
295        match result {
296            Ok(Ok(output)) => {
297                let stdout = String::from_utf8_lossy(&output.stdout);
298                let stderr = String::from_utf8_lossy(&output.stderr);
299                let exit_code = output.status.code().unwrap_or(-1);
300
301                let combined = if stderr.is_empty() {
302                    stdout.to_string()
303                } else if stdout.is_empty() {
304                    stderr.to_string()
305                } else {
306                    format!("{}\n--- stderr ---\n{}", stdout, stderr)
307                };
308
309                let success = output.status.success();
310
311                if !success && looks_like_auth_prompt(&combined) {
312                    tracing::warn!("Interactive auth prompt detected in output");
313                }
314
315                // Truncate if too long
316                let max_len = 50_000;
317                let (output_str, truncated) = if combined.len() > max_len {
318                    let truncated_output = format!(
319                        "{}...\n[Output truncated, {} bytes total]",
320                        &combined[..max_len],
321                        combined.len()
322                    );
323                    (truncated_output, true)
324                } else {
325                    (combined.clone(), false)
326                };
327
328                let duration = exec_start.elapsed();
329
330                // Record telemetry
331                let exec = ToolExecution::start(
332                    "bash",
333                    json!({
334                        "command": command,
335                        "cwd": cwd,
336                        "timeout": timeout_secs,
337                    }),
338                );
339                let exec = if success {
340                    exec.complete_success(
341                        format!("exit_code={}, output_len={}", exit_code, combined.len()),
342                        duration,
343                    )
344                } else {
345                    exec.complete_error(
346                        format!(
347                            "exit_code={}: {}",
348                            exit_code,
349                            combined.lines().next().unwrap_or("(no output)")
350                        ),
351                        duration,
352                    )
353                };
354                TOOL_EXECUTIONS.record(exec.success);
355                let _ = record_persistent("tool_execution", &serde_json::to_value(&exec).unwrap_or_default());
356
357                Ok(ToolResult {
358                    output: output_str,
359                    success,
360                    metadata: [
361                        ("exit_code".to_string(), json!(exit_code)),
362                        ("truncated".to_string(), json!(truncated)),
363                    ]
364                    .into_iter()
365                    .collect(),
366                })
367            }
368            Ok(Err(e)) => {
369                let duration = exec_start.elapsed();
370                let exec = ToolExecution::start(
371                    "bash",
372                    json!({
373                        "command": command,
374                        "cwd": cwd,
375                    }),
376                )
377                .complete_error(format!("Failed to execute: {}", e), duration);
378                TOOL_EXECUTIONS.record(exec.success);
379                let _ = record_persistent("tool_execution", &serde_json::to_value(&exec).unwrap_or_default());
380
381                Ok(ToolResult::structured_error(
382                    "EXECUTION_FAILED",
383                    "bash",
384                    &format!("Failed to execute command: {}", e),
385                    None,
386                    Some(json!({"command": command})),
387                ))
388            }
389            Err(_) => {
390                let duration = exec_start.elapsed();
391                let exec = ToolExecution::start(
392                    "bash",
393                    json!({
394                        "command": command,
395                        "cwd": cwd,
396                    }),
397                )
398                .complete_error(format!("Timeout after {}s", timeout_secs), duration);
399                TOOL_EXECUTIONS.record(exec.success);
400                let _ = record_persistent("tool_execution", &serde_json::to_value(&exec).unwrap_or_default());
401
402                Ok(ToolResult::structured_error(
403                    "TIMEOUT",
404                    "bash",
405                    &format!("Command timed out after {} seconds", timeout_secs),
406                    None,
407                    Some(json!({
408                        "command": command,
409                        "hint": "Consider increasing timeout or breaking into smaller commands"
410                    })),
411                ))
412            }
413        }
414    }
415}
416
417impl Default for BashTool {
418    fn default() -> Self {
419        Self::new()
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    #[tokio::test]
428    async fn sandboxed_bash_basic() {
429        let tool = BashTool {
430            timeout_secs: 10,
431            sandboxed: true,
432        };
433        let result = tool
434            .execute(json!({ "command": "echo hello sandbox" }))
435            .await
436            .unwrap();
437        assert!(result.success);
438        assert!(result.output.contains("hello sandbox"));
439        assert_eq!(result.metadata.get("sandboxed"), Some(&json!(true)));
440    }
441
442    #[tokio::test]
443    async fn sandboxed_bash_timeout() {
444        let tool = BashTool {
445            timeout_secs: 1,
446            sandboxed: true,
447        };
448        let result = tool
449            .execute(json!({ "command": "sleep 30" }))
450            .await
451            .unwrap();
452        assert!(!result.success);
453    }
454
455    #[test]
456    fn detects_interactive_auth_risk() {
457        assert!(interactive_auth_risk_reason("sudo apt update").is_some());
458        assert!(interactive_auth_risk_reason("ssh user@host").is_some());
459        assert!(interactive_auth_risk_reason("sudo -n apt update").is_none());
460        assert!(interactive_auth_risk_reason("ssh -o BatchMode=yes user@host").is_none());
461    }
462
463    #[test]
464    fn detects_auth_prompt_output() {
465        assert!(looks_like_auth_prompt("[sudo] password for riley:"));
466        assert!(looks_like_auth_prompt(
467            "sudo: a terminal is required to read the password"
468        ));
469        assert!(!looks_like_auth_prompt("command completed successfully"));
470    }
471}