Skip to main content

mofa_plugins/tools/
shell.rs

1use super::*;
2use serde_json::json;
3use tokio::process::Command;
4
5/// Shell 命令工具 - 执行系统命令(受限)
6pub struct ShellCommandTool {
7    definition: ToolDefinition,
8    allowed_commands: Vec<String>,
9}
10
11impl ShellCommandTool {
12    pub fn new(allowed_commands: Vec<String>) -> Self {
13        Self {
14            definition: ToolDefinition {
15                name: "shell".to_string(),
16                description:
17                    "Execute shell commands. Only whitelisted commands are allowed for security."
18                        .to_string(),
19                parameters: json!({
20                    "type": "object",
21                    "properties": {
22                        "command": {
23                            "type": "string",
24                            "description": "The command to execute"
25                        },
26                        "args": {
27                            "type": "array",
28                            "items": { "type": "string" },
29                            "description": "Command arguments"
30                        },
31                        "working_dir": {
32                            "type": "string",
33                            "description": "Working directory for command execution"
34                        }
35                    },
36                    "required": ["command"]
37                }),
38                requires_confirmation: true,
39            },
40            allowed_commands,
41        }
42    }
43
44    /// Create with default allowed commands
45    pub fn new_with_defaults() -> Self {
46        Self::new(vec![
47            "ls".to_string(),
48            "pwd".to_string(),
49            "echo".to_string(),
50            "date".to_string(),
51            "whoami".to_string(),
52            "cat".to_string(),
53            "head".to_string(),
54            "tail".to_string(),
55            "wc".to_string(),
56            "grep".to_string(),
57            "find".to_string(),
58        ])
59    }
60
61    fn is_command_allowed(&self, command: &str) -> bool {
62        if self.allowed_commands.is_empty() {
63            return false; // Default deny if no whitelist
64        }
65        self.allowed_commands
66            .iter()
67            .any(|allowed| command == allowed || command.starts_with(&format!("{} ", allowed)))
68    }
69}
70
71#[async_trait::async_trait]
72impl ToolExecutor for ShellCommandTool {
73    fn definition(&self) -> &ToolDefinition {
74        &self.definition
75    }
76
77    async fn execute(&self, arguments: serde_json::Value) -> PluginResult<serde_json::Value> {
78        let command = arguments["command"]
79            .as_str()
80            .ok_or_else(|| anyhow::anyhow!("Command is required"))?;
81
82        if !self.is_command_allowed(command) {
83            return Err(anyhow::anyhow!(
84                "Command '{}' is not in the allowed commands list. Allowed: {:?}",
85                command,
86                self.allowed_commands
87            ));
88        }
89
90        let args: Vec<String> = arguments
91            .get("args")
92            .and_then(|a| a.as_array())
93            .map(|arr| {
94                arr.iter()
95                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
96                    .collect()
97            })
98            .unwrap_or_default();
99
100        let mut cmd = Command::new(command);
101        cmd.args(&args);
102
103        if let Some(dir) = arguments.get("working_dir").and_then(|d| d.as_str()) {
104            cmd.current_dir(dir);
105        }
106
107        let output = cmd.output().await?;
108
109        let stdout = String::from_utf8_lossy(&output.stdout).to_string();
110        let stderr = String::from_utf8_lossy(&output.stderr).to_string();
111
112        Ok(json!({
113            "success": output.status.success(),
114            "exit_code": output.status.code(),
115            "stdout": if stdout.len() > 5000 { format!("{}...[truncated]", &stdout[..5000]) } else { stdout },
116            "stderr": if stderr.len() > 5000 { format!("{}...[truncated]", &stderr[..5000]) } else { stderr }
117        }))
118    }
119}