use serde::{Deserialize, Serialize};
use tracing::{debug, warn};
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct EnhancedMcpSecurityConfig {
#[serde(default = "default_auth_enabled")]
pub auth_enabled: bool,
#[serde(default)]
pub api_key_env: Option<String>,
#[serde(default)]
pub rate_limit: McpRateLimitConfig,
#[serde(default)]
pub validation: McpValidationConfig,
}
impl Default for EnhancedMcpSecurityConfig {
fn default() -> Self {
Self {
auth_enabled: default_auth_enabled(),
api_key_env: None,
rate_limit: McpRateLimitConfig::default(),
validation: McpValidationConfig::default(),
}
}
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpRateLimitConfig {
#[serde(default = "default_requests_per_minute")]
pub requests_per_minute: u32,
#[serde(default = "default_concurrent_requests")]
pub concurrent_requests: u32,
}
impl Default for McpRateLimitConfig {
fn default() -> Self {
Self {
requests_per_minute: default_requests_per_minute(),
concurrent_requests: default_concurrent_requests(),
}
}
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpValidationConfig {
#[serde(default = "default_schema_validation_enabled")]
pub schema_validation_enabled: bool,
#[serde(default = "default_path_traversal_protection_enabled")]
pub path_traversal_protection: bool,
#[serde(default = "default_max_argument_size")]
pub max_argument_size: u32,
}
impl Default for McpValidationConfig {
fn default() -> Self {
Self {
schema_validation_enabled: default_schema_validation_enabled(),
path_traversal_protection: default_path_traversal_protection_enabled(),
max_argument_size: default_max_argument_size(),
}
}
}
fn default_auth_enabled() -> bool {
false
}
fn default_requests_per_minute() -> u32 {
100
}
fn default_concurrent_requests() -> u32 {
10
}
fn default_schema_validation_enabled() -> bool {
true
}
fn default_path_traversal_protection_enabled() -> bool {
true
}
fn default_max_argument_size() -> u32 {
1024 * 1024 }
#[derive(Debug, Clone)]
pub struct ValidatedMcpClientConfig {
pub original: crate::config::mcp::McpClientConfig,
pub security: EnhancedMcpSecurityConfig,
}
impl ValidatedMcpClientConfig {
pub fn new(original: crate::config::mcp::McpClientConfig) -> Self {
let security = EnhancedMcpSecurityConfig::default();
Self { original, security }
}
pub fn validate(&self) -> Vec<ValidationError> {
let mut errors = Vec::new();
if self.original.server.enabled {
if self.original.server.port == 0 {
errors.push(ValidationError::InvalidPort(
self.original.server.port.into(),
));
}
if self.original.server.bind_address.is_empty() {
errors.push(ValidationError::EmptyBindAddress);
}
if self.security.auth_enabled && self.security.api_key_env.is_none() {
errors.push(ValidationError::MissingApiKeyEnv);
}
}
if let Some(startup_timeout) = self.original.startup_timeout_seconds
&& startup_timeout > 300
{
errors.push(ValidationError::InvalidStartupTimeout(startup_timeout));
}
if let Some(tool_timeout) = self.original.tool_timeout_seconds
&& tool_timeout > 3600
{
errors.push(ValidationError::InvalidToolTimeout(tool_timeout));
}
for provider in &self.original.providers {
if provider.name.is_empty() {
errors.push(ValidationError::EmptyProviderName);
}
if provider.max_concurrent_requests == 0 {
errors.push(ValidationError::InvalidMaxConcurrentRequests(
provider.name.clone(),
provider.max_concurrent_requests,
));
}
}
errors
}
pub fn is_valid(&self) -> bool {
self.validate().is_empty()
}
pub fn log_warnings(&self) {
let errors = self.validate();
if !errors.is_empty() {
warn!("MCP configuration validation issues found:");
for error in errors {
warn!(" - {}", error);
}
} else {
debug!("MCP configuration validation passed");
}
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum ValidationError {
#[error("Invalid server port: {0}")]
InvalidPort(u64),
#[error("Server bind address cannot be empty")]
EmptyBindAddress,
#[error("API key environment variable must be set when auth is enabled")]
MissingApiKeyEnv,
#[error("Startup timeout cannot exceed 300 seconds: {0}")]
InvalidStartupTimeout(u64),
#[error("Tool timeout cannot exceed 3600 seconds: {0}")]
InvalidToolTimeout(u64),
#[error("MCP provider name cannot be empty")]
EmptyProviderName,
#[error("Max concurrent requests must be greater than 0 for provider '{0}': {1}")]
InvalidMaxConcurrentRequests(String, usize),
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct EnhancedMcpToolConfig {
pub name: String,
#[serde(default = "default_tool_enabled")]
pub enabled: bool,
pub description: Option<String>,
#[serde(default)]
pub rate_limit: Option<McpRateLimitConfig>,
#[serde(default)]
pub validation: Option<McpValidationConfig>,
}
fn default_tool_enabled() -> bool {
true
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::mcp::{
McpClientConfig, McpProviderConfig, McpServerConfig, McpStdioServerConfig,
McpTransportConfig,
};
use hashbrown::HashMap;
fn create_test_config() -> McpClientConfig {
McpClientConfig {
enabled: true,
ui: Default::default(),
providers: vec![McpProviderConfig {
name: "test_provider".to_owned(),
transport: McpTransportConfig::Stdio(McpStdioServerConfig {
command: "test_command".to_owned(),
args: vec![],
working_directory: None,
}),
env: HashMap::new(),
enabled: true,
max_concurrent_requests: 5,
startup_timeout_ms: None,
}],
server: McpServerConfig {
enabled: true,
bind_address: "127.0.0.1".to_owned(),
port: 3000,
transport: crate::config::mcp::McpServerTransport::Sse,
name: "test_server".to_owned(),
version: "1.0.0".to_owned(),
exposed_tools: vec![],
},
allowlist: Default::default(),
requirements: Default::default(),
max_concurrent_connections: 10,
request_timeout_seconds: 30,
retry_attempts: 3,
startup_timeout_seconds: Some(60),
tool_timeout_seconds: Some(300),
experimental_use_rmcp_client: false,
connection_pooling_enabled: true,
tool_cache_capacity: 128,
connection_timeout_seconds: 30,
security: Default::default(),
}
}
#[test]
fn test_validated_config_creation() {
let original = create_test_config();
let validated = ValidatedMcpClientConfig::new(original);
assert!(validated.is_valid());
}
#[test]
fn test_invalid_port_validation() {
let mut original = create_test_config();
original.server.port = 65535; let validated = ValidatedMcpClientConfig::new(original);
assert!(validated.is_valid());
}
#[test]
fn test_empty_bind_address_validation() {
let mut original = create_test_config();
original.server.bind_address = String::new(); let validated = ValidatedMcpClientConfig::new(original);
assert!(!validated.is_valid());
}
#[test]
fn test_timeout_validation() {
let mut original = create_test_config();
original.startup_timeout_seconds = Some(400); let validated = ValidatedMcpClientConfig::new(original);
assert!(!validated.is_valid());
}
#[test]
fn test_zero_concurrent_requests_validation() {
let mut original = create_test_config();
original.providers[0].max_concurrent_requests = 0; let validated = ValidatedMcpClientConfig::new(original);
assert!(!validated.is_valid());
}
}