Skip to main content

enact_core/tool/
shell.rs

1//! Shell command execution tool
2
3use crate::tool::Tool;
4use async_trait::async_trait;
5use serde_json::json;
6use std::time::Duration;
7use tokio::process::Command;
8
9const SHELL_TIMEOUT_SECS: u64 = 60;
10const MAX_OUTPUT_BYTES: usize = 1_048_576; // 1MB
11
12/// Shell command execution tool
13pub struct ShellTool;
14
15impl ShellTool {
16    pub fn new() -> Self {
17        Self
18    }
19}
20
21impl Default for ShellTool {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27#[async_trait]
28impl Tool for ShellTool {
29    fn name(&self) -> &str {
30        "shell"
31    }
32
33    fn description(&self) -> &str {
34        "Execute a shell command in the workspace directory"
35    }
36
37    fn parameters_schema(&self) -> serde_json::Value {
38        json!({
39            "type": "object",
40            "properties": {
41                "command": {
42                    "type": "string",
43                    "description": "The shell command to execute"
44                },
45                "timeout": {
46                    "type": "integer",
47                    "description": "Timeout in seconds (default: 60)",
48                    "minimum": 1,
49                    "maximum": 300
50                }
51            },
52            "required": ["command"]
53        })
54    }
55
56    fn requires_network(&self) -> bool {
57        false // Shell can work offline, though commands might need network
58    }
59
60    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<serde_json::Value> {
61        let command = args
62            .get("command")
63            .and_then(|v| v.as_str())
64            .ok_or_else(|| anyhow::anyhow!("Missing 'command' parameter"))?;
65
66        let timeout_secs = args
67            .get("timeout")
68            .and_then(|v| v.as_u64())
69            .unwrap_or(SHELL_TIMEOUT_SECS);
70
71        let timeout = Duration::from_secs(timeout_secs.min(300));
72
73        // Use sh -c for shell commands
74        let output =
75            tokio::time::timeout(timeout, Command::new("sh").arg("-c").arg(command).output())
76                .await
77                .map_err(|_| {
78                    anyhow::anyhow!("Command timed out after {} seconds", timeout_secs)
79                })??;
80
81        let stdout = String::from_utf8_lossy(&output.stdout);
82        let stderr = String::from_utf8_lossy(&output.stderr);
83
84        // Truncate if too large
85        let stdout = if stdout.len() > MAX_OUTPUT_BYTES {
86            format!("{}... [truncated]", &stdout[..MAX_OUTPUT_BYTES])
87        } else {
88            stdout.to_string()
89        };
90
91        let stderr = if stderr.len() > MAX_OUTPUT_BYTES {
92            format!("{}... [truncated]", &stderr[..MAX_OUTPUT_BYTES])
93        } else {
94            stderr.to_string()
95        };
96
97        Ok(json!({
98            "success": output.status.success(),
99            "stdout": stdout,
100            "stderr": stderr,
101            "exit_code": output.status.code(),
102            "command": command
103        }))
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[tokio::test]
112    async fn test_shell_echo() {
113        let tool = ShellTool::new();
114        let result = tool
115            .execute(json!({
116                "command": "echo 'Hello World'"
117            }))
118            .await
119            .unwrap();
120
121        assert_eq!(result["success"], true);
122        assert!(result["stdout"].as_str().unwrap().contains("Hello World"));
123    }
124
125    #[tokio::test]
126    async fn test_shell_error() {
127        let tool = ShellTool::new();
128        let result = tool
129            .execute(json!({
130                "command": "exit 1"
131            }))
132            .await
133            .unwrap();
134
135        assert_eq!(result["success"], false);
136        assert_eq!(result["exit_code"], 1);
137    }
138
139    #[tokio::test]
140    async fn test_shell_timeout() {
141        let tool = ShellTool::new();
142        let result = tool
143            .execute(json!({
144                "command": "sleep 10",
145                "timeout": 1
146            }))
147            .await;
148
149        assert!(result.is_err());
150        assert!(result.unwrap_err().to_string().contains("timed out"));
151    }
152}