use std::collections::HashMap;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use tokio::fs;
use crate::bootstrap::ironclaw_base_dir;
use crate::tools::tool::ToolError;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "transport", rename_all = "lowercase")]
pub enum McpTransportConfig {
Http,
Stdio {
command: String,
#[serde(default)]
args: Vec<String>,
#[serde(default)]
env: HashMap<String, String>,
},
Unix { socket_path: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerConfig {
pub name: String,
pub url: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub transport: Option<McpTransportConfig>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub headers: HashMap<String, String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub oauth: Option<OAuthConfig>,
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
fn default_true() -> bool {
true
}
impl McpServerConfig {
pub fn new(name: impl Into<String>, url: impl Into<String>) -> Self {
Self {
name: name.into(),
url: url.into(),
transport: None,
headers: HashMap::new(),
oauth: None,
enabled: true,
description: None,
}
}
pub fn new_stdio(
name: impl Into<String>,
command: impl Into<String>,
args: Vec<String>,
env: HashMap<String, String>,
) -> Self {
Self {
name: name.into(),
url: String::new(),
transport: Some(McpTransportConfig::Stdio {
command: command.into(),
args,
env,
}),
headers: HashMap::new(),
oauth: None,
enabled: true,
description: None,
}
}
pub fn new_unix(name: impl Into<String>, socket_path: impl Into<String>) -> Self {
Self {
name: name.into(),
url: String::new(),
transport: Some(McpTransportConfig::Unix {
socket_path: socket_path.into(),
}),
headers: HashMap::new(),
oauth: None,
enabled: true,
description: None,
}
}
pub fn with_oauth(mut self, oauth: OAuthConfig) -> Self {
self.oauth = Some(oauth);
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
self.headers = headers;
self
}
pub fn effective_transport(&self) -> EffectiveTransport<'_> {
match &self.transport {
Some(McpTransportConfig::Http) | None => EffectiveTransport::Http,
Some(McpTransportConfig::Stdio { command, args, env }) => {
EffectiveTransport::Stdio { command, args, env }
}
Some(McpTransportConfig::Unix { socket_path }) => {
EffectiveTransport::Unix { socket_path }
}
}
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.name.is_empty() {
return Err(ConfigError::InvalidConfig {
reason: "Server name cannot be empty".to_string(),
});
}
match self.effective_transport() {
EffectiveTransport::Http => {
if self.url.is_empty() {
return Err(ConfigError::InvalidConfig {
reason: "Server URL cannot be empty".to_string(),
});
}
let is_localhost = is_localhost_url(&self.url);
if !is_localhost && !self.url.to_lowercase().starts_with("https://") {
return Err(ConfigError::InvalidConfig {
reason: "Remote MCP servers must use HTTPS".to_string(),
});
}
}
EffectiveTransport::Stdio { command, .. } => {
if command.is_empty() {
return Err(ConfigError::InvalidConfig {
reason: "Stdio transport command cannot be empty".to_string(),
});
}
}
EffectiveTransport::Unix { socket_path } => {
if socket_path.is_empty() {
return Err(ConfigError::InvalidConfig {
reason: "Unix socket path cannot be empty".to_string(),
});
}
}
}
for (name, value) in &self.headers {
if name.is_empty() {
return Err(ConfigError::InvalidConfig {
reason: "Header name cannot be empty".to_string(),
});
}
if reqwest::header::HeaderName::from_bytes(name.as_bytes()).is_err() {
return Err(ConfigError::InvalidConfig {
reason: format!(
"Header name '{}' is not a valid HTTP header name (RFC 9110)",
name
),
});
}
if reqwest::header::HeaderValue::from_str(value).is_err() {
return Err(ConfigError::InvalidConfig {
reason: format!("Header value for '{}' contains invalid characters", name),
});
}
}
Ok(())
}
pub fn has_custom_auth_header(&self) -> bool {
self.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("authorization"))
}
pub fn requires_auth(&self) -> bool {
if !matches!(self.effective_transport(), EffectiveTransport::Http) {
return false;
}
if self.oauth.is_some() {
return true;
}
let url_lower = self.url.to_lowercase();
let is_localhost = is_localhost_url(&url_lower);
url_lower.starts_with("https://") && !is_localhost
}
pub fn token_secret_name(&self) -> String {
format!("mcp_{}_access_token", self.name)
}
pub fn refresh_token_secret_name(&self) -> String {
format!("{}_refresh_token", self.token_secret_name())
}
pub fn legacy_refresh_token_secret_name(&self) -> String {
format!("mcp_{}_refresh_token", self.name)
}
pub fn client_id_secret_name(&self) -> String {
format!("mcp_{}_client_id", self.name)
}
pub fn client_secret_secret_name(&self) -> String {
format!("mcp_{}_client_secret", self.name)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthConfig {
pub client_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub authorization_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_url: Option<String>,
#[serde(default)]
pub scopes: Vec<String>,
#[serde(default = "default_true")]
pub use_pkce: bool,
#[serde(default)]
pub extra_params: HashMap<String, String>,
}
impl OAuthConfig {
pub fn new(client_id: impl Into<String>) -> Self {
Self {
client_id: client_id.into(),
authorization_url: None,
token_url: None,
scopes: Vec::new(),
use_pkce: true,
extra_params: HashMap::new(),
}
}
pub fn with_endpoints(
mut self,
authorization_url: impl Into<String>,
token_url: impl Into<String>,
) -> Self {
self.authorization_url = Some(authorization_url.into());
self.token_url = Some(token_url.into());
self
}
pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
self.scopes = scopes;
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct McpServersFile {
#[serde(default)]
pub servers: Vec<McpServerConfig>,
#[serde(default = "default_schema_version")]
pub schema_version: u32,
}
fn default_schema_version() -> u32 {
1
}
impl McpServersFile {
pub fn get(&self, name: &str) -> Option<&McpServerConfig> {
self.servers.iter().find(|s| s.name == name)
}
pub fn get_mut(&mut self, name: &str) -> Option<&mut McpServerConfig> {
self.servers.iter_mut().find(|s| s.name == name)
}
pub fn upsert(&mut self, config: McpServerConfig) {
if let Some(existing) = self.get_mut(&config.name) {
*existing = config;
} else {
self.servers.push(config);
}
}
pub fn remove(&mut self, name: &str) -> bool {
let len_before = self.servers.len();
self.servers.retain(|s| s.name != name);
self.servers.len() < len_before
}
pub fn enabled_servers(&self) -> impl Iterator<Item = &McpServerConfig> {
self.servers.iter().filter(|s| s.enabled)
}
}
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Invalid configuration: {reason}")]
InvalidConfig { reason: String },
#[error("Server not found: {name}")]
ServerNotFound { name: String },
}
impl From<ConfigError> for ToolError {
fn from(err: ConfigError) -> Self {
ToolError::ExternalService(err.to_string())
}
}
pub fn default_config_path() -> PathBuf {
ironclaw_base_dir().join("mcp-servers.json")
}
pub async fn load_mcp_servers() -> Result<McpServersFile, ConfigError> {
load_mcp_servers_from(default_config_path()).await
}
pub async fn load_mcp_servers_from(path: impl AsRef<Path>) -> Result<McpServersFile, ConfigError> {
let path = path.as_ref();
if !path.exists() {
return Ok(McpServersFile::default());
}
let content = fs::read_to_string(path).await?;
let config: McpServersFile = serde_json::from_str(&content)?;
for server in &config.servers {
server.validate().map_err(|e| ConfigError::InvalidConfig {
reason: format!("Server '{}': {}", server.name, e),
})?;
}
Ok(config)
}
pub async fn save_mcp_servers(config: &McpServersFile) -> Result<(), ConfigError> {
save_mcp_servers_to(config, default_config_path()).await
}
pub async fn save_mcp_servers_to(
config: &McpServersFile,
path: impl AsRef<Path>,
) -> Result<(), ConfigError> {
let path = path.as_ref();
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).await?;
}
let content = serde_json::to_string_pretty(config)?;
let tmp_path = path.with_extension("json.tmp");
fs::write(&tmp_path, content).await?;
fs::rename(&tmp_path, path).await?;
Ok(())
}
pub async fn add_mcp_server(config: McpServerConfig) -> Result<(), ConfigError> {
config.validate()?;
let mut servers = load_mcp_servers().await?;
servers.upsert(config);
save_mcp_servers(&servers).await?;
Ok(())
}
pub async fn remove_mcp_server(name: &str) -> Result<(), ConfigError> {
let mut servers = load_mcp_servers().await?;
if !servers.remove(name) {
return Err(ConfigError::ServerNotFound {
name: name.to_string(),
});
}
save_mcp_servers(&servers).await?;
Ok(())
}
pub async fn get_mcp_server(name: &str) -> Result<McpServerConfig, ConfigError> {
let servers = load_mcp_servers().await?;
servers
.get(name)
.cloned()
.ok_or_else(|| ConfigError::ServerNotFound {
name: name.to_string(),
})
}
pub async fn load_mcp_servers_from_db(
store: &dyn crate::db::Database,
user_id: &str,
) -> Result<McpServersFile, ConfigError> {
match store.get_setting(user_id, "mcp_servers").await {
Ok(Some(value)) => {
let config: McpServersFile = serde_json::from_value(value)?;
for server in &config.servers {
server.validate().map_err(|e| ConfigError::InvalidConfig {
reason: format!("Server '{}': {}", server.name, e),
})?;
}
Ok(config)
}
Ok(None) => {
load_mcp_servers().await
}
Err(e) => {
tracing::warn!(
"Failed to load MCP servers from DB: {}, falling back to disk",
e
);
load_mcp_servers().await
}
}
}
pub async fn save_mcp_servers_to_db(
store: &dyn crate::db::Database,
user_id: &str,
config: &McpServersFile,
) -> Result<(), ConfigError> {
let value = serde_json::to_value(config)?;
store
.set_setting(user_id, "mcp_servers", &value)
.await
.map_err(std::io::Error::other)?;
Ok(())
}
pub async fn add_mcp_server_db(
store: &dyn crate::db::Database,
user_id: &str,
config: McpServerConfig,
) -> Result<(), ConfigError> {
config.validate()?;
let mut servers = load_mcp_servers_from_db(store, user_id).await?;
servers.upsert(config);
save_mcp_servers_to_db(store, user_id, &servers).await?;
Ok(())
}
pub async fn remove_mcp_server_db(
store: &dyn crate::db::Database,
user_id: &str,
name: &str,
) -> Result<(), ConfigError> {
let mut servers = load_mcp_servers_from_db(store, user_id).await?;
if !servers.remove(name) {
return Err(ConfigError::ServerNotFound {
name: name.to_string(),
});
}
save_mcp_servers_to_db(store, user_id, &servers).await?;
Ok(())
}
pub(crate) fn is_localhost_url(url: &str) -> bool {
let Ok(parsed) = url::Url::parse(url) else {
return false;
};
match parsed.host() {
Some(url::Host::Domain(d)) => d.eq_ignore_ascii_case("localhost"),
Some(url::Host::Ipv4(ip)) => ip.is_loopback(),
Some(url::Host::Ipv6(ip)) => ip.is_loopback(),
None => false,
}
}
#[derive(Debug)]
pub enum EffectiveTransport<'a> {
Http,
Stdio {
command: &'a str,
args: &'a [String],
env: &'a HashMap<String, String>,
},
Unix {
socket_path: &'a str,
},
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_is_localhost_url() {
assert!(is_localhost_url("http://localhost:3000/path"));
assert!(is_localhost_url("https://localhost/path"));
assert!(is_localhost_url("http://127.0.0.1:8080"));
assert!(is_localhost_url("http://127.0.0.1"));
assert!(!is_localhost_url("https://notlocalhost.com/path"));
assert!(!is_localhost_url("https://example-localhost.io"));
assert!(!is_localhost_url("https://mcp.notion.com"));
assert!(is_localhost_url("http://user:pass@localhost:3000/path"));
assert!(is_localhost_url("http://[::1]:8080/path"));
assert!(is_localhost_url("http://[::1]/path"));
assert!(!is_localhost_url("http://[::2]:8080/path"));
}
#[test]
fn test_server_config_validation() {
let config = McpServerConfig::new("notion", "https://mcp.notion.com");
assert!(config.validate().is_ok());
let config = McpServerConfig::new("local", "http://localhost:8080");
assert!(config.validate().is_ok());
let config = McpServerConfig::new("", "https://example.com");
assert!(config.validate().is_err());
let config = McpServerConfig::new("remote", "http://mcp.example.com");
assert!(config.validate().is_err());
}
#[test]
fn test_oauth_config_builder() {
let oauth = OAuthConfig::new("client-123")
.with_endpoints(
"https://auth.example.com/authorize",
"https://auth.example.com/token",
)
.with_scopes(vec!["read".to_string(), "write".to_string()]);
assert_eq!(oauth.client_id, "client-123");
assert!(oauth.authorization_url.is_some());
assert!(oauth.token_url.is_some());
assert_eq!(oauth.scopes.len(), 2);
assert!(oauth.use_pkce);
}
#[test]
fn test_servers_file_operations() {
let mut file = McpServersFile::default();
file.upsert(McpServerConfig::new("notion", "https://mcp.notion.com"));
assert_eq!(file.servers.len(), 1);
let mut updated = McpServerConfig::new("notion", "https://mcp.notion.com/v2");
updated.enabled = false;
file.upsert(updated);
assert_eq!(file.servers.len(), 1);
assert!(!file.get("notion").unwrap().enabled);
file.upsert(McpServerConfig::new("github", "https://mcp.github.com"));
assert_eq!(file.servers.len(), 2);
assert!(file.remove("notion"));
assert_eq!(file.servers.len(), 1);
assert!(file.get("notion").is_none());
assert!(!file.remove("nonexistent"));
}
#[tokio::test]
async fn test_load_save_config() {
let dir = tempdir().unwrap();
let path = dir.path().join("mcp-servers.json");
let mut config = McpServersFile::default();
config.upsert(
McpServerConfig::new("notion", "https://mcp.notion.com").with_oauth(
OAuthConfig::new("client-123")
.with_scopes(vec!["read".to_string(), "write".to_string()]),
),
);
save_mcp_servers_to(&config, &path).await.unwrap();
let loaded = load_mcp_servers_from(&path).await.unwrap();
assert_eq!(loaded.servers.len(), 1);
let server = loaded.get("notion").unwrap();
assert_eq!(server.url, "https://mcp.notion.com");
assert!(server.oauth.is_some());
assert_eq!(server.oauth.as_ref().unwrap().client_id, "client-123");
}
#[tokio::test]
async fn test_load_nonexistent_returns_empty() {
let dir = tempdir().unwrap();
let path = dir.path().join("nonexistent.json");
let config = load_mcp_servers_from(&path).await.unwrap();
assert!(config.servers.is_empty());
}
#[tokio::test]
async fn test_load_rejects_corrupted_headers() {
let dir = tempdir().unwrap();
let path = dir.path().join("mcp-servers.json");
let corrupted = serde_json::json!({
"servers": [{
"name": "bad-server",
"url": "https://mcp.example.com",
"enabled": true,
"headers": { "X Bad": "value" }
}]
});
tokio::fs::write(&path, corrupted.to_string())
.await
.unwrap();
let result = load_mcp_servers_from(&path).await;
assert!(result.is_err(), "Load should reject corrupted headers");
let err = result.unwrap_err().to_string();
assert!(
err.contains("bad-server"),
"Error should name the offending server, got: {err}"
);
}
#[test]
fn test_token_secret_names() {
let config = McpServerConfig::new("notion", "https://mcp.notion.com");
assert_eq!(config.token_secret_name(), "mcp_notion_access_token");
assert_eq!(
config.refresh_token_secret_name(),
"mcp_notion_access_token_refresh_token"
);
assert_eq!(
config.legacy_refresh_token_secret_name(),
"mcp_notion_refresh_token"
);
assert_eq!(config.client_id_secret_name(), "mcp_notion_client_id");
assert_eq!(
config.client_secret_secret_name(),
"mcp_notion_client_secret"
);
}
#[test]
fn test_requires_auth_with_oauth() {
let config = McpServerConfig::new("notion", "https://mcp.notion.com")
.with_oauth(OAuthConfig::new("client-123"));
assert!(config.requires_auth());
}
#[test]
fn test_requires_auth_remote_https_without_oauth() {
let config = McpServerConfig::new("github-copilot", "https://api.githubcopilot.com/mcp/");
assert!(config.requires_auth());
let config = McpServerConfig::new("notion", "https://mcp.notion.com");
assert!(config.requires_auth());
}
#[test]
fn test_requires_auth_localhost_no_auth() {
let config = McpServerConfig::new("local", "http://localhost:8080");
assert!(!config.requires_auth());
let config = McpServerConfig::new("local", "http://127.0.0.1:3000/mcp");
assert!(!config.requires_auth());
let config = McpServerConfig::new("local", "https://localhost:8443");
assert!(!config.requires_auth());
}
#[test]
fn test_requires_auth_http_remote_no_auth() {
let config = McpServerConfig::new("bad", "http://mcp.example.com");
assert!(!config.requires_auth());
}
#[test]
fn test_stdio_config_creation() {
let env = HashMap::from([("PATH".to_string(), "/usr/bin".to_string())]);
let config = McpServerConfig::new_stdio(
"my-server",
"npx",
vec!["-y".to_string(), "@modelcontextprotocol/server".to_string()],
env.clone(),
);
assert_eq!(config.name, "my-server");
assert!(config.url.is_empty());
assert!(config.enabled);
assert!(config.oauth.is_none());
assert!(config.headers.is_empty());
match &config.transport {
Some(McpTransportConfig::Stdio {
command,
args,
env: e,
}) => {
assert_eq!(command, "npx");
assert_eq!(
args,
&["-y".to_string(), "@modelcontextprotocol/server".to_string()]
);
assert_eq!(e, &env);
}
other => panic!("Expected Stdio transport, got {:?}", other),
}
}
#[test]
fn test_unix_config_creation() {
let config = McpServerConfig::new_unix("local-server", "/tmp/mcp.sock");
assert_eq!(config.name, "local-server");
assert!(config.url.is_empty());
assert!(config.enabled);
match &config.transport {
Some(McpTransportConfig::Unix { socket_path }) => {
assert_eq!(socket_path, "/tmp/mcp.sock");
}
other => panic!("Expected Unix transport, got {:?}", other),
}
}
#[test]
fn test_stdio_validation() {
let config = McpServerConfig::new_stdio("server", "npx", vec![], HashMap::new());
assert!(config.validate().is_ok());
let config = McpServerConfig::new_stdio("server", "", vec![], HashMap::new());
assert!(config.validate().is_err());
let err = config.validate().unwrap_err().to_string();
assert!(
err.contains("command"),
"Error should mention command: {}",
err
);
let config = McpServerConfig::new_stdio("", "npx", vec![], HashMap::new());
assert!(config.validate().is_err());
}
#[test]
fn test_unix_validation() {
let config = McpServerConfig::new_unix("server", "/tmp/mcp.sock");
assert!(config.validate().is_ok());
let config = McpServerConfig::new_unix("server", "");
assert!(config.validate().is_err());
let err = config.validate().unwrap_err().to_string();
assert!(
err.contains("socket"),
"Error should mention socket: {}",
err
);
let config = McpServerConfig::new_unix("", "/tmp/mcp.sock");
assert!(config.validate().is_err());
}
#[test]
fn test_requires_auth_stdio_never() {
let mut config = McpServerConfig::new_stdio("server", "npx", vec![], HashMap::new());
assert!(!config.requires_auth());
config.oauth = Some(OAuthConfig::new("client-123"));
assert!(!config.requires_auth());
}
#[test]
fn test_requires_auth_unix_never() {
let mut config = McpServerConfig::new_unix("server", "/tmp/mcp.sock");
assert!(!config.requires_auth());
config.oauth = Some(OAuthConfig::new("client-123"));
assert!(!config.requires_auth());
}
#[test]
fn test_header_crlf_injection_rejected() {
let mut headers = HashMap::new();
headers.insert("X-Good".to_string(), "safe".to_string());
headers.insert("X-Bad\r\nInjected: true".to_string(), "value".to_string());
let config =
McpServerConfig::new("server", "https://mcp.example.com").with_headers(headers);
let err = config.validate().unwrap_err().to_string();
assert!(
err.contains("not a valid HTTP header name"),
"Expected RFC 9110 error, got: {err}"
);
}
#[test]
fn test_header_value_crlf_injection_rejected() {
let mut headers = HashMap::new();
headers.insert(
"X-Header".to_string(),
"value\r\nInjected: true".to_string(),
);
let config =
McpServerConfig::new("server", "https://mcp.example.com").with_headers(headers);
let err = config.validate().unwrap_err().to_string();
assert!(
err.contains("invalid characters"),
"Expected invalid characters error, got: {err}"
);
}
#[test]
fn test_header_name_with_space_rejected() {
let headers = HashMap::from([("X Bad".to_string(), "value".to_string())]);
let config =
McpServerConfig::new("server", "https://mcp.example.com").with_headers(headers);
assert!(config.validate().is_err());
}
#[test]
fn test_header_name_with_colon_rejected() {
let headers = HashMap::from([("X:Bad".to_string(), "value".to_string())]);
let config =
McpServerConfig::new("server", "https://mcp.example.com").with_headers(headers);
assert!(config.validate().is_err());
}
#[test]
fn test_header_name_with_null_byte_rejected() {
let headers = HashMap::from([("X-Bad\0".to_string(), "value".to_string())]);
let config =
McpServerConfig::new("server", "https://mcp.example.com").with_headers(headers);
assert!(config.validate().is_err());
}
#[test]
fn test_header_empty_name_rejected() {
let mut headers = HashMap::new();
headers.insert(String::new(), "value".to_string());
let config =
McpServerConfig::new("server", "https://mcp.example.com").with_headers(headers);
let err = config.validate().unwrap_err().to_string();
assert!(
err.contains("empty"),
"Expected empty name error, got: {err}"
);
}
#[test]
fn test_has_custom_auth_header_case_insensitive() {
let headers = HashMap::from([("authorization".to_string(), "Bearer token".to_string())]);
let config =
McpServerConfig::new("server", "https://mcp.example.com").with_headers(headers);
assert!(config.has_custom_auth_header());
let headers = HashMap::from([("AUTHORIZATION".to_string(), "Bearer token".to_string())]);
let config =
McpServerConfig::new("server", "https://mcp.example.com").with_headers(headers);
assert!(config.has_custom_auth_header());
let headers = HashMap::from([("X-Api-Key".to_string(), "key".to_string())]);
let config =
McpServerConfig::new("server", "https://mcp.example.com").with_headers(headers);
assert!(!config.has_custom_auth_header());
}
#[test]
fn test_custom_headers() {
let headers = HashMap::from([
("X-Api-Key".to_string(), "secret".to_string()),
("Authorization".to_string(), "Bearer token".to_string()),
]);
let config =
McpServerConfig::new("server", "https://mcp.example.com").with_headers(headers.clone());
assert_eq!(config.headers, headers);
assert_eq!(config.headers.get("X-Api-Key").unwrap(), "secret");
}
#[test]
fn test_transport_config_serde_http() {
let transport = McpTransportConfig::Http;
let json = serde_json::to_string(&transport).unwrap();
assert!(json.contains("\"transport\":\"http\""));
let parsed: McpTransportConfig = serde_json::from_str(&json).unwrap();
assert!(matches!(parsed, McpTransportConfig::Http));
}
#[test]
fn test_transport_config_serde_stdio() {
let transport = McpTransportConfig::Stdio {
command: "npx".to_string(),
args: vec!["-y".to_string(), "server".to_string()],
env: HashMap::from([("KEY".to_string(), "val".to_string())]),
};
let json = serde_json::to_string(&transport).unwrap();
assert!(json.contains("\"transport\":\"stdio\""));
assert!(json.contains("\"command\":\"npx\""));
let parsed: McpTransportConfig = serde_json::from_str(&json).unwrap();
match parsed {
McpTransportConfig::Stdio { command, args, env } => {
assert_eq!(command, "npx");
assert_eq!(args, vec!["-y".to_string(), "server".to_string()]);
assert_eq!(env.get("KEY").unwrap(), "val");
}
other => panic!("Expected Stdio, got {:?}", other),
}
}
#[test]
fn test_transport_config_serde_unix() {
let transport = McpTransportConfig::Unix {
socket_path: "/tmp/mcp.sock".to_string(),
};
let json = serde_json::to_string(&transport).unwrap();
assert!(json.contains("\"transport\":\"unix\""));
assert!(json.contains("\"socket_path\":\"/tmp/mcp.sock\""));
let parsed: McpTransportConfig = serde_json::from_str(&json).unwrap();
match parsed {
McpTransportConfig::Unix { socket_path } => {
assert_eq!(socket_path, "/tmp/mcp.sock");
}
other => panic!("Expected Unix, got {:?}", other),
}
}
#[test]
fn test_backward_compat_no_transport_field() {
let json = r#"{
"name": "notion",
"url": "https://mcp.notion.com",
"enabled": true
}"#;
let config: McpServerConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.name, "notion");
assert_eq!(config.url, "https://mcp.notion.com");
assert!(config.transport.is_none());
assert!(config.headers.is_empty());
assert!(matches!(
config.effective_transport(),
EffectiveTransport::Http
));
}
#[test]
fn test_config_roundtrip_with_transport() {
let config = McpServerConfig::new_stdio(
"test-server",
"node",
vec!["server.js".to_string()],
HashMap::from([("NODE_ENV".to_string(), "production".to_string())]),
)
.with_description("A test server");
let json = serde_json::to_string_pretty(&config).unwrap();
let parsed: McpServerConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.name, "test-server");
assert!(parsed.url.is_empty());
assert_eq!(parsed.description.as_deref(), Some("A test server"));
match &parsed.transport {
Some(McpTransportConfig::Stdio { command, args, env }) => {
assert_eq!(command, "node");
assert_eq!(args, &["server.js".to_string()]);
assert_eq!(env.get("NODE_ENV").unwrap(), "production");
}
other => panic!("Expected Stdio transport, got {:?}", other),
}
let config = McpServerConfig::new_unix("unix-server", "/var/run/mcp.sock");
let json = serde_json::to_string_pretty(&config).unwrap();
let parsed: McpServerConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.name, "unix-server");
match &parsed.transport {
Some(McpTransportConfig::Unix { socket_path }) => {
assert_eq!(socket_path, "/var/run/mcp.sock");
}
other => panic!("Expected Unix transport, got {:?}", other),
}
let headers = HashMap::from([("X-Custom".to_string(), "value".to_string())]);
let config =
McpServerConfig::new("http-server", "https://mcp.example.com").with_headers(headers);
let json = serde_json::to_string_pretty(&config).unwrap();
let parsed: McpServerConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.name, "http-server");
assert!(parsed.transport.is_none());
assert_eq!(parsed.headers.get("X-Custom").unwrap(), "value");
}
#[test]
fn test_is_localhost_url_rejects_attacker_subdomain() {
assert!(
!is_localhost_url("http://evil.localhost.attacker.com:8080/mcp"),
"attacker subdomain containing 'localhost' must not be treated as local"
);
}
#[test]
fn test_is_localhost_url_accepts_real_localhost() {
assert!(is_localhost_url("http://localhost:8080/mcp"));
assert!(is_localhost_url("https://localhost/path"));
}
#[test]
fn test_is_localhost_url_accepts_loopback_ip() {
assert!(is_localhost_url("http://127.0.0.1:3000"));
assert!(is_localhost_url("http://[::1]:3000"));
}
#[test]
fn test_is_localhost_url_rejects_remote() {
assert!(!is_localhost_url("https://mcp.example.com"));
assert!(!is_localhost_url("http://192.168.1.1:8080"));
}
}