Skip to main content

matrixcode_core/mcp/
config.rs

1//! MCP Configuration
2//!
3//! 管理 MCP 服务器配置
4
5use anyhow::{anyhow, Result};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::Path;
9
10use super::transport::TransportConfig;
11
12// ============================================================================
13// Configuration Types
14// ============================================================================
15
16/// MCP 服务器配置
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct McpServerConfig {
19    /// 服务器名称(可选,默认使用配置键名)
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub name: Option<String>,
22    
23    /// 启动命令(stdio 模式)
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub command: Option<String>,
26    
27    /// 命令参数
28    #[serde(default)]
29    pub args: Vec<String>,
30    
31    /// 环境变量
32    #[serde(default)]
33    pub env: HashMap<String, String>,
34    
35    /// SSE URL(HTTP 模式)
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub url: Option<String>,
38    
39    /// 请求超时(毫秒)
40    #[serde(default = "default_timeout")]
41    pub timeout_ms: u64,
42    
43    /// 是否启用
44    #[serde(default = "default_enabled")]
45    pub enabled: bool,
46}
47
48impl Default for McpServerConfig {
49    fn default() -> Self {
50        Self {
51            name: None,
52            command: None,
53            args: Vec::new(),
54            env: HashMap::new(),
55            url: None,
56            timeout_ms: default_timeout(),
57            enabled: default_enabled(),
58        }
59    }
60}
61
62fn default_timeout() -> u64 {
63    30000
64}
65
66fn default_enabled() -> bool {
67    true
68}
69
70impl McpServerConfig {
71    /// 创建 stdio 配置
72    pub fn stdio(command: impl Into<String>, args: Vec<String>) -> Self {
73        Self {
74            name: None,
75            command: Some(command.into()),
76            args,
77            env: HashMap::new(),
78            url: None,
79            timeout_ms: default_timeout(),
80            enabled: true,
81        }
82    }
83    
84    /// 创建 SSE 配置
85    pub fn sse(url: impl Into<String>) -> Self {
86        Self {
87            name: None,
88            command: None,
89            args: Vec::new(),
90            env: HashMap::new(),
91            url: Some(url.into()),
92            timeout_ms: default_timeout(),
93            enabled: true,
94        }
95    }
96    
97    /// 设置名称
98    pub fn with_name(mut self, name: impl Into<String>) -> Self {
99        self.name = Some(name.into());
100        self
101    }
102    
103    /// 设置环境变量
104    pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
105        self.env.insert(key.into(), value.into());
106        self
107    }
108    
109    /// 设置超时
110    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
111        self.timeout_ms = timeout_ms;
112        self
113    }
114    
115    /// 设置启用状态
116    pub fn with_enabled(mut self, enabled: bool) -> Self {
117        self.enabled = enabled;
118        self
119    }
120    
121    /// 转换为传输配置
122    pub fn to_transport_config(&self) -> Result<TransportConfig> {
123        if let Some(command) = &self.command {
124            // Stdio 模式
125            let env_vec: Vec<(String, String)> = self.env.iter()
126                .map(|(k, v)| (k.clone(), v.clone()))
127                .collect();
128            
129            Ok(TransportConfig::Stdio {
130                command: command.clone(),
131                args: self.args.clone(),
132                env: if env_vec.is_empty() { None } else { Some(env_vec) },
133            })
134        } else if let Some(url) = &self.url {
135            // SSE 模式
136            Ok(TransportConfig::Sse {
137                url: url.clone(),
138                timeout_ms: Some(self.timeout_ms),
139            })
140        } else {
141            Err(anyhow!("MCP server config must have either 'command' or 'url'"))
142        }
143    }
144    
145    /// 获取服务器名称
146    pub fn get_name(&self, key: &str) -> String {
147        self.name.clone().unwrap_or_else(|| key.to_string())
148    }
149}
150
151/// MCP 配置文件
152#[derive(Debug, Clone, Serialize, Deserialize, Default)]
153pub struct McpConfig {
154    /// MCP 服务器配置
155    #[serde(default)]
156    pub servers: HashMap<String, McpServerConfig>,
157    
158    /// 全局设置
159    #[serde(default)]
160    pub settings: McpSettings,
161}
162
163/// MCP 全局设置
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct McpSettings {
166    /// 自动发现 MCP 服务器
167    #[serde(default = "default_auto_discover")]
168    pub auto_discover: bool,
169    
170    /// 连接超时(毫秒)
171    #[serde(default = "default_connect_timeout")]
172    pub connect_timeout_ms: u64,
173}
174
175fn default_auto_discover() -> bool {
176    true
177}
178
179fn default_connect_timeout() -> u64 {
180    10000
181}
182
183impl Default for McpSettings {
184    fn default() -> Self {
185        Self {
186            auto_discover: default_auto_discover(),
187            connect_timeout_ms: default_connect_timeout(),
188        }
189    }
190}
191
192impl McpConfig {
193    /// 创建空配置
194    pub fn new() -> Self {
195        Self::default()
196    }
197    
198    /// 添加服务器配置
199    pub fn add_server(mut self, key: impl Into<String>, config: McpServerConfig) -> Self {
200        self.servers.insert(key.into(), config);
201        self
202    }
203    
204    /// 从文件加载配置
205    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
206        let content = std::fs::read_to_string(path.as_ref())
207            .map_err(|e| anyhow!("Failed to read MCP config file: {}", e))?;
208        
209        Self::from_str(&content)
210    }
211    
212    /// 从字符串解析配置
213    pub fn from_str(content: &str) -> Result<Self> {
214        // 尝试解析为 TOML
215        if let Ok(config) = toml::from_str(content) {
216            return Ok(config);
217        }
218        
219        // 尝试解析为 JSON
220        if let Ok(config) = serde_json::from_str(content) {
221            return Ok(config);
222        }
223        
224        Err(anyhow!("Failed to parse MCP config as TOML or JSON"))
225    }
226    
227    /// 保存到文件
228    pub fn to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
229        let content = toml::to_string_pretty(self)
230            .map_err(|e| anyhow!("Failed to serialize MCP config: {}", e))?;
231        
232        std::fs::write(path.as_ref(), content)
233            .map_err(|e| anyhow!("Failed to write MCP config file: {}", e))?;
234        
235        Ok(())
236    }
237    
238    /// 获取所有启用的服务器配置
239    pub fn enabled_servers(&self) -> Vec<(String, &McpServerConfig)> {
240        self.servers
241            .iter()
242            .filter(|(_, config)| config.enabled)
243            .map(|(key, config)| (key.clone(), config))
244            .collect()
245    }
246    
247    /// 合并两个配置(other 覆盖 self)
248    pub fn merge(mut self, other: McpConfig) -> Self {
249        // 合并服务器配置(other 覆盖同名服务器)
250        for (key, config) in other.servers {
251            self.servers.insert(key, config);
252        }
253        
254        // 合并设置(other 的非默认值覆盖)
255        if !other.settings.auto_discover {
256            self.settings.auto_discover = false;
257        }
258        if other.settings.connect_timeout_ms != default_connect_timeout() {
259            self.settings.connect_timeout_ms = other.settings.connect_timeout_ms;
260        }
261        
262        self
263    }
264}
265
266// ============================================================================
267// Default Configurations
268// ============================================================================
269
270/// 创建 Playwright MCP 配置
271pub fn playwright_config() -> McpConfig {
272    McpConfig::new()
273        .add_server("playwright", McpServerConfig::stdio(
274            "npx",
275            vec!["-y".into(), "@playwright/mcp@latest".into()]
276        ))
277}
278
279/// 创建常用 MCP 配置
280pub fn default_mcp_config() -> McpConfig {
281    McpConfig::new()
282        // Playwright 浏览器自动化
283        .add_server("playwright", McpServerConfig::stdio(
284            "npx",
285            vec!["-y".into(), "@playwright/mcp@latest".into()]
286        ))
287        // 文件系统(可选)
288        // .add_server("filesystem", McpServerConfig::stdio(
289        //     "npx",
290        //     vec!["-y".into(), "@modelcontextprotocol/server-filesystem".into()]
291        // ))
292}
293
294// ============================================================================
295// Config File Discovery
296// ============================================================================
297
298/// MCP 配置文件名
299pub const MCP_CONFIG_FILENAMES: &[&str] = &[
300    "mcp.toml",
301    "mcp.json",
302    ".mcp.toml",
303    ".mcp.json",
304];
305
306/// 查找工作目录中的 MCP 配置文件
307pub fn find_mcp_config(start_dir: &Path) -> Option<std::path::PathBuf> {
308    // 1. 项目级配置(优先)
309    for filename in MCP_CONFIG_FILENAMES {
310        let path = start_dir.join(filename);
311        if path.exists() {
312            return Some(path);
313        }
314    }
315    
316    // 2. 用户级配置目录 (~/.matrixcode/)
317    if let Some(home) = dirs::home_dir() {
318        let matrixcode_dir = home.join(".matrixcode");
319        for filename in MCP_CONFIG_FILENAMES {
320            let path = matrixcode_dir.join(filename);
321            if path.exists() {
322                return Some(path);
323            }
324        }
325        
326        // 3. 用户主目录 (~/.mcp.toml)
327        for filename in MCP_CONFIG_FILENAMES {
328            let path = home.join(filename);
329            if path.exists() {
330                return Some(path);
331            }
332        }
333    }
334    
335    None
336}
337
338/// 加载 MCP 配置(合并项目级和用户级)
339pub fn load_mcp_config(start_dir: &Path) -> McpConfig {
340    let mut config = McpConfig::new();
341    
342    // 1. 加载用户级配置(基础)
343    if let Some(home) = dirs::home_dir() {
344        // ~/.matrixcode/ 目录
345        let matrixcode_dir = home.join(".matrixcode");
346        for filename in MCP_CONFIG_FILENAMES {
347            let path = matrixcode_dir.join(filename);
348            if path.exists() {
349                if let Ok(user_config) = McpConfig::from_file(&path) {
350                    tracing::info!("Loaded user-level MCP config from {:?}", path);
351                    config = config.merge(user_config);
352                    break;
353                }
354            }
355        }
356        
357        // ~/.mcp.toml(备选)
358        if config.servers.is_empty() {
359            for filename in MCP_CONFIG_FILENAMES {
360                let path = home.join(filename);
361                if path.exists() {
362                    if let Ok(user_config) = McpConfig::from_file(&path) {
363                        tracing::info!("Loaded user MCP config from {:?}", path);
364                        config = config.merge(user_config);
365                        break;
366                    }
367                }
368            }
369        }
370    }
371    
372    // 2. 加载项目级配置(覆盖)
373    for filename in MCP_CONFIG_FILENAMES {
374        let path = start_dir.join(filename);
375        if path.exists() {
376            if let Ok(project_config) = McpConfig::from_file(&path) {
377                tracing::info!("Loaded project-level MCP config from {:?}", path);
378                config = config.merge(project_config);
379                break;
380            }
381        }
382    }
383    
384    config
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    
391    #[test]
392    fn test_server_config_stdio() {
393        let config = McpServerConfig::stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
394        
395        assert!(config.command.is_some());
396        assert!(config.url.is_none());
397        assert!(config.enabled);
398        
399        let transport = config.to_transport_config().unwrap();
400        match transport {
401            TransportConfig::Stdio { command, args, .. } => {
402                assert_eq!(command, "npx");
403                assert_eq!(args.len(), 2);
404            }
405            _ => panic!("Expected Stdio transport"),
406        }
407    }
408    
409    #[test]
410    fn test_server_config_sse() {
411        let config = McpServerConfig::sse("http://localhost:3000");
412        
413        assert!(config.command.is_none());
414        assert!(config.url.is_some());
415        
416        let transport = config.to_transport_config().unwrap();
417        match transport {
418            TransportConfig::Sse { url, .. } => {
419                assert_eq!(url, "http://localhost:3000");
420            }
421            _ => panic!("Expected SSE transport"),
422        }
423    }
424    
425    #[test]
426    fn test_config_serialization() {
427        let config = McpConfig::new()
428            .add_server("playwright", McpServerConfig::stdio(
429                "npx",
430                vec!["-y".into(), "@playwright/mcp".into()]
431            ));
432        
433        // TOML 序列化
434        let toml = toml::to_string(&config).unwrap();
435        assert!(toml.contains("[servers.playwright]"));
436        
437        // 反序列化
438        let parsed: McpConfig = toml::from_str(&toml).unwrap();
439        assert!(parsed.servers.contains_key("playwright"));
440    }
441    
442    #[test]
443    fn test_playwright_config() {
444        let config = playwright_config();
445        assert!(config.servers.contains_key("playwright"));
446        
447        let server = &config.servers["playwright"];
448        assert_eq!(server.command, Some("npx".to_string()));
449    }
450}