Skip to main content

atomcode_core/mcp/
config.rs

1//! MCP configuration loading.
2
3use std::collections::BTreeMap;
4use std::path::Path;
5
6use anyhow::{bail, Context, Result};
7use serde_json::{json, Map, Value};
8
9/// MCP server transport configuration.
10#[derive(Debug, Clone)]
11pub enum McpTransportConfig {
12    Stdio {
13        command: String,
14        args: Vec<String>,
15        env: BTreeMap<String, String>,
16        timeout_ms: Option<u64>,
17    },
18    Http {
19        url: String,
20        headers: BTreeMap<String, String>,
21        auth: Option<McpHttpAuthConfig>,
22        timeout_ms: Option<u64>,
23    },
24}
25
26/// Authentication configuration for HTTP MCP servers.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub enum McpHttpAuthConfig {
29    OAuth(McpOAuthConfig),
30}
31
32/// OAuth configuration for HTTP MCP servers.
33#[derive(Debug, Clone, Default, PartialEq, Eq)]
34pub struct McpOAuthConfig {
35    /// Human-readable/provider compatibility name. Older configs only set this.
36    pub provider: Option<String>,
37    /// Optional authorization server issuer URL.
38    pub issuer: Option<String>,
39    /// Optional protected resource metadata URL or fixed resource identifier.
40    pub resource: Option<String>,
41    /// Optional pre-registered OAuth client id.
42    pub client_id: Option<String>,
43    /// Optional environment variable containing a confidential client secret.
44    pub client_secret_env: Option<String>,
45    /// Optional requested scopes.
46    pub scopes: Vec<String>,
47}
48
49/// MCP server configuration.
50#[derive(Debug, Clone)]
51pub struct McpServerConfig {
52    pub name: String,
53    pub disabled: bool,
54    pub config: McpTransportConfig,
55    /// Where this server config was loaded from (user-level or project-level).
56    pub source: McpConfigSource,
57}
58
59/// Configuration source for a server.
60#[derive(Debug, Clone, Copy, PartialEq)]
61pub enum McpConfigSource {
62    Project,
63    User,
64}
65
66impl McpConfigSource {
67    /// Returns the string representation for telemetry JSON.
68    pub fn as_str(self) -> &'static str {
69        match self {
70            McpConfigSource::Project => "project",
71            McpConfigSource::User => "global",
72        }
73    }
74}
75
76/// Raw MCP config file format (for deserialization).
77#[derive(Debug, Deserialize)]
78struct McpConfigFile {
79    /// JSON key `mcpServers`(与 Cursor 等工具一致);`servers` 仍可作为别名读取旧配置。
80    #[serde(default, rename = "mcpServers", alias = "servers")]
81    mcp_servers: BTreeMap<String, McpServerEntry>,
82}
83
84#[derive(Debug, Deserialize)]
85struct McpServerEntry {
86    /// Ignored for transport selection (stdio vs HTTP is inferred from `command` vs `url`).
87    /// Accepted so configs copied from Claude / Cursor validate.
88    #[serde(default, rename = "type")]
89    _transport_hint: Option<String>,
90    #[serde(default)]
91    disabled: bool,
92    #[serde(default)]
93    command: Option<String>,
94    #[serde(default)]
95    args: Option<Vec<String>>,
96    #[serde(default)]
97    env: Option<BTreeMap<String, String>>,
98    #[serde(default)]
99    url: Option<String>,
100    #[serde(default)]
101    headers: Option<BTreeMap<String, String>>,
102    #[serde(default)]
103    auth: Option<McpAuthEntry>,
104    #[serde(default)]
105    timeout_ms: Option<u64>,
106}
107
108use serde::Deserialize;
109
110#[derive(Debug, Deserialize)]
111struct McpAuthEntry {
112    #[serde(default, rename = "type")]
113    kind: Option<String>,
114    #[serde(default)]
115    provider: Option<String>,
116    #[serde(default)]
117    issuer: Option<String>,
118    #[serde(default)]
119    resource: Option<String>,
120    #[serde(default)]
121    client_id: Option<String>,
122    #[serde(default)]
123    client_secret_env: Option<String>,
124    #[serde(default)]
125    scopes: Vec<String>,
126    #[serde(default)]
127    bearer: Option<String>,
128    #[serde(default)]
129    header: Option<String>,
130}
131
132#[derive(Debug, Clone)]
133struct ParsedHttpAuth {
134    oauth: Option<McpHttpAuthConfig>,
135    headers: BTreeMap<String, String>,
136}
137
138/// Load and merge MCP configurations from project and user levels.
139///
140/// Project config (`.mcp.json` in project root) overrides user config
141/// (`ATOMCODE_HOME/mcp.json`) for servers with the same name.
142pub fn load_mcp_config(project_dir: &Path) -> Result<Vec<McpServerConfig>> {
143    let user_config = load_config_file(
144        &crate::config::Config::config_dir().join("mcp.json"),
145        McpConfigSource::User,
146    )
147    .unwrap_or_default();
148
149    let project_config = load_config_file(&project_dir.join(".mcp.json"), McpConfigSource::Project)
150        .unwrap_or_default();
151
152    // Merge: project overrides user
153    let mut merged: BTreeMap<String, McpServerConfig> = BTreeMap::new();
154
155    for config in user_config {
156        merged.insert(config.name.clone(), config);
157    }
158
159    for config in project_config {
160        merged.insert(config.name.clone(), config);
161    }
162
163    Ok(merged.into_values().filter(|c| !c.disabled).collect())
164}
165
166fn load_config_file(path: &Path, source: McpConfigSource) -> Result<Vec<McpServerConfig>> {
167    if !path.exists() {
168        return Ok(Vec::new());
169    }
170
171    let content = std::fs::read_to_string(path)
172        .with_context(|| format!("Failed to read MCP config from {}", path.display()))?;
173
174    let raw: McpConfigFile = serde_json::from_str(&content)
175        .with_context(|| format!("Failed to parse MCP config from {}", path.display()))?;
176
177    let mut configs = Vec::new();
178
179    for (name, entry) in raw.mcp_servers {
180        let mut config = server_entry_to_config(&name, entry)?;
181        config.source = source;
182        configs.push(config);
183    }
184
185    Ok(configs)
186}
187
188fn server_entry_to_config(name: &str, entry: McpServerEntry) -> Result<McpServerConfig> {
189    let transport = if let Some(command) = entry.command {
190        McpTransportConfig::Stdio {
191            command: expand_tilde(&expand_env_vars(&command)),
192            args: entry
193                .args
194                .unwrap_or_default()
195                .into_iter()
196                .map(|a| expand_tilde(&expand_env_vars(&a)))
197                .collect(),
198            env: entry
199                .env
200                .unwrap_or_default()
201                .into_iter()
202                .map(|(k, v)| (k, expand_env_vars(&v)))
203                .collect(),
204            timeout_ms: entry.timeout_ms,
205        }
206    } else if let Some(url) = entry.url {
207        let parsed_auth = parse_http_auth(name, entry.auth)?;
208        let mut headers: BTreeMap<String, String> = entry
209            .headers
210            .unwrap_or_default()
211            .into_iter()
212            .map(|(k, v)| (k, expand_env_vars(&v)))
213            .collect();
214        for (k, v) in parsed_auth.headers {
215            headers.entry(k).or_insert(v);
216        }
217        McpTransportConfig::Http {
218            url: expand_tilde(&expand_env_vars(&url)),
219            headers,
220            auth: parsed_auth.oauth,
221            timeout_ms: entry.timeout_ms,
222        }
223    } else {
224        bail!(
225            "MCP server '{}' must have either 'command' (stdio) or 'url' (http)",
226            name
227        );
228    };
229
230    Ok(McpServerConfig {
231        name: name.to_string(),
232        disabled: entry.disabled,
233        config: transport,
234        source: McpConfigSource::Project, // default; overwritten by load_config_file
235    })
236}
237
238fn parse_http_auth(name: &str, auth: Option<McpAuthEntry>) -> Result<ParsedHttpAuth> {
239    let mut parsed = ParsedHttpAuth {
240        oauth: None,
241        headers: BTreeMap::new(),
242    };
243    let Some(auth) = auth else {
244        return Ok(parsed);
245    };
246
247    if let (Some(header), Some(bearer)) = (auth.header, auth.bearer) {
248        parsed.headers.insert(header, expand_env_vars(&bearer));
249    }
250
251    match auth.kind.as_deref() {
252        Some("oauth") => {
253            parsed.oauth = Some(McpHttpAuthConfig::OAuth(McpOAuthConfig {
254                provider: Some(auth.provider.unwrap_or_else(|| name.to_string())),
255                issuer: auth.issuer.map(|v| expand_env_vars(&v)),
256                resource: auth.resource.map(|v| expand_env_vars(&v)),
257                client_id: auth.client_id.map(|v| expand_env_vars(&v)),
258                client_secret_env: auth.client_secret_env,
259                scopes: auth
260                    .scopes
261                    .into_iter()
262                    .map(|s| expand_env_vars(&s))
263                    .collect(),
264            }));
265            Ok(parsed)
266        }
267        Some(other) => bail!(
268            "MCP server '{}' has unsupported auth.type '{}'",
269            name,
270            other
271        ),
272        None => Ok(parsed),
273    }
274}
275
276fn collect_merged_mcp_server_maps(root: &Map<String, Value>) -> Map<String, Value> {
277    let mut out = Map::new();
278    if let Some(Value::Object(m)) = root.get("servers") {
279        for (k, v) in m {
280            out.insert(k.clone(), v.clone());
281        }
282    }
283    if let Some(Value::Object(m)) = root.get("mcpServers") {
284        for (k, v) in m {
285            out.insert(k.clone(), v.clone());
286        }
287    }
288    out
289}
290
291/// Add or replace a **stdio** MCP server entry in a JSON config file (`.mcp.json` or `$ATOMCODE_HOME/mcp.json`).
292///
293/// Merges existing `servers` and `mcpServers` maps, then writes a single `mcpServers` object (drops the legacy
294/// `servers` key). Other top-level JSON keys are preserved.
295pub fn merge_stdio_mcp_server_into_json_file(
296    path: &Path,
297    server_key: &str,
298    program: &str,
299    args: &[String],
300) -> Result<()> {
301    if server_key.is_empty() {
302        bail!("MCP server name must not be empty");
303    }
304    if program.is_empty() {
305        bail!("command must not be empty");
306    }
307
308    let mut root: Value = if path.exists() {
309        let text = std::fs::read_to_string(path)
310            .with_context(|| format!("Failed to read MCP config from {}", path.display()))?;
311        serde_json::from_str(&text)
312            .with_context(|| format!("Failed to parse MCP config JSON from {}", path.display()))?
313    } else {
314        json!({})
315    };
316
317    let root_obj = root
318        .as_object_mut()
319        .ok_or_else(|| anyhow::anyhow!("MCP config root must be a JSON object"))?;
320
321    let mut servers = collect_merged_mcp_server_maps(root_obj);
322    let entry = json!({
323        "command": program,
324        "args": args,
325    });
326    servers.insert(server_key.to_string(), entry);
327    root_obj.insert("mcpServers".to_string(), Value::Object(servers));
328    root_obj.remove("servers");
329
330    if let Some(parent) = path.parent() {
331        if !parent.as_os_str().is_empty() {
332            std::fs::create_dir_all(parent).with_context(|| {
333                format!("Failed to create parent directory for {}", path.display())
334            })?;
335        }
336    }
337
338    let text = serde_json::to_string_pretty(&root).context("Failed to serialize MCP config")?;
339    std::fs::write(path, format!("{text}\n"))
340        .with_context(|| format!("Failed to write MCP config to {}", path.display()))?;
341
342    Ok(())
343}
344
345/// Add or replace an **HTTP OAuth** MCP server entry in a JSON config file.
346pub fn merge_http_oauth_mcp_server_into_json_file(
347    path: &Path,
348    server_key: &str,
349    url: &str,
350    provider: &str,
351) -> Result<()> {
352    if server_key.is_empty() {
353        bail!("MCP server name must not be empty");
354    }
355    if url.is_empty() {
356        bail!("url must not be empty");
357    }
358    if provider.is_empty() {
359        bail!("provider must not be empty");
360    }
361
362    let mut root: Value = if path.exists() {
363        let text = std::fs::read_to_string(path)
364            .with_context(|| format!("Failed to read MCP config from {}", path.display()))?;
365        serde_json::from_str(&text)
366            .with_context(|| format!("Failed to parse MCP config JSON from {}", path.display()))?
367    } else {
368        json!({})
369    };
370
371    let root_obj = root
372        .as_object_mut()
373        .ok_or_else(|| anyhow::anyhow!("MCP config root must be a JSON object"))?;
374
375    let mut servers = collect_merged_mcp_server_maps(root_obj);
376    let entry = json!({
377        "url": url,
378        "auth": {
379            "type": "oauth",
380            "provider": provider,
381        },
382    });
383    servers.insert(server_key.to_string(), entry);
384    root_obj.insert("mcpServers".to_string(), Value::Object(servers));
385    root_obj.remove("servers");
386
387    if let Some(parent) = path.parent() {
388        if !parent.as_os_str().is_empty() {
389            std::fs::create_dir_all(parent).with_context(|| {
390                format!("Failed to create parent directory for {}", path.display())
391            })?;
392        }
393    }
394
395    let pretty = serde_json::to_string_pretty(&root).context("Failed to serialize MCP config")?;
396    std::fs::write(path, format!("{}\n", pretty))
397        .with_context(|| format!("Failed to write MCP config to {}", path.display()))?;
398    Ok(())
399}
400
401/// Expand environment variables in a string.
402///
403/// Supports `${VAR}` and `${VAR:-default}` syntax.
404fn expand_env_vars(s: &str) -> String {
405    let mut result = String::with_capacity(s.len());
406    let bytes = s.as_bytes();
407    let mut i = 0;
408
409    while i < bytes.len() {
410        if bytes[i] == b'$' && i + 1 < bytes.len() && bytes[i + 1] == b'{' {
411            i += 2; // skip ${
412
413            let mut var_name = String::new();
414            let mut default = String::new();
415            let mut has_default = false;
416
417            while i < bytes.len() && bytes[i] != b'}' {
418                if bytes[i] == b':' && !has_default && i + 1 < bytes.len() && bytes[i + 1] == b'-' {
419                    i += 2; // skip :-
420                    has_default = true;
421                    continue;
422                }
423                if has_default {
424                    default.push(bytes[i] as char);
425                } else {
426                    var_name.push(bytes[i] as char);
427                }
428                i += 1;
429            }
430            if i < bytes.len() {
431                i += 1; // skip }
432            }
433
434            let value = std::env::var(&var_name).unwrap_or_else(|_| {
435                if has_default {
436                    default
437                } else {
438                    String::new()
439                }
440            });
441            result.push_str(&value);
442        } else {
443            result.push(bytes[i] as char);
444            i += 1;
445        }
446    }
447
448    result
449}
450
451/// Expand a leading `~` (home) in a string.
452///
453/// - `~/path` → `$HOME/path`
454/// - `~` → `$HOME`
455/// - Other forms (e.g. `~user/...`) are left unchanged.
456fn expand_tilde(s: &str) -> String {
457    if s == "~" {
458        return crate::tool::real_home_dir()
459            .map(|h| h.to_string_lossy().to_string())
460            .unwrap_or_else(|| s.to_string());
461    }
462    let Some(rest) = s.strip_prefix("~/") else {
463        return s.to_string();
464    };
465    let Some(home) = crate::tool::real_home_dir() else {
466        return s.to_string();
467    };
468    home.join(rest).to_string_lossy().to_string()
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use serde_json::Value;
475
476    #[test]
477    fn test_expand_env_vars_simple() {
478        std::env::set_var("TEST_VAR", "test_value");
479        let result = expand_env_vars("${TEST_VAR}");
480        assert_eq!(result, "test_value");
481    }
482
483    #[test]
484    fn test_expand_env_vars_with_default() {
485        std::env::remove_var("NONEXISTENT_VAR");
486        let result = expand_env_vars("${NONEXISTENT_VAR:-default_value}");
487        assert_eq!(result, "default_value");
488    }
489
490    #[test]
491    fn test_expand_env_vars_existing_with_default() {
492        std::env::set_var("EXISTING_VAR", "actual");
493        let result = expand_env_vars("${EXISTING_VAR:-unused}");
494        assert_eq!(result, "actual");
495    }
496
497    #[test]
498    fn test_expand_env_vars_no_var() {
499        std::env::remove_var("MISSING_VAR");
500        let result = expand_env_vars("${MISSING_VAR}");
501        assert_eq!(result, "");
502    }
503
504    #[test]
505    fn test_expand_env_vars_mixed() {
506        std::env::set_var("VAR1", "a");
507        std::env::set_var("VAR2", "b");
508        let result = expand_env_vars("prefix_${VAR1}_middle_${VAR2}_suffix");
509        assert_eq!(result, "prefix_a_middle_b_suffix");
510    }
511
512    #[test]
513    fn test_expand_tilde_home_only() {
514        let Some(home) = crate::tool::real_home_dir() else {
515            return;
516        };
517        assert_eq!(expand_tilde("~"), home.to_string_lossy());
518    }
519
520    #[test]
521    fn test_expand_tilde_home_prefix() {
522        let Some(home) = crate::tool::real_home_dir() else {
523            return;
524        };
525        assert_eq!(
526            expand_tilde("~/x/y"),
527            home.join("x/y").to_string_lossy().to_string()
528        );
529    }
530
531    #[test]
532    fn test_expand_tilde_does_not_expand_other_forms() {
533        assert_eq!(expand_tilde("~someone/x"), "~someone/x");
534        assert_eq!(expand_tilde("/abs/path"), "/abs/path");
535    }
536
537    #[test]
538    fn mcp_config_file_accepts_mcp_servers_key() {
539        let raw: McpConfigFile =
540            serde_json::from_str(r#"{"mcpServers":{"a":{"command":"echo","args":[]}}}"#).unwrap();
541        assert!(raw.mcp_servers.contains_key("a"));
542    }
543
544    #[test]
545    fn mcp_config_file_accepts_servers_alias() {
546        let raw: McpConfigFile =
547            serde_json::from_str(r#"{"servers":{"b":{"command":"echo","args":[]}}}"#).unwrap();
548        assert!(raw.mcp_servers.contains_key("b"));
549    }
550
551    #[test]
552    fn merge_stdio_creates_mcp_servers() {
553        let dir = tempfile::tempdir().unwrap();
554        let path = dir.path().join("mcp.json");
555        merge_stdio_mcp_server_into_json_file(&path, "p", "npx", &["@x/y".to_string()]).unwrap();
556        let v: Value = serde_json::from_str(&std::fs::read_to_string(&path).unwrap()).unwrap();
557        let p = v["mcpServers"]["p"].as_object().unwrap();
558        assert_eq!(p["command"].as_str(), Some("npx"));
559        assert_eq!(p["args"].as_array().unwrap()[0].as_str(), Some("@x/y"));
560    }
561
562    #[test]
563    fn merge_stdio_preserves_other_top_level_keys() {
564        let dir = tempfile::tempdir().unwrap();
565        let path = dir.path().join("mcp.json");
566        std::fs::write(
567            &path,
568            r#"{"note":"keep","mcpServers":{"old":{"command":"true","args":[]}}}"#,
569        )
570        .unwrap();
571        merge_stdio_mcp_server_into_json_file(&path, "new", "uv", &[]).unwrap();
572        let v: Value = serde_json::from_str(&std::fs::read_to_string(&path).unwrap()).unwrap();
573        assert_eq!(v.get("note").and_then(|x| x.as_str()), Some("keep"));
574        let m = v.get("mcpServers").unwrap().as_object().unwrap();
575        assert!(m.contains_key("old"));
576        assert!(m.contains_key("new"));
577    }
578
579    #[test]
580    fn http_config_accepts_oauth_auth() {
581        let cfg = server_entry_to_config(
582            "github",
583            serde_json::from_str(
584                r#"{
585                    "url":"https://api.githubcopilot.com/mcp/",
586                    "auth":{"type":"oauth","provider":"github"}
587                }"#,
588            )
589            .unwrap(),
590        )
591        .unwrap();
592        match cfg.config {
593            McpTransportConfig::Http { auth, .. } => {
594                assert_eq!(
595                    auth,
596                    Some(McpHttpAuthConfig::OAuth(McpOAuthConfig {
597                        provider: Some("github".to_string()),
598                        ..McpOAuthConfig::default()
599                    }))
600                );
601            }
602            _ => panic!("expected http config"),
603        }
604    }
605
606    #[test]
607    fn http_config_accepts_generic_oauth_auth() {
608        let cfg = server_entry_to_config(
609            "notion",
610            serde_json::from_str(
611                r#"{
612                    "url":"https://mcp.notion.com/mcp",
613                    "auth":{
614                        "type":"oauth",
615                        "issuer":"https://mcp.notion.com",
616                        "resource":"https://mcp.notion.com/mcp",
617                        "client_id":"client",
618                        "client_secret_env":"NOTION_SECRET",
619                        "scopes":["read","write"]
620                    }
621                }"#,
622            )
623            .unwrap(),
624        )
625        .unwrap();
626        match cfg.config {
627            McpTransportConfig::Http { auth, .. } => {
628                assert_eq!(
629                    auth,
630                    Some(McpHttpAuthConfig::OAuth(McpOAuthConfig {
631                        provider: Some("notion".to_string()),
632                        issuer: Some("https://mcp.notion.com".to_string()),
633                        resource: Some("https://mcp.notion.com/mcp".to_string()),
634                        client_id: Some("client".to_string()),
635                        client_secret_env: Some("NOTION_SECRET".to_string()),
636                        scopes: vec!["read".to_string(), "write".to_string()],
637                    }))
638                );
639            }
640            _ => panic!("expected http config"),
641        }
642    }
643
644    #[test]
645    fn http_config_accepts_bearer_header_auth_without_type() {
646        let cfg = server_entry_to_config(
647            "figma",
648            serde_json::from_str(
649                r#"{
650                    "url":"https://mcp.figma.com/mcp",
651                    "auth":{"bearer":"figd_token","header":"X-Figma-Token"}
652                }"#,
653            )
654            .unwrap(),
655        )
656        .unwrap();
657        match cfg.config {
658            McpTransportConfig::Http { headers, auth, .. } => {
659                assert_eq!(
660                    headers.get("X-Figma-Token").map(String::as_str),
661                    Some("figd_token")
662                );
663                assert_eq!(auth, None);
664            }
665            _ => panic!("expected http config"),
666        }
667    }
668
669    #[test]
670    fn merge_http_oauth_creates_mcp_servers() {
671        let dir = tempfile::tempdir().unwrap();
672        let path = dir.path().join("mcp.json");
673        merge_http_oauth_mcp_server_into_json_file(
674            &path,
675            "github",
676            "https://api.githubcopilot.com/mcp/",
677            "github",
678        )
679        .unwrap();
680        let v: Value = serde_json::from_str(&std::fs::read_to_string(&path).unwrap()).unwrap();
681        let p = v["mcpServers"]["github"].as_object().unwrap();
682        assert_eq!(
683            p["url"].as_str(),
684            Some("https://api.githubcopilot.com/mcp/")
685        );
686        assert_eq!(p["auth"]["type"].as_str(), Some("oauth"));
687        assert_eq!(p["auth"]["provider"].as_str(), Some("github"));
688    }
689}