Skip to main content

rustant_tools/
shell.rs

1//! Shell command execution tool with streaming output support.
2
3use crate::registry::Tool;
4use async_trait::async_trait;
5use rustant_core::error::ToolError;
6use rustant_core::types::{ProgressUpdate, RiskLevel, ToolOutput};
7use std::path::PathBuf;
8use std::time::Duration;
9use tokio::io::{AsyncBufReadExt, BufReader};
10use tokio::sync::mpsc;
11use tracing::{debug, warn};
12
13/// Execute shell commands within the workspace.
14///
15/// Supports optional streaming of stdout/stderr lines via a progress channel.
16pub struct ShellExecTool {
17    workspace: PathBuf,
18    /// Optional channel for streaming progress updates (shell output lines).
19    progress_tx: Option<mpsc::UnboundedSender<ProgressUpdate>>,
20}
21
22impl ShellExecTool {
23    pub fn new(workspace: PathBuf) -> Self {
24        Self {
25            workspace,
26            progress_tx: None,
27        }
28    }
29
30    /// Create a shell tool with a progress sender for streaming output.
31    pub fn with_progress(workspace: PathBuf, tx: mpsc::UnboundedSender<ProgressUpdate>) -> Self {
32        Self {
33            workspace,
34            progress_tx: Some(tx),
35        }
36    }
37}
38
39#[async_trait]
40impl Tool for ShellExecTool {
41    fn name(&self) -> &str {
42        "shell_exec"
43    }
44
45    fn description(&self) -> &str {
46        "Execute a shell command in the workspace directory. Returns stdout, stderr, and exit code."
47    }
48
49    fn parameters_schema(&self) -> serde_json::Value {
50        serde_json::json!({
51            "type": "object",
52            "properties": {
53                "command": {
54                    "type": "string",
55                    "description": "The shell command to execute"
56                },
57                "working_dir": {
58                    "type": "string",
59                    "description": "Working directory (relative to workspace). Defaults to workspace root."
60                }
61            },
62            "required": ["command"]
63        })
64    }
65
66    async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
67        let command = args["command"]
68            .as_str()
69            .ok_or_else(|| ToolError::InvalidArguments {
70                name: "shell_exec".into(),
71                reason: "'command' parameter is required".into(),
72            })?;
73
74        let working_dir = if let Some(dir) = args["working_dir"].as_str() {
75            self.workspace.join(dir)
76        } else {
77            self.workspace.clone()
78        };
79
80        debug!(command = command, cwd = %working_dir.display(), "Executing shell command");
81
82        // If we have a progress sender, stream output line by line
83        if let Some(ref tx) = self.progress_tx {
84            self.execute_streaming(command, &working_dir, tx).await
85        } else {
86            self.execute_buffered(command, &working_dir).await
87        }
88    }
89
90    fn risk_level(&self) -> RiskLevel {
91        RiskLevel::Execute
92    }
93
94    fn timeout(&self) -> Duration {
95        Duration::from_secs(120)
96    }
97}
98
99impl ShellExecTool {
100    /// Execute a command with streaming output via the progress channel.
101    async fn execute_streaming(
102        &self,
103        command: &str,
104        working_dir: &PathBuf,
105        tx: &mpsc::UnboundedSender<ProgressUpdate>,
106    ) -> Result<ToolOutput, ToolError> {
107        use tokio::process::Command;
108
109        let mut child = Command::new("sh")
110            .arg("-c")
111            .arg(command)
112            .current_dir(working_dir)
113            .stdout(std::process::Stdio::piped())
114            .stderr(std::process::Stdio::piped())
115            .spawn()
116            .map_err(|e| ToolError::ExecutionFailed {
117                name: "shell_exec".into(),
118                message: format!("Failed to execute command: {}", e),
119            })?;
120
121        // Send initial progress
122        let _ = tx.send(ProgressUpdate::ToolProgress {
123            tool: "shell_exec".into(),
124            stage: format!("running: {}", truncate_cmd(command, 50)),
125            percent: None,
126        });
127
128        let stdout_pipe = child.stdout.take();
129        let stderr_pipe = child.stderr.take();
130
131        let mut stdout_lines = Vec::new();
132        let mut stderr_lines = Vec::new();
133
134        let tx_stdout = tx.clone();
135        let tx_stderr = tx.clone();
136
137        // Spawn tasks to read stdout and stderr concurrently
138        let stdout_task = tokio::spawn(async move {
139            let mut lines = Vec::new();
140            if let Some(pipe) = stdout_pipe {
141                let reader = BufReader::new(pipe);
142                let mut line_stream = reader.lines();
143                while let Ok(Some(line)) = line_stream.next_line().await {
144                    let _ = tx_stdout.send(ProgressUpdate::ShellOutput {
145                        line: line.clone(),
146                        is_stderr: false,
147                    });
148                    lines.push(line);
149                }
150            }
151            lines
152        });
153
154        let stderr_task = tokio::spawn(async move {
155            let mut lines = Vec::new();
156            if let Some(pipe) = stderr_pipe {
157                let reader = BufReader::new(pipe);
158                let mut line_stream = reader.lines();
159                while let Ok(Some(line)) = line_stream.next_line().await {
160                    let _ = tx_stderr.send(ProgressUpdate::ShellOutput {
161                        line: line.clone(),
162                        is_stderr: true,
163                    });
164                    lines.push(line);
165                }
166            }
167            lines
168        });
169
170        // Wait for the process to complete
171        let status = child.wait().await.map_err(|e| ToolError::ExecutionFailed {
172            name: "shell_exec".into(),
173            message: format!("Failed to wait for command: {}", e),
174        })?;
175
176        // Collect output from tasks
177        if let Ok(lines) = stdout_task.await {
178            stdout_lines = lines;
179        }
180        if let Ok(lines) = stderr_task.await {
181            stderr_lines = lines;
182        }
183
184        let exit_code = status.code().unwrap_or(-1);
185        let stdout = stdout_lines.join("\n");
186        let stderr = stderr_lines.join("\n");
187
188        let result = format!(
189            "Exit code: {}\n\n--- stdout ---\n{}\n--- stderr ---\n{}",
190            exit_code,
191            if stdout.is_empty() {
192                "(empty)"
193            } else {
194                &stdout
195            },
196            if stderr.is_empty() {
197                "(empty)"
198            } else {
199                &stderr
200            }
201        );
202
203        if exit_code != 0 {
204            warn!(
205                command = command,
206                exit_code, "Command exited with non-zero status"
207            );
208        }
209
210        Ok(ToolOutput::text(result))
211    }
212
213    /// Execute a command with buffered output (no streaming).
214    async fn execute_buffered(
215        &self,
216        command: &str,
217        working_dir: &PathBuf,
218    ) -> Result<ToolOutput, ToolError> {
219        let output = tokio::process::Command::new("sh")
220            .arg("-c")
221            .arg(command)
222            .current_dir(working_dir)
223            .output()
224            .await
225            .map_err(|e| ToolError::ExecutionFailed {
226                name: "shell_exec".into(),
227                message: format!("Failed to execute command: {}", e),
228            })?;
229
230        let stdout = String::from_utf8_lossy(&output.stdout);
231        let stderr = String::from_utf8_lossy(&output.stderr);
232        let exit_code = output.status.code().unwrap_or(-1);
233
234        let result = format!(
235            "Exit code: {}\n\n--- stdout ---\n{}\n--- stderr ---\n{}",
236            exit_code,
237            if stdout.is_empty() {
238                "(empty)"
239            } else {
240                &stdout
241            },
242            if stderr.is_empty() {
243                "(empty)"
244            } else {
245                &stderr
246            }
247        );
248
249        if exit_code != 0 {
250            warn!(
251                command = command,
252                exit_code, "Command exited with non-zero status"
253            );
254        }
255
256        Ok(ToolOutput::text(result))
257    }
258}
259
260/// Truncate a command string for display.
261fn truncate_cmd(cmd: &str, max: usize) -> String {
262    if cmd.len() <= max {
263        cmd.to_string()
264    } else {
265        format!("{}..", &cmd[..max.saturating_sub(2)])
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use tempfile::TempDir;
273
274    fn setup_workspace() -> TempDir {
275        let dir = TempDir::new().unwrap();
276        std::fs::write(dir.path().join("test.txt"), "hello world").unwrap();
277        dir
278    }
279
280    #[tokio::test]
281    async fn test_shell_exec_basic() {
282        let dir = setup_workspace();
283        let tool = ShellExecTool::new(dir.path().to_path_buf());
284
285        let result = tool
286            .execute(serde_json::json!({"command": "echo hello"}))
287            .await
288            .unwrap();
289
290        assert!(result.content.contains("hello"));
291        assert!(result.content.contains("Exit code: 0"));
292    }
293
294    #[tokio::test]
295    async fn test_shell_exec_with_cwd() {
296        let dir = setup_workspace();
297        std::fs::create_dir_all(dir.path().join("subdir")).unwrap();
298        std::fs::write(dir.path().join("subdir/file.txt"), "sub content").unwrap();
299
300        let tool = ShellExecTool::new(dir.path().to_path_buf());
301
302        let result = tool
303            .execute(serde_json::json!({
304                "command": "cat file.txt",
305                "working_dir": "subdir"
306            }))
307            .await
308            .unwrap();
309
310        assert!(result.content.contains("sub content"));
311    }
312
313    #[tokio::test]
314    async fn test_shell_exec_nonzero_exit() {
315        let dir = setup_workspace();
316        let tool = ShellExecTool::new(dir.path().to_path_buf());
317
318        let result = tool
319            .execute(serde_json::json!({"command": "exit 42"}))
320            .await
321            .unwrap();
322
323        assert!(result.content.contains("Exit code: 42"));
324    }
325
326    #[tokio::test]
327    async fn test_shell_exec_stderr() {
328        let dir = setup_workspace();
329        let tool = ShellExecTool::new(dir.path().to_path_buf());
330
331        let result = tool
332            .execute(serde_json::json!({"command": "echo error >&2"}))
333            .await
334            .unwrap();
335
336        assert!(result.content.contains("error"));
337        assert!(result.content.contains("stderr"));
338    }
339
340    #[tokio::test]
341    async fn test_shell_exec_missing_command() {
342        let dir = setup_workspace();
343        let tool = ShellExecTool::new(dir.path().to_path_buf());
344
345        let result = tool.execute(serde_json::json!({})).await;
346        assert!(result.is_err());
347        match result.unwrap_err() {
348            ToolError::InvalidArguments { name, .. } => assert_eq!(name, "shell_exec"),
349            e => panic!("Expected InvalidArguments, got: {:?}", e),
350        }
351    }
352
353    #[test]
354    fn test_shell_exec_properties() {
355        let tool = ShellExecTool::new(PathBuf::from("/tmp"));
356        assert_eq!(tool.name(), "shell_exec");
357        assert_eq!(tool.risk_level(), RiskLevel::Execute);
358        assert_eq!(tool.timeout(), Duration::from_secs(120));
359    }
360
361    #[tokio::test]
362    async fn test_shell_exec_streaming() {
363        let dir = setup_workspace();
364        let (tx, mut rx) = mpsc::unbounded_channel();
365        let tool = ShellExecTool::with_progress(dir.path().to_path_buf(), tx);
366
367        let result = tool
368            .execute(serde_json::json!({"command": "echo line1 && echo line2"}))
369            .await
370            .unwrap();
371
372        assert!(result.content.contains("line1"));
373        assert!(result.content.contains("line2"));
374        assert!(result.content.contains("Exit code: 0"));
375
376        // Should have received progress updates
377        let mut progress_count = 0;
378        while let Ok(update) = rx.try_recv() {
379            progress_count += 1;
380            match update {
381                ProgressUpdate::ToolProgress { tool, .. } => {
382                    assert_eq!(tool, "shell_exec");
383                }
384                ProgressUpdate::ShellOutput { is_stderr, .. } => {
385                    assert!(!is_stderr);
386                }
387                _ => {}
388            }
389        }
390        // At least the initial ToolProgress + 2 stdout lines
391        assert!(
392            progress_count >= 3,
393            "Expected at least 3 progress updates, got {}",
394            progress_count
395        );
396    }
397
398    #[tokio::test]
399    async fn test_shell_exec_streaming_stderr() {
400        let dir = setup_workspace();
401        let (tx, mut rx) = mpsc::unbounded_channel();
402        let tool = ShellExecTool::with_progress(dir.path().to_path_buf(), tx);
403
404        let result = tool
405            .execute(serde_json::json!({"command": "echo err >&2"}))
406            .await
407            .unwrap();
408
409        assert!(result.content.contains("err"));
410
411        let mut has_stderr = false;
412        while let Ok(update) = rx.try_recv() {
413            if let ProgressUpdate::ShellOutput { is_stderr, .. } = update
414                && is_stderr
415            {
416                has_stderr = true;
417            }
418        }
419        assert!(has_stderr, "Expected at least one stderr progress update");
420    }
421
422    #[test]
423    fn test_truncate_cmd() {
424        assert_eq!(truncate_cmd("echo hello", 20), "echo hello");
425        assert_eq!(
426            truncate_cmd("a very long command that should be truncated", 20),
427            "a very long comman.."
428        );
429    }
430
431    #[test]
432    fn test_shell_exec_schema() {
433        let tool = ShellExecTool::new(PathBuf::from("/tmp"));
434        let schema = tool.parameters_schema();
435        assert!(schema["properties"]["command"].is_object());
436        assert!(schema["properties"]["working_dir"].is_object());
437        let required = schema["required"].as_array().unwrap();
438        assert!(required.contains(&serde_json::json!("command")));
439        assert!(!required.contains(&serde_json::json!("working_dir")));
440    }
441
442    #[tokio::test]
443    async fn test_shell_exec_empty_command() {
444        let dir = setup_workspace();
445        let tool = ShellExecTool::new(dir.path().to_path_buf());
446
447        // Empty string is still a valid command arg (sh -c "" exits 0)
448        let result = tool
449            .execute(serde_json::json!({"command": ""}))
450            .await
451            .unwrap();
452        assert!(result.content.contains("Exit code: 0"));
453    }
454
455    #[tokio::test]
456    async fn test_shell_exec_multiline_output() {
457        let dir = setup_workspace();
458        let tool = ShellExecTool::new(dir.path().to_path_buf());
459
460        let result = tool
461            .execute(serde_json::json!({"command": "echo line1; echo line2; echo line3"}))
462            .await
463            .unwrap();
464
465        assert!(result.content.contains("line1"));
466        assert!(result.content.contains("line2"));
467        assert!(result.content.contains("line3"));
468    }
469
470    #[tokio::test]
471    async fn test_shell_exec_special_chars() {
472        let dir = setup_workspace();
473        let tool = ShellExecTool::new(dir.path().to_path_buf());
474
475        let result = tool
476            .execute(serde_json::json!({"command": "echo 'hello world' \"with quotes\""}))
477            .await
478            .unwrap();
479
480        assert!(result.content.contains("hello world"));
481        assert!(result.content.contains("with quotes"));
482    }
483
484    #[tokio::test]
485    async fn test_shell_exec_reads_workspace_file() {
486        let dir = setup_workspace();
487        let tool = ShellExecTool::new(dir.path().to_path_buf());
488
489        let result = tool
490            .execute(serde_json::json!({"command": "cat test.txt"}))
491            .await
492            .unwrap();
493
494        assert!(result.content.contains("hello world"));
495    }
496
497    #[test]
498    fn test_truncate_cmd_exact_length() {
499        assert_eq!(truncate_cmd("12345", 5), "12345");
500    }
501
502    #[test]
503    fn test_truncate_cmd_empty() {
504        assert_eq!(truncate_cmd("", 10), "");
505    }
506}