1#![cfg_attr(test, allow(clippy::expect_used, clippy::unwrap_used))]
2
3use std::collections::HashMap;
4use std::path::Path;
5
6use serde::Deserialize;
7
8#[derive(Debug, Clone, Default, Deserialize, PartialEq)]
9pub struct McpConfig {
10 #[serde(default)]
11 pub servers: HashMap<String, McpServerConfig>,
12}
13
14#[derive(Debug, Clone, Deserialize, PartialEq)]
15#[serde(tag = "transport", rename_all = "snake_case")]
16pub enum McpServerConfig {
17 Stdio {
18 command: String,
19 #[serde(default)]
20 args: Vec<String>,
21 #[serde(default)]
22 env: HashMap<String, String>,
23 #[serde(default = "default_timeout_ms")]
24 startup_timeout_ms: u64,
25 #[serde(default = "default_enabled")]
26 enabled: bool,
27 },
28 Http {
29 url: String,
30 #[serde(default)]
31 headers: HashMap<String, String>,
32 #[serde(default = "default_timeout_ms")]
33 startup_timeout_ms: u64,
34 #[serde(default = "default_enabled")]
35 enabled: bool,
36 },
37}
38
39fn default_timeout_ms() -> u64 {
40 10_000
41}
42fn default_enabled() -> bool {
43 true
44}
45
46pub fn parse_str(raw: &str) -> Result<McpConfig, toml::de::Error> {
47 toml::from_str(raw)
48}
49
50#[cfg(test)]
51mod tests {
52 use super::*;
53 use pretty_assertions::assert_eq;
54
55 #[test]
56 fn parses_stdio_server() {
57 let raw = r#"
58 [servers.github]
59 transport = "stdio"
60 command = "github-mcp-server"
61 args = ["--scope", "read-only"]
62 env = { GITHUB_TOKEN = "abc" }
63 startup_timeout_ms = 5000
64 "#;
65 let cfg = parse_str(raw).unwrap();
66 let s = cfg.servers.get("github").unwrap();
67 match s {
68 McpServerConfig::Stdio {
69 command,
70 args,
71 env,
72 startup_timeout_ms,
73 enabled,
74 } => {
75 assert_eq!(command, "github-mcp-server");
76 assert_eq!(args, &vec!["--scope".to_string(), "read-only".into()]);
77 assert_eq!(env.get("GITHUB_TOKEN").map(String::as_str), Some("abc"));
78 assert_eq!(*startup_timeout_ms, 5000);
79 assert!(*enabled);
80 }
81 _ => panic!("expected stdio"),
82 }
83 }
84
85 #[test]
86 fn parses_http_server() {
87 let raw = r#"
88 [servers.sentry]
89 transport = "http"
90 url = "https://mcp.sentry.io/v1"
91 headers = { Authorization = "Bearer t" }
92 "#;
93 let cfg = parse_str(raw).unwrap();
94 match cfg.servers.get("sentry").unwrap() {
95 McpServerConfig::Http {
96 url,
97 headers,
98 startup_timeout_ms,
99 enabled,
100 } => {
101 assert_eq!(url, "https://mcp.sentry.io/v1");
102 assert_eq!(
103 headers.get("Authorization").map(String::as_str),
104 Some("Bearer t")
105 );
106 assert_eq!(*startup_timeout_ms, 10_000);
107 assert!(*enabled);
108 }
109 _ => panic!("expected http"),
110 }
111 }
112
113 #[test]
114 fn unknown_transport_errors() {
115 let raw = r#"
116 [servers.x]
117 transport = "carrier-pigeon"
118 url = "..."
119 "#;
120 assert!(parse_str(raw).is_err());
121 }
122}
123
124pub fn expand_env<F>(s: &str, lookup: &F) -> Result<String, String>
127where
128 F: Fn(&str) -> Option<String>,
129{
130 let bytes = s.as_bytes();
131 let mut out = String::with_capacity(s.len());
132 let mut i = 0;
133 while i < bytes.len() {
134 if bytes[i] == b'$' {
135 if i + 1 < bytes.len() && bytes[i + 1] == b'$' {
137 out.push('$');
138 i += 2;
139 continue;
140 }
141 if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
143 if let Some(close) = s[i + 2..].find('}') {
144 let var = &s[i + 2..i + 2 + close];
145 match lookup(var) {
146 Some(v) => {
147 out.push_str(&v);
148 i = i + 2 + close + 1;
149 continue;
150 }
151 None => return Err(var.to_string()),
152 }
153 }
154 }
156 }
157 if let Some(ch) = s[i..].chars().next() {
158 out.push(ch);
159 i += ch.len_utf8();
160 } else {
161 break;
162 }
163 }
164 Ok(out)
165}
166
167#[cfg(test)]
168mod expand_env_tests {
169 use super::*;
170
171 fn lk(values: &'static [(&'static str, &'static str)]) -> impl Fn(&str) -> Option<String> {
172 let map: std::collections::HashMap<&str, &str> = values.iter().copied().collect();
173 move |k| map.get(k).map(|v| (*v).to_string())
174 }
175
176 #[test]
177 fn substitutes_single_var() {
178 let s = expand_env("Bearer ${TOK}", &lk(&[("TOK", "abc")])).unwrap();
179 assert_eq!(s, "Bearer abc");
180 }
181
182 #[test]
183 fn substitutes_multiple_vars() {
184 let s = expand_env("${A}-${B}", &lk(&[("A", "x"), ("B", "y")])).unwrap();
185 assert_eq!(s, "x-y");
186 }
187
188 #[test]
189 fn missing_var_errors_with_var_name() {
190 let err = expand_env("${MISSING}", &lk(&[])).unwrap_err();
191 assert_eq!(err, "MISSING");
192 }
193
194 #[test]
195 fn dollar_dollar_is_literal() {
196 let s = expand_env("price: $$5", &lk(&[])).unwrap();
197 assert_eq!(s, "price: $5");
198 }
199
200 #[test]
201 fn unclosed_dollar_brace_is_literal() {
202 let s = expand_env("oops ${INCOMPLETE", &lk(&[])).unwrap();
203 assert_eq!(s, "oops ${INCOMPLETE");
204 }
205
206 #[test]
207 fn empty_string_is_empty() {
208 assert_eq!(expand_env("", &lk(&[])).unwrap(), "");
209 }
210
211 #[test]
212 fn preserves_non_ascii_text() {
213 let s = expand_env("Bearer café-${TOK}", &lk(&[("TOK", "🔑")])).unwrap();
214 assert_eq!(s, "Bearer café-🔑");
215 }
216}
217
218pub fn resolve_env<F>(mut cfg: McpConfig, lookup: &F) -> (McpConfig, Vec<String>)
222where
223 F: Fn(&str) -> Option<String>,
224{
225 let mut diags = Vec::new();
226 cfg.servers.retain(|name, server| {
227 match server {
228 McpServerConfig::Stdio { env, .. } => {
229 for (k, v) in env.iter_mut() {
230 match expand_env(v, lookup) {
231 Ok(new) => *v = new,
232 Err(var) => {
233 diags.push(format!(
234 "server `{name}` env `{k}` references unset `${{{var}}}`; server skipped"
235 ));
236 return false;
237 }
238 }
239 }
240 true
241 }
242 McpServerConfig::Http { headers, .. } => {
243 for (k, v) in headers.iter_mut() {
244 match expand_env(v, lookup) {
245 Ok(new) => *v = new,
246 Err(var) => {
247 diags.push(format!(
248 "server `{name}` header `{k}` references unset `${{{var}}}`; server skipped"
249 ));
250 return false;
251 }
252 }
253 }
254 true
255 }
256 }
257 });
258 (cfg, diags)
259}
260
261#[cfg(test)]
262mod resolve_env_tests {
263 use super::*;
264
265 fn lk(values: &'static [(&'static str, &'static str)]) -> impl Fn(&str) -> Option<String> {
266 let map: std::collections::HashMap<&str, &str> = values.iter().copied().collect();
267 move |k| map.get(k).map(|v| (*v).to_string())
268 }
269
270 #[test]
271 fn resolves_stdio_env() {
272 let cfg = parse_str(
273 r#"
274 [servers.x]
275 transport = "stdio"
276 command = "c"
277 env = { TOK = "${T}" }
278 "#,
279 )
280 .unwrap();
281 let (cfg, diags) = resolve_env(cfg, &lk(&[("T", "value")]));
282 assert!(diags.is_empty());
283 if let McpServerConfig::Stdio { env, .. } = cfg.servers.get("x").unwrap() {
284 assert_eq!(env.get("TOK").map(String::as_str), Some("value"));
285 } else {
286 panic!()
287 }
288 }
289
290 #[test]
291 fn drops_server_with_missing_env_var() {
292 let cfg = parse_str(
293 r#"
294 [servers.bad]
295 transport = "stdio"
296 command = "c"
297 env = { TOK = "${MISSING}" }
298 [servers.good]
299 transport = "stdio"
300 command = "c"
301 "#,
302 )
303 .unwrap();
304 let (cfg, diags) = resolve_env(cfg, &lk(&[]));
305 assert!(!cfg.servers.contains_key("bad"));
306 assert!(cfg.servers.contains_key("good"));
307 assert_eq!(diags.len(), 1);
308 assert!(diags[0].contains("bad"));
309 assert!(diags[0].contains("MISSING"));
310 }
311}
312
313pub fn merge(mut global: McpConfig, project: McpConfig) -> Result<McpConfig, String> {
318 for (name, proj_server) in project.servers {
319 if let Some(global_server) = global.servers.get(&name) {
320 match (global_server, &proj_server) {
322 (
323 McpServerConfig::Stdio {
324 command: gc,
325 args: ga,
326 env: ge,
327 startup_timeout_ms: gt,
328 enabled: _,
329 },
330 McpServerConfig::Stdio {
331 command: pc,
332 args: pa,
333 env: pe,
334 startup_timeout_ms: pt,
335 enabled: _,
336 },
337 ) => {
338 if gc != pc {
339 return Err(format!("project mcp.toml may not override `command` for global server `{name}`"));
340 }
341 if ga != pa {
342 return Err(format!(
343 "project mcp.toml may not override `args` for global server `{name}`"
344 ));
345 }
346 if ge != pe {
347 return Err(format!(
348 "project mcp.toml may not override `env` for global server `{name}`"
349 ));
350 }
351 if gt != pt {
352 return Err(format!("project mcp.toml may not override `startup_timeout_ms` for global server `{name}`"));
353 }
354 }
355 (
356 McpServerConfig::Http {
357 url: gu,
358 headers: gh,
359 startup_timeout_ms: gt,
360 enabled: _,
361 },
362 McpServerConfig::Http {
363 url: pu,
364 headers: ph,
365 startup_timeout_ms: pt,
366 enabled: _,
367 },
368 ) => {
369 if gu != pu {
370 return Err(format!(
371 "project mcp.toml may not override `url` for global server `{name}`"
372 ));
373 }
374 if gh != ph {
375 return Err(format!("project mcp.toml may not override `headers` for global server `{name}`"));
376 }
377 if gt != pt {
378 return Err(format!("project mcp.toml may not override `startup_timeout_ms` for global server `{name}`"));
379 }
380 }
381 _ => {
382 return Err(format!(
383 "project mcp.toml may not change `transport` for global server `{name}`"
384 ))
385 }
386 }
387 global.servers.insert(name, proj_server);
389 } else {
390 global.servers.insert(name, proj_server);
391 }
392 }
393 Ok(global)
394}
395
396#[cfg(test)]
397mod merge_tests {
398 use super::*;
399 use std::collections::HashMap;
400
401 fn stdio(cmd: &str) -> McpServerConfig {
402 McpServerConfig::Stdio {
403 command: cmd.into(),
404 args: vec![],
405 env: HashMap::new(),
406 startup_timeout_ms: 10_000,
407 enabled: true,
408 }
409 }
410
411 fn stdio_with_enabled(cmd: &str, enabled: bool) -> McpServerConfig {
412 let mut s = stdio(cmd);
413 if let McpServerConfig::Stdio { enabled: e, .. } = &mut s {
414 *e = enabled;
415 }
416 s
417 }
418
419 fn cfg(entries: &[(&str, McpServerConfig)]) -> McpConfig {
420 let mut c = McpConfig::default();
421 for (k, v) in entries {
422 c.servers.insert((*k).into(), v.clone());
423 }
424 c
425 }
426
427 #[test]
428 fn project_adds_new_server() {
429 let merged = merge(cfg(&[("a", stdio("ca"))]), cfg(&[("b", stdio("cb"))])).unwrap();
430 assert!(merged.servers.contains_key("a"));
431 assert!(merged.servers.contains_key("b"));
432 }
433
434 #[test]
435 fn project_can_disable_global_server() {
436 let project = cfg(&[("a", stdio_with_enabled("ca", false))]);
437 let merged = merge(cfg(&[("a", stdio("ca"))]), project).unwrap();
438 if let McpServerConfig::Stdio { enabled, .. } = merged.servers.get("a").unwrap() {
439 assert!(!*enabled);
440 }
441 }
442
443 #[test]
444 fn project_can_reenable_disabled_global_server() {
445 let global = cfg(&[("a", stdio_with_enabled("ca", false))]);
446 let merged = merge(global, cfg(&[("a", stdio_with_enabled("ca", true))])).unwrap();
447 if let McpServerConfig::Stdio { enabled, .. } = merged.servers.get("a").unwrap() {
448 assert!(*enabled);
449 }
450 }
451
452 #[test]
453 fn project_cannot_override_command() {
454 let err = merge(
455 cfg(&[("a", stdio("ca"))]),
456 cfg(&[("a", stdio("DIFFERENT"))]),
457 )
458 .unwrap_err();
459 assert!(err.contains("command"), "{err}");
460 assert!(err.contains("`a`"), "{err}");
461 }
462
463 #[test]
464 fn project_cannot_override_transport() {
465 let http = McpServerConfig::Http {
466 url: "x".into(),
467 headers: HashMap::new(),
468 startup_timeout_ms: 10_000,
469 enabled: true,
470 };
471 let err = merge(cfg(&[("a", stdio("ca"))]), cfg(&[("a", http)])).unwrap_err();
472 assert!(err.contains("transport"), "{err}");
473 }
474}
475
476pub fn load_config<F>(
481 cwd: &Path,
482 agent_dir: &Path,
483 lookup: &F,
484) -> Result<(McpConfig, Vec<String>), String>
485where
486 F: Fn(&str) -> Option<String>,
487{
488 let global = read_or_default(&agent_dir.join("mcp.toml"))?;
489 let merged = read_project_overlay(&cwd.join(".capo").join("mcp.toml"), global)?;
490 let (resolved, diags) = resolve_env(merged, lookup);
491 Ok((resolved, diags))
492}
493
494fn read_or_default(path: &Path) -> Result<McpConfig, String> {
495 if !path.exists() {
496 return Ok(McpConfig::default());
497 }
498 let raw = std::fs::read_to_string(path)
499 .map_err(|e| format!("read {} failed: {e}", path.display()))?;
500 parse_str(&raw).map_err(|e| format!("parse {} failed: {e}", path.display()))
501}
502
503fn read_project_overlay(path: &Path, mut global: McpConfig) -> Result<McpConfig, String> {
504 if !path.exists() {
505 return Ok(global);
506 }
507 let raw = std::fs::read_to_string(path)
508 .map_err(|e| format!("read {} failed: {e}", path.display()))?;
509 let project: ProjectMcpConfig =
510 toml::from_str(&raw).map_err(|e| format!("parse {} failed: {e}", path.display()))?;
511
512 let mut full_project = McpConfig::default();
513 for (name, raw_server) in project.servers {
514 if raw_server.is_enabled_only() {
515 let enabled = raw_server.enabled.unwrap_or(true);
516 let Some(existing) = global.servers.get_mut(&name) else {
517 return Err(format!(
518 "project mcp.toml cannot define enabled-only server `{name}` without a global server"
519 ));
520 };
521 set_enabled(existing, enabled);
522 continue;
523 }
524 full_project
525 .servers
526 .insert(name.clone(), raw_server.into_server_config(&name)?);
527 }
528 merge(global, full_project)
529}
530
531#[derive(Debug, Default, Deserialize)]
532struct ProjectMcpConfig {
533 #[serde(default)]
534 servers: HashMap<String, ProjectMcpServerConfig>,
535}
536
537#[derive(Debug, Default, Deserialize)]
538struct ProjectMcpServerConfig {
539 transport: Option<String>,
540 command: Option<String>,
541 #[serde(default)]
542 args: Option<Vec<String>>,
543 #[serde(default)]
544 env: Option<HashMap<String, String>>,
545 url: Option<String>,
546 #[serde(default)]
547 headers: Option<HashMap<String, String>>,
548 startup_timeout_ms: Option<u64>,
549 enabled: Option<bool>,
550}
551
552impl ProjectMcpServerConfig {
553 fn is_enabled_only(&self) -> bool {
554 self.enabled.is_some()
555 && self.transport.is_none()
556 && self.command.is_none()
557 && self.args.is_none()
558 && self.env.is_none()
559 && self.url.is_none()
560 && self.headers.is_none()
561 && self.startup_timeout_ms.is_none()
562 }
563
564 fn into_server_config(self, name: &str) -> Result<McpServerConfig, String> {
565 match self.transport.as_deref() {
566 Some("stdio") => Ok(McpServerConfig::Stdio {
567 command: self.command.ok_or_else(|| {
568 format!("project mcp.toml stdio server `{name}` missing `command`")
569 })?,
570 args: self.args.unwrap_or_default(),
571 env: self.env.unwrap_or_default(),
572 startup_timeout_ms: self.startup_timeout_ms.unwrap_or_else(default_timeout_ms),
573 enabled: self.enabled.unwrap_or_else(default_enabled),
574 }),
575 Some("http") => Ok(McpServerConfig::Http {
576 url: self.url.ok_or_else(|| {
577 format!("project mcp.toml http server `{name}` missing `url`")
578 })?,
579 headers: self.headers.unwrap_or_default(),
580 startup_timeout_ms: self.startup_timeout_ms.unwrap_or_else(default_timeout_ms),
581 enabled: self.enabled.unwrap_or_else(default_enabled),
582 }),
583 Some(other) => Err(format!(
584 "project mcp.toml server `{name}` has unknown transport `{other}`"
585 )),
586 None => Err(format!(
587 "project mcp.toml server `{name}` missing `transport`"
588 )),
589 }
590 }
591}
592
593fn set_enabled(server: &mut McpServerConfig, value: bool) {
594 match server {
595 McpServerConfig::Stdio { enabled, .. } | McpServerConfig::Http { enabled, .. } => {
596 *enabled = value;
597 }
598 }
599}