use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpClientConfig {
#[serde(default = "default_mcp_enabled")]
pub enabled: bool,
#[serde(default)]
pub ui: McpUiConfig,
#[serde(default)]
pub providers: Vec<McpProviderConfig>,
#[serde(default)]
pub server: McpServerConfig,
#[serde(default)]
pub allowlist: McpAllowListConfig,
#[serde(default = "default_max_concurrent_connections")]
pub max_concurrent_connections: usize,
#[serde(default = "default_request_timeout_seconds")]
pub request_timeout_seconds: u64,
#[serde(default = "default_retry_attempts")]
pub retry_attempts: u32,
}
impl Default for McpClientConfig {
fn default() -> Self {
Self {
enabled: default_mcp_enabled(),
ui: McpUiConfig::default(),
providers: Vec::new(),
server: McpServerConfig::default(),
allowlist: McpAllowListConfig::default(),
max_concurrent_connections: default_max_concurrent_connections(),
request_timeout_seconds: default_request_timeout_seconds(),
retry_attempts: default_retry_attempts(),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpUiConfig {
#[serde(default = "default_mcp_ui_mode")]
pub mode: McpUiMode,
#[serde(default = "default_max_mcp_events")]
pub max_events: usize,
#[serde(default = "default_show_provider_names")]
pub show_provider_names: bool,
}
impl Default for McpUiConfig {
fn default() -> Self {
Self {
mode: default_mcp_ui_mode(),
max_events: default_max_mcp_events(),
show_provider_names: default_show_provider_names(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum McpUiMode {
Compact,
Full,
}
impl std::fmt::Display for McpUiMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
McpUiMode::Compact => write!(f, "compact"),
McpUiMode::Full => write!(f, "full"),
}
}
}
impl Default for McpUiMode {
fn default() -> Self {
McpUiMode::Compact
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpProviderConfig {
pub name: String,
#[serde(flatten)]
pub transport: McpTransportConfig,
#[serde(default)]
pub env: HashMap<String, String>,
#[serde(default = "default_provider_enabled")]
pub enabled: bool,
#[serde(default = "default_provider_max_concurrent")]
pub max_concurrent_requests: usize,
}
impl Default for McpProviderConfig {
fn default() -> Self {
Self {
name: String::new(),
transport: McpTransportConfig::Stdio(McpStdioServerConfig::default()),
env: HashMap::new(),
enabled: default_provider_enabled(),
max_concurrent_requests: default_provider_max_concurrent(),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpAllowListConfig {
#[serde(default = "default_allowlist_enforced")]
pub enforce: bool,
#[serde(default)]
pub default: McpAllowListRules,
#[serde(default)]
pub providers: BTreeMap<String, McpAllowListRules>,
}
impl Default for McpAllowListConfig {
fn default() -> Self {
Self {
enforce: default_allowlist_enforced(),
default: McpAllowListRules::default(),
providers: BTreeMap::new(),
}
}
}
impl McpAllowListConfig {
pub fn is_tool_allowed(&self, provider: &str, tool_name: &str) -> bool {
if !self.enforce {
return true;
}
self.resolve_match(provider, tool_name, |rules| &rules.tools)
}
pub fn is_resource_allowed(&self, provider: &str, resource: &str) -> bool {
if !self.enforce {
return true;
}
self.resolve_match(provider, resource, |rules| &rules.resources)
}
pub fn is_prompt_allowed(&self, provider: &str, prompt: &str) -> bool {
if !self.enforce {
return true;
}
self.resolve_match(provider, prompt, |rules| &rules.prompts)
}
pub fn is_logging_channel_allowed(&self, provider: Option<&str>, channel: &str) -> bool {
if !self.enforce {
return true;
}
if let Some(name) = provider {
if let Some(rules) = self.providers.get(name) {
if let Some(patterns) = &rules.logging {
return pattern_matches(patterns, channel);
}
}
}
if let Some(patterns) = &self.default.logging {
if pattern_matches(patterns, channel) {
return true;
}
}
false
}
pub fn is_configuration_allowed(
&self,
provider: Option<&str>,
category: &str,
key: &str,
) -> bool {
if !self.enforce {
return true;
}
if let Some(name) = provider {
if let Some(rules) = self.providers.get(name) {
if let Some(result) = configuration_allowed(rules, category, key) {
return result;
}
}
}
if let Some(result) = configuration_allowed(&self.default, category, key) {
return result;
}
false
}
fn resolve_match<'a, F>(&'a self, provider: &str, candidate: &str, accessor: F) -> bool
where
F: Fn(&'a McpAllowListRules) -> &'a Option<Vec<String>>,
{
if let Some(rules) = self.providers.get(provider) {
if let Some(patterns) = accessor(rules) {
return pattern_matches(patterns, candidate);
}
}
if let Some(patterns) = accessor(&self.default) {
if pattern_matches(patterns, candidate) {
return true;
}
}
false
}
}
fn configuration_allowed(rules: &McpAllowListRules, category: &str, key: &str) -> Option<bool> {
rules.configuration.as_ref().and_then(|entries| {
entries
.get(category)
.map(|patterns| pattern_matches(patterns, key))
})
}
fn pattern_matches(patterns: &[String], candidate: &str) -> bool {
patterns
.iter()
.any(|pattern| wildcard_match(pattern, candidate))
}
fn wildcard_match(pattern: &str, candidate: &str) -> bool {
if pattern == "*" {
return true;
}
let mut regex_pattern = String::from("^");
let mut literal_buffer = String::new();
for ch in pattern.chars() {
match ch {
'*' => {
if !literal_buffer.is_empty() {
regex_pattern.push_str(®ex::escape(&literal_buffer));
literal_buffer.clear();
}
regex_pattern.push_str(".*");
}
'?' => {
if !literal_buffer.is_empty() {
regex_pattern.push_str(®ex::escape(&literal_buffer));
literal_buffer.clear();
}
regex_pattern.push('.');
}
_ => literal_buffer.push(ch),
}
}
if !literal_buffer.is_empty() {
regex_pattern.push_str(®ex::escape(&literal_buffer));
}
regex_pattern.push('$');
Regex::new(®ex_pattern)
.map(|regex| regex.is_match(candidate))
.unwrap_or(false)
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct McpAllowListRules {
#[serde(default)]
pub tools: Option<Vec<String>>,
#[serde(default)]
pub resources: Option<Vec<String>>,
#[serde(default)]
pub prompts: Option<Vec<String>>,
#[serde(default)]
pub logging: Option<Vec<String>>,
#[serde(default)]
pub configuration: Option<BTreeMap<String, Vec<String>>>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpServerConfig {
#[serde(default = "default_mcp_server_enabled")]
pub enabled: bool,
#[serde(default = "default_mcp_server_bind")]
pub bind_address: String,
#[serde(default = "default_mcp_server_port")]
pub port: u16,
#[serde(default = "default_mcp_server_transport")]
pub transport: McpServerTransport,
#[serde(default = "default_mcp_server_name")]
pub name: String,
#[serde(default = "default_mcp_server_version")]
pub version: String,
#[serde(default)]
pub exposed_tools: Vec<String>,
}
impl Default for McpServerConfig {
fn default() -> Self {
Self {
enabled: default_mcp_server_enabled(),
bind_address: default_mcp_server_bind(),
port: default_mcp_server_port(),
transport: default_mcp_server_transport(),
name: default_mcp_server_name(),
version: default_mcp_server_version(),
exposed_tools: Vec::new(),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum McpServerTransport {
Sse,
Http,
}
impl Default for McpServerTransport {
fn default() -> Self {
McpServerTransport::Sse
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum McpTransportConfig {
Stdio(McpStdioServerConfig),
Http(McpHttpServerConfig),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpStdioServerConfig {
pub command: String,
pub args: Vec<String>,
#[serde(default)]
pub working_directory: Option<String>,
}
impl Default for McpStdioServerConfig {
fn default() -> Self {
Self {
command: String::new(),
args: Vec::new(),
working_directory: None,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpHttpServerConfig {
pub endpoint: String,
#[serde(default)]
pub api_key_env: Option<String>,
#[serde(default = "default_mcp_protocol_version")]
pub protocol_version: String,
#[serde(default)]
pub headers: HashMap<String, String>,
}
impl Default for McpHttpServerConfig {
fn default() -> Self {
Self {
endpoint: String::new(),
api_key_env: None,
protocol_version: default_mcp_protocol_version(),
headers: HashMap::new(),
}
}
}
fn default_mcp_enabled() -> bool {
false
}
fn default_mcp_ui_mode() -> McpUiMode {
McpUiMode::Compact
}
fn default_max_mcp_events() -> usize {
50
}
fn default_show_provider_names() -> bool {
true
}
fn default_max_concurrent_connections() -> usize {
5
}
fn default_request_timeout_seconds() -> u64 {
30
}
fn default_retry_attempts() -> u32 {
3
}
fn default_provider_enabled() -> bool {
true
}
fn default_provider_max_concurrent() -> usize {
3
}
fn default_allowlist_enforced() -> bool {
false
}
fn default_mcp_protocol_version() -> String {
"2024-11-05".to_string()
}
fn default_mcp_server_enabled() -> bool {
false
}
fn default_mcp_server_bind() -> String {
"127.0.0.1".to_string()
}
fn default_mcp_server_port() -> u16 {
3000
}
fn default_mcp_server_transport() -> McpServerTransport {
McpServerTransport::Sse
}
fn default_mcp_server_name() -> String {
"vtcode-mcp-server".to_string()
}
fn default_mcp_server_version() -> String {
env!("CARGO_PKG_VERSION").to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::BTreeMap;
#[test]
fn test_mcp_config_defaults() {
let config = McpClientConfig::default();
assert!(!config.enabled);
assert_eq!(config.ui.mode, McpUiMode::Compact);
assert_eq!(config.ui.max_events, 50);
assert!(config.ui.show_provider_names);
assert_eq!(config.max_concurrent_connections, 5);
assert_eq!(config.request_timeout_seconds, 30);
assert_eq!(config.retry_attempts, 3);
assert!(config.providers.is_empty());
assert!(!config.server.enabled);
assert!(!config.allowlist.enforce);
assert!(config.allowlist.default.tools.is_none());
}
#[test]
fn test_allowlist_pattern_matching() {
let patterns = vec!["get_*".to_string(), "convert_timezone".to_string()];
assert!(pattern_matches(&patterns, "get_current_time"));
assert!(pattern_matches(&patterns, "convert_timezone"));
assert!(!pattern_matches(&patterns, "delete_timezone"));
}
#[test]
fn test_allowlist_provider_override() {
let mut config = McpAllowListConfig::default();
config.enforce = true;
config.default.tools = Some(vec!["get_*".to_string()]);
let mut provider_rules = McpAllowListRules::default();
provider_rules.tools = Some(vec!["list_*".to_string()]);
config
.providers
.insert("context7".to_string(), provider_rules);
assert!(config.is_tool_allowed("context7", "list_documents"));
assert!(!config.is_tool_allowed("context7", "get_current_time"));
assert!(config.is_tool_allowed("other", "get_timezone"));
assert!(!config.is_tool_allowed("other", "list_documents"));
}
#[test]
fn test_allowlist_configuration_rules() {
let mut config = McpAllowListConfig::default();
config.enforce = true;
let mut default_rules = McpAllowListRules::default();
default_rules.configuration = Some(BTreeMap::from([(
"ui".to_string(),
vec!["mode".to_string(), "max_events".to_string()],
)]));
config.default = default_rules;
let mut provider_rules = McpAllowListRules::default();
provider_rules.configuration = Some(BTreeMap::from([(
"provider".to_string(),
vec!["max_concurrent_requests".to_string()],
)]));
config.providers.insert("time".to_string(), provider_rules);
assert!(config.is_configuration_allowed(None, "ui", "mode"));
assert!(!config.is_configuration_allowed(None, "ui", "show_provider_names"));
assert!(config.is_configuration_allowed(
Some("time"),
"provider",
"max_concurrent_requests"
));
assert!(!config.is_configuration_allowed(Some("time"), "provider", "retry_attempts"));
}
#[test]
fn test_allowlist_resource_override() {
let mut config = McpAllowListConfig::default();
config.enforce = true;
config.default.resources = Some(vec!["docs/*".to_string()]);
let mut provider_rules = McpAllowListRules::default();
provider_rules.resources = Some(vec!["journals/*".to_string()]);
config
.providers
.insert("context7".to_string(), provider_rules);
assert!(config.is_resource_allowed("context7", "journals/2024"));
assert!(!config.is_resource_allowed("context7", "docs/manual"));
assert!(config.is_resource_allowed("other", "docs/reference"));
assert!(!config.is_resource_allowed("other", "journals/2023"));
}
#[test]
fn test_allowlist_logging_override() {
let mut config = McpAllowListConfig::default();
config.enforce = true;
config.default.logging = Some(vec!["info".to_string(), "debug".to_string()]);
let mut provider_rules = McpAllowListRules::default();
provider_rules.logging = Some(vec!["audit".to_string()]);
config
.providers
.insert("sequential".to_string(), provider_rules);
assert!(config.is_logging_channel_allowed(Some("sequential"), "audit"));
assert!(!config.is_logging_channel_allowed(Some("sequential"), "info"));
assert!(config.is_logging_channel_allowed(Some("other"), "info"));
assert!(!config.is_logging_channel_allowed(Some("other"), "trace"));
}
}