use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use super::transport::TransportConfig;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub command: Option<String>,
#[serde(default)]
pub args: Vec<String>,
#[serde(default)]
pub env: HashMap<String, String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub url: Option<String>,
#[serde(default = "default_timeout")]
pub timeout_ms: u64,
#[serde(default = "default_enabled")]
pub enabled: bool,
}
fn default_timeout() -> u64 {
30000
}
fn default_enabled() -> bool {
true
}
impl McpServerConfig {
pub fn stdio(command: impl Into<String>, args: Vec<String>) -> Self {
Self {
name: None,
command: Some(command.into()),
args,
env: HashMap::new(),
url: None,
timeout_ms: default_timeout(),
enabled: true,
}
}
pub fn sse(url: impl Into<String>) -> Self {
Self {
name: None,
command: None,
args: Vec::new(),
env: HashMap::new(),
url: Some(url.into()),
timeout_ms: default_timeout(),
enabled: true,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.env.insert(key.into(), value.into());
self
}
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = timeout_ms;
self
}
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub fn to_transport_config(&self) -> Result<TransportConfig> {
if let Some(command) = &self.command {
let env_vec: Vec<(String, String)> = self.env.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
Ok(TransportConfig::Stdio {
command: command.clone(),
args: self.args.clone(),
env: if env_vec.is_empty() { None } else { Some(env_vec) },
})
} else if let Some(url) = &self.url {
Ok(TransportConfig::Sse {
url: url.clone(),
timeout_ms: Some(self.timeout_ms),
})
} else {
Err(anyhow!("MCP server config must have either 'command' or 'url'"))
}
}
pub fn get_name(&self, key: &str) -> String {
self.name.clone().unwrap_or_else(|| key.to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct McpConfig {
#[serde(default)]
pub servers: HashMap<String, McpServerConfig>,
#[serde(default)]
pub settings: McpSettings,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpSettings {
#[serde(default = "default_auto_discover")]
pub auto_discover: bool,
#[serde(default = "default_connect_timeout")]
pub connect_timeout_ms: u64,
}
fn default_auto_discover() -> bool {
true
}
fn default_connect_timeout() -> u64 {
10000
}
impl Default for McpSettings {
fn default() -> Self {
Self {
auto_discover: default_auto_discover(),
connect_timeout_ms: default_connect_timeout(),
}
}
}
impl McpConfig {
pub fn new() -> Self {
Self::default()
}
pub fn add_server(mut self, key: impl Into<String>, config: McpServerConfig) -> Self {
self.servers.insert(key.into(), config);
self
}
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = std::fs::read_to_string(path.as_ref())
.map_err(|e| anyhow!("Failed to read MCP config file: {}", e))?;
Self::from_str(&content)
}
pub fn from_str(content: &str) -> Result<Self> {
if let Ok(config) = toml::from_str(content) {
return Ok(config);
}
if let Ok(config) = serde_json::from_str(content) {
return Ok(config);
}
Err(anyhow!("Failed to parse MCP config as TOML or JSON"))
}
pub fn to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let content = toml::to_string_pretty(self)
.map_err(|e| anyhow!("Failed to serialize MCP config: {}", e))?;
std::fs::write(path.as_ref(), content)
.map_err(|e| anyhow!("Failed to write MCP config file: {}", e))?;
Ok(())
}
pub fn enabled_servers(&self) -> Vec<(String, &McpServerConfig)> {
self.servers
.iter()
.filter(|(_, config)| config.enabled)
.map(|(key, config)| (key.clone(), config))
.collect()
}
}
pub fn playwright_config() -> McpConfig {
McpConfig::new()
.add_server("playwright", McpServerConfig::stdio(
"npx",
vec!["-y".into(), "@playwright/mcp@latest".into()]
))
}
pub fn default_mcp_config() -> McpConfig {
McpConfig::new()
.add_server("playwright", McpServerConfig::stdio(
"npx",
vec!["-y".into(), "@playwright/mcp@latest".into()]
))
}
pub const MCP_CONFIG_FILENAMES: &[&str] = &[
"mcp.toml",
"mcp.json",
".mcp.toml",
".mcp.json",
];
pub fn find_mcp_config(start_dir: &Path) -> Option<std::path::PathBuf> {
for filename in MCP_CONFIG_FILENAMES {
let path = start_dir.join(filename);
if path.exists() {
return Some(path);
}
}
if let Some(home) = dirs::home_dir() {
for filename in MCP_CONFIG_FILENAMES {
let path = home.join(filename);
if path.exists() {
return Some(path);
}
}
}
None
}
pub fn load_mcp_config(start_dir: &Path) -> McpConfig {
if let Some(path) = find_mcp_config(start_dir) {
match McpConfig::from_file(&path) {
Ok(config) => {
tracing::info!("Loaded MCP config from {:?}", path);
return config;
}
Err(e) => {
tracing::warn!("Failed to load MCP config from {:?}: {}", path, e);
}
}
}
McpConfig::new()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_server_config_stdio() {
let config = McpServerConfig::stdio("npx", vec!["-y".into(), "@playwright/mcp".into()]);
assert!(config.command.is_some());
assert!(config.url.is_none());
assert!(config.enabled);
let transport = config.to_transport_config().unwrap();
match transport {
TransportConfig::Stdio { command, args, .. } => {
assert_eq!(command, "npx");
assert_eq!(args.len(), 2);
}
_ => panic!("Expected Stdio transport"),
}
}
#[test]
fn test_server_config_sse() {
let config = McpServerConfig::sse("http://localhost:3000");
assert!(config.command.is_none());
assert!(config.url.is_some());
let transport = config.to_transport_config().unwrap();
match transport {
TransportConfig::Sse { url, .. } => {
assert_eq!(url, "http://localhost:3000");
}
_ => panic!("Expected SSE transport"),
}
}
#[test]
fn test_config_serialization() {
let config = McpConfig::new()
.add_server("playwright", McpServerConfig::stdio(
"npx",
vec!["-y".into(), "@playwright/mcp".into()]
));
let toml = toml::to_string(&config).unwrap();
assert!(toml.contains("[servers.playwright]"));
let parsed: McpConfig = toml::from_str(&toml).unwrap();
assert!(parsed.servers.contains_key("playwright"));
}
#[test]
fn test_playwright_config() {
let config = playwright_config();
assert!(config.servers.contains_key("playwright"));
let server = &config.servers["playwright"];
assert_eq!(server.command, Some("npx".to_string()));
}
}