use hashbrown::HashMap;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use vtcode_auth::McpOAuthConfig;
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpClientConfig {
#[serde(default = "default_mcp_enabled")]
pub enabled: bool,
#[serde(default)]
pub ui: McpUiConfig,
#[serde(default)]
pub providers: Vec<McpProviderConfig>,
#[serde(default)]
pub requirements: McpRequirementsConfig,
#[serde(default)]
pub server: McpServerConfig,
#[serde(default)]
pub allowlist: McpAllowListConfig,
#[serde(default = "default_max_concurrent_connections")]
pub max_concurrent_connections: usize,
#[serde(default = "default_request_timeout_seconds")]
pub request_timeout_seconds: u64,
#[serde(default = "default_retry_attempts")]
pub retry_attempts: u32,
#[serde(default)]
pub startup_timeout_seconds: Option<u64>,
#[serde(default)]
pub tool_timeout_seconds: Option<u64>,
#[serde(default = "default_experimental_use_rmcp_client")]
pub experimental_use_rmcp_client: bool,
#[serde(default = "default_connection_pooling_enabled")]
pub connection_pooling_enabled: bool,
#[serde(default = "default_tool_cache_capacity")]
pub tool_cache_capacity: usize,
#[serde(default = "default_connection_timeout_seconds")]
pub connection_timeout_seconds: u64,
#[serde(default)]
pub security: McpSecurityConfig,
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpSecurityConfig {
#[serde(default = "default_mcp_auth_enabled")]
pub auth_enabled: bool,
#[serde(default)]
pub api_key_env: Option<String>,
#[serde(default)]
pub rate_limit: McpRateLimitConfig,
#[serde(default)]
pub validation: McpValidationConfig,
}
impl Default for McpSecurityConfig {
fn default() -> Self {
Self {
auth_enabled: default_mcp_auth_enabled(),
api_key_env: None,
rate_limit: McpRateLimitConfig::default(),
validation: McpValidationConfig::default(),
}
}
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpRateLimitConfig {
#[serde(default = "default_requests_per_minute")]
pub requests_per_minute: u32,
#[serde(default = "default_concurrent_requests")]
pub concurrent_requests: u32,
}
impl Default for McpRateLimitConfig {
fn default() -> Self {
Self {
requests_per_minute: default_requests_per_minute(),
concurrent_requests: default_concurrent_requests(),
}
}
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpValidationConfig {
#[serde(default = "default_schema_validation_enabled")]
pub schema_validation_enabled: bool,
#[serde(default = "default_path_traversal_protection_enabled")]
pub path_traversal_protection: bool,
#[serde(default = "default_max_argument_size")]
pub max_argument_size: u32,
}
impl Default for McpValidationConfig {
fn default() -> Self {
Self {
schema_validation_enabled: default_schema_validation_enabled(),
path_traversal_protection: default_path_traversal_protection_enabled(),
max_argument_size: default_max_argument_size(),
}
}
}
impl Default for McpClientConfig {
fn default() -> Self {
Self {
enabled: default_mcp_enabled(),
ui: McpUiConfig::default(),
providers: Vec::new(),
requirements: McpRequirementsConfig::default(),
server: McpServerConfig::default(),
allowlist: McpAllowListConfig::default(),
max_concurrent_connections: default_max_concurrent_connections(),
request_timeout_seconds: default_request_timeout_seconds(),
retry_attempts: default_retry_attempts(),
startup_timeout_seconds: None,
tool_timeout_seconds: None,
experimental_use_rmcp_client: default_experimental_use_rmcp_client(),
security: McpSecurityConfig::default(),
connection_pooling_enabled: default_connection_pooling_enabled(),
connection_timeout_seconds: default_connection_timeout_seconds(),
tool_cache_capacity: default_tool_cache_capacity(),
}
}
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpRequirementsConfig {
#[serde(default = "default_mcp_requirements_enforce")]
pub enforce: bool,
#[serde(default)]
pub allowed_stdio_commands: Vec<String>,
#[serde(default)]
pub allowed_http_endpoints: Vec<String>,
}
impl Default for McpRequirementsConfig {
fn default() -> Self {
Self {
enforce: default_mcp_requirements_enforce(),
allowed_stdio_commands: Vec::new(),
allowed_http_endpoints: Vec::new(),
}
}
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpUiConfig {
#[serde(default = "default_mcp_ui_mode")]
pub mode: McpUiMode,
#[serde(default = "default_max_mcp_events")]
pub max_events: usize,
#[serde(default = "default_show_provider_names")]
pub show_provider_names: bool,
#[serde(default)]
#[cfg_attr(
feature = "schema",
schemars(with = "BTreeMap<String, McpRendererProfile>")
)]
pub renderers: HashMap<String, McpRendererProfile>,
}
impl Default for McpUiConfig {
fn default() -> Self {
Self {
mode: default_mcp_ui_mode(),
max_events: default_max_mcp_events(),
show_provider_names: default_show_provider_names(),
renderers: HashMap::new(),
}
}
}
impl McpUiConfig {
pub fn renderer_for_identifier(&self, identifier: &str) -> Option<McpRendererProfile> {
let normalized_identifier = normalize_mcp_identifier(identifier);
if normalized_identifier.is_empty() {
return None;
}
self.renderers.iter().find_map(|(key, profile)| {
let normalized_key = normalize_mcp_identifier(key);
if normalized_identifier.starts_with(&normalized_key) {
Some(*profile)
} else {
None
}
})
}
pub fn renderer_for_tool(&self, tool_name: &str) -> Option<McpRendererProfile> {
let identifier = tool_name.strip_prefix("mcp_").unwrap_or(tool_name);
self.renderer_for_identifier(identifier)
}
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum McpUiMode {
#[default]
Compact,
Full,
}
impl std::fmt::Display for McpUiMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
McpUiMode::Compact => write!(f, "compact"),
McpUiMode::Full => write!(f, "full"),
}
}
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub enum McpRendererProfile {
Context7,
SequentialThinking,
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpProviderConfig {
pub name: String,
#[serde(flatten)]
pub transport: McpTransportConfig,
#[serde(default)]
#[cfg_attr(feature = "schema", schemars(with = "BTreeMap<String, String>"))]
pub env: HashMap<String, String>,
#[serde(default = "default_provider_enabled")]
pub enabled: bool,
#[serde(default = "default_provider_max_concurrent")]
pub max_concurrent_requests: usize,
#[serde(default)]
pub startup_timeout_ms: Option<u64>,
}
impl Default for McpProviderConfig {
fn default() -> Self {
Self {
name: String::new(),
transport: McpTransportConfig::Stdio(McpStdioServerConfig::default()),
env: HashMap::new(),
enabled: default_provider_enabled(),
max_concurrent_requests: default_provider_max_concurrent(),
startup_timeout_ms: None,
}
}
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpAllowListConfig {
#[serde(default = "default_allowlist_enforced")]
pub enforce: bool,
#[serde(default)]
pub default: McpAllowListRules,
#[serde(default)]
pub providers: BTreeMap<String, McpAllowListRules>,
}
impl Default for McpAllowListConfig {
fn default() -> Self {
Self {
enforce: default_allowlist_enforced(),
default: McpAllowListRules::default(),
providers: BTreeMap::new(),
}
}
}
impl McpAllowListConfig {
pub fn is_tool_allowed(&self, provider: &str, tool_name: &str) -> bool {
if !self.enforce {
return true;
}
self.resolve_match(provider, tool_name, |rules| &rules.tools)
}
pub fn is_resource_allowed(&self, provider: &str, resource: &str) -> bool {
if !self.enforce {
return true;
}
self.resolve_match(provider, resource, |rules| &rules.resources)
}
pub fn is_prompt_allowed(&self, provider: &str, prompt: &str) -> bool {
if !self.enforce {
return true;
}
self.resolve_match(provider, prompt, |rules| &rules.prompts)
}
pub fn is_logging_channel_allowed(&self, provider: Option<&str>, channel: &str) -> bool {
if !self.enforce {
return true;
}
if let Some(name) = provider
&& let Some(rules) = self.providers.get(name)
&& let Some(patterns) = &rules.logging
{
return pattern_matches(patterns, channel);
}
if let Some(patterns) = &self.default.logging
&& pattern_matches(patterns, channel)
{
return true;
}
false
}
pub fn is_configuration_allowed(
&self,
provider: Option<&str>,
category: &str,
key: &str,
) -> bool {
if !self.enforce {
return true;
}
if let Some(name) = provider
&& let Some(rules) = self.providers.get(name)
&& let Some(result) = configuration_allowed(rules, category, key)
{
return result;
}
if let Some(result) = configuration_allowed(&self.default, category, key) {
return result;
}
false
}
fn resolve_match<'a, F>(&'a self, provider: &str, candidate: &str, accessor: F) -> bool
where
F: Fn(&'a McpAllowListRules) -> &'a Option<Vec<String>>,
{
if let Some(rules) = self.providers.get(provider)
&& let Some(patterns) = accessor(rules)
{
return pattern_matches(patterns, candidate);
}
if let Some(patterns) = accessor(&self.default)
&& pattern_matches(patterns, candidate)
{
return true;
}
false
}
}
fn configuration_allowed(rules: &McpAllowListRules, category: &str, key: &str) -> Option<bool> {
rules.configuration.as_ref().and_then(|entries| {
entries
.get(category)
.map(|patterns| pattern_matches(patterns, key))
})
}
fn pattern_matches(patterns: &[String], candidate: &str) -> bool {
patterns
.iter()
.any(|pattern| wildcard_match(pattern, candidate))
}
fn wildcard_match(pattern: &str, candidate: &str) -> bool {
if pattern == "*" {
return true;
}
let mut regex_pattern = String::from("^");
let mut literal_buffer = String::new();
for ch in pattern.chars() {
match ch {
'*' => {
if !literal_buffer.is_empty() {
regex_pattern.push_str(®ex::escape(&literal_buffer));
literal_buffer.clear();
}
regex_pattern.push_str(".*");
}
'?' => {
if !literal_buffer.is_empty() {
regex_pattern.push_str(®ex::escape(&literal_buffer));
literal_buffer.clear();
}
regex_pattern.push('.');
}
_ => literal_buffer.push(ch),
}
}
if !literal_buffer.is_empty() {
regex_pattern.push_str(®ex::escape(&literal_buffer));
}
regex_pattern.push('$');
Regex::new(®ex_pattern)
.map(|regex| regex.is_match(candidate))
.unwrap_or(false)
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct McpAllowListRules {
#[serde(default)]
pub tools: Option<Vec<String>>,
#[serde(default)]
pub resources: Option<Vec<String>>,
#[serde(default)]
pub prompts: Option<Vec<String>>,
#[serde(default)]
pub logging: Option<Vec<String>>,
#[serde(default)]
pub configuration: Option<BTreeMap<String, Vec<String>>>,
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpServerConfig {
#[serde(default = "default_mcp_server_enabled")]
pub enabled: bool,
#[serde(default = "default_mcp_server_bind")]
pub bind_address: String,
#[serde(default = "default_mcp_server_port")]
pub port: u16,
#[serde(default = "default_mcp_server_transport")]
pub transport: McpServerTransport,
#[serde(default = "default_mcp_server_name")]
pub name: String,
#[serde(default = "default_mcp_server_version")]
pub version: String,
#[serde(default)]
pub exposed_tools: Vec<String>,
}
impl Default for McpServerConfig {
fn default() -> Self {
Self {
enabled: default_mcp_server_enabled(),
bind_address: default_mcp_server_bind(),
port: default_mcp_server_port(),
transport: default_mcp_server_transport(),
name: default_mcp_server_name(),
version: default_mcp_server_version(),
exposed_tools: Vec::new(),
}
}
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum McpServerTransport {
#[default]
Sse,
Http,
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum McpTransportConfig {
Stdio(McpStdioServerConfig),
Http(McpHttpServerConfig),
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct McpStdioServerConfig {
pub command: String,
pub args: Vec<String>,
#[serde(default)]
pub working_directory: Option<String>,
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpHttpServerConfig {
pub endpoint: String,
#[serde(default)]
pub api_key_env: Option<String>,
#[serde(default)]
pub oauth: Option<McpOAuthConfig>,
#[serde(default = "default_mcp_protocol_version")]
pub protocol_version: String,
#[serde(default, alias = "headers")]
#[cfg_attr(feature = "schema", schemars(with = "BTreeMap<String, String>"))]
pub http_headers: HashMap<String, String>,
#[serde(default)]
#[cfg_attr(feature = "schema", schemars(with = "BTreeMap<String, String>"))]
pub env_http_headers: HashMap<String, String>,
}
impl Default for McpHttpServerConfig {
fn default() -> Self {
Self {
endpoint: String::new(),
api_key_env: None,
oauth: None,
protocol_version: default_mcp_protocol_version(),
http_headers: HashMap::new(),
env_http_headers: HashMap::new(),
}
}
}
fn default_mcp_enabled() -> bool {
false
}
fn default_mcp_ui_mode() -> McpUiMode {
McpUiMode::Compact
}
fn default_max_mcp_events() -> usize {
50
}
fn default_show_provider_names() -> bool {
true
}
fn default_max_concurrent_connections() -> usize {
5
}
fn default_request_timeout_seconds() -> u64 {
30
}
fn default_retry_attempts() -> u32 {
3
}
fn default_experimental_use_rmcp_client() -> bool {
true
}
fn default_provider_enabled() -> bool {
true
}
fn default_provider_max_concurrent() -> usize {
3
}
fn default_allowlist_enforced() -> bool {
false
}
fn default_mcp_protocol_version() -> String {
"2024-11-05".into()
}
fn default_mcp_server_enabled() -> bool {
false
}
fn default_connection_pooling_enabled() -> bool {
true
}
fn default_tool_cache_capacity() -> usize {
100
}
fn default_connection_timeout_seconds() -> u64 {
30
}
fn default_mcp_server_bind() -> String {
"127.0.0.1".into()
}
fn default_mcp_server_port() -> u16 {
3000
}
fn default_mcp_server_transport() -> McpServerTransport {
McpServerTransport::Sse
}
fn default_mcp_server_name() -> String {
"vtcode-mcp-server".into()
}
fn default_mcp_server_version() -> String {
env!("CARGO_PKG_VERSION").into()
}
fn normalize_mcp_identifier(value: &str) -> String {
value
.chars()
.filter(|ch| ch.is_ascii_alphanumeric())
.map(|ch| ch.to_ascii_lowercase())
.collect()
}
fn default_mcp_auth_enabled() -> bool {
false
}
fn default_requests_per_minute() -> u32 {
100
}
fn default_concurrent_requests() -> u32 {
10
}
fn default_schema_validation_enabled() -> bool {
true
}
fn default_path_traversal_protection_enabled() -> bool {
true
}
fn default_max_argument_size() -> u32 {
1024 * 1024 }
fn default_mcp_requirements_enforce() -> bool {
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::constants::mcp as mcp_constants;
use std::collections::BTreeMap;
#[test]
fn test_mcp_config_defaults() {
let config = McpClientConfig::default();
assert!(!config.enabled);
assert_eq!(config.ui.mode, McpUiMode::Compact);
assert_eq!(config.ui.max_events, 50);
assert!(config.ui.show_provider_names);
assert!(config.ui.renderers.is_empty());
assert_eq!(config.max_concurrent_connections, 5);
assert_eq!(config.request_timeout_seconds, 30);
assert_eq!(config.retry_attempts, 3);
assert!(config.providers.is_empty());
assert!(!config.requirements.enforce);
assert!(config.requirements.allowed_stdio_commands.is_empty());
assert!(config.requirements.allowed_http_endpoints.is_empty());
assert!(!config.server.enabled);
assert!(!config.allowlist.enforce);
assert!(config.allowlist.default.tools.is_none());
}
#[test]
fn test_allowlist_pattern_matching() {
let patterns = vec!["get_*".to_string(), "convert_timezone".to_string()];
assert!(pattern_matches(&patterns, "get_current_time"));
assert!(pattern_matches(&patterns, "convert_timezone"));
assert!(!pattern_matches(&patterns, "delete_timezone"));
}
#[test]
fn test_allowlist_provider_override() {
let mut config = McpAllowListConfig {
enforce: true,
default: McpAllowListRules {
tools: Some(vec!["get_*".to_string()]),
..Default::default()
},
..Default::default()
};
let provider_rules = McpAllowListRules {
tools: Some(vec!["list_*".to_string()]),
..Default::default()
};
config
.providers
.insert("context7".to_string(), provider_rules);
assert!(config.is_tool_allowed("context7", "list_documents"));
assert!(!config.is_tool_allowed("context7", "get_current_time"));
assert!(config.is_tool_allowed("other", "get_timezone"));
assert!(!config.is_tool_allowed("other", "list_documents"));
}
#[test]
fn test_allowlist_configuration_rules() {
let mut config = McpAllowListConfig {
enforce: true,
default: McpAllowListRules {
configuration: Some(BTreeMap::from([(
"ui".to_string(),
vec!["mode".to_string(), "max_events".to_string()],
)])),
..Default::default()
},
..Default::default()
};
let provider_rules = McpAllowListRules {
configuration: Some(BTreeMap::from([(
"provider".to_string(),
vec!["max_concurrent_requests".to_string()],
)])),
..Default::default()
};
config.providers.insert("time".to_string(), provider_rules);
assert!(config.is_configuration_allowed(None, "ui", "mode"));
assert!(!config.is_configuration_allowed(None, "ui", "show_provider_names"));
assert!(config.is_configuration_allowed(
Some("time"),
"provider",
"max_concurrent_requests"
));
assert!(!config.is_configuration_allowed(Some("time"), "provider", "retry_attempts"));
}
#[test]
fn test_allowlist_resource_override() {
let mut config = McpAllowListConfig {
enforce: true,
default: McpAllowListRules {
resources: Some(vec!["docs/**/*".to_string()]),
..Default::default()
},
..Default::default()
};
let provider_rules = McpAllowListRules {
resources: Some(vec!["journals/*".to_string()]),
..Default::default()
};
config
.providers
.insert("context7".to_string(), provider_rules);
assert!(config.is_resource_allowed("context7", "journals/2024"));
assert!(config.is_resource_allowed("other", "docs/config/config.md"));
assert!(config.is_resource_allowed("other", "docs/guides/zed-acp.md"));
assert!(!config.is_resource_allowed("other", "journals/2023"));
}
#[test]
fn test_allowlist_logging_override() {
let mut config = McpAllowListConfig {
enforce: true,
default: McpAllowListRules {
logging: Some(vec!["info".to_string(), "debug".to_string()]),
..Default::default()
},
..Default::default()
};
let provider_rules = McpAllowListRules {
logging: Some(vec!["audit".to_string()]),
..Default::default()
};
config
.providers
.insert("sequential".to_string(), provider_rules);
assert!(config.is_logging_channel_allowed(Some("sequential"), "audit"));
assert!(!config.is_logging_channel_allowed(Some("sequential"), "info"));
assert!(config.is_logging_channel_allowed(Some("other"), "info"));
assert!(!config.is_logging_channel_allowed(Some("other"), "trace"));
}
#[test]
fn test_mcp_ui_renderer_resolution() {
let mut config = McpUiConfig::default();
config.renderers.insert(
mcp_constants::RENDERER_CONTEXT7.to_string(),
McpRendererProfile::Context7,
);
config.renderers.insert(
mcp_constants::RENDERER_SEQUENTIAL_THINKING.to_string(),
McpRendererProfile::SequentialThinking,
);
assert_eq!(
config.renderer_for_tool("mcp_context7_lookup"),
Some(McpRendererProfile::Context7)
);
assert_eq!(
config.renderer_for_tool("mcp_context7lookup"),
Some(McpRendererProfile::Context7)
);
assert_eq!(
config.renderer_for_tool("mcp_sequentialthinking_run"),
Some(McpRendererProfile::SequentialThinking)
);
assert_eq!(
config.renderer_for_identifier("sequential-thinking-analyze"),
Some(McpRendererProfile::SequentialThinking)
);
assert_eq!(config.renderer_for_tool("mcp_unknown"), None);
}
}