Skip to main content

matrixcode_core/mcp/
config.rs

1//! MCP Configuration
2//!
3//! 管理 MCP 服务器配置
4
5use anyhow::{Result, anyhow};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::Path;
9
10use super::transport::TransportConfig;
11
12// ============================================================================
13// Configuration Types
14// ============================================================================
15
16/// MCP 服务器配置
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct McpServerConfig {
19    /// 服务器名称(可选,默认使用配置键名)
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub name: Option<String>,
22
23    /// 启动命令(stdio 模式)
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub command: Option<String>,
26
27    /// 命令参数
28    #[serde(default)]
29    pub args: Vec<String>,
30
31    /// 环境变量
32    #[serde(default)]
33    pub env: HashMap<String, String>,
34
35    /// SSE URL(HTTP 模式)
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub url: Option<String>,
38
39    /// 请求超时(毫秒)
40    #[serde(default = "default_timeout")]
41    pub timeout_ms: u64,
42
43    /// 是否启用
44    #[serde(default = "default_enabled")]
45    pub enabled: bool,
46}
47
48impl Default for McpServerConfig {
49    fn default() -> Self {
50        Self {
51            name: None,
52            command: None,
53            args: Vec::new(),
54            env: HashMap::new(),
55            url: None,
56            timeout_ms: default_timeout(),
57            enabled: default_enabled(),
58        }
59    }
60}
61
62fn default_timeout() -> u64 {
63    30000
64}
65
66fn default_enabled() -> bool {
67    true
68}
69
70impl McpServerConfig {
71    /// 创建 stdio 配置
72    pub fn stdio(command: impl Into<String>, args: Vec<String>) -> Self {
73        Self {
74            name: None,
75            command: Some(command.into()),
76            args,
77            env: HashMap::new(),
78            url: None,
79            timeout_ms: default_timeout(),
80            enabled: true,
81        }
82    }
83
84    /// 创建 SSE 配置
85    pub fn sse(url: impl Into<String>) -> Self {
86        Self {
87            name: None,
88            command: None,
89            args: Vec::new(),
90            env: HashMap::new(),
91            url: Some(url.into()),
92            timeout_ms: default_timeout(),
93            enabled: true,
94        }
95    }
96
97    /// 设置名称
98    pub fn with_name(mut self, name: impl Into<String>) -> Self {
99        self.name = Some(name.into());
100        self
101    }
102
103    /// 设置环境变量
104    pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
105        self.env.insert(key.into(), value.into());
106        self
107    }
108
109    /// 设置超时
110    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
111        self.timeout_ms = timeout_ms;
112        self
113    }
114
115    /// 设置启用状态
116    pub fn with_enabled(mut self, enabled: bool) -> Self {
117        self.enabled = enabled;
118        self
119    }
120
121    /// 转换为传输配置
122    pub fn to_transport_config(&self) -> Result<TransportConfig> {
123        if let Some(command) = &self.command {
124            // Stdio 模式
125            let env_vec: Vec<(String, String)> = self
126                .env
127                .iter()
128                .map(|(k, v)| (k.clone(), v.clone()))
129                .collect();
130
131            Ok(TransportConfig::Stdio {
132                command: command.clone(),
133                args: self.args.clone(),
134                env: if env_vec.is_empty() {
135                    None
136                } else {
137                    Some(env_vec)
138                },
139            })
140        } else if let Some(url) = &self.url {
141            // SSE 模式
142            Ok(TransportConfig::Sse {
143                url: url.clone(),
144                timeout_ms: Some(self.timeout_ms),
145            })
146        } else {
147            Err(anyhow!(
148                "MCP server config must have either 'command' or 'url'"
149            ))
150        }
151    }
152
153    /// 获取服务器名称
154    pub fn get_name(&self, key: &str) -> String {
155        self.name.clone().unwrap_or_else(|| key.to_string())
156    }
157}
158
159/// MCP 配置文件
160#[derive(Debug, Clone, Serialize, Deserialize, Default)]
161pub struct McpConfig {
162    /// MCP 服务器配置
163    #[serde(default)]
164    pub servers: HashMap<String, McpServerConfig>,
165
166    /// 全局设置
167    #[serde(default)]
168    pub settings: McpSettings,
169}
170
171/// MCP 全局设置
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct McpSettings {
174    /// 自动发现 MCP 服务器
175    #[serde(default = "default_auto_discover")]
176    pub auto_discover: bool,
177
178    /// 连接超时(毫秒)
179    #[serde(default = "default_connect_timeout")]
180    pub connect_timeout_ms: u64,
181}
182
183fn default_auto_discover() -> bool {
184    true
185}
186
187fn default_connect_timeout() -> u64 {
188    10000
189}
190
191impl Default for McpSettings {
192    fn default() -> Self {
193        Self {
194            auto_discover: default_auto_discover(),
195            connect_timeout_ms: default_connect_timeout(),
196        }
197    }
198}
199
200impl McpConfig {
201    /// 创建空配置
202    pub fn new() -> Self {
203        Self::default()
204    }
205
206    /// 添加服务器配置
207    pub fn add_server(mut self, key: impl Into<String>, config: McpServerConfig) -> Self {
208        self.servers.insert(key.into(), config);
209        self
210    }
211
212    /// 从文件加载配置
213    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
214        let content = std::fs::read_to_string(path.as_ref())
215            .map_err(|e| anyhow!("Failed to read MCP config file: {}", e))?;
216
217        Self::from_str(&content)
218    }
219
220    /// 从字符串解析配置
221    pub fn from_str(content: &str) -> Result<Self> {
222        // 尝试解析为 TOML
223        if let Ok(config) = toml::from_str(content) {
224            return Ok(config);
225        }
226
227        // 尝试解析为 JSON
228        if let Ok(config) = serde_json::from_str(content) {
229            return Ok(config);
230        }
231
232        Err(anyhow!("Failed to parse MCP config as TOML or JSON"))
233    }
234
235    /// 保存到文件
236    pub fn to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
237        let content = toml::to_string_pretty(self)
238            .map_err(|e| anyhow!("Failed to serialize MCP config: {}", e))?;
239
240        std::fs::write(path.as_ref(), content)
241            .map_err(|e| anyhow!("Failed to write MCP config file: {}", e))?;
242
243        Ok(())
244    }
245
246    /// 获取所有启用的服务器配置
247    pub fn enabled_servers(&self) -> Vec<(String, &McpServerConfig)> {
248        self.servers
249            .iter()
250            .filter(|(_, config)| config.enabled)
251            .map(|(key, config)| (key.clone(), config))
252            .collect()
253    }
254
255    /// 合并两个配置(other 覆盖 self)
256    pub fn merge(mut self, other: McpConfig) -> Self {
257        // 合并服务器配置(other 覆盖同名服务器)
258        for (key, config) in other.servers {
259            self.servers.insert(key, config);
260        }
261
262        // 合并设置(other 的非默认值覆盖)
263        if !other.settings.auto_discover {
264            self.settings.auto_discover = false;
265        }
266        if other.settings.connect_timeout_ms != default_connect_timeout() {
267            self.settings.connect_timeout_ms = other.settings.connect_timeout_ms;
268        }
269
270        self
271    }
272}
273
274// ============================================================================
275// Default Configurations
276// ============================================================================
277
278/// 创建 Playwright MCP 配置
279pub fn playwright_config() -> McpConfig {
280    McpConfig::new().add_server(
281        "playwright",
282        McpServerConfig::stdio("npx", vec!["-y".into(), "@playwright/mcp@latest".into()]),
283    )
284}
285
286/// 创建常用 MCP 配置
287pub fn default_mcp_config() -> McpConfig {
288    McpConfig::new()
289        // Playwright 浏览器自动化
290        .add_server(
291            "playwright",
292            McpServerConfig::stdio("npx", vec!["-y".into(), "@playwright/mcp@latest".into()]),
293        )
294    // 文件系统(可选)
295    // .add_server("filesystem", McpServerConfig::stdio(
296    //     "npx",
297    //     vec!["-y".into(), "@modelcontextprotocol/server-filesystem".into()]
298    // ))
299}
300
301// ============================================================================
302// Config File Discovery
303// ============================================================================
304
305/// MCP 配置文件名
306pub const MCP_CONFIG_FILENAMES: &[&str] = &["mcp.toml", "mcp.json", ".mcp.toml", ".mcp.json"];
307
308/// 查找工作目录中的 MCP 配置文件
309pub fn find_mcp_config(start_dir: &Path) -> Option<std::path::PathBuf> {
310    // 1. 项目级配置(优先)
311    for filename in MCP_CONFIG_FILENAMES {
312        let path = start_dir.join(filename);
313        if path.exists() {
314            return Some(path);
315        }
316    }
317
318    // 2. 用户级配置目录 (~/.matrixcode/)
319    if let Some(home) = dirs::home_dir() {
320        let matrixcode_dir = home.join(".matrixcode");
321        for filename in MCP_CONFIG_FILENAMES {
322            let path = matrixcode_dir.join(filename);
323            if path.exists() {
324                return Some(path);
325            }
326        }
327
328        // 3. 用户主目录 (~/.mcp.toml)
329        for filename in MCP_CONFIG_FILENAMES {
330            let path = home.join(filename);
331            if path.exists() {
332                return Some(path);
333            }
334        }
335    }
336
337    None
338}
339
340/// 加载 MCP 配置(合并项目级和用户级)
341pub fn load_mcp_config(start_dir: &Path) -> McpConfig {
342    let mut config = McpConfig::new();
343
344    // 1. 加载用户级配置(基础)
345    if let Some(home) = dirs::home_dir() {
346        // ~/.matrixcode/ 目录
347        let matrixcode_dir = home.join(".matrixcode");
348        for filename in MCP_CONFIG_FILENAMES {
349            let path = matrixcode_dir.join(filename);
350            if path.exists() {
351                if let Ok(user_config) = McpConfig::from_file(&path) {
352                    tracing::info!("Loaded user-level MCP config from {:?}", path);
353                    config = config.merge(user_config);
354                    break;
355                }
356            }
357        }
358
359        // ~/.mcp.toml(备选)
360        if config.servers.is_empty() {
361            for filename in MCP_CONFIG_FILENAMES {
362                let path = home.join(filename);
363                if path.exists() {
364                    if let Ok(user_config) = McpConfig::from_file(&path) {
365                        tracing::info!("Loaded user MCP config from {:?}", path);
366                        config = config.merge(user_config);
367                        break;
368                    }
369                }
370            }
371        }
372    }
373
374    // 2. 加载项目级配置(覆盖)
375    for filename in MCP_CONFIG_FILENAMES {
376        let path = start_dir.join(filename);
377        if path.exists() {
378            if let Ok(project_config) = McpConfig::from_file(&path) {
379                tracing::info!("Loaded project-level MCP config from {:?}", path);
380                config = config.merge(project_config);
381                break;
382            }
383        }
384    }
385
386    config
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn test_server_config_stdio() {
395        let config = McpServerConfig::stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
396
397        assert!(config.command.is_some());
398        assert!(config.url.is_none());
399        assert!(config.enabled);
400
401        let transport = config.to_transport_config().unwrap();
402        match transport {
403            TransportConfig::Stdio { command, args, .. } => {
404                assert_eq!(command, "npx");
405                assert_eq!(args.len(), 2);
406            }
407            _ => panic!("Expected Stdio transport"),
408        }
409    }
410
411    #[test]
412    fn test_server_config_sse() {
413        let config = McpServerConfig::sse("http://localhost:3000");
414
415        assert!(config.command.is_none());
416        assert!(config.url.is_some());
417
418        let transport = config.to_transport_config().unwrap();
419        match transport {
420            TransportConfig::Sse { url, .. } => {
421                assert_eq!(url, "http://localhost:3000");
422            }
423            _ => panic!("Expected SSE transport"),
424        }
425    }
426
427    #[test]
428    fn test_config_serialization() {
429        let config = McpConfig::new().add_server(
430            "playwright",
431            McpServerConfig::stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]),
432        );
433
434        // TOML 序列化
435        let toml = toml::to_string(&config).unwrap();
436        assert!(toml.contains("[servers.playwright]"));
437
438        // 反序列化
439        let parsed: McpConfig = toml::from_str(&toml).unwrap();
440        assert!(parsed.servers.contains_key("playwright"));
441    }
442
443    #[test]
444    fn test_playwright_config() {
445        let config = playwright_config();
446        assert!(config.servers.contains_key("playwright"));
447
448        let server = &config.servers["playwright"];
449        assert_eq!(server.command, Some("npx".to_string()));
450    }
451}