oxify_mcp/servers/
shell.rs1use crate::{McpServer, Result};
4use async_trait::async_trait;
5use serde_json::{json, Value};
6use tokio::process::Command;
7
8pub struct ShellServer {
10 allowed_commands: Vec<String>,
12 working_dir: std::path::PathBuf,
14 env_vars: Vec<(String, String)>,
16}
17
18impl ShellServer {
19 pub fn new(allowed_commands: Vec<String>) -> Self {
21 Self {
22 allowed_commands,
23 working_dir: std::env::current_dir().unwrap_or_else(|_| "/tmp".into()),
24 env_vars: Vec::new(),
25 }
26 }
27
28 pub fn with_working_dir(mut self, dir: std::path::PathBuf) -> Self {
30 self.working_dir = dir;
31 self
32 }
33
34 pub fn with_env(mut self, key: String, value: String) -> Self {
36 self.env_vars.push((key, value));
37 self
38 }
39
40 fn is_command_allowed(&self, command: &str) -> bool {
42 let base_cmd = command.split_whitespace().next().unwrap_or("");
44
45 self.allowed_commands.is_empty() || self.allowed_commands.contains(&base_cmd.to_string())
47 }
48}
49
50impl Default for ShellServer {
51 fn default() -> Self {
52 Self::new(vec![
54 "ls".to_string(),
55 "cat".to_string(),
56 "echo".to_string(),
57 "pwd".to_string(),
58 "date".to_string(),
59 "whoami".to_string(),
60 "grep".to_string(),
61 "find".to_string(),
62 "wc".to_string(),
63 "head".to_string(),
64 "tail".to_string(),
65 ])
66 }
67}
68
69#[async_trait]
70impl McpServer for ShellServer {
71 async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value> {
72 match name {
73 "shell_exec" => {
74 let command = arguments["command"].as_str().ok_or_else(|| {
75 crate::McpError::InvalidRequest("Missing 'command'".to_string())
76 })?;
77
78 if !self.is_command_allowed(command) {
80 return Err(crate::McpError::ToolExecutionError(format!(
81 "Command '{}' is not allowed. Allowed commands: {:?}",
82 command, self.allowed_commands
83 )));
84 }
85
86 let output = if cfg!(target_os = "windows") {
88 Command::new("cmd")
89 .args(["/C", command])
90 .current_dir(&self.working_dir)
91 .envs(self.env_vars.iter().cloned())
92 .output()
93 .await
94 } else {
95 Command::new("sh")
96 .arg("-c")
97 .arg(command)
98 .current_dir(&self.working_dir)
99 .envs(self.env_vars.iter().cloned())
100 .output()
101 .await
102 }
103 .map_err(|e| crate::McpError::ToolExecutionError(e.to_string()))?;
104
105 Ok(json!({
106 "stdout": String::from_utf8_lossy(&output.stdout),
107 "stderr": String::from_utf8_lossy(&output.stderr),
108 "exit_code": output.status.code().unwrap_or(-1),
109 "success": output.status.success(),
110 }))
111 }
112
113 "shell_which" => {
114 let command = arguments["command"].as_str().ok_or_else(|| {
115 crate::McpError::InvalidRequest("Missing 'command'".to_string())
116 })?;
117
118 let output = if cfg!(target_os = "windows") {
119 Command::new("where").arg(command).output().await
120 } else {
121 Command::new("which").arg(command).output().await
122 }
123 .map_err(|e| crate::McpError::ToolExecutionError(e.to_string()))?;
124
125 Ok(json!({
126 "path": String::from_utf8_lossy(&output.stdout).trim(),
127 "found": output.status.success(),
128 }))
129 }
130
131 _ => Err(crate::McpError::ToolNotFound(name.to_string())),
132 }
133 }
134
135 async fn list_tools(&self) -> Result<Vec<Value>> {
136 Ok(vec![
137 json!({
138 "name": "shell_exec",
139 "description": "Execute a shell command",
140 "inputSchema": {
141 "type": "object",
142 "properties": {
143 "command": {
144 "type": "string",
145 "description": "Shell command to execute"
146 }
147 },
148 "required": ["command"]
149 }
150 }),
151 json!({
152 "name": "shell_which",
153 "description": "Find the path of a command",
154 "inputSchema": {
155 "type": "object",
156 "properties": {
157 "command": {
158 "type": "string",
159 "description": "Command name to locate"
160 }
161 },
162 "required": ["command"]
163 }
164 }),
165 ])
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use serde_json::json;
173
174 #[tokio::test]
175 async fn test_shell_exec_echo() {
176 let server = ShellServer::default();
177
178 let result = server
179 .call_tool(
180 "shell_exec",
181 json!({
182 "command": "echo hello"
183 }),
184 )
185 .await
186 .unwrap();
187
188 assert_eq!(result["success"], true);
189 assert!(result["stdout"].as_str().unwrap().contains("hello"));
190 }
191
192 #[tokio::test]
193 async fn test_shell_exec_pwd() {
194 let server = ShellServer::default();
195
196 let result = server
197 .call_tool(
198 "shell_exec",
199 json!({
200 "command": "pwd"
201 }),
202 )
203 .await
204 .unwrap();
205
206 assert_eq!(result["success"], true);
207 assert!(!result["stdout"].as_str().unwrap().is_empty());
208 }
209
210 #[tokio::test]
211 async fn test_shell_which() {
212 let server = ShellServer::default();
213
214 let result = server
215 .call_tool(
216 "shell_which",
217 json!({
218 "command": "ls"
219 }),
220 )
221 .await
222 .unwrap();
223
224 if cfg!(not(target_os = "windows")) {
226 assert_eq!(result["found"], true);
227 assert!(!result["path"].as_str().unwrap().is_empty());
228 }
229 }
230
231 #[tokio::test]
232 async fn test_shell_disallowed_command() {
233 let server = ShellServer::new(vec!["echo".to_string()]);
234
235 let result = server
236 .call_tool(
237 "shell_exec",
238 json!({
239 "command": "rm -rf /"
240 }),
241 )
242 .await;
243
244 assert!(result.is_err());
245 }
246
247 #[tokio::test]
248 async fn test_shell_list_tools() {
249 let server = ShellServer::default();
250
251 let tools = server.list_tools().await.unwrap();
252
253 assert_eq!(tools.len(), 2);
254 assert!(tools.iter().any(|t| t["name"] == "shell_exec"));
255 assert!(tools.iter().any(|t| t["name"] == "shell_which"));
256 }
257
258 #[tokio::test]
259 async fn test_shell_with_working_dir() {
260 let temp_dir = std::env::temp_dir();
261 let server = ShellServer::default().with_working_dir(temp_dir);
262
263 let result = server
264 .call_tool(
265 "shell_exec",
266 json!({
267 "command": "pwd"
268 }),
269 )
270 .await
271 .unwrap();
272
273 assert_eq!(result["success"], true);
274 }
275
276 #[tokio::test]
277 async fn test_shell_with_env() {
278 let server =
279 ShellServer::default().with_env("TEST_VAR".to_string(), "test_value".to_string());
280
281 let result = if cfg!(target_os = "windows") {
282 server
283 .call_tool(
284 "shell_exec",
285 json!({
286 "command": "echo %TEST_VAR%"
287 }),
288 )
289 .await
290 } else {
291 server
292 .call_tool(
293 "shell_exec",
294 json!({
295 "command": "echo $TEST_VAR"
296 }),
297 )
298 .await
299 };
300
301 if let Ok(result) = result {
302 assert_eq!(result["success"], true);
303 }
304 }
305}