use std::collections::HashMap;
use std::path::PathBuf;
use crate::services::mcp::normalization::normalize_name_for_mcp;
use crate::services::mcp::types::*;
const CCR_PROXY_PATH_MARKERS: &[&str] = &["/v2/session_ingress/shttp/mcp/", "/v2/ccr-sessions/"];
pub fn get_server_command_array(config: &McpServerConfig) -> Option<Vec<String>> {
match config {
McpServerConfig::Stdio(stdio) => {
let mut cmd = vec![stdio.command.clone()];
cmd.extend(stdio.args.clone());
Some(cmd)
}
_ => None,
}
}
fn command_arrays_match(a: &[String], b: &[String]) -> bool {
if a.len() != b.len() {
return false;
}
a.iter().zip(b.iter()).all(|(x, y)| x == y)
}
pub fn get_server_url(config: &McpServerConfig) -> Option<String> {
match config {
McpServerConfig::Sse(sse) => Some(sse.url.clone()),
McpServerConfig::SseIde(sse_ide) => Some(sse_ide.url.clone()),
McpServerConfig::WebSocketIde(ws_ide) => Some(ws_ide.url.clone()),
McpServerConfig::Http(http) => Some(http.url.clone()),
McpServerConfig::WebSocket(ws) => Some(ws.url.clone()),
_ => None,
}
}
pub fn unwrap_ccr_proxy_url(url: &str) -> String {
if !CCR_PROXY_PATH_MARKERS.iter().any(|m| url.contains(m)) {
return url.to_string();
}
if let Some(idx) = url.find('?') {
let path = &url[..idx];
let query = &url[idx + 1..];
if query.contains("mcp_url=") {
for param in query.split('&') {
if param.starts_with("mcp_url=") {
if let Ok(decoded) = urlencoding_decode(¶m[8..]) {
return decoded;
}
}
}
}
}
url.to_string()
}
fn urlencoding_decode(input: &str) -> Result<String, ()> {
let mut result = String::new();
let mut chars = input.chars().peekable();
while let Some(c) = chars.next() {
if c == '%' {
let hex: String = chars.by_ref().take(2).collect();
if hex.len() == 2 {
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
result.push(byte as char);
continue;
}
}
return Err(());
} else if c == '+' {
result.push(' ');
} else {
result.push(c);
}
}
Ok(result)
}
pub fn get_mcp_server_signature(config: &McpServerConfig) -> Option<String> {
if let Some(cmd) = get_server_command_array(config) {
let json = serde_json::to_string(&cmd).unwrap_or_default();
return Some(format!("stdio:{}", json));
}
if let Some(url) = get_server_url(config) {
return Some(format!("url:{}", unwrap_ccr_proxy_url(&url)));
}
None
}
pub fn dedup_plugin_mcp_servers(
plugin_servers: &HashMap<String, ScopedMcpServerConfig>,
manual_servers: &HashMap<String, ScopedMcpServerConfig>,
) -> (
HashMap<String, ScopedMcpServerConfig>,
Vec<SuppressedServer>,
) {
let mut manual_sigs: HashMap<String, String> = HashMap::new();
for (name, config) in manual_servers {
if let Some(sig) = get_mcp_server_signature(&config.config) {
manual_sigs.entry(sig).or_insert_with(|| name.clone());
}
}
let mut servers: HashMap<String, ScopedMcpServerConfig> = HashMap::new();
let mut suppressed: Vec<SuppressedServer> = Vec::new();
let mut seen_plugin_sigs: HashMap<String, String> = HashMap::new();
for (name, config) in plugin_servers {
let sig = match get_mcp_server_signature(&config.config) {
Some(s) => s,
None => {
servers.insert(name.clone(), config.clone());
continue;
}
};
if let Some(manual_dup) = manual_sigs.get(&sig) {
log::debug!(
"Suppressing plugin MCP server \"{}\": duplicates manually-configured \"{}\"",
name,
manual_dup
);
suppressed.push(SuppressedServer {
name: name.clone(),
duplicate_of: manual_dup.clone(),
});
continue;
}
if let Some(plugin_dup) = seen_plugin_sigs.get(&sig) {
log::debug!(
"Suppressing plugin MCP server \"{}\": duplicates earlier plugin server \"{}\"",
name,
plugin_dup
);
suppressed.push(SuppressedServer {
name: name.clone(),
duplicate_of: plugin_dup.clone(),
});
continue;
}
seen_plugin_sigs.insert(sig, name.clone());
servers.insert(name.clone(), config.clone());
}
(servers, suppressed)
}
#[derive(Debug, Clone)]
pub struct SuppressedServer {
pub name: String,
pub duplicate_of: String,
}
fn url_pattern_to_regex(pattern: &str) -> String {
let escaped: String = pattern
.chars()
.map(|c| {
if "+?^${()|[]\\".contains(c) {
format!("\\{}", c)
} else {
c.to_string()
}
})
.collect();
escaped.replace('*', ".*")
}
fn url_matches_pattern(url: &str, pattern: &str) -> bool {
let regex_str = format!("^{}$", url_pattern_to_regex(pattern));
regex::Regex::new(®ex_str)
.map(|re| re.is_match(url))
.unwrap_or(false)
}
pub fn is_mcp_server_disabled(server_name: &str, disabled_servers: Option<&[String]>) -> bool {
if let Some(disabled) = disabled_servers {
let normalized = normalize_name_for_mcp(server_name);
disabled
.iter()
.any(|name| normalize_name_for_mcp(name) == normalized)
} else {
false
}
}
pub fn is_mcp_server_denied(
server_name: &str,
config: Option<&McpServerConfig>,
denied_servers: Option<&[McpServerDenialEntry]>,
) -> bool {
let Some(denied) = denied_servers else {
return false;
};
for entry in denied {
match entry {
McpServerDenialEntry::Name(name) => {
if name == server_name {
return true;
}
}
McpServerDenialEntry::Command(cmd) => {
if let Some(cfg) = config {
if let Some(server_cmd) = get_server_command_array(cfg) {
if command_arrays_match(&server_cmd, cmd) {
return true;
}
}
}
}
McpServerDenialEntry::Url(url_pattern) => {
if let Some(cfg) = config {
if let Some(server_url) = get_server_url(cfg) {
if url_matches_pattern(&server_url, url_pattern) {
return true;
}
}
}
}
}
}
false
}
pub fn is_mcp_server_allowed_by_policy(
server_name: &str,
config: Option<&McpServerConfig>,
allowed_servers: Option<&[McpServerAllowanceEntry]>,
denied_servers: Option<&[McpServerDenialEntry]>,
) -> bool {
if is_mcp_server_denied(server_name, config, denied_servers) {
return false;
}
let Some(allowed) = allowed_servers else {
return true; };
if allowed.is_empty() {
return false;
}
let has_command_entries = allowed
.iter()
.any(|e| matches!(e, McpServerAllowanceEntry::Command(_)));
let has_url_entries = allowed
.iter()
.any(|e| matches!(e, McpServerAllowanceEntry::Url(_)));
if let Some(cfg) = config {
if let Some(server_cmd) = get_server_command_array(cfg) {
if has_command_entries {
for entry in allowed {
if let McpServerAllowanceEntry::Command(cmd) = entry {
if command_arrays_match(&server_cmd, cmd) {
return true;
}
}
}
return false;
} else {
for entry in allowed {
if let McpServerAllowanceEntry::Name(name) = entry {
if name == server_name {
return true;
}
}
}
return false;
}
} else if let Some(server_url) = get_server_url(cfg) {
if has_url_entries {
for entry in allowed {
if let McpServerAllowanceEntry::Url(pattern) = entry {
if url_matches_pattern(&server_url, pattern) {
return true;
}
}
}
return false;
} else {
for entry in allowed {
if let McpServerAllowanceEntry::Name(name) = entry {
if name == server_name {
return true;
}
}
}
return false;
}
}
}
for entry in allowed {
if let McpServerAllowanceEntry::Name(name) = entry {
if name == server_name {
return true;
}
}
}
false
}
#[derive(Debug, Clone)]
pub enum McpServerDenialEntry {
Name(String),
Command(Vec<String>),
Url(String),
}
#[derive(Debug, Clone)]
pub enum McpServerAllowanceEntry {
Name(String),
Command(Vec<String>),
Url(String),
}
pub fn add_scope_to_servers(
servers: &HashMap<String, McpServerConfig>,
scope: ConfigScope,
) -> HashMap<String, ScopedMcpServerConfig> {
servers
.iter()
.map(|(name, config)| {
(
name.clone(),
ScopedMcpServerConfig {
config: config.clone(),
scope: scope.clone(),
plugin_source: None,
},
)
})
.collect()
}
pub fn get_project_mcp_file_path(cwd: &PathBuf) -> PathBuf {
cwd.join(".mcp.json")
}
pub fn get_global_mcp_file_path() -> PathBuf {
dirs::config_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("ai-agent")
.join("mcp.json")
}
pub fn get_enterprise_mcp_file_path() -> PathBuf {
dirs::config_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("ai-agent")
.join("managed-mcp.json")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_server_command_array_stdio() {
let config = McpServerConfig::Stdio(McpStdioServerConfig {
config_type: Some("stdio".to_string()),
command: "node".to_string(),
args: vec!["server.js".to_string()],
env: None,
});
let cmd = get_server_command_array(&config);
assert_eq!(cmd, Some(vec!["node".to_string(), "server.js".to_string()]));
}
#[test]
fn test_get_server_command_array_non_stdio() {
let config = McpServerConfig::Http(McpHttpServerConfig {
config_type: "http".to_string(),
url: "https://example.com".to_string(),
headers: None,
headers_helper: None,
oauth: None,
});
let cmd = get_server_command_array(&config);
assert!(cmd.is_none());
}
#[test]
fn test_get_server_url_http() {
let config = McpServerConfig::Http(McpHttpServerConfig {
config_type: "http".to_string(),
url: "https://example.com/mcp".to_string(),
headers: None,
headers_helper: None,
oauth: None,
});
let url = get_server_url(&config);
assert_eq!(url, Some("https://example.com/mcp".to_string()));
}
#[test]
fn test_url_matches_pattern() {
assert!(url_matches_pattern(
"https://example.com/api/v1",
"https://example.com/*"
));
assert!(url_matches_pattern(
"https://api.example.com/path",
"https://*.example.com/*"
));
assert!(!url_matches_pattern(
"https://other.com/path",
"https://example.com/*"
));
}
#[test]
fn test_mcp_server_signature() {
let config = McpServerConfig::Stdio(McpStdioServerConfig {
config_type: Some("stdio".to_string()),
command: "npx".to_string(),
args: vec!["-y".to_string(), "server".to_string()],
env: None,
});
let sig = get_mcp_server_signature(&config);
assert!(sig.is_some());
assert!(sig.unwrap().starts_with("stdio:"));
}
#[test]
fn test_command_arrays_match() {
assert!(command_arrays_match(
&["node".to_string(), "server.js".to_string()],
&["node".to_string(), "server.js".to_string()]
));
assert!(!command_arrays_match(
&["node".to_string(), "server.js".to_string()],
&["node".to_string()]
));
}
}