Skip to main content

capo_agent/mcp/
config.rs

1#![cfg_attr(test, allow(clippy::expect_used, clippy::unwrap_used))]
2
3use std::collections::HashMap;
4use std::path::Path;
5
6use serde::Deserialize;
7
8#[derive(Debug, Clone, Default, Deserialize, PartialEq)]
9pub struct McpConfig {
10    #[serde(default)]
11    pub servers: HashMap<String, McpServerConfig>,
12}
13
14#[derive(Debug, Clone, Deserialize, PartialEq)]
15#[serde(tag = "transport", rename_all = "snake_case")]
16pub enum McpServerConfig {
17    Stdio {
18        command: String,
19        #[serde(default)]
20        args: Vec<String>,
21        #[serde(default)]
22        env: HashMap<String, String>,
23        #[serde(default = "default_timeout_ms")]
24        startup_timeout_ms: u64,
25        #[serde(default = "default_enabled")]
26        enabled: bool,
27    },
28    Http {
29        url: String,
30        #[serde(default)]
31        headers: HashMap<String, String>,
32        #[serde(default = "default_timeout_ms")]
33        startup_timeout_ms: u64,
34        #[serde(default = "default_enabled")]
35        enabled: bool,
36    },
37}
38
39fn default_timeout_ms() -> u64 {
40    10_000
41}
42fn default_enabled() -> bool {
43    true
44}
45
46pub fn parse_str(raw: &str) -> Result<McpConfig, toml::de::Error> {
47    toml::from_str(raw)
48}
49
50#[cfg(test)]
51mod tests {
52    use super::*;
53    use pretty_assertions::assert_eq;
54
55    #[test]
56    fn parses_stdio_server() {
57        let raw = r#"
58            [servers.github]
59            transport = "stdio"
60            command = "github-mcp-server"
61            args = ["--scope", "read-only"]
62            env = { GITHUB_TOKEN = "abc" }
63            startup_timeout_ms = 5000
64        "#;
65        let cfg = parse_str(raw).unwrap();
66        let s = cfg.servers.get("github").unwrap();
67        match s {
68            McpServerConfig::Stdio {
69                command,
70                args,
71                env,
72                startup_timeout_ms,
73                enabled,
74            } => {
75                assert_eq!(command, "github-mcp-server");
76                assert_eq!(args, &vec!["--scope".to_string(), "read-only".into()]);
77                assert_eq!(env.get("GITHUB_TOKEN").map(String::as_str), Some("abc"));
78                assert_eq!(*startup_timeout_ms, 5000);
79                assert!(*enabled);
80            }
81            _ => panic!("expected stdio"),
82        }
83    }
84
85    #[test]
86    fn parses_http_server() {
87        let raw = r#"
88            [servers.sentry]
89            transport = "http"
90            url = "https://mcp.sentry.io/v1"
91            headers = { Authorization = "Bearer t" }
92        "#;
93        let cfg = parse_str(raw).unwrap();
94        match cfg.servers.get("sentry").unwrap() {
95            McpServerConfig::Http {
96                url,
97                headers,
98                startup_timeout_ms,
99                enabled,
100            } => {
101                assert_eq!(url, "https://mcp.sentry.io/v1");
102                assert_eq!(
103                    headers.get("Authorization").map(String::as_str),
104                    Some("Bearer t")
105                );
106                assert_eq!(*startup_timeout_ms, 10_000);
107                assert!(*enabled);
108            }
109            _ => panic!("expected http"),
110        }
111    }
112
113    #[test]
114    fn unknown_transport_errors() {
115        let raw = r#"
116            [servers.x]
117            transport = "carrier-pigeon"
118            url = "..."
119        "#;
120        assert!(parse_str(raw).is_err());
121    }
122}
123
124/// Replace `${VAR}` in s with `lookup(VAR)`. Returns `Err(missing_var_name)`
125/// when any `${VAR}` cannot be resolved. `$$` escapes a literal `$`.
126pub fn expand_env<F>(s: &str, lookup: &F) -> Result<String, String>
127where
128    F: Fn(&str) -> Option<String>,
129{
130    let bytes = s.as_bytes();
131    let mut out = String::with_capacity(s.len());
132    let mut i = 0;
133    while i < bytes.len() {
134        if bytes[i] == b'$' {
135            // $$ → literal $
136            if i + 1 < bytes.len() && bytes[i + 1] == b'$' {
137                out.push('$');
138                i += 2;
139                continue;
140            }
141            // ${VAR}
142            if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
143                if let Some(close) = s[i + 2..].find('}') {
144                    let var = &s[i + 2..i + 2 + close];
145                    match lookup(var) {
146                        Some(v) => {
147                            out.push_str(&v);
148                            i = i + 2 + close + 1;
149                            continue;
150                        }
151                        None => return Err(var.to_string()),
152                    }
153                }
154                // unclosed → literal
155            }
156        }
157        if let Some(ch) = s[i..].chars().next() {
158            out.push(ch);
159            i += ch.len_utf8();
160        } else {
161            break;
162        }
163    }
164    Ok(out)
165}
166
167#[cfg(test)]
168mod expand_env_tests {
169    use super::*;
170
171    fn lk(values: &'static [(&'static str, &'static str)]) -> impl Fn(&str) -> Option<String> {
172        let map: std::collections::HashMap<&str, &str> = values.iter().copied().collect();
173        move |k| map.get(k).map(|v| (*v).to_string())
174    }
175
176    #[test]
177    fn substitutes_single_var() {
178        let s = expand_env("Bearer ${TOK}", &lk(&[("TOK", "abc")])).unwrap();
179        assert_eq!(s, "Bearer abc");
180    }
181
182    #[test]
183    fn substitutes_multiple_vars() {
184        let s = expand_env("${A}-${B}", &lk(&[("A", "x"), ("B", "y")])).unwrap();
185        assert_eq!(s, "x-y");
186    }
187
188    #[test]
189    fn missing_var_errors_with_var_name() {
190        let err = expand_env("${MISSING}", &lk(&[])).unwrap_err();
191        assert_eq!(err, "MISSING");
192    }
193
194    #[test]
195    fn dollar_dollar_is_literal() {
196        let s = expand_env("price: $$5", &lk(&[])).unwrap();
197        assert_eq!(s, "price: $5");
198    }
199
200    #[test]
201    fn unclosed_dollar_brace_is_literal() {
202        let s = expand_env("oops ${INCOMPLETE", &lk(&[])).unwrap();
203        assert_eq!(s, "oops ${INCOMPLETE");
204    }
205
206    #[test]
207    fn empty_string_is_empty() {
208        assert_eq!(expand_env("", &lk(&[])).unwrap(), "");
209    }
210
211    #[test]
212    fn preserves_non_ascii_text() {
213        let s = expand_env("Bearer café-${TOK}", &lk(&[("TOK", "🔑")])).unwrap();
214        assert_eq!(s, "Bearer café-🔑");
215    }
216}
217
218/// Resolve `${VAR}` in every `env` and `headers` value. Stdio servers with
219/// a missing `${VAR}` in their `env` are dropped (returns a diagnostic);
220/// http servers with a missing `${VAR}` in their `headers` are dropped.
221pub fn resolve_env<F>(mut cfg: McpConfig, lookup: &F) -> (McpConfig, Vec<String>)
222where
223    F: Fn(&str) -> Option<String>,
224{
225    let mut diags = Vec::new();
226    cfg.servers.retain(|name, server| {
227        match server {
228            McpServerConfig::Stdio { env, .. } => {
229                for (k, v) in env.iter_mut() {
230                    match expand_env(v, lookup) {
231                        Ok(new) => *v = new,
232                        Err(var) => {
233                            diags.push(format!(
234                                "server `{name}` env `{k}` references unset `${{{var}}}`; server skipped"
235                            ));
236                            return false;
237                        }
238                    }
239                }
240                true
241            }
242            McpServerConfig::Http { headers, .. } => {
243                for (k, v) in headers.iter_mut() {
244                    match expand_env(v, lookup) {
245                        Ok(new) => *v = new,
246                        Err(var) => {
247                            diags.push(format!(
248                                "server `{name}` header `{k}` references unset `${{{var}}}`; server skipped"
249                            ));
250                            return false;
251                        }
252                    }
253                }
254                true
255            }
256        }
257    });
258    (cfg, diags)
259}
260
261#[cfg(test)]
262mod resolve_env_tests {
263    use super::*;
264
265    fn lk(values: &'static [(&'static str, &'static str)]) -> impl Fn(&str) -> Option<String> {
266        let map: std::collections::HashMap<&str, &str> = values.iter().copied().collect();
267        move |k| map.get(k).map(|v| (*v).to_string())
268    }
269
270    #[test]
271    fn resolves_stdio_env() {
272        let cfg = parse_str(
273            r#"
274            [servers.x]
275            transport = "stdio"
276            command = "c"
277            env = { TOK = "${T}" }
278        "#,
279        )
280        .unwrap();
281        let (cfg, diags) = resolve_env(cfg, &lk(&[("T", "value")]));
282        assert!(diags.is_empty());
283        if let McpServerConfig::Stdio { env, .. } = cfg.servers.get("x").unwrap() {
284            assert_eq!(env.get("TOK").map(String::as_str), Some("value"));
285        } else {
286            panic!()
287        }
288    }
289
290    #[test]
291    fn drops_server_with_missing_env_var() {
292        let cfg = parse_str(
293            r#"
294            [servers.bad]
295            transport = "stdio"
296            command = "c"
297            env = { TOK = "${MISSING}" }
298            [servers.good]
299            transport = "stdio"
300            command = "c"
301        "#,
302        )
303        .unwrap();
304        let (cfg, diags) = resolve_env(cfg, &lk(&[]));
305        assert!(!cfg.servers.contains_key("bad"));
306        assert!(cfg.servers.contains_key("good"));
307        assert_eq!(diags.len(), 1);
308        assert!(diags[0].contains("bad"));
309        assert!(diags[0].contains("MISSING"));
310    }
311}
312
313/// Merge a project config over a global config. Project may add servers
314/// or toggle `enabled` on existing ones, but may NOT override transport,
315/// command, args, url, env, or headers on a globally-defined server
316/// (security boundary — see design spec §5.5).
317pub fn merge(mut global: McpConfig, project: McpConfig) -> Result<McpConfig, String> {
318    for (name, proj_server) in project.servers {
319        if let Some(global_server) = global.servers.get(&name) {
320            // Override allowed: only `enabled`. Everything else must match.
321            match (global_server, &proj_server) {
322                (
323                    McpServerConfig::Stdio {
324                        command: gc,
325                        args: ga,
326                        env: ge,
327                        startup_timeout_ms: gt,
328                        enabled: _,
329                    },
330                    McpServerConfig::Stdio {
331                        command: pc,
332                        args: pa,
333                        env: pe,
334                        startup_timeout_ms: pt,
335                        enabled: _,
336                    },
337                ) => {
338                    if gc != pc {
339                        return Err(format!("project mcp.toml may not override `command` for global server `{name}`"));
340                    }
341                    if ga != pa {
342                        return Err(format!(
343                            "project mcp.toml may not override `args` for global server `{name}`"
344                        ));
345                    }
346                    if ge != pe {
347                        return Err(format!(
348                            "project mcp.toml may not override `env` for global server `{name}`"
349                        ));
350                    }
351                    if gt != pt {
352                        return Err(format!("project mcp.toml may not override `startup_timeout_ms` for global server `{name}`"));
353                    }
354                }
355                (
356                    McpServerConfig::Http {
357                        url: gu,
358                        headers: gh,
359                        startup_timeout_ms: gt,
360                        enabled: _,
361                    },
362                    McpServerConfig::Http {
363                        url: pu,
364                        headers: ph,
365                        startup_timeout_ms: pt,
366                        enabled: _,
367                    },
368                ) => {
369                    if gu != pu {
370                        return Err(format!(
371                            "project mcp.toml may not override `url` for global server `{name}`"
372                        ));
373                    }
374                    if gh != ph {
375                        return Err(format!("project mcp.toml may not override `headers` for global server `{name}`"));
376                    }
377                    if gt != pt {
378                        return Err(format!("project mcp.toml may not override `startup_timeout_ms` for global server `{name}`"));
379                    }
380                }
381                _ => {
382                    return Err(format!(
383                        "project mcp.toml may not change `transport` for global server `{name}`"
384                    ))
385                }
386            }
387            // Only `enabled` differs (or nothing); keep the project value to allow toggling.
388            global.servers.insert(name, proj_server);
389        } else {
390            global.servers.insert(name, proj_server);
391        }
392    }
393    Ok(global)
394}
395
396#[cfg(test)]
397mod merge_tests {
398    use super::*;
399    use std::collections::HashMap;
400
401    fn stdio(cmd: &str) -> McpServerConfig {
402        McpServerConfig::Stdio {
403            command: cmd.into(),
404            args: vec![],
405            env: HashMap::new(),
406            startup_timeout_ms: 10_000,
407            enabled: true,
408        }
409    }
410
411    fn stdio_with_enabled(cmd: &str, enabled: bool) -> McpServerConfig {
412        let mut s = stdio(cmd);
413        if let McpServerConfig::Stdio { enabled: e, .. } = &mut s {
414            *e = enabled;
415        }
416        s
417    }
418
419    fn cfg(entries: &[(&str, McpServerConfig)]) -> McpConfig {
420        let mut c = McpConfig::default();
421        for (k, v) in entries {
422            c.servers.insert((*k).into(), v.clone());
423        }
424        c
425    }
426
427    #[test]
428    fn project_adds_new_server() {
429        let merged = merge(cfg(&[("a", stdio("ca"))]), cfg(&[("b", stdio("cb"))])).unwrap();
430        assert!(merged.servers.contains_key("a"));
431        assert!(merged.servers.contains_key("b"));
432    }
433
434    #[test]
435    fn project_can_disable_global_server() {
436        let project = cfg(&[("a", stdio_with_enabled("ca", false))]);
437        let merged = merge(cfg(&[("a", stdio("ca"))]), project).unwrap();
438        if let McpServerConfig::Stdio { enabled, .. } = merged.servers.get("a").unwrap() {
439            assert!(!*enabled);
440        }
441    }
442
443    #[test]
444    fn project_can_reenable_disabled_global_server() {
445        let global = cfg(&[("a", stdio_with_enabled("ca", false))]);
446        let merged = merge(global, cfg(&[("a", stdio_with_enabled("ca", true))])).unwrap();
447        if let McpServerConfig::Stdio { enabled, .. } = merged.servers.get("a").unwrap() {
448            assert!(*enabled);
449        }
450    }
451
452    #[test]
453    fn project_cannot_override_command() {
454        let err = merge(
455            cfg(&[("a", stdio("ca"))]),
456            cfg(&[("a", stdio("DIFFERENT"))]),
457        )
458        .unwrap_err();
459        assert!(err.contains("command"), "{err}");
460        assert!(err.contains("`a`"), "{err}");
461    }
462
463    #[test]
464    fn project_cannot_override_transport() {
465        let http = McpServerConfig::Http {
466            url: "x".into(),
467            headers: HashMap::new(),
468            startup_timeout_ms: 10_000,
469            enabled: true,
470        };
471        let err = merge(cfg(&[("a", stdio("ca"))]), cfg(&[("a", http)])).unwrap_err();
472        assert!(err.contains("transport"), "{err}");
473    }
474}
475
476/// Load merged MCP config: `<agent_dir>/mcp.toml` (global) overlaid by
477/// `<cwd>/.capo/mcp.toml` (project). Env vars in `env`/`headers` are
478/// resolved from `lookup`; servers with missing vars are dropped with
479/// diagnostics. Missing files are treated as empty configs.
480pub fn load_config<F>(
481    cwd: &Path,
482    agent_dir: &Path,
483    lookup: &F,
484) -> Result<(McpConfig, Vec<String>), String>
485where
486    F: Fn(&str) -> Option<String>,
487{
488    let global = read_or_default(&agent_dir.join("mcp.toml"))?;
489    let merged = read_project_overlay(&cwd.join(".capo").join("mcp.toml"), global)?;
490    let (resolved, diags) = resolve_env(merged, lookup);
491    Ok((resolved, diags))
492}
493
494fn read_or_default(path: &Path) -> Result<McpConfig, String> {
495    if !path.exists() {
496        return Ok(McpConfig::default());
497    }
498    let raw = std::fs::read_to_string(path)
499        .map_err(|e| format!("read {} failed: {e}", path.display()))?;
500    parse_str(&raw).map_err(|e| format!("parse {} failed: {e}", path.display()))
501}
502
503fn read_project_overlay(path: &Path, mut global: McpConfig) -> Result<McpConfig, String> {
504    if !path.exists() {
505        return Ok(global);
506    }
507    let raw = std::fs::read_to_string(path)
508        .map_err(|e| format!("read {} failed: {e}", path.display()))?;
509    let project: ProjectMcpConfig =
510        toml::from_str(&raw).map_err(|e| format!("parse {} failed: {e}", path.display()))?;
511
512    let mut full_project = McpConfig::default();
513    for (name, raw_server) in project.servers {
514        if raw_server.is_enabled_only() {
515            let enabled = raw_server.enabled.unwrap_or(true);
516            let Some(existing) = global.servers.get_mut(&name) else {
517                return Err(format!(
518                    "project mcp.toml cannot define enabled-only server `{name}` without a global server"
519                ));
520            };
521            set_enabled(existing, enabled);
522            continue;
523        }
524        full_project
525            .servers
526            .insert(name.clone(), raw_server.into_server_config(&name)?);
527    }
528    merge(global, full_project)
529}
530
531#[derive(Debug, Default, Deserialize)]
532struct ProjectMcpConfig {
533    #[serde(default)]
534    servers: HashMap<String, ProjectMcpServerConfig>,
535}
536
537#[derive(Debug, Default, Deserialize)]
538struct ProjectMcpServerConfig {
539    transport: Option<String>,
540    command: Option<String>,
541    #[serde(default)]
542    args: Option<Vec<String>>,
543    #[serde(default)]
544    env: Option<HashMap<String, String>>,
545    url: Option<String>,
546    #[serde(default)]
547    headers: Option<HashMap<String, String>>,
548    startup_timeout_ms: Option<u64>,
549    enabled: Option<bool>,
550}
551
552impl ProjectMcpServerConfig {
553    fn is_enabled_only(&self) -> bool {
554        self.enabled.is_some()
555            && self.transport.is_none()
556            && self.command.is_none()
557            && self.args.is_none()
558            && self.env.is_none()
559            && self.url.is_none()
560            && self.headers.is_none()
561            && self.startup_timeout_ms.is_none()
562    }
563
564    fn into_server_config(self, name: &str) -> Result<McpServerConfig, String> {
565        match self.transport.as_deref() {
566            Some("stdio") => Ok(McpServerConfig::Stdio {
567                command: self.command.ok_or_else(|| {
568                    format!("project mcp.toml stdio server `{name}` missing `command`")
569                })?,
570                args: self.args.unwrap_or_default(),
571                env: self.env.unwrap_or_default(),
572                startup_timeout_ms: self.startup_timeout_ms.unwrap_or_else(default_timeout_ms),
573                enabled: self.enabled.unwrap_or_else(default_enabled),
574            }),
575            Some("http") => Ok(McpServerConfig::Http {
576                url: self.url.ok_or_else(|| {
577                    format!("project mcp.toml http server `{name}` missing `url`")
578                })?,
579                headers: self.headers.unwrap_or_default(),
580                startup_timeout_ms: self.startup_timeout_ms.unwrap_or_else(default_timeout_ms),
581                enabled: self.enabled.unwrap_or_else(default_enabled),
582            }),
583            Some(other) => Err(format!(
584                "project mcp.toml server `{name}` has unknown transport `{other}`"
585            )),
586            None => Err(format!(
587                "project mcp.toml server `{name}` missing `transport`"
588            )),
589        }
590    }
591}
592
593fn set_enabled(server: &mut McpServerConfig, value: bool) {
594    match server {
595        McpServerConfig::Stdio { enabled, .. } | McpServerConfig::Http { enabled, .. } => {
596            *enabled = value;
597        }
598    }
599}