use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpConfig {
#[serde(rename = "mcp-servers", default)]
pub mcp_servers: HashMap<String, McpServerConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub command: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub args: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub env: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
}
impl McpServerConfig {
pub fn is_stdio(&self) -> bool {
self.command.is_some()
}
pub fn is_http(&self) -> bool {
self.url.is_some()
}
pub fn validate(&self) -> Result<(), String> {
if self.command.is_none() && self.url.is_none() {
return Err("MCP server must have either 'command' or 'url'".to_string());
}
if self.command.is_some() && self.url.is_some() {
return Err("MCP server cannot have both 'command' and 'url'".to_string());
}
Ok(())
}
}
pub fn load_mcp_config(workspace: &Path) -> HashMap<String, McpServerConfig> {
let mut servers = HashMap::new();
if let Some(home) = std::env::var_os("HOME") {
let global_config_path = PathBuf::from(home).join(".sofos/config.toml");
if let Ok(global_servers) = load_mcp_config_from_file(&global_config_path) {
servers.extend(global_servers);
}
}
let local_config_path = workspace.join(".sofos/config.local.toml");
if let Ok(local_servers) = load_mcp_config_from_file(&local_config_path) {
servers.extend(local_servers);
}
servers
}
fn load_mcp_config_from_file(
path: &PathBuf,
) -> Result<HashMap<String, McpServerConfig>, Box<dyn std::error::Error>> {
if !path.exists() {
return Ok(HashMap::new());
}
let content = std::fs::read_to_string(path)?;
let config: McpConfig = toml::from_str(&content)?;
for (name, server_config) in &config.mcp_servers {
if let Err(e) = server_config.validate() {
eprintln!("Warning: Invalid MCP server config for '{}': {}", name, e);
}
}
Ok(config.mcp_servers)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_stdio_server() {
let toml_content = r#"
[mcp-servers.test-server]
command = "/path/to/server"
args = ["--port", "8080"]
env = { "API_KEY" = "secret" }
"#;
let config: McpConfig = toml::from_str(toml_content).unwrap();
let server = config.mcp_servers.get("test-server").unwrap();
assert_eq!(server.command, Some("/path/to/server".to_string()));
assert_eq!(
server.args,
Some(vec!["--port".to_string(), "8080".to_string()])
);
assert!(server.is_stdio());
assert!(!server.is_http());
}
#[test]
fn test_parse_http_server() {
let toml_content = r#"
[mcp-servers.http-server]
url = "https://example.com/mcp"
headers = { "Authorization" = "Bearer token" }
"#;
let config: McpConfig = toml::from_str(toml_content).unwrap();
let server = config.mcp_servers.get("http-server").unwrap();
assert_eq!(server.url, Some("https://example.com/mcp".to_string()));
assert!(!server.is_stdio());
assert!(server.is_http());
}
#[test]
fn test_validation() {
let valid_stdio = McpServerConfig {
command: Some("/path/to/server".to_string()),
args: None,
env: None,
url: None,
headers: None,
};
assert!(valid_stdio.validate().is_ok());
let valid_http = McpServerConfig {
command: None,
args: None,
env: None,
url: Some("https://example.com".to_string()),
headers: None,
};
assert!(valid_http.validate().is_ok());
let invalid_empty = McpServerConfig {
command: None,
args: None,
env: None,
url: None,
headers: None,
};
assert!(invalid_empty.validate().is_err());
let invalid_both = McpServerConfig {
command: Some("/path".to_string()),
args: None,
env: None,
url: Some("https://example.com".to_string()),
headers: None,
};
assert!(invalid_both.validate().is_err());
}
}