use std::collections::HashMap;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use tokio::fs;
use crate::tools::tool::ToolError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerConfig {
pub name: String,
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub oauth: Option<OAuthConfig>,
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
fn default_true() -> bool {
true
}
impl McpServerConfig {
pub fn new(name: impl Into<String>, url: impl Into<String>) -> Self {
Self {
name: name.into(),
url: url.into(),
oauth: None,
enabled: true,
description: None,
}
}
pub fn with_oauth(mut self, oauth: OAuthConfig) -> Self {
self.oauth = Some(oauth);
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.name.is_empty() {
return Err(ConfigError::InvalidConfig {
reason: "Server name cannot be empty".to_string(),
});
}
if self.url.is_empty() {
return Err(ConfigError::InvalidConfig {
reason: "Server URL cannot be empty".to_string(),
});
}
let url_lower = self.url.to_lowercase();
let is_localhost = url_lower.contains("localhost") || url_lower.contains("127.0.0.1");
if !is_localhost && !url_lower.starts_with("https://") {
return Err(ConfigError::InvalidConfig {
reason: "Remote MCP servers must use HTTPS".to_string(),
});
}
Ok(())
}
pub fn requires_auth(&self) -> bool {
if self.oauth.is_some() {
return true;
}
let url_lower = self.url.to_lowercase();
let is_localhost = is_localhost_url(&url_lower);
url_lower.starts_with("https://") && !is_localhost
}
pub fn token_secret_name(&self) -> String {
format!("mcp_{}_access_token", self.name)
}
pub fn refresh_token_secret_name(&self) -> String {
format!("mcp_{}_refresh_token", self.name)
}
pub fn client_id_secret_name(&self) -> String {
format!("mcp_{}_client_id", self.name)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthConfig {
pub client_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub authorization_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_url: Option<String>,
#[serde(default)]
pub scopes: Vec<String>,
#[serde(default = "default_true")]
pub use_pkce: bool,
#[serde(default)]
pub extra_params: HashMap<String, String>,
}
impl OAuthConfig {
pub fn new(client_id: impl Into<String>) -> Self {
Self {
client_id: client_id.into(),
authorization_url: None,
token_url: None,
scopes: Vec::new(),
use_pkce: true,
extra_params: HashMap::new(),
}
}
pub fn with_endpoints(
mut self,
authorization_url: impl Into<String>,
token_url: impl Into<String>,
) -> Self {
self.authorization_url = Some(authorization_url.into());
self.token_url = Some(token_url.into());
self
}
pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
self.scopes = scopes;
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct McpServersFile {
#[serde(default)]
pub servers: Vec<McpServerConfig>,
#[serde(default = "default_schema_version")]
pub schema_version: u32,
}
fn default_schema_version() -> u32 {
1
}
impl McpServersFile {
pub fn get(&self, name: &str) -> Option<&McpServerConfig> {
self.servers.iter().find(|s| s.name == name)
}
pub fn get_mut(&mut self, name: &str) -> Option<&mut McpServerConfig> {
self.servers.iter_mut().find(|s| s.name == name)
}
pub fn upsert(&mut self, config: McpServerConfig) {
if let Some(existing) = self.get_mut(&config.name) {
*existing = config;
} else {
self.servers.push(config);
}
}
pub fn remove(&mut self, name: &str) -> bool {
let len_before = self.servers.len();
self.servers.retain(|s| s.name != name);
self.servers.len() < len_before
}
pub fn enabled_servers(&self) -> impl Iterator<Item = &McpServerConfig> {
self.servers.iter().filter(|s| s.enabled)
}
}
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Invalid configuration: {reason}")]
InvalidConfig { reason: String },
#[error("Server not found: {name}")]
ServerNotFound { name: String },
}
impl From<ConfigError> for ToolError {
fn from(err: ConfigError) -> Self {
ToolError::ExternalService(err.to_string())
}
}
pub fn default_config_path() -> PathBuf {
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(".ironclaw")
.join("mcp-servers.json")
}
pub async fn load_mcp_servers() -> Result<McpServersFile, ConfigError> {
load_mcp_servers_from(default_config_path()).await
}
pub async fn load_mcp_servers_from(path: impl AsRef<Path>) -> Result<McpServersFile, ConfigError> {
let path = path.as_ref();
if !path.exists() {
return Ok(McpServersFile::default());
}
let content = fs::read_to_string(path).await?;
let config: McpServersFile = serde_json::from_str(&content)?;
Ok(config)
}
pub async fn save_mcp_servers(config: &McpServersFile) -> Result<(), ConfigError> {
save_mcp_servers_to(config, default_config_path()).await
}
pub async fn save_mcp_servers_to(
config: &McpServersFile,
path: impl AsRef<Path>,
) -> Result<(), ConfigError> {
let path = path.as_ref();
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).await?;
}
let content = serde_json::to_string_pretty(config)?;
fs::write(path, content).await?;
Ok(())
}
pub async fn add_mcp_server(config: McpServerConfig) -> Result<(), ConfigError> {
config.validate()?;
let mut servers = load_mcp_servers().await?;
servers.upsert(config);
save_mcp_servers(&servers).await?;
Ok(())
}
pub async fn remove_mcp_server(name: &str) -> Result<(), ConfigError> {
let mut servers = load_mcp_servers().await?;
if !servers.remove(name) {
return Err(ConfigError::ServerNotFound {
name: name.to_string(),
});
}
save_mcp_servers(&servers).await?;
Ok(())
}
pub async fn get_mcp_server(name: &str) -> Result<McpServerConfig, ConfigError> {
let servers = load_mcp_servers().await?;
servers
.get(name)
.cloned()
.ok_or_else(|| ConfigError::ServerNotFound {
name: name.to_string(),
})
}
pub async fn load_mcp_servers_from_db(
store: &dyn crate::db::Database,
user_id: &str,
) -> Result<McpServersFile, ConfigError> {
match store.get_setting(user_id, "mcp_servers").await {
Ok(Some(value)) => {
let config: McpServersFile = serde_json::from_value(value)?;
Ok(config)
}
Ok(None) => {
load_mcp_servers().await
}
Err(e) => {
tracing::warn!(
"Failed to load MCP servers from DB: {}, falling back to disk",
e
);
load_mcp_servers().await
}
}
}
pub async fn save_mcp_servers_to_db(
store: &dyn crate::db::Database,
user_id: &str,
config: &McpServersFile,
) -> Result<(), ConfigError> {
let value = serde_json::to_value(config)?;
store
.set_setting(user_id, "mcp_servers", &value)
.await
.map_err(std::io::Error::other)?;
Ok(())
}
pub async fn add_mcp_server_db(
store: &dyn crate::db::Database,
user_id: &str,
config: McpServerConfig,
) -> Result<(), ConfigError> {
config.validate()?;
let mut servers = load_mcp_servers_from_db(store, user_id).await?;
servers.upsert(config);
save_mcp_servers_to_db(store, user_id, &servers).await?;
Ok(())
}
pub async fn remove_mcp_server_db(
store: &dyn crate::db::Database,
user_id: &str,
name: &str,
) -> Result<(), ConfigError> {
let mut servers = load_mcp_servers_from_db(store, user_id).await?;
if !servers.remove(name) {
return Err(ConfigError::ServerNotFound {
name: name.to_string(),
});
}
save_mcp_servers_to_db(store, user_id, &servers).await?;
Ok(())
}
fn is_localhost_url(url: &str) -> bool {
let Ok(parsed) = url::Url::parse(url) else {
return false;
};
match parsed.host() {
Some(url::Host::Domain(d)) => d.eq_ignore_ascii_case("localhost"),
Some(url::Host::Ipv4(ip)) => ip.is_loopback(),
Some(url::Host::Ipv6(ip)) => ip.is_loopback(),
None => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_is_localhost_url() {
assert!(is_localhost_url("http://localhost:3000/path"));
assert!(is_localhost_url("https://localhost/path"));
assert!(is_localhost_url("http://127.0.0.1:8080"));
assert!(is_localhost_url("http://127.0.0.1"));
assert!(!is_localhost_url("https://notlocalhost.com/path"));
assert!(!is_localhost_url("https://example-localhost.io"));
assert!(!is_localhost_url("https://mcp.notion.com"));
assert!(is_localhost_url("http://user:pass@localhost:3000/path"));
assert!(is_localhost_url("http://[::1]:8080/path"));
assert!(is_localhost_url("http://[::1]/path"));
assert!(!is_localhost_url("http://[::2]:8080/path"));
}
#[test]
fn test_server_config_validation() {
let config = McpServerConfig::new("notion", "https://mcp.notion.com");
assert!(config.validate().is_ok());
let config = McpServerConfig::new("local", "http://localhost:8080");
assert!(config.validate().is_ok());
let config = McpServerConfig::new("", "https://example.com");
assert!(config.validate().is_err());
let config = McpServerConfig::new("remote", "http://mcp.example.com");
assert!(config.validate().is_err());
}
#[test]
fn test_oauth_config_builder() {
let oauth = OAuthConfig::new("client-123")
.with_endpoints(
"https://auth.example.com/authorize",
"https://auth.example.com/token",
)
.with_scopes(vec!["read".to_string(), "write".to_string()]);
assert_eq!(oauth.client_id, "client-123");
assert!(oauth.authorization_url.is_some());
assert!(oauth.token_url.is_some());
assert_eq!(oauth.scopes.len(), 2);
assert!(oauth.use_pkce);
}
#[test]
fn test_servers_file_operations() {
let mut file = McpServersFile::default();
file.upsert(McpServerConfig::new("notion", "https://mcp.notion.com"));
assert_eq!(file.servers.len(), 1);
let mut updated = McpServerConfig::new("notion", "https://mcp.notion.com/v2");
updated.enabled = false;
file.upsert(updated);
assert_eq!(file.servers.len(), 1);
assert!(!file.get("notion").unwrap().enabled);
file.upsert(McpServerConfig::new("github", "https://mcp.github.com"));
assert_eq!(file.servers.len(), 2);
assert!(file.remove("notion"));
assert_eq!(file.servers.len(), 1);
assert!(file.get("notion").is_none());
assert!(!file.remove("nonexistent"));
}
#[tokio::test]
async fn test_load_save_config() {
let dir = tempdir().unwrap();
let path = dir.path().join("mcp-servers.json");
let mut config = McpServersFile::default();
config.upsert(
McpServerConfig::new("notion", "https://mcp.notion.com").with_oauth(
OAuthConfig::new("client-123")
.with_scopes(vec!["read".to_string(), "write".to_string()]),
),
);
save_mcp_servers_to(&config, &path).await.unwrap();
let loaded = load_mcp_servers_from(&path).await.unwrap();
assert_eq!(loaded.servers.len(), 1);
let server = loaded.get("notion").unwrap();
assert_eq!(server.url, "https://mcp.notion.com");
assert!(server.oauth.is_some());
assert_eq!(server.oauth.as_ref().unwrap().client_id, "client-123");
}
#[tokio::test]
async fn test_load_nonexistent_returns_empty() {
let dir = tempdir().unwrap();
let path = dir.path().join("nonexistent.json");
let config = load_mcp_servers_from(&path).await.unwrap();
assert!(config.servers.is_empty());
}
#[test]
fn test_token_secret_names() {
let config = McpServerConfig::new("notion", "https://mcp.notion.com");
assert_eq!(config.token_secret_name(), "mcp_notion_access_token");
assert_eq!(
config.refresh_token_secret_name(),
"mcp_notion_refresh_token"
);
}
#[test]
fn test_requires_auth_with_oauth() {
let config = McpServerConfig::new("notion", "https://mcp.notion.com")
.with_oauth(OAuthConfig::new("client-123"));
assert!(config.requires_auth());
}
#[test]
fn test_requires_auth_remote_https_without_oauth() {
let config = McpServerConfig::new("github-copilot", "https://api.githubcopilot.com/mcp/");
assert!(config.requires_auth());
let config = McpServerConfig::new("notion", "https://mcp.notion.com");
assert!(config.requires_auth());
}
#[test]
fn test_requires_auth_localhost_no_auth() {
let config = McpServerConfig::new("local", "http://localhost:8080");
assert!(!config.requires_auth());
let config = McpServerConfig::new("local", "http://127.0.0.1:3000/mcp");
assert!(!config.requires_auth());
let config = McpServerConfig::new("local", "https://localhost:8443");
assert!(!config.requires_auth());
}
#[test]
fn test_requires_auth_http_remote_no_auth() {
let config = McpServerConfig::new("bad", "http://mcp.example.com");
assert!(!config.requires_auth());
}
}