Skip to main content

rho_tools/
bash.rs

1// Bash tool — PTY command execution with timeout and output truncation
2
3use async_trait::async_trait;
4use portable_pty::{CommandBuilder, PtySize};
5use rho_core::tool::{AgentTool, ToolError};
6use rho_core::types::{Content, ToolResult};
7use serde_json::Value;
8use std::path::PathBuf;
9use tokio_util::sync::CancellationToken;
10
11const DEFAULT_TIMEOUT_SECS: u64 = 300;
12const MAX_TIMEOUT_SECS: u64 = 3600;
13const MAX_OUTPUT_BYTES: usize = 102_400; // 100KB
14const TRUNCATION_EDGE: usize = 10_240; // 10KB kept at each end
15
16pub struct BashTool {
17    working_dir: PathBuf,
18}
19
20impl BashTool {
21    pub fn new(working_dir: PathBuf) -> Self {
22        Self { working_dir }
23    }
24}
25
26/// Truncate output that exceeds MAX_OUTPUT_BYTES, keeping the first and last
27/// TRUNCATION_EDGE bytes with a marker in the middle.
28fn truncate_output(output: &str) -> String {
29    if output.len() <= MAX_OUTPUT_BYTES {
30        return output.to_string();
31    }
32    let total = output.len();
33
34    // Find a valid UTF-8 boundary for the head slice
35    let head_end = {
36        let mut end = TRUNCATION_EDGE;
37        while end > 0 && !output.is_char_boundary(end) {
38            end -= 1;
39        }
40        end
41    };
42
43    // Find a valid UTF-8 boundary for the tail slice
44    let tail_start = {
45        let mut start = total.saturating_sub(TRUNCATION_EDGE);
46        while start < total && !output.is_char_boundary(start) {
47            start += 1;
48        }
49        start
50    };
51
52    format!(
53        "{}\n\n[...output truncated ({} bytes total)...]\n\n{}",
54        &output[..head_end],
55        total,
56        &output[tail_start..],
57    )
58}
59
60#[async_trait]
61impl AgentTool for BashTool {
62    fn name(&self) -> &str {
63        "bash"
64    }
65
66    fn label(&self) -> String {
67        "Bash".to_string()
68    }
69
70    fn description(&self) -> String {
71        "Execute a shell command and return stdout/stderr.".to_string()
72    }
73
74    fn parameters_schema(&self) -> Value {
75        serde_json::json!({
76            "type": "object",
77            "properties": {
78                "command": {
79                    "type": "string",
80                    "description": "The command to execute"
81                },
82                "timeout": {
83                    "type": "integer",
84                    "description": "Timeout in seconds (default 300, max 3600)"
85                }
86            },
87            "required": ["command"]
88        })
89    }
90
91    async fn execute(
92        &self,
93        _tool_call_id: &str,
94        params: Value,
95        _cancel: CancellationToken,
96    ) -> Result<ToolResult, ToolError> {
97        let command = params
98            .get("command")
99            .and_then(|v| v.as_str())
100            .ok_or_else(|| {
101                ToolError::InvalidParameters("missing or invalid 'command' parameter".into())
102            })?;
103
104        let timeout_secs = params
105            .get("timeout")
106            .and_then(|v| v.as_u64())
107            .unwrap_or(DEFAULT_TIMEOUT_SECS)
108            .min(MAX_TIMEOUT_SECS);
109
110        // Set up the PTY
111        let pty_system = portable_pty::native_pty_system();
112        let pair = pty_system
113            .openpty(PtySize::default())
114            .map_err(|e| ToolError::ExecutionFailed(format!("failed to open pty: {e}")))?;
115
116        let mut cmd = CommandBuilder::new("bash");
117        cmd.arg("-c");
118        cmd.arg(command);
119        cmd.cwd(&self.working_dir);
120
121        // Spawn the child process
122        let mut child = pair
123            .slave
124            .spawn_command(cmd)
125            .map_err(|e| ToolError::ExecutionFailed(format!("failed to spawn command: {e}")))?;
126
127        // Drop slave so the master reader gets EOF when the child exits
128        drop(pair.slave);
129
130        let mut reader = pair
131            .master
132            .try_clone_reader()
133            .map_err(|e| ToolError::ExecutionFailed(format!("failed to clone pty reader: {e}")))?;
134
135        // Clone a killer handle so we can kill the child on timeout
136        let killer = child.clone_killer();
137
138        // Read output in a blocking thread (PTY reader is blocking I/O)
139        let read_handle = tokio::task::spawn_blocking(move || {
140            let mut buf = Vec::new();
141            let _ = std::io::Read::read_to_end(&mut reader, &mut buf);
142            buf
143        });
144
145        // Wait for the child in a blocking thread
146        let wait_handle = tokio::task::spawn_blocking(move || child.wait());
147
148        // Apply timeout to both operations
149        let timeout_duration = std::time::Duration::from_secs(timeout_secs);
150        match tokio::time::timeout(timeout_duration, async {
151            let output_bytes = read_handle
152                .await
153                .map_err(|e| ToolError::ExecutionFailed(format!("read task panicked: {e}")))?;
154            let exit_status = wait_handle
155                .await
156                .map_err(|e| ToolError::ExecutionFailed(format!("wait task panicked: {e}")))?
157                .map_err(|e| ToolError::ExecutionFailed(format!("failed to wait on child: {e}")))?;
158            Ok::<_, ToolError>((output_bytes, exit_status))
159        })
160        .await
161        {
162            Ok(Ok((output_bytes, exit_status))) => {
163                let output = String::from_utf8_lossy(&output_bytes);
164                let output = truncate_output(&output);
165                let exit_code = exit_status.exit_code();
166
167                if exit_code == 0 {
168                    Ok(ToolResult {
169                        content: vec![Content::Text { text: output }],
170                        details: serde_json::json!({"exit_code": exit_code}),
171                    })
172                } else {
173                    Ok(ToolResult {
174                        content: vec![Content::Text {
175                            text: format!("{output}\n\nExit code: {exit_code}"),
176                        }],
177                        details: serde_json::json!({"exit_code": exit_code}),
178                    })
179                }
180            }
181            Ok(Err(e)) => Err(e),
182            Err(_) => {
183                // Timeout: kill the child process
184                let mut killer = killer;
185                let _ = killer.kill();
186
187                // Try to collect any partial output that was read
188                let partial = "(timeout — command did not complete)";
189
190                Ok(ToolResult {
191                    content: vec![Content::Text {
192                        text: format!(
193                            "Command timed out after {timeout_secs} seconds.\n{partial}"
194                        ),
195                    }],
196                    details: serde_json::json!({"timeout": true}),
197                })
198            }
199        }
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use std::time::Instant;
207
208    fn tool_in(dir: &std::path::Path) -> BashTool {
209        BashTool::new(dir.to_path_buf())
210    }
211
212    fn cancel() -> CancellationToken {
213        CancellationToken::new()
214    }
215
216    fn text_of(result: &ToolResult) -> &str {
217        match &result.content[0] {
218            Content::Text { text } => text.as_str(),
219            _ => panic!("expected Text content"),
220        }
221    }
222
223    #[tokio::test]
224    async fn simple_echo() {
225        let dir = tempfile::tempdir().unwrap();
226        let tool = tool_in(dir.path());
227        let params = serde_json::json!({"command": "echo hello"});
228        let result = tool.execute("c1", params, cancel()).await.unwrap();
229        let output = text_of(&result);
230        assert!(
231            output.contains("hello"),
232            "expected 'hello' in output: {output}"
233        );
234        assert_eq!(result.details["exit_code"], 0);
235    }
236
237    #[tokio::test]
238    async fn working_directory_respected() {
239        let dir = tempfile::tempdir().unwrap();
240        let tool = tool_in(dir.path());
241        let params = serde_json::json!({"command": "pwd"});
242        let result = tool.execute("c2", params, cancel()).await.unwrap();
243        let output = text_of(&result);
244        // On macOS, /tmp -> /private/tmp, so canonicalize both
245        let expected = dir.path().canonicalize().unwrap();
246        let actual_trimmed = output.trim();
247        let actual = std::path::Path::new(actual_trimmed)
248            .canonicalize()
249            .unwrap_or_else(|_| std::path::PathBuf::from(actual_trimmed));
250        assert_eq!(actual, expected);
251    }
252
253    #[tokio::test]
254    async fn nonzero_exit_code() {
255        let dir = tempfile::tempdir().unwrap();
256        let tool = tool_in(dir.path());
257        let params = serde_json::json!({"command": "exit 42"});
258        let result = tool.execute("c3", params, cancel()).await.unwrap();
259        assert_eq!(result.details["exit_code"], 42);
260        let output = text_of(&result);
261        assert!(
262            output.contains("Exit code: 42"),
263            "expected exit code info: {output}"
264        );
265    }
266
267    #[tokio::test]
268    async fn output_truncation() {
269        let dir = tempfile::tempdir().unwrap();
270        let tool = tool_in(dir.path());
271        // Generate output well over 100KB
272        let params = serde_json::json!({"command": "seq 1 100000"});
273        let result = tool.execute("c4", params, cancel()).await.unwrap();
274        let output = text_of(&result);
275        assert!(
276            output.contains("[...output truncated"),
277            "expected truncation marker: (output len = {})",
278            output.len()
279        );
280        // The truncated output should be around 20KB + marker, well under 100KB
281        assert!(output.len() < MAX_OUTPUT_BYTES);
282    }
283
284    #[tokio::test]
285    async fn timeout() {
286        let dir = tempfile::tempdir().unwrap();
287        let tool = tool_in(dir.path());
288        let params = serde_json::json!({"command": "sleep 30", "timeout": 1});
289        let start = Instant::now();
290        let result = tool.execute("c5", params, cancel()).await.unwrap();
291        let elapsed = start.elapsed();
292        let output = text_of(&result);
293        assert!(
294            output.contains("timed out"),
295            "expected timeout message: {output}"
296        );
297        assert!(result.details["timeout"] == true);
298        // Should complete in roughly 1-3 seconds, not 30
299        assert!(elapsed.as_secs() < 10, "took too long: {elapsed:?}");
300    }
301
302    #[tokio::test]
303    async fn missing_command_parameter() {
304        let dir = tempfile::tempdir().unwrap();
305        let tool = tool_in(dir.path());
306        let params = serde_json::json!({});
307        let err = tool.execute("c6", params, cancel()).await.unwrap_err();
308        match err {
309            ToolError::InvalidParameters(msg) => assert!(msg.contains("command")),
310            _ => panic!("expected InvalidParameters, got: {err:?}"),
311        }
312    }
313}