oxify_mcp/servers/
shell.rs

1//! Shell MCP server - provides safe shell command execution
2
3use crate::{McpServer, Result};
4use async_trait::async_trait;
5use serde_json::{json, Value};
6use tokio::process::Command;
7
8/// Built-in MCP server for shell command execution
9pub struct ShellServer {
10    /// Allowed commands (whitelist for security)
11    allowed_commands: Vec<String>,
12    /// Working directory for commands
13    working_dir: std::path::PathBuf,
14    /// Environment variables to set
15    env_vars: Vec<(String, String)>,
16}
17
18impl ShellServer {
19    /// Create a new shell server with allowed commands
20    pub fn new(allowed_commands: Vec<String>) -> Self {
21        Self {
22            allowed_commands,
23            working_dir: std::env::current_dir().unwrap_or_else(|_| "/tmp".into()),
24            env_vars: Vec::new(),
25        }
26    }
27
28    /// Set the working directory
29    pub fn with_working_dir(mut self, dir: std::path::PathBuf) -> Self {
30        self.working_dir = dir;
31        self
32    }
33
34    /// Add environment variable
35    pub fn with_env(mut self, key: String, value: String) -> Self {
36        self.env_vars.push((key, value));
37        self
38    }
39
40    /// Check if a command is allowed
41    fn is_command_allowed(&self, command: &str) -> bool {
42        // Extract the base command (first word)
43        let base_cmd = command.split_whitespace().next().unwrap_or("");
44
45        // Allow if whitelist is empty (unrestricted mode) or if command is in whitelist
46        self.allowed_commands.is_empty() || self.allowed_commands.contains(&base_cmd.to_string())
47    }
48}
49
50impl Default for ShellServer {
51    fn default() -> Self {
52        // Default with common safe commands
53        Self::new(vec![
54            "ls".to_string(),
55            "cat".to_string(),
56            "echo".to_string(),
57            "pwd".to_string(),
58            "date".to_string(),
59            "whoami".to_string(),
60            "grep".to_string(),
61            "find".to_string(),
62            "wc".to_string(),
63            "head".to_string(),
64            "tail".to_string(),
65        ])
66    }
67}
68
69#[async_trait]
70impl McpServer for ShellServer {
71    async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value> {
72        match name {
73            "shell_exec" => {
74                let command = arguments["command"].as_str().ok_or_else(|| {
75                    crate::McpError::InvalidRequest("Missing 'command'".to_string())
76                })?;
77
78                // Security check
79                if !self.is_command_allowed(command) {
80                    return Err(crate::McpError::ToolExecutionError(format!(
81                        "Command '{}' is not allowed. Allowed commands: {:?}",
82                        command, self.allowed_commands
83                    )));
84                }
85
86                // Execute command
87                let output = if cfg!(target_os = "windows") {
88                    Command::new("cmd")
89                        .args(["/C", command])
90                        .current_dir(&self.working_dir)
91                        .envs(self.env_vars.iter().cloned())
92                        .output()
93                        .await
94                } else {
95                    Command::new("sh")
96                        .arg("-c")
97                        .arg(command)
98                        .current_dir(&self.working_dir)
99                        .envs(self.env_vars.iter().cloned())
100                        .output()
101                        .await
102                }
103                .map_err(|e| crate::McpError::ToolExecutionError(e.to_string()))?;
104
105                Ok(json!({
106                    "stdout": String::from_utf8_lossy(&output.stdout),
107                    "stderr": String::from_utf8_lossy(&output.stderr),
108                    "exit_code": output.status.code().unwrap_or(-1),
109                    "success": output.status.success(),
110                }))
111            }
112
113            "shell_which" => {
114                let command = arguments["command"].as_str().ok_or_else(|| {
115                    crate::McpError::InvalidRequest("Missing 'command'".to_string())
116                })?;
117
118                let output = if cfg!(target_os = "windows") {
119                    Command::new("where").arg(command).output().await
120                } else {
121                    Command::new("which").arg(command).output().await
122                }
123                .map_err(|e| crate::McpError::ToolExecutionError(e.to_string()))?;
124
125                Ok(json!({
126                    "path": String::from_utf8_lossy(&output.stdout).trim(),
127                    "found": output.status.success(),
128                }))
129            }
130
131            _ => Err(crate::McpError::ToolNotFound(name.to_string())),
132        }
133    }
134
135    async fn list_tools(&self) -> Result<Vec<Value>> {
136        Ok(vec![
137            json!({
138                "name": "shell_exec",
139                "description": "Execute a shell command",
140                "inputSchema": {
141                    "type": "object",
142                    "properties": {
143                        "command": {
144                            "type": "string",
145                            "description": "Shell command to execute"
146                        }
147                    },
148                    "required": ["command"]
149                }
150            }),
151            json!({
152                "name": "shell_which",
153                "description": "Find the path of a command",
154                "inputSchema": {
155                    "type": "object",
156                    "properties": {
157                        "command": {
158                            "type": "string",
159                            "description": "Command name to locate"
160                        }
161                    },
162                    "required": ["command"]
163                }
164            }),
165        ])
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use serde_json::json;
173
174    #[tokio::test]
175    async fn test_shell_exec_echo() {
176        let server = ShellServer::default();
177
178        let result = server
179            .call_tool(
180                "shell_exec",
181                json!({
182                    "command": "echo hello"
183                }),
184            )
185            .await
186            .unwrap();
187
188        assert_eq!(result["success"], true);
189        assert!(result["stdout"].as_str().unwrap().contains("hello"));
190    }
191
192    #[tokio::test]
193    async fn test_shell_exec_pwd() {
194        let server = ShellServer::default();
195
196        let result = server
197            .call_tool(
198                "shell_exec",
199                json!({
200                    "command": "pwd"
201                }),
202            )
203            .await
204            .unwrap();
205
206        assert_eq!(result["success"], true);
207        assert!(!result["stdout"].as_str().unwrap().is_empty());
208    }
209
210    #[tokio::test]
211    async fn test_shell_which() {
212        let server = ShellServer::default();
213
214        let result = server
215            .call_tool(
216                "shell_which",
217                json!({
218                    "command": "ls"
219                }),
220            )
221            .await
222            .unwrap();
223
224        // On most Unix systems, 'ls' should be found
225        if cfg!(not(target_os = "windows")) {
226            assert_eq!(result["found"], true);
227            assert!(!result["path"].as_str().unwrap().is_empty());
228        }
229    }
230
231    #[tokio::test]
232    async fn test_shell_disallowed_command() {
233        let server = ShellServer::new(vec!["echo".to_string()]);
234
235        let result = server
236            .call_tool(
237                "shell_exec",
238                json!({
239                    "command": "rm -rf /"
240                }),
241            )
242            .await;
243
244        assert!(result.is_err());
245    }
246
247    #[tokio::test]
248    async fn test_shell_list_tools() {
249        let server = ShellServer::default();
250
251        let tools = server.list_tools().await.unwrap();
252
253        assert_eq!(tools.len(), 2);
254        assert!(tools.iter().any(|t| t["name"] == "shell_exec"));
255        assert!(tools.iter().any(|t| t["name"] == "shell_which"));
256    }
257
258    #[tokio::test]
259    async fn test_shell_with_working_dir() {
260        let temp_dir = std::env::temp_dir();
261        let server = ShellServer::default().with_working_dir(temp_dir);
262
263        let result = server
264            .call_tool(
265                "shell_exec",
266                json!({
267                    "command": "pwd"
268                }),
269            )
270            .await
271            .unwrap();
272
273        assert_eq!(result["success"], true);
274    }
275
276    #[tokio::test]
277    async fn test_shell_with_env() {
278        let server =
279            ShellServer::default().with_env("TEST_VAR".to_string(), "test_value".to_string());
280
281        let result = if cfg!(target_os = "windows") {
282            server
283                .call_tool(
284                    "shell_exec",
285                    json!({
286                        "command": "echo %TEST_VAR%"
287                    }),
288                )
289                .await
290        } else {
291            server
292                .call_tool(
293                    "shell_exec",
294                    json!({
295                        "command": "echo $TEST_VAR"
296                    }),
297                )
298                .await
299        };
300
301        if let Ok(result) = result {
302            assert_eq!(result["success"], true);
303        }
304    }
305}