Skip to main content

hh_cli/tool/
bash.rs

1use crate::tool::{Tool, ToolResult, ToolSchema};
2use async_trait::async_trait;
3use serde_json::{Value, json};
4use tokio::process::Command;
5use tokio::time::{Duration, timeout};
6
7pub struct BashTool {
8    denylist: Vec<String>,
9}
10
11impl Default for BashTool {
12    fn default() -> Self {
13        Self::new()
14    }
15}
16
17impl BashTool {
18    pub fn new() -> Self {
19        Self {
20            denylist: vec![
21                "rm -rf /".to_string(),
22                "mkfs".to_string(),
23                "shutdown".to_string(),
24                "reboot".to_string(),
25            ],
26        }
27    }
28
29    fn denied(&self, command: &str) -> bool {
30        self.denylist.iter().any(|needle| command.contains(needle))
31    }
32}
33
34#[async_trait]
35impl Tool for BashTool {
36    fn schema(&self) -> ToolSchema {
37        ToolSchema {
38            name: "bash".to_string(),
39            description: "Run a shell command".to_string(),
40            capability: Some("bash".to_string()),
41            mutating: Some(true),
42            parameters: json!({
43                "type": "object",
44                "properties": {
45                    "command": {"type": "string"},
46                    "timeout_ms": {"type": "integer", "minimum": 1}
47                },
48                "required": ["command"]
49            }),
50        }
51    }
52
53    async fn execute(&self, args: Value) -> ToolResult {
54        let command = args
55            .get("command")
56            .and_then(|v| v.as_str())
57            .unwrap_or_default();
58        let timeout_ms = args
59            .get("timeout_ms")
60            .and_then(|v| v.as_u64())
61            .unwrap_or(60_000);
62
63        if self.denied(command) {
64            return ToolResult::err_json(
65                "blocked",
66                json!({
67                    "status": "blocked",
68                    "ok": false,
69                    "command": command,
70                    "error": "command blocked by denylist"
71                }),
72            );
73        }
74
75        let fut = Command::new("sh").arg("-lc").arg(command).output();
76        let output = match timeout(Duration::from_millis(timeout_ms), fut).await {
77            Ok(Ok(out)) => out,
78            Ok(Err(err)) => {
79                return ToolResult::err_json(
80                    "spawn_error",
81                    json!({
82                        "status": "spawn_error",
83                        "ok": false,
84                        "command": command,
85                        "error": err.to_string()
86                    }),
87                );
88            }
89            Err(_) => {
90                return ToolResult::err_json(
91                    "timeout",
92                    json!({
93                        "status": "timeout",
94                        "ok": false,
95                        "command": command,
96                        "timeout_ms": timeout_ms,
97                        "error": format!("command timed out after {} ms", timeout_ms)
98                    }),
99                );
100            }
101        };
102
103        let stdout = String::from_utf8_lossy(&output.stdout);
104        let stderr = String::from_utf8_lossy(&output.stderr);
105        let combined = if stderr.trim().is_empty() {
106            stdout.to_string()
107        } else if stdout.trim().is_empty() {
108            stderr.to_string()
109        } else {
110            format!("{}\n{}", stdout, stderr)
111        };
112
113        let is_error = !output.status.success();
114        let status_text = if is_error { "error" } else { "success" };
115        let exit_code = output.status.code();
116
117        let payload = json!({
118            "status": status_text,
119            "ok": !is_error,
120            "command": command,
121            "exit_code": exit_code,
122            "stdout": stdout,
123            "stderr": stderr,
124            "output": combined
125        });
126
127        if is_error {
128            ToolResult::err_json("command_failed", payload)
129        } else {
130            ToolResult::ok_json("command_succeeded", payload)
131        }
132    }
133}