use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpConfig {
pub servers: Vec<McpServerConfig>,
#[serde(default)]
pub settings: McpSettings,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerConfig {
pub id: String,
pub name: String,
#[serde(flatten)]
pub transport: McpTransportConfig,
pub description: Option<String>,
#[serde(default)]
pub auto_approve: bool,
pub working_directory: Option<PathBuf>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "protocol", rename_all = "lowercase")]
pub enum McpTransportConfig {
Stdio {
command: String,
#[serde(default)]
args: Vec<String>,
#[serde(default)]
env: HashMap<String, String>,
},
Sse { url: String },
Http { url: String },
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct McpSettings {
#[serde(default = "default_max_retries")]
pub max_retries: usize,
#[serde(default)]
pub debug: bool,
}
fn default_max_retries() -> usize {
3
}
impl McpConfig {
pub fn from_yaml_file(path: &std::path::Path) -> anyhow::Result<Self> {
let contents = std::fs::read_to_string(path)?;
let config: McpConfig = serde_yaml::from_str(&contents)?;
config.validate()?;
Ok(config)
}
pub fn from_json_file(path: &std::path::Path) -> anyhow::Result<Self> {
let contents = std::fs::read_to_string(path)?;
let config: McpConfig = serde_json::from_str(&contents)?;
config.validate()?;
Ok(config)
}
pub fn from_toml_file(path: &std::path::Path) -> anyhow::Result<Self> {
let contents = std::fs::read_to_string(path)?;
let config: McpConfig = toml::from_str(&contents)?;
config.validate()?;
Ok(config)
}
pub fn from_file(path: &std::path::Path) -> anyhow::Result<Self> {
match path.extension().and_then(|s| s.to_str()) {
Some("yaml") | Some("yml") => Self::from_yaml_file(path),
Some("json") => Self::from_json_file(path),
Some("toml") => Self::from_toml_file(path),
_ => Err(anyhow::anyhow!(
"Unsupported config file format. Use .yaml, .yml, .json, or .toml"
)),
}
}
fn validate(&self) -> anyhow::Result<()> {
let mut seen_ids = std::collections::HashSet::new();
for server in &self.servers {
if !seen_ids.insert(&server.id) {
return Err(anyhow::anyhow!("Duplicate server ID found: {}", server.id));
}
}
Ok(())
}
}