#![cfg_attr(test, allow(clippy::expect_used, clippy::unwrap_used))]
use std::collections::HashMap;
use std::path::Path;
use serde::Deserialize;
#[derive(Debug, Clone, Default, Deserialize, PartialEq)]
pub struct McpConfig {
#[serde(default)]
pub servers: HashMap<String, McpServerConfig>,
}
#[derive(Debug, Clone, Deserialize, PartialEq)]
#[serde(tag = "transport", rename_all = "snake_case")]
pub enum McpServerConfig {
Stdio {
command: String,
#[serde(default)]
args: Vec<String>,
#[serde(default)]
env: HashMap<String, String>,
#[serde(default = "default_timeout_ms")]
startup_timeout_ms: u64,
#[serde(default = "default_enabled")]
enabled: bool,
},
Http {
url: String,
#[serde(default)]
headers: HashMap<String, String>,
#[serde(default = "default_timeout_ms")]
startup_timeout_ms: u64,
#[serde(default = "default_enabled")]
enabled: bool,
},
}
fn default_timeout_ms() -> u64 {
10_000
}
fn default_enabled() -> bool {
true
}
pub fn parse_str(raw: &str) -> Result<McpConfig, toml::de::Error> {
toml::from_str(raw)
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn parses_stdio_server() {
let raw = r#"
[servers.github]
transport = "stdio"
command = "github-mcp-server"
args = ["--scope", "read-only"]
env = { GITHUB_TOKEN = "abc" }
startup_timeout_ms = 5000
"#;
let cfg = parse_str(raw).unwrap();
let s = cfg.servers.get("github").unwrap();
match s {
McpServerConfig::Stdio {
command,
args,
env,
startup_timeout_ms,
enabled,
} => {
assert_eq!(command, "github-mcp-server");
assert_eq!(args, &vec!["--scope".to_string(), "read-only".into()]);
assert_eq!(env.get("GITHUB_TOKEN").map(String::as_str), Some("abc"));
assert_eq!(*startup_timeout_ms, 5000);
assert!(*enabled);
}
_ => panic!("expected stdio"),
}
}
#[test]
fn parses_http_server() {
let raw = r#"
[servers.sentry]
transport = "http"
url = "https://mcp.sentry.io/v1"
headers = { Authorization = "Bearer t" }
"#;
let cfg = parse_str(raw).unwrap();
match cfg.servers.get("sentry").unwrap() {
McpServerConfig::Http {
url,
headers,
startup_timeout_ms,
enabled,
} => {
assert_eq!(url, "https://mcp.sentry.io/v1");
assert_eq!(
headers.get("Authorization").map(String::as_str),
Some("Bearer t")
);
assert_eq!(*startup_timeout_ms, 10_000);
assert!(*enabled);
}
_ => panic!("expected http"),
}
}
#[test]
fn unknown_transport_errors() {
let raw = r#"
[servers.x]
transport = "carrier-pigeon"
url = "..."
"#;
assert!(parse_str(raw).is_err());
}
}
pub fn expand_env<F>(s: &str, lookup: &F) -> Result<String, String>
where
F: Fn(&str) -> Option<String>,
{
let bytes = s.as_bytes();
let mut out = String::with_capacity(s.len());
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'$' {
if i + 1 < bytes.len() && bytes[i + 1] == b'$' {
out.push('$');
i += 2;
continue;
}
if i + 1 < bytes.len() && bytes[i + 1] == b'{' {
if let Some(close) = s[i + 2..].find('}') {
let var = &s[i + 2..i + 2 + close];
match lookup(var) {
Some(v) => {
out.push_str(&v);
i = i + 2 + close + 1;
continue;
}
None => return Err(var.to_string()),
}
}
}
}
if let Some(ch) = s[i..].chars().next() {
out.push(ch);
i += ch.len_utf8();
} else {
break;
}
}
Ok(out)
}
#[cfg(test)]
mod expand_env_tests {
use super::*;
fn lk(values: &'static [(&'static str, &'static str)]) -> impl Fn(&str) -> Option<String> {
let map: std::collections::HashMap<&str, &str> = values.iter().copied().collect();
move |k| map.get(k).map(|v| (*v).to_string())
}
#[test]
fn substitutes_single_var() {
let s = expand_env("Bearer ${TOK}", &lk(&[("TOK", "abc")])).unwrap();
assert_eq!(s, "Bearer abc");
}
#[test]
fn substitutes_multiple_vars() {
let s = expand_env("${A}-${B}", &lk(&[("A", "x"), ("B", "y")])).unwrap();
assert_eq!(s, "x-y");
}
#[test]
fn missing_var_errors_with_var_name() {
let err = expand_env("${MISSING}", &lk(&[])).unwrap_err();
assert_eq!(err, "MISSING");
}
#[test]
fn dollar_dollar_is_literal() {
let s = expand_env("price: $$5", &lk(&[])).unwrap();
assert_eq!(s, "price: $5");
}
#[test]
fn unclosed_dollar_brace_is_literal() {
let s = expand_env("oops ${INCOMPLETE", &lk(&[])).unwrap();
assert_eq!(s, "oops ${INCOMPLETE");
}
#[test]
fn empty_string_is_empty() {
assert_eq!(expand_env("", &lk(&[])).unwrap(), "");
}
#[test]
fn preserves_non_ascii_text() {
let s = expand_env("Bearer café-${TOK}", &lk(&[("TOK", "🔑")])).unwrap();
assert_eq!(s, "Bearer café-🔑");
}
}
pub fn resolve_env<F>(mut cfg: McpConfig, lookup: &F) -> (McpConfig, Vec<String>)
where
F: Fn(&str) -> Option<String>,
{
let mut diags = Vec::new();
cfg.servers.retain(|name, server| {
match server {
McpServerConfig::Stdio { env, .. } => {
for (k, v) in env.iter_mut() {
match expand_env(v, lookup) {
Ok(new) => *v = new,
Err(var) => {
diags.push(format!(
"server `{name}` env `{k}` references unset `${{{var}}}`; server skipped"
));
return false;
}
}
}
true
}
McpServerConfig::Http { headers, .. } => {
for (k, v) in headers.iter_mut() {
match expand_env(v, lookup) {
Ok(new) => *v = new,
Err(var) => {
diags.push(format!(
"server `{name}` header `{k}` references unset `${{{var}}}`; server skipped"
));
return false;
}
}
}
true
}
}
});
(cfg, diags)
}
#[cfg(test)]
mod resolve_env_tests {
use super::*;
fn lk(values: &'static [(&'static str, &'static str)]) -> impl Fn(&str) -> Option<String> {
let map: std::collections::HashMap<&str, &str> = values.iter().copied().collect();
move |k| map.get(k).map(|v| (*v).to_string())
}
#[test]
fn resolves_stdio_env() {
let cfg = parse_str(
r#"
[servers.x]
transport = "stdio"
command = "c"
env = { TOK = "${T}" }
"#,
)
.unwrap();
let (cfg, diags) = resolve_env(cfg, &lk(&[("T", "value")]));
assert!(diags.is_empty());
if let McpServerConfig::Stdio { env, .. } = cfg.servers.get("x").unwrap() {
assert_eq!(env.get("TOK").map(String::as_str), Some("value"));
} else {
panic!()
}
}
#[test]
fn drops_server_with_missing_env_var() {
let cfg = parse_str(
r#"
[servers.bad]
transport = "stdio"
command = "c"
env = { TOK = "${MISSING}" }
[servers.good]
transport = "stdio"
command = "c"
"#,
)
.unwrap();
let (cfg, diags) = resolve_env(cfg, &lk(&[]));
assert!(!cfg.servers.contains_key("bad"));
assert!(cfg.servers.contains_key("good"));
assert_eq!(diags.len(), 1);
assert!(diags[0].contains("bad"));
assert!(diags[0].contains("MISSING"));
}
}
pub fn merge(mut global: McpConfig, project: McpConfig) -> Result<McpConfig, String> {
for (name, proj_server) in project.servers {
if let Some(global_server) = global.servers.get(&name) {
match (global_server, &proj_server) {
(
McpServerConfig::Stdio {
command: gc,
args: ga,
env: ge,
startup_timeout_ms: gt,
enabled: _,
},
McpServerConfig::Stdio {
command: pc,
args: pa,
env: pe,
startup_timeout_ms: pt,
enabled: _,
},
) => {
if gc != pc {
return Err(format!("project mcp.toml may not override `command` for global server `{name}`"));
}
if ga != pa {
return Err(format!(
"project mcp.toml may not override `args` for global server `{name}`"
));
}
if ge != pe {
return Err(format!(
"project mcp.toml may not override `env` for global server `{name}`"
));
}
if gt != pt {
return Err(format!("project mcp.toml may not override `startup_timeout_ms` for global server `{name}`"));
}
}
(
McpServerConfig::Http {
url: gu,
headers: gh,
startup_timeout_ms: gt,
enabled: _,
},
McpServerConfig::Http {
url: pu,
headers: ph,
startup_timeout_ms: pt,
enabled: _,
},
) => {
if gu != pu {
return Err(format!(
"project mcp.toml may not override `url` for global server `{name}`"
));
}
if gh != ph {
return Err(format!("project mcp.toml may not override `headers` for global server `{name}`"));
}
if gt != pt {
return Err(format!("project mcp.toml may not override `startup_timeout_ms` for global server `{name}`"));
}
}
_ => {
return Err(format!(
"project mcp.toml may not change `transport` for global server `{name}`"
))
}
}
global.servers.insert(name, proj_server);
} else {
global.servers.insert(name, proj_server);
}
}
Ok(global)
}
#[cfg(test)]
mod merge_tests {
use super::*;
use std::collections::HashMap;
fn stdio(cmd: &str) -> McpServerConfig {
McpServerConfig::Stdio {
command: cmd.into(),
args: vec![],
env: HashMap::new(),
startup_timeout_ms: 10_000,
enabled: true,
}
}
fn stdio_with_enabled(cmd: &str, enabled: bool) -> McpServerConfig {
let mut s = stdio(cmd);
if let McpServerConfig::Stdio { enabled: e, .. } = &mut s {
*e = enabled;
}
s
}
fn cfg(entries: &[(&str, McpServerConfig)]) -> McpConfig {
let mut c = McpConfig::default();
for (k, v) in entries {
c.servers.insert((*k).into(), v.clone());
}
c
}
#[test]
fn project_adds_new_server() {
let merged = merge(cfg(&[("a", stdio("ca"))]), cfg(&[("b", stdio("cb"))])).unwrap();
assert!(merged.servers.contains_key("a"));
assert!(merged.servers.contains_key("b"));
}
#[test]
fn project_can_disable_global_server() {
let project = cfg(&[("a", stdio_with_enabled("ca", false))]);
let merged = merge(cfg(&[("a", stdio("ca"))]), project).unwrap();
if let McpServerConfig::Stdio { enabled, .. } = merged.servers.get("a").unwrap() {
assert!(!*enabled);
}
}
#[test]
fn project_can_reenable_disabled_global_server() {
let global = cfg(&[("a", stdio_with_enabled("ca", false))]);
let merged = merge(global, cfg(&[("a", stdio_with_enabled("ca", true))])).unwrap();
if let McpServerConfig::Stdio { enabled, .. } = merged.servers.get("a").unwrap() {
assert!(*enabled);
}
}
#[test]
fn project_cannot_override_command() {
let err = merge(
cfg(&[("a", stdio("ca"))]),
cfg(&[("a", stdio("DIFFERENT"))]),
)
.unwrap_err();
assert!(err.contains("command"), "{err}");
assert!(err.contains("`a`"), "{err}");
}
#[test]
fn project_cannot_override_transport() {
let http = McpServerConfig::Http {
url: "x".into(),
headers: HashMap::new(),
startup_timeout_ms: 10_000,
enabled: true,
};
let err = merge(cfg(&[("a", stdio("ca"))]), cfg(&[("a", http)])).unwrap_err();
assert!(err.contains("transport"), "{err}");
}
}
pub fn load_config<F>(
cwd: &Path,
agent_dir: &Path,
lookup: &F,
) -> Result<(McpConfig, Vec<String>), String>
where
F: Fn(&str) -> Option<String>,
{
let global = read_or_default(&agent_dir.join("mcp.toml"))?;
let merged = read_project_overlay(&cwd.join(".capo").join("mcp.toml"), global)?;
let (resolved, diags) = resolve_env(merged, lookup);
Ok((resolved, diags))
}
fn read_or_default(path: &Path) -> Result<McpConfig, String> {
if !path.exists() {
return Ok(McpConfig::default());
}
let raw = std::fs::read_to_string(path)
.map_err(|e| format!("read {} failed: {e}", path.display()))?;
parse_str(&raw).map_err(|e| format!("parse {} failed: {e}", path.display()))
}
fn read_project_overlay(path: &Path, mut global: McpConfig) -> Result<McpConfig, String> {
if !path.exists() {
return Ok(global);
}
let raw = std::fs::read_to_string(path)
.map_err(|e| format!("read {} failed: {e}", path.display()))?;
let project: ProjectMcpConfig =
toml::from_str(&raw).map_err(|e| format!("parse {} failed: {e}", path.display()))?;
let mut full_project = McpConfig::default();
for (name, raw_server) in project.servers {
if raw_server.is_enabled_only() {
let enabled = raw_server.enabled.unwrap_or(true);
let Some(existing) = global.servers.get_mut(&name) else {
return Err(format!(
"project mcp.toml cannot define enabled-only server `{name}` without a global server"
));
};
set_enabled(existing, enabled);
continue;
}
full_project
.servers
.insert(name.clone(), raw_server.into_server_config(&name)?);
}
merge(global, full_project)
}
#[derive(Debug, Default, Deserialize)]
struct ProjectMcpConfig {
#[serde(default)]
servers: HashMap<String, ProjectMcpServerConfig>,
}
#[derive(Debug, Default, Deserialize)]
struct ProjectMcpServerConfig {
transport: Option<String>,
command: Option<String>,
#[serde(default)]
args: Option<Vec<String>>,
#[serde(default)]
env: Option<HashMap<String, String>>,
url: Option<String>,
#[serde(default)]
headers: Option<HashMap<String, String>>,
startup_timeout_ms: Option<u64>,
enabled: Option<bool>,
}
impl ProjectMcpServerConfig {
fn is_enabled_only(&self) -> bool {
self.enabled.is_some()
&& self.transport.is_none()
&& self.command.is_none()
&& self.args.is_none()
&& self.env.is_none()
&& self.url.is_none()
&& self.headers.is_none()
&& self.startup_timeout_ms.is_none()
}
fn into_server_config(self, name: &str) -> Result<McpServerConfig, String> {
match self.transport.as_deref() {
Some("stdio") => Ok(McpServerConfig::Stdio {
command: self.command.ok_or_else(|| {
format!("project mcp.toml stdio server `{name}` missing `command`")
})?,
args: self.args.unwrap_or_default(),
env: self.env.unwrap_or_default(),
startup_timeout_ms: self.startup_timeout_ms.unwrap_or_else(default_timeout_ms),
enabled: self.enabled.unwrap_or_else(default_enabled),
}),
Some("http") => Ok(McpServerConfig::Http {
url: self.url.ok_or_else(|| {
format!("project mcp.toml http server `{name}` missing `url`")
})?,
headers: self.headers.unwrap_or_default(),
startup_timeout_ms: self.startup_timeout_ms.unwrap_or_else(default_timeout_ms),
enabled: self.enabled.unwrap_or_else(default_enabled),
}),
Some(other) => Err(format!(
"project mcp.toml server `{name}` has unknown transport `{other}`"
)),
None => Err(format!(
"project mcp.toml server `{name}` missing `transport`"
)),
}
}
}
fn set_enabled(server: &mut McpServerConfig, value: bool) {
match server {
McpServerConfig::Stdio { enabled, .. } | McpServerConfig::Http { enabled, .. } => {
*enabled = value;
}
}
}