1use anyhow::{anyhow, Result};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::Path;
9
10use super::transport::TransportConfig;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct McpServerConfig {
19 #[serde(skip_serializing_if = "Option::is_none")]
21 pub name: Option<String>,
22
23 #[serde(skip_serializing_if = "Option::is_none")]
25 pub command: Option<String>,
26
27 #[serde(default)]
29 pub args: Vec<String>,
30
31 #[serde(default)]
33 pub env: HashMap<String, String>,
34
35 #[serde(skip_serializing_if = "Option::is_none")]
37 pub url: Option<String>,
38
39 #[serde(default = "default_timeout")]
41 pub timeout_ms: u64,
42
43 #[serde(default = "default_enabled")]
45 pub enabled: bool,
46}
47
48impl Default for McpServerConfig {
49 fn default() -> Self {
50 Self {
51 name: None,
52 command: None,
53 args: Vec::new(),
54 env: HashMap::new(),
55 url: None,
56 timeout_ms: default_timeout(),
57 enabled: default_enabled(),
58 }
59 }
60}
61
62fn default_timeout() -> u64 {
63 30000
64}
65
66fn default_enabled() -> bool {
67 true
68}
69
70impl McpServerConfig {
71 pub fn stdio(command: impl Into<String>, args: Vec<String>) -> Self {
73 Self {
74 name: None,
75 command: Some(command.into()),
76 args,
77 env: HashMap::new(),
78 url: None,
79 timeout_ms: default_timeout(),
80 enabled: true,
81 }
82 }
83
84 pub fn sse(url: impl Into<String>) -> Self {
86 Self {
87 name: None,
88 command: None,
89 args: Vec::new(),
90 env: HashMap::new(),
91 url: Some(url.into()),
92 timeout_ms: default_timeout(),
93 enabled: true,
94 }
95 }
96
97 pub fn with_name(mut self, name: impl Into<String>) -> Self {
99 self.name = Some(name.into());
100 self
101 }
102
103 pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
105 self.env.insert(key.into(), value.into());
106 self
107 }
108
109 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
111 self.timeout_ms = timeout_ms;
112 self
113 }
114
115 pub fn with_enabled(mut self, enabled: bool) -> Self {
117 self.enabled = enabled;
118 self
119 }
120
121 pub fn to_transport_config(&self) -> Result<TransportConfig> {
123 if let Some(command) = &self.command {
124 let env_vec: Vec<(String, String)> = self.env.iter()
126 .map(|(k, v)| (k.clone(), v.clone()))
127 .collect();
128
129 Ok(TransportConfig::Stdio {
130 command: command.clone(),
131 args: self.args.clone(),
132 env: if env_vec.is_empty() { None } else { Some(env_vec) },
133 })
134 } else if let Some(url) = &self.url {
135 Ok(TransportConfig::Sse {
137 url: url.clone(),
138 timeout_ms: Some(self.timeout_ms),
139 })
140 } else {
141 Err(anyhow!("MCP server config must have either 'command' or 'url'"))
142 }
143 }
144
145 pub fn get_name(&self, key: &str) -> String {
147 self.name.clone().unwrap_or_else(|| key.to_string())
148 }
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize, Default)]
153pub struct McpConfig {
154 #[serde(default)]
156 pub servers: HashMap<String, McpServerConfig>,
157
158 #[serde(default)]
160 pub settings: McpSettings,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct McpSettings {
166 #[serde(default = "default_auto_discover")]
168 pub auto_discover: bool,
169
170 #[serde(default = "default_connect_timeout")]
172 pub connect_timeout_ms: u64,
173}
174
175fn default_auto_discover() -> bool {
176 true
177}
178
179fn default_connect_timeout() -> u64 {
180 10000
181}
182
183impl Default for McpSettings {
184 fn default() -> Self {
185 Self {
186 auto_discover: default_auto_discover(),
187 connect_timeout_ms: default_connect_timeout(),
188 }
189 }
190}
191
192impl McpConfig {
193 pub fn new() -> Self {
195 Self::default()
196 }
197
198 pub fn add_server(mut self, key: impl Into<String>, config: McpServerConfig) -> Self {
200 self.servers.insert(key.into(), config);
201 self
202 }
203
204 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
206 let content = std::fs::read_to_string(path.as_ref())
207 .map_err(|e| anyhow!("Failed to read MCP config file: {}", e))?;
208
209 Self::from_str(&content)
210 }
211
212 pub fn from_str(content: &str) -> Result<Self> {
214 if let Ok(config) = toml::from_str(content) {
216 return Ok(config);
217 }
218
219 if let Ok(config) = serde_json::from_str(content) {
221 return Ok(config);
222 }
223
224 Err(anyhow!("Failed to parse MCP config as TOML or JSON"))
225 }
226
227 pub fn to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
229 let content = toml::to_string_pretty(self)
230 .map_err(|e| anyhow!("Failed to serialize MCP config: {}", e))?;
231
232 std::fs::write(path.as_ref(), content)
233 .map_err(|e| anyhow!("Failed to write MCP config file: {}", e))?;
234
235 Ok(())
236 }
237
238 pub fn enabled_servers(&self) -> Vec<(String, &McpServerConfig)> {
240 self.servers
241 .iter()
242 .filter(|(_, config)| config.enabled)
243 .map(|(key, config)| (key.clone(), config))
244 .collect()
245 }
246
247 pub fn merge(mut self, other: McpConfig) -> Self {
249 for (key, config) in other.servers {
251 self.servers.insert(key, config);
252 }
253
254 if !other.settings.auto_discover {
256 self.settings.auto_discover = false;
257 }
258 if other.settings.connect_timeout_ms != default_connect_timeout() {
259 self.settings.connect_timeout_ms = other.settings.connect_timeout_ms;
260 }
261
262 self
263 }
264}
265
266pub fn playwright_config() -> McpConfig {
272 McpConfig::new()
273 .add_server("playwright", McpServerConfig::stdio(
274 "npx",
275 vec!["-y".into(), "@playwright/mcp@latest".into()]
276 ))
277}
278
279pub fn default_mcp_config() -> McpConfig {
281 McpConfig::new()
282 .add_server("playwright", McpServerConfig::stdio(
284 "npx",
285 vec!["-y".into(), "@playwright/mcp@latest".into()]
286 ))
287 }
293
294pub const MCP_CONFIG_FILENAMES: &[&str] = &[
300 "mcp.toml",
301 "mcp.json",
302 ".mcp.toml",
303 ".mcp.json",
304];
305
306pub fn find_mcp_config(start_dir: &Path) -> Option<std::path::PathBuf> {
308 for filename in MCP_CONFIG_FILENAMES {
310 let path = start_dir.join(filename);
311 if path.exists() {
312 return Some(path);
313 }
314 }
315
316 if let Some(home) = dirs::home_dir() {
318 let matrixcode_dir = home.join(".matrixcode");
319 for filename in MCP_CONFIG_FILENAMES {
320 let path = matrixcode_dir.join(filename);
321 if path.exists() {
322 return Some(path);
323 }
324 }
325
326 for filename in MCP_CONFIG_FILENAMES {
328 let path = home.join(filename);
329 if path.exists() {
330 return Some(path);
331 }
332 }
333 }
334
335 None
336}
337
338pub fn load_mcp_config(start_dir: &Path) -> McpConfig {
340 let mut config = McpConfig::new();
341
342 if let Some(home) = dirs::home_dir() {
344 let matrixcode_dir = home.join(".matrixcode");
346 for filename in MCP_CONFIG_FILENAMES {
347 let path = matrixcode_dir.join(filename);
348 if path.exists() {
349 if let Ok(user_config) = McpConfig::from_file(&path) {
350 tracing::info!("Loaded user-level MCP config from {:?}", path);
351 config = config.merge(user_config);
352 break;
353 }
354 }
355 }
356
357 if config.servers.is_empty() {
359 for filename in MCP_CONFIG_FILENAMES {
360 let path = home.join(filename);
361 if path.exists() {
362 if let Ok(user_config) = McpConfig::from_file(&path) {
363 tracing::info!("Loaded user MCP config from {:?}", path);
364 config = config.merge(user_config);
365 break;
366 }
367 }
368 }
369 }
370 }
371
372 for filename in MCP_CONFIG_FILENAMES {
374 let path = start_dir.join(filename);
375 if path.exists() {
376 if let Ok(project_config) = McpConfig::from_file(&path) {
377 tracing::info!("Loaded project-level MCP config from {:?}", path);
378 config = config.merge(project_config);
379 break;
380 }
381 }
382 }
383
384 config
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn test_server_config_stdio() {
393 let config = McpServerConfig::stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
394
395 assert!(config.command.is_some());
396 assert!(config.url.is_none());
397 assert!(config.enabled);
398
399 let transport = config.to_transport_config().unwrap();
400 match transport {
401 TransportConfig::Stdio { command, args, .. } => {
402 assert_eq!(command, "npx");
403 assert_eq!(args.len(), 2);
404 }
405 _ => panic!("Expected Stdio transport"),
406 }
407 }
408
409 #[test]
410 fn test_server_config_sse() {
411 let config = McpServerConfig::sse("http://localhost:3000");
412
413 assert!(config.command.is_none());
414 assert!(config.url.is_some());
415
416 let transport = config.to_transport_config().unwrap();
417 match transport {
418 TransportConfig::Sse { url, .. } => {
419 assert_eq!(url, "http://localhost:3000");
420 }
421 _ => panic!("Expected SSE transport"),
422 }
423 }
424
425 #[test]
426 fn test_config_serialization() {
427 let config = McpConfig::new()
428 .add_server("playwright", McpServerConfig::stdio(
429 "npx",
430 vec!["-y".into(), "@playwright/mcp".into()]
431 ));
432
433 let toml = toml::to_string(&config).unwrap();
435 assert!(toml.contains("[servers.playwright]"));
436
437 let parsed: McpConfig = toml::from_str(&toml).unwrap();
439 assert!(parsed.servers.contains_key("playwright"));
440 }
441
442 #[test]
443 fn test_playwright_config() {
444 let config = playwright_config();
445 assert!(config.servers.contains_key("playwright"));
446
447 let server = &config.servers["playwright"];
448 assert_eq!(server.command, Some("npx".to_string()));
449 }
450}