1use async_trait::async_trait;
4use portable_pty::{CommandBuilder, PtySize};
5use rho_core::tool::{AgentTool, ToolError};
6use rho_core::types::{Content, ToolResult};
7use serde_json::Value;
8use std::path::PathBuf;
9use tokio_util::sync::CancellationToken;
10
11const DEFAULT_TIMEOUT_SECS: u64 = 300;
12const MAX_TIMEOUT_SECS: u64 = 3600;
13const MAX_OUTPUT_BYTES: usize = 102_400; const TRUNCATION_EDGE: usize = 10_240; pub struct BashTool {
17 working_dir: PathBuf,
18}
19
20impl BashTool {
21 pub fn new(working_dir: PathBuf) -> Self {
22 Self { working_dir }
23 }
24}
25
26fn truncate_output(output: &str) -> String {
29 if output.len() <= MAX_OUTPUT_BYTES {
30 return output.to_string();
31 }
32 let total = output.len();
33
34 let head_end = {
36 let mut end = TRUNCATION_EDGE;
37 while end > 0 && !output.is_char_boundary(end) {
38 end -= 1;
39 }
40 end
41 };
42
43 let tail_start = {
45 let mut start = total.saturating_sub(TRUNCATION_EDGE);
46 while start < total && !output.is_char_boundary(start) {
47 start += 1;
48 }
49 start
50 };
51
52 format!(
53 "{}\n\n[...output truncated ({} bytes total)...]\n\n{}",
54 &output[..head_end],
55 total,
56 &output[tail_start..],
57 )
58}
59
60#[async_trait]
61impl AgentTool for BashTool {
62 fn name(&self) -> &str {
63 "bash"
64 }
65
66 fn label(&self) -> String {
67 "Bash".to_string()
68 }
69
70 fn description(&self) -> String {
71 "Execute a shell command and return stdout/stderr.".to_string()
72 }
73
74 fn parameters_schema(&self) -> Value {
75 serde_json::json!({
76 "type": "object",
77 "properties": {
78 "command": {
79 "type": "string",
80 "description": "The command to execute"
81 },
82 "timeout": {
83 "type": "integer",
84 "description": "Timeout in seconds (default 300, max 3600)"
85 }
86 },
87 "required": ["command"]
88 })
89 }
90
91 async fn execute(
92 &self,
93 _tool_call_id: &str,
94 params: Value,
95 _cancel: CancellationToken,
96 ) -> Result<ToolResult, ToolError> {
97 let command = params
98 .get("command")
99 .and_then(|v| v.as_str())
100 .ok_or_else(|| {
101 ToolError::InvalidParameters("missing or invalid 'command' parameter".into())
102 })?;
103
104 let timeout_secs = params
105 .get("timeout")
106 .and_then(|v| v.as_u64())
107 .unwrap_or(DEFAULT_TIMEOUT_SECS)
108 .min(MAX_TIMEOUT_SECS);
109
110 let pty_system = portable_pty::native_pty_system();
112 let pair = pty_system
113 .openpty(PtySize::default())
114 .map_err(|e| ToolError::ExecutionFailed(format!("failed to open pty: {e}")))?;
115
116 let mut cmd = CommandBuilder::new("bash");
117 cmd.arg("-c");
118 cmd.arg(command);
119 cmd.cwd(&self.working_dir);
120
121 let mut child = pair
123 .slave
124 .spawn_command(cmd)
125 .map_err(|e| ToolError::ExecutionFailed(format!("failed to spawn command: {e}")))?;
126
127 drop(pair.slave);
129
130 let mut reader = pair
131 .master
132 .try_clone_reader()
133 .map_err(|e| ToolError::ExecutionFailed(format!("failed to clone pty reader: {e}")))?;
134
135 let killer = child.clone_killer();
137
138 let read_handle = tokio::task::spawn_blocking(move || {
140 let mut buf = Vec::new();
141 let _ = std::io::Read::read_to_end(&mut reader, &mut buf);
142 buf
143 });
144
145 let wait_handle = tokio::task::spawn_blocking(move || child.wait());
147
148 let timeout_duration = std::time::Duration::from_secs(timeout_secs);
150 match tokio::time::timeout(timeout_duration, async {
151 let output_bytes = read_handle
152 .await
153 .map_err(|e| ToolError::ExecutionFailed(format!("read task panicked: {e}")))?;
154 let exit_status = wait_handle
155 .await
156 .map_err(|e| ToolError::ExecutionFailed(format!("wait task panicked: {e}")))?
157 .map_err(|e| ToolError::ExecutionFailed(format!("failed to wait on child: {e}")))?;
158 Ok::<_, ToolError>((output_bytes, exit_status))
159 })
160 .await
161 {
162 Ok(Ok((output_bytes, exit_status))) => {
163 let output = String::from_utf8_lossy(&output_bytes);
164 let output = truncate_output(&output);
165 let exit_code = exit_status.exit_code();
166
167 if exit_code == 0 {
168 Ok(ToolResult {
169 content: vec![Content::Text { text: output }],
170 details: serde_json::json!({"exit_code": exit_code}),
171 })
172 } else {
173 Ok(ToolResult {
174 content: vec![Content::Text {
175 text: format!("{output}\n\nExit code: {exit_code}"),
176 }],
177 details: serde_json::json!({"exit_code": exit_code}),
178 })
179 }
180 }
181 Ok(Err(e)) => Err(e),
182 Err(_) => {
183 let mut killer = killer;
185 let _ = killer.kill();
186
187 let partial = "(timeout — command did not complete)";
189
190 Ok(ToolResult {
191 content: vec![Content::Text {
192 text: format!(
193 "Command timed out after {timeout_secs} seconds.\n{partial}"
194 ),
195 }],
196 details: serde_json::json!({"timeout": true}),
197 })
198 }
199 }
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use std::time::Instant;
207
208 fn tool_in(dir: &std::path::Path) -> BashTool {
209 BashTool::new(dir.to_path_buf())
210 }
211
212 fn cancel() -> CancellationToken {
213 CancellationToken::new()
214 }
215
216 fn text_of(result: &ToolResult) -> &str {
217 match &result.content[0] {
218 Content::Text { text } => text.as_str(),
219 _ => panic!("expected Text content"),
220 }
221 }
222
223 #[tokio::test]
224 async fn simple_echo() {
225 let dir = tempfile::tempdir().unwrap();
226 let tool = tool_in(dir.path());
227 let params = serde_json::json!({"command": "echo hello"});
228 let result = tool.execute("c1", params, cancel()).await.unwrap();
229 let output = text_of(&result);
230 assert!(
231 output.contains("hello"),
232 "expected 'hello' in output: {output}"
233 );
234 assert_eq!(result.details["exit_code"], 0);
235 }
236
237 #[tokio::test]
238 async fn working_directory_respected() {
239 let dir = tempfile::tempdir().unwrap();
240 let tool = tool_in(dir.path());
241 let params = serde_json::json!({"command": "pwd"});
242 let result = tool.execute("c2", params, cancel()).await.unwrap();
243 let output = text_of(&result);
244 let expected = dir.path().canonicalize().unwrap();
246 let actual_trimmed = output.trim();
247 let actual = std::path::Path::new(actual_trimmed)
248 .canonicalize()
249 .unwrap_or_else(|_| std::path::PathBuf::from(actual_trimmed));
250 assert_eq!(actual, expected);
251 }
252
253 #[tokio::test]
254 async fn nonzero_exit_code() {
255 let dir = tempfile::tempdir().unwrap();
256 let tool = tool_in(dir.path());
257 let params = serde_json::json!({"command": "exit 42"});
258 let result = tool.execute("c3", params, cancel()).await.unwrap();
259 assert_eq!(result.details["exit_code"], 42);
260 let output = text_of(&result);
261 assert!(
262 output.contains("Exit code: 42"),
263 "expected exit code info: {output}"
264 );
265 }
266
267 #[tokio::test]
268 async fn output_truncation() {
269 let dir = tempfile::tempdir().unwrap();
270 let tool = tool_in(dir.path());
271 let params = serde_json::json!({"command": "seq 1 100000"});
273 let result = tool.execute("c4", params, cancel()).await.unwrap();
274 let output = text_of(&result);
275 assert!(
276 output.contains("[...output truncated"),
277 "expected truncation marker: (output len = {})",
278 output.len()
279 );
280 assert!(output.len() < MAX_OUTPUT_BYTES);
282 }
283
284 #[tokio::test]
285 async fn timeout() {
286 let dir = tempfile::tempdir().unwrap();
287 let tool = tool_in(dir.path());
288 let params = serde_json::json!({"command": "sleep 30", "timeout": 1});
289 let start = Instant::now();
290 let result = tool.execute("c5", params, cancel()).await.unwrap();
291 let elapsed = start.elapsed();
292 let output = text_of(&result);
293 assert!(
294 output.contains("timed out"),
295 "expected timeout message: {output}"
296 );
297 assert!(result.details["timeout"] == true);
298 assert!(elapsed.as_secs() < 10, "took too long: {elapsed:?}");
300 }
301
302 #[tokio::test]
303 async fn missing_command_parameter() {
304 let dir = tempfile::tempdir().unwrap();
305 let tool = tool_in(dir.path());
306 let params = serde_json::json!({});
307 let err = tool.execute("c6", params, cancel()).await.unwrap_err();
308 match err {
309 ToolError::InvalidParameters(msg) => assert!(msg.contains("command")),
310 _ => panic!("expected InvalidParameters, got: {err:?}"),
311 }
312 }
313}