mcp_guard_core/config/
mod.rs

1// Copyright (c) 2025 Austin Green
2// SPDX-License-Identifier: AGPL-3.0
3//
4// This file is part of MCP-Guard.
5//
6// MCP-Guard is free software: you can redistribute it and/or modify
7// it under the terms of the GNU Affero General Public License as published by
8// the Free Software Foundation, either version 3 of the License, or
9// (at your option) any later version.
10//
11// MCP-Guard is distributed in the hope that it will be useful,
12// but WITHOUT ANY WARRANTY; without even the implied warranty of
13// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14// GNU Affero General Public License for more details.
15//
16// You should have received a copy of the GNU Affero General Public License
17// along with MCP-Guard. If not, see <https://www.gnu.org/licenses/>.
18//! Configuration types and parsing for mcp-guard
19//!
20//! This module provides strongly-typed configuration for all mcp-guard features:
21//! - Server settings (host, port, TLS)
22//! - Authentication (API keys, JWT, OAuth 2.1, mTLS)
23//! - Rate limiting (per-identity token bucket)
24//! - Audit logging (file, stdout, HTTP export)
25//! - Tracing (OpenTelemetry/OTLP)
26//! - Upstream routing (single server or multi-server)
27//!
28//! Configuration can be loaded from TOML or YAML files via [`Config::from_file`].
29
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32use std::path::PathBuf;
33
34// ============================================================================
35// Error Types
36// ============================================================================
37
38/// Configuration error type
39#[derive(Debug, thiserror::Error)]
40pub enum ConfigError {
41    #[error("Failed to read config file: {0}")]
42    Read(#[from] std::io::Error),
43
44    #[error("Failed to parse config: {0}")]
45    Parse(String),
46
47    #[error("Validation error: {0}")]
48    Validation(String),
49}
50
51// ============================================================================
52// Core Configuration
53// ============================================================================
54
55/// Main configuration struct
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct Config {
58    /// Server configuration
59    #[serde(default)]
60    pub server: ServerConfig,
61
62    /// Authentication configuration
63    #[serde(default)]
64    pub auth: AuthConfig,
65
66    /// Rate limiting configuration
67    #[serde(default)]
68    pub rate_limit: RateLimitConfig,
69
70    /// Audit logging configuration
71    #[serde(default)]
72    pub audit: AuditConfig,
73
74    /// OpenTelemetry tracing configuration
75    #[serde(default)]
76    pub tracing: TracingConfig,
77
78    /// Upstream MCP server configuration
79    pub upstream: UpstreamConfig,
80}
81
82/// Server configuration
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ServerConfig {
85    /// Host to bind to
86    #[serde(default = "default_host")]
87    pub host: String,
88
89    /// Port to listen on
90    #[serde(default = "default_port")]
91    pub port: u16,
92
93    /// Maximum request body size in bytes (default: 1MB)
94    /// Requests exceeding this size will receive 413 Payload Too Large
95    #[serde(default = "default_max_request_size")]
96    pub max_request_size: usize,
97
98    /// CORS configuration
99    #[serde(default)]
100    pub cors: CorsConfig,
101
102    /// Enable TLS
103    #[serde(default)]
104    pub tls: Option<TlsConfig>,
105}
106
107impl Default for ServerConfig {
108    fn default() -> Self {
109        Self {
110            host: default_host(),
111            port: default_port(),
112            max_request_size: default_max_request_size(),
113            cors: CorsConfig::default(),
114            tls: None,
115        }
116    }
117}
118
119fn default_host() -> String {
120    "127.0.0.1".to_string()
121}
122
123fn default_port() -> u16 {
124    3000
125}
126
127fn default_max_request_size() -> usize {
128    1024 * 1024 // 1MB default
129}
130
131/// CORS configuration
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct CorsConfig {
134    /// Enable CORS (default: false for API-only use)
135    #[serde(default)]
136    pub enabled: bool,
137
138    /// Allowed origins (default: none - same-origin only when enabled)
139    /// Use ["*"] for permissive mode (not recommended for production)
140    #[serde(default)]
141    pub allowed_origins: Vec<String>,
142
143    /// Allowed methods (default: GET, POST, OPTIONS)
144    #[serde(default = "default_cors_methods")]
145    pub allowed_methods: Vec<String>,
146
147    /// Allowed headers (default: Authorization, Content-Type)
148    #[serde(default = "default_cors_headers")]
149    pub allowed_headers: Vec<String>,
150
151    /// Max age for preflight cache in seconds (default: 3600)
152    #[serde(default = "default_cors_max_age")]
153    pub max_age: u64,
154}
155
156impl Default for CorsConfig {
157    fn default() -> Self {
158        Self {
159            enabled: false,
160            allowed_origins: vec![],
161            allowed_methods: default_cors_methods(),
162            allowed_headers: default_cors_headers(),
163            max_age: default_cors_max_age(),
164        }
165    }
166}
167
168fn default_cors_methods() -> Vec<String> {
169    vec!["GET".into(), "POST".into(), "OPTIONS".into()]
170}
171
172fn default_cors_headers() -> Vec<String> {
173    vec!["Authorization".into(), "Content-Type".into()]
174}
175
176fn default_cors_max_age() -> u64 {
177    3600 // 1 hour
178}
179
180/// TLS configuration
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct TlsConfig {
183    /// Path to server certificate (PEM format)
184    pub cert_path: PathBuf,
185    /// Path to server private key (PEM format)
186    pub key_path: PathBuf,
187    /// Path to CA certificate for client certificate validation (mTLS)
188    /// If set, client certificates will be required and validated against this CA
189    pub client_ca_path: Option<PathBuf>,
190}
191
192/// mTLS authentication configuration
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct MtlsConfig {
195    /// Whether to enable mTLS authentication
196    #[serde(default)]
197    pub enabled: bool,
198    /// Claim to extract user ID from (CN or SAN)
199    /// Default: "cn" (Common Name)
200    #[serde(default = "default_mtls_identity_source")]
201    pub identity_source: MtlsIdentitySource,
202    /// Allowed tools for mTLS-authenticated identities (empty means all)
203    #[serde(default)]
204    pub allowed_tools: Vec<String>,
205    /// Custom rate limit for mTLS-authenticated identities
206    #[serde(default)]
207    pub rate_limit: Option<u32>,
208    /// Trusted proxy IP addresses/CIDR ranges that are allowed to set mTLS headers
209    /// SECURITY: If empty, mTLS header authentication is DISABLED to prevent header spoofing
210    /// You MUST configure this when using mTLS with a reverse proxy.
211    /// Example: ["10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "127.0.0.1"]
212    #[serde(default)]
213    pub trusted_proxy_ips: Vec<String>,
214}
215
216impl Default for MtlsConfig {
217    fn default() -> Self {
218        Self {
219            enabled: false,
220            identity_source: default_mtls_identity_source(),
221            allowed_tools: vec![],
222            rate_limit: None,
223            trusted_proxy_ips: vec![],
224        }
225    }
226}
227
228/// Source for extracting identity from client certificate
229#[derive(Debug, Clone, Serialize, Deserialize)]
230#[serde(rename_all = "lowercase")]
231pub enum MtlsIdentitySource {
232    /// Extract from Common Name (CN)
233    Cn,
234    /// Extract from Subject Alternative Name (SAN) - DNS name
235    SanDns,
236    /// Extract from Subject Alternative Name (SAN) - Email
237    SanEmail,
238}
239
240fn default_mtls_identity_source() -> MtlsIdentitySource {
241    MtlsIdentitySource::Cn
242}
243
244// ============================================================================
245// Authentication Configuration
246// ============================================================================
247
248/// Authentication configuration
249#[derive(Debug, Clone, Default, Serialize, Deserialize)]
250pub struct AuthConfig {
251    /// API key authentication
252    #[serde(default)]
253    pub api_keys: Vec<ApiKeyConfig>,
254
255    /// JWT authentication
256    #[serde(default)]
257    pub jwt: Option<JwtConfig>,
258
259    /// OAuth 2.1 configuration
260    #[serde(default)]
261    pub oauth: Option<OAuthConfig>,
262
263    /// mTLS client certificate authentication
264    #[serde(default)]
265    pub mtls: Option<MtlsConfig>,
266}
267
268/// API key configuration
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct ApiKeyConfig {
271    /// User/service identifier
272    pub id: String,
273
274    /// The hashed API key
275    pub key_hash: String,
276
277    /// Allowed tools (empty means all)
278    #[serde(default)]
279    pub allowed_tools: Vec<String>,
280
281    /// Custom rate limit (overrides global)
282    #[serde(default)]
283    pub rate_limit: Option<u32>,
284}
285
286/// JWT authentication mode
287#[derive(Debug, Clone, Serialize, Deserialize)]
288#[serde(tag = "mode", rename_all = "lowercase")]
289pub enum JwtMode {
290    /// Simple mode: HS256 with local secret
291    Simple {
292        /// Shared secret for HS256 signing (min 32 characters recommended)
293        secret: String,
294    },
295    /// JWKS mode: RS256/ES256 with remote JWKS endpoint
296    Jwks {
297        /// JWKS endpoint URL
298        jwks_url: String,
299        /// Allowed algorithms (default: ["RS256", "ES256"])
300        #[serde(default = "default_jwks_algorithms")]
301        algorithms: Vec<String>,
302        /// JWKS cache duration in seconds (default: 3600 = 1 hour)
303        #[serde(default = "default_cache_duration")]
304        cache_duration_secs: u64,
305    },
306}
307
308/// JWT configuration supporting both simple and JWKS modes
309#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct JwtConfig {
311    /// JWT validation mode (simple or jwks)
312    #[serde(flatten)]
313    pub mode: JwtMode,
314
315    /// Expected issuer (iss claim) - required for validation
316    pub issuer: String,
317
318    /// Expected audience (aud claim) - required for validation
319    pub audience: String,
320
321    /// Claim to extract user ID from (default: "sub")
322    #[serde(default = "default_user_id_claim")]
323    pub user_id_claim: String,
324
325    /// Claim to extract scopes from (default: "scope")
326    #[serde(default = "default_scopes_claim")]
327    pub scopes_claim: String,
328
329    /// Mapping from scopes to allowed tools
330    /// e.g., {"read:files": ["read_file", "list_files"], "admin": ["*"]}
331    #[serde(default)]
332    pub scope_tool_mapping: HashMap<String, Vec<String>>,
333
334    /// Leeway in seconds for exp/nbf validation (default: 0)
335    #[serde(default)]
336    pub leeway_secs: u64,
337}
338
339fn default_jwks_algorithms() -> Vec<String> {
340    vec!["RS256".to_string(), "ES256".to_string()]
341}
342
343fn default_cache_duration() -> u64 {
344    3600 // 1 hour
345}
346
347fn default_user_id_claim() -> String {
348    "sub".to_string()
349}
350
351fn default_scopes_claim() -> String {
352    "scope".to_string()
353}
354
355/// OAuth 2.1 provider type
356#[derive(Debug, Clone, Serialize, Deserialize)]
357#[serde(rename_all = "lowercase")]
358pub enum OAuthProvider {
359    /// GitHub OAuth
360    GitHub,
361    /// Google OAuth
362    Google,
363    /// Okta OAuth
364    Okta,
365    /// Custom OAuth provider
366    Custom,
367}
368
369/// OAuth 2.1 configuration
370#[derive(Debug, Clone, Serialize, Deserialize)]
371pub struct OAuthConfig {
372    /// OAuth provider type
373    pub provider: OAuthProvider,
374
375    /// Client ID
376    pub client_id: String,
377
378    /// Client secret (for confidential clients)
379    pub client_secret: Option<String>,
380
381    /// Authorization endpoint URL (required for custom, auto-derived for known providers)
382    pub authorization_url: Option<String>,
383
384    /// Token endpoint URL (required for custom, auto-derived for known providers)
385    pub token_url: Option<String>,
386
387    /// Token introspection endpoint URL (for validating opaque tokens)
388    pub introspection_url: Option<String>,
389
390    /// User info endpoint URL (fallback if no introspection)
391    pub userinfo_url: Option<String>,
392
393    /// Redirect URI for authorization code flow
394    #[serde(default = "default_redirect_uri")]
395    pub redirect_uri: String,
396
397    /// OAuth scopes to request
398    #[serde(default = "default_oauth_scopes")]
399    pub scopes: Vec<String>,
400
401    /// Claim to extract user ID from (default: "sub")
402    #[serde(default = "default_user_id_claim")]
403    pub user_id_claim: String,
404
405    /// Mapping from scopes to allowed tools (same as JWT)
406    #[serde(default)]
407    pub scope_tool_mapping: HashMap<String, Vec<String>>,
408
409    /// Token cache TTL in seconds (default: 300 = 5 minutes)
410    ///
411    /// SECURITY NOTE: Revoked tokens remain valid in the cache until they expire.
412    /// Lower values provide faster revocation detection but increase OAuth provider load.
413    /// Set to 0 to disable caching (not recommended for production).
414    #[serde(default = "default_token_cache_ttl")]
415    pub token_cache_ttl_secs: u64,
416}
417
418fn default_token_cache_ttl() -> u64 {
419    300 // 5 minutes
420}
421
422fn default_redirect_uri() -> String {
423    "http://localhost:3000/oauth/callback".to_string()
424}
425
426fn default_oauth_scopes() -> Vec<String> {
427    vec!["openid".to_string(), "profile".to_string()]
428}
429
430// ============================================================================
431// Rate Limiting Configuration
432// ============================================================================
433
434/// Rate limiting configuration
435#[derive(Debug, Clone, Serialize, Deserialize)]
436pub struct RateLimitConfig {
437    /// Enable rate limiting
438    #[serde(default = "default_true")]
439    pub enabled: bool,
440
441    /// Requests per second
442    #[serde(default = "default_rps")]
443    pub requests_per_second: u32,
444
445    /// Burst size
446    #[serde(default = "default_burst")]
447    pub burst_size: u32,
448
449    /// Per-tool rate limits (optional)
450    /// Apply stricter limits to specific tools matched by glob patterns
451    #[serde(default)]
452    pub tool_limits: Vec<ToolRateLimitConfig>,
453}
454
455/// Per-tool rate limit configuration
456///
457/// Allows applying stricter rate limits to expensive or dangerous operations
458/// using glob patterns to match tool names.
459#[derive(Debug, Clone, Serialize, Deserialize)]
460pub struct ToolRateLimitConfig {
461    /// Glob pattern to match tool names (e.g., "execute_*", "write_*", "delete_*")
462    pub tool_pattern: String,
463
464    /// Maximum requests per second for matched tools
465    pub requests_per_second: u32,
466
467    /// Burst size for matched tools
468    #[serde(default = "default_tool_burst")]
469    pub burst_size: u32,
470}
471
472fn default_tool_burst() -> u32 {
473    5 // Conservative burst for per-tool limits
474}
475
476impl Default for RateLimitConfig {
477    fn default() -> Self {
478        Self {
479            enabled: true,
480            requests_per_second: default_rps(),
481            burst_size: default_burst(),
482            tool_limits: Vec::new(),
483        }
484    }
485}
486
487fn default_true() -> bool {
488    true
489}
490
491fn default_rps() -> u32 {
492    25 // Conservative default - 25 RPS per identity
493}
494
495fn default_burst() -> u32 {
496    10 // Conservative default burst size
497}
498
499// ============================================================================
500// Audit Configuration
501// ============================================================================
502
503/// Audit logging configuration
504#[derive(Debug, Clone, Serialize, Deserialize)]
505pub struct AuditConfig {
506    /// Enable audit logging
507    #[serde(default = "default_true")]
508    pub enabled: bool,
509
510    /// Log file path
511    #[serde(default)]
512    pub file: Option<PathBuf>,
513
514    /// Log to stdout
515    #[serde(default)]
516    pub stdout: bool,
517
518    /// HTTP export URL for SIEM integration (e.g., `<https://siem.example.com/logs>`)
519    /// If set, audit logs will be batched and sent to this endpoint
520    #[serde(default)]
521    pub export_url: Option<String>,
522
523    /// Number of logs to batch before sending (default: 100)
524    #[serde(default = "default_export_batch_size")]
525    pub export_batch_size: usize,
526
527    /// Interval in seconds to flush logs even if batch is not full (default: 30)
528    #[serde(default = "default_export_interval_secs")]
529    pub export_interval_secs: u64,
530
531    /// Additional headers to include in export requests (e.g., for authentication)
532    #[serde(default)]
533    pub export_headers: HashMap<String, String>,
534
535    /// Secret redaction rules to prevent sensitive data from being logged
536    /// Each rule defines a regex pattern and replacement text
537    #[serde(default)]
538    pub redaction_rules: Vec<RedactionRule>,
539
540    /// Log rotation configuration
541    #[serde(default)]
542    pub rotation: Option<LogRotationConfig>,
543}
544
545/// Secret redaction rule for audit logs
546///
547/// Matches sensitive data using regex patterns and replaces with safe text.
548/// Patterns are applied in order, so more specific patterns should come first.
549#[derive(Debug, Clone, Serialize, Deserialize)]
550pub struct RedactionRule {
551    /// Rule name for logging/debugging (e.g., "bearer_tokens", "api_keys")
552    pub name: String,
553
554    /// Regex pattern to match sensitive data
555    /// Uses Rust regex syntax: <https://docs.rs/regex>
556    pub pattern: String,
557
558    /// Replacement text (default: `"[REDACTED]"`)
559    #[serde(default = "default_redaction_replacement")]
560    pub replacement: String,
561}
562
563fn default_redaction_replacement() -> String {
564    "[REDACTED]".to_string()
565}
566
567/// Log rotation configuration
568///
569/// Prevents audit log files from growing indefinitely by rotating
570/// based on size and/or age.
571#[derive(Debug, Clone, Serialize, Deserialize)]
572pub struct LogRotationConfig {
573    /// Enable log rotation
574    #[serde(default)]
575    pub enabled: bool,
576
577    /// Maximum file size in bytes before rotation (e.g., 104857600 = 100MB)
578    #[serde(default)]
579    pub max_size_bytes: Option<u64>,
580
581    /// Maximum age in seconds before rotation (e.g., 86400 = 1 day)
582    #[serde(default)]
583    pub max_age_secs: Option<u64>,
584
585    /// Number of backup files to keep (default: 10)
586    #[serde(default = "default_max_backups")]
587    pub max_backups: usize,
588
589    /// Compress rotated files with gzip
590    #[serde(default)]
591    pub compress: bool,
592}
593
594fn default_max_backups() -> usize {
595    10
596}
597
598fn default_export_batch_size() -> usize {
599    100
600}
601
602fn default_export_interval_secs() -> u64 {
603    30
604}
605
606impl Default for AuditConfig {
607    fn default() -> Self {
608        Self {
609            enabled: true,
610            file: None,
611            // SECURITY: Default to false to prevent accidental PII exposure in logs.
612            // Users should explicitly configure their log destination.
613            stdout: false,
614            export_url: None,
615            export_batch_size: default_export_batch_size(),
616            export_interval_secs: default_export_interval_secs(),
617            export_headers: HashMap::new(),
618            redaction_rules: Vec::new(),
619            rotation: None,
620        }
621    }
622}
623
624// ============================================================================
625// Tracing Configuration
626// ============================================================================
627
628/// OpenTelemetry tracing configuration
629#[derive(Debug, Clone, Serialize, Deserialize)]
630pub struct TracingConfig {
631    /// Enable OpenTelemetry distributed tracing
632    #[serde(default)]
633    pub enabled: bool,
634
635    /// Service name for traces (default: "mcp-guard")
636    #[serde(default = "default_service_name")]
637    pub service_name: String,
638
639    /// OTLP exporter endpoint (e.g., "http://localhost:4317" for gRPC)
640    /// If not set, traces are only logged locally
641    pub otlp_endpoint: Option<String>,
642
643    /// Sample rate (0.0 to 1.0, default: 1.0 = sample all)
644    #[serde(default = "default_sample_rate")]
645    pub sample_rate: f64,
646
647    /// Propagate W3C trace context headers (traceparent, tracestate)
648    #[serde(default = "default_true")]
649    pub propagate_context: bool,
650}
651
652impl Default for TracingConfig {
653    fn default() -> Self {
654        Self {
655            enabled: false,
656            service_name: default_service_name(),
657            otlp_endpoint: None,
658            sample_rate: default_sample_rate(),
659            propagate_context: true,
660        }
661    }
662}
663
664fn default_service_name() -> String {
665    "mcp-guard".to_string()
666}
667
668fn default_sample_rate() -> f64 {
669    // SECURITY: Default to 10% sampling to avoid performance impact and cost
670    // in production. Users can increase to 1.0 for development/debugging.
671    0.1
672}
673
674// ============================================================================
675// Upstream Configuration
676// ============================================================================
677
678/// Upstream MCP server configuration
679#[derive(Debug, Clone, Serialize, Deserialize)]
680pub struct UpstreamConfig {
681    /// Transport type (used for single-server mode)
682    pub transport: TransportType,
683
684    /// Command to run (for stdio transport)
685    pub command: Option<String>,
686
687    /// Arguments for the command
688    #[serde(default)]
689    pub args: Vec<String>,
690
691    /// URL for HTTP transport
692    pub url: Option<String>,
693
694    /// Multiple server routes (if configured, path-based routing is enabled)
695    /// Requests are routed based on path prefix matching
696    #[serde(default)]
697    pub servers: Vec<ServerRouteConfig>,
698}
699
700/// Server route configuration for multi-server routing
701#[derive(Debug, Clone, Serialize, Deserialize)]
702pub struct ServerRouteConfig {
703    /// Unique name for this server
704    pub name: String,
705
706    /// Path prefix to match (e.g., "/github", "/filesystem")
707    /// Requests with this prefix are routed to this server
708    pub path_prefix: String,
709
710    /// Transport type for this server
711    pub transport: TransportType,
712
713    /// Command to run (for stdio transport)
714    pub command: Option<String>,
715
716    /// Arguments for the command
717    #[serde(default)]
718    pub args: Vec<String>,
719
720    /// URL for HTTP/SSE transport
721    pub url: Option<String>,
722
723    /// Strip the path prefix when forwarding requests
724    /// If true, "/github/repos" becomes "/repos" when sent to the server
725    #[serde(default)]
726    pub strip_prefix: bool,
727}
728
729/// Transport type for upstream connection
730#[derive(Debug, Clone, Serialize, Deserialize)]
731#[serde(rename_all = "lowercase")]
732pub enum TransportType {
733    Stdio,
734    Http,
735    Sse,
736}
737
738// ============================================================================
739// Implementation
740// ============================================================================
741
742impl Config {
743    /// Load configuration from a file
744    pub fn from_file(path: &PathBuf) -> Result<Self, ConfigError> {
745        let content = std::fs::read_to_string(path)?;
746
747        let config: Config = if path
748            .extension()
749            .map(|e| e == "yaml" || e == "yml")
750            .unwrap_or(false)
751        {
752            serde_yaml::from_str(&content).map_err(|e| ConfigError::Parse(e.to_string()))?
753        } else {
754            toml::from_str(&content).map_err(|e| ConfigError::Parse(e.to_string()))?
755        };
756
757        config.validate()?;
758        Ok(config)
759    }
760
761    /// Validate the configuration
762    pub fn validate(&self) -> Result<(), ConfigError> {
763        // First validate tier/license requirements
764        crate::tier::validate_tier(self)?;
765
766        // Then validate individual sections
767        self.validate_server()?;
768        self.validate_rate_limit()?;
769        self.validate_jwt()?;
770        self.validate_oauth()?;
771        self.validate_audit()?;
772        self.validate_mtls()?;
773        self.validate_tracing()?;
774        self.validate_upstream()
775    }
776
777    // ========================================================================
778    // Validation Helpers
779    // ========================================================================
780
781    /// Validate server configuration.
782    fn validate_server(&self) -> Result<(), ConfigError> {
783        if self.server.port == 0 {
784            return Err(ConfigError::Validation(
785                "server.port must be between 1 and 65535".to_string(),
786            ));
787        }
788        Ok(())
789    }
790
791    /// Validate rate limit configuration.
792    fn validate_rate_limit(&self) -> Result<(), ConfigError> {
793        if self.rate_limit.enabled {
794            if self.rate_limit.requests_per_second == 0 {
795                return Err(ConfigError::Validation(
796                    "rate_limit.requests_per_second must be greater than 0".to_string(),
797                ));
798            }
799            if self.rate_limit.burst_size == 0 {
800                return Err(ConfigError::Validation(
801                    "rate_limit.burst_size must be greater than 0".to_string(),
802                ));
803            }
804        }
805        Ok(())
806    }
807
808    /// Validate JWT configuration.
809    fn validate_jwt(&self) -> Result<(), ConfigError> {
810        if let Some(ref jwt_config) = self.auth.jwt {
811            if let JwtMode::Jwks { ref jwks_url, .. } = jwt_config.mode {
812                // JWKS URL must use HTTPS in production (allow HTTP in debug builds for local testing)
813                #[cfg(not(debug_assertions))]
814                if !jwks_url.starts_with("https://") {
815                    return Err(ConfigError::Validation(
816                        "jwt.jwks_url must use HTTPS in production".to_string(),
817                    ));
818                }
819                // Validate URL format
820                if !jwks_url.starts_with("http://") && !jwks_url.starts_with("https://") {
821                    return Err(ConfigError::Validation(
822                        "jwt.jwks_url must be a valid HTTP(S) URL".to_string(),
823                    ));
824                }
825            }
826        }
827        Ok(())
828    }
829
830    /// Validate OAuth configuration.
831    fn validate_oauth(&self) -> Result<(), ConfigError> {
832        if let Some(ref oauth_config) = self.auth.oauth {
833            // Validate redirect_uri is a valid URL
834            if !oauth_config.redirect_uri.starts_with("http://")
835                && !oauth_config.redirect_uri.starts_with("https://")
836            {
837                return Err(ConfigError::Validation(
838                    "oauth.redirect_uri must be a valid HTTP(S) URL".to_string(),
839                ));
840            }
841            // SECURITY: Warn about HTTP redirect_uri in production (allow in debug for local testing)
842            #[cfg(not(debug_assertions))]
843            if oauth_config.redirect_uri.starts_with("http://") {
844                tracing::warn!(
845                    "SECURITY WARNING: oauth.redirect_uri uses HTTP instead of HTTPS. \
846                     This is insecure in production and may allow authorization code interception."
847                );
848            }
849        }
850        Ok(())
851    }
852
853    /// Validate audit configuration.
854    fn validate_audit(&self) -> Result<(), ConfigError> {
855        if let Some(ref export_url) = self.audit.export_url {
856            // Validate URL format
857            if !export_url.starts_with("http://") && !export_url.starts_with("https://") {
858                return Err(ConfigError::Validation(
859                    "audit.export_url must be a valid HTTP(S) URL".to_string(),
860                ));
861            }
862            // Validate batch size
863            if self.audit.export_batch_size == 0 {
864                return Err(ConfigError::Validation(
865                    "audit.export_batch_size must be greater than 0".to_string(),
866                ));
867            }
868            if self.audit.export_batch_size > 10000 {
869                return Err(ConfigError::Validation(
870                    "audit.export_batch_size must be less than or equal to 10000".to_string(),
871                ));
872            }
873            // Validate flush interval
874            if self.audit.export_interval_secs == 0 {
875                return Err(ConfigError::Validation(
876                    "audit.export_interval_secs must be greater than 0".to_string(),
877                ));
878            }
879        }
880        Ok(())
881    }
882
883    /// Validate mTLS configuration.
884    fn validate_mtls(&self) -> Result<(), ConfigError> {
885        if let Some(ref mtls_config) = self.auth.mtls {
886            if mtls_config.enabled && mtls_config.trusted_proxy_ips.is_empty() {
887                // SECURITY: mTLS without trusted proxy IPs allows header spoofing
888                return Err(ConfigError::Validation(
889                    "auth.mtls.trusted_proxy_ips must be configured when mTLS is enabled. \
890                     Without trusted proxy IPs, attackers could spoof client certificate headers."
891                        .to_string(),
892                ));
893            }
894        }
895        Ok(())
896    }
897
898    /// Validate tracing configuration.
899    fn validate_tracing(&self) -> Result<(), ConfigError> {
900        if self.tracing.enabled
901            && (self.tracing.sample_rate < 0.0 || self.tracing.sample_rate > 1.0)
902        {
903            return Err(ConfigError::Validation(
904                "tracing.sample_rate must be between 0.0 and 1.0".to_string(),
905            ));
906        }
907        Ok(())
908    }
909
910    /// Validate upstream configuration.
911    fn validate_upstream(&self) -> Result<(), ConfigError> {
912        // If multi-server routing is configured, validate each server
913        if !self.upstream.servers.is_empty() {
914            for server in &self.upstream.servers {
915                server.validate()?;
916            }
917            return Ok(());
918        }
919
920        // Single-server mode validation
921        match self.upstream.transport {
922            TransportType::Stdio => {
923                if self.upstream.command.is_none() {
924                    return Err(ConfigError::Validation(
925                        "stdio transport requires 'command' to be set".to_string(),
926                    ));
927                }
928            }
929            TransportType::Http | TransportType::Sse => {
930                if self.upstream.url.is_none() {
931                    return Err(ConfigError::Validation(
932                        "http/sse transport requires 'url' to be set".to_string(),
933                    ));
934                }
935            }
936        }
937
938        Ok(())
939    }
940
941    /// Check if multi-server routing is enabled
942    pub fn is_multi_server(&self) -> bool {
943        !self.upstream.servers.is_empty()
944    }
945
946    /// Check if the configuration uses any Pro tier features
947    ///
948    /// Returns true if ANY of these are configured:
949    /// - OAuth 2.1 authentication
950    /// - JWT JWKS mode
951    /// - HTTP or SSE transport
952    pub fn requires_pro_features(&self) -> bool {
953        // OAuth 2.1
954        if self.auth.oauth.is_some() {
955            return true;
956        }
957
958        // JWT JWKS mode
959        if let Some(ref jwt_config) = self.auth.jwt {
960            if matches!(jwt_config.mode, JwtMode::Jwks { .. }) {
961                return true;
962            }
963        }
964
965        // HTTP or SSE transport (single-server mode)
966        if self.upstream.servers.is_empty() {
967            match self.upstream.transport {
968                TransportType::Http | TransportType::Sse => return true,
969                TransportType::Stdio => {}
970            }
971        } else {
972            // Multi-server mode - check if any server uses HTTP/SSE
973            for server in &self.upstream.servers {
974                match server.transport {
975                    TransportType::Http | TransportType::Sse => return true,
976                    TransportType::Stdio => {}
977                }
978            }
979        }
980
981        false
982    }
983
984    /// Check if the configuration uses any Enterprise tier features
985    ///
986    /// Returns true if ANY of these are configured:
987    /// - mTLS client certificate authentication
988    /// - Multi-server routing
989    /// - SIEM audit log shipping
990    /// - OpenTelemetry tracing with OTLP export
991    /// - Per-tool rate limiting
992    pub fn requires_enterprise_features(&self) -> bool {
993        // mTLS authentication
994        if let Some(ref mtls_config) = self.auth.mtls {
995            if mtls_config.enabled {
996                return true;
997            }
998        }
999
1000        // Multi-server routing
1001        if !self.upstream.servers.is_empty() {
1002            return true;
1003        }
1004
1005        // SIEM audit log shipping
1006        if self.audit.export_url.is_some() {
1007            return true;
1008        }
1009
1010        // OpenTelemetry with OTLP export
1011        if self.tracing.enabled && self.tracing.otlp_endpoint.is_some() {
1012            return true;
1013        }
1014
1015        // Per-tool rate limiting
1016        if !self.rate_limit.tool_limits.is_empty() {
1017            return true;
1018        }
1019
1020        false
1021    }
1022}
1023
1024impl ServerRouteConfig {
1025    /// Validate the server route configuration
1026    pub fn validate(&self) -> Result<(), ConfigError> {
1027        if self.name.is_empty() {
1028            return Err(ConfigError::Validation(
1029                "Server route 'name' cannot be empty".to_string(),
1030            ));
1031        }
1032
1033        if self.path_prefix.is_empty() {
1034            return Err(ConfigError::Validation(format!(
1035                "Server route '{}' path_prefix cannot be empty",
1036                self.name
1037            )));
1038        }
1039
1040        if !self.path_prefix.starts_with('/') {
1041            return Err(ConfigError::Validation(format!(
1042                "Server route '{}' path_prefix must start with '/'",
1043                self.name
1044            )));
1045        }
1046
1047        match self.transport {
1048            TransportType::Stdio => {
1049                if self.command.is_none() {
1050                    return Err(ConfigError::Validation(format!(
1051                        "Server route '{}' with stdio transport requires 'command' to be set",
1052                        self.name
1053                    )));
1054                }
1055            }
1056            TransportType::Http | TransportType::Sse => {
1057                if self.url.is_none() {
1058                    return Err(ConfigError::Validation(format!(
1059                        "Server route '{}' with http/sse transport requires 'url' to be set",
1060                        self.name
1061                    )));
1062                }
1063            }
1064        }
1065
1066        Ok(())
1067    }
1068}
1069
1070// ============================================================================
1071// Tests
1072// ============================================================================
1073
1074#[cfg(test)]
1075mod tests {
1076    use super::*;
1077
1078    fn create_valid_config() -> Config {
1079        // Use stdio transport which is available in free tier
1080        Config {
1081            server: ServerConfig::default(),
1082            auth: AuthConfig::default(),
1083            rate_limit: RateLimitConfig::default(),
1084            audit: AuditConfig::default(),
1085            tracing: TracingConfig::default(),
1086            upstream: UpstreamConfig {
1087                transport: TransportType::Stdio,
1088                command: Some("/bin/echo".to_string()),
1089                args: vec![],
1090                url: None,
1091                servers: vec![],
1092            },
1093        }
1094    }
1095
1096    // Create a config with HTTP transport for Pro tier tests
1097    #[cfg(feature = "pro")]
1098    fn create_valid_config_http() -> Config {
1099        Config {
1100            server: ServerConfig::default(),
1101            auth: AuthConfig::default(),
1102            rate_limit: RateLimitConfig::default(),
1103            audit: AuditConfig::default(),
1104            tracing: TracingConfig::default(),
1105            upstream: UpstreamConfig {
1106                transport: TransportType::Http,
1107                command: None,
1108                args: vec![],
1109                url: Some("http://localhost:8080".to_string()),
1110                servers: vec![],
1111            },
1112        }
1113    }
1114
1115    // ------------------------------------------------------------------------
1116    // Default Tests
1117    // ------------------------------------------------------------------------
1118
1119    #[test]
1120    fn test_server_config_defaults() {
1121        let config = ServerConfig::default();
1122        assert_eq!(config.host, "127.0.0.1");
1123        assert_eq!(config.port, 3000);
1124        assert!(config.tls.is_none());
1125    }
1126
1127    #[test]
1128    fn test_rate_limit_config_defaults() {
1129        let config = RateLimitConfig::default();
1130        assert!(config.enabled);
1131        // SECURITY: Conservative defaults (25 RPS, burst 10) to limit abuse
1132        assert_eq!(config.requests_per_second, 25);
1133        assert_eq!(config.burst_size, 10);
1134    }
1135
1136    #[test]
1137    fn test_audit_config_defaults() {
1138        let config = AuditConfig::default();
1139        assert!(config.enabled);
1140        assert!(config.file.is_none());
1141        // SECURITY: stdout defaults to false to prevent accidental PII exposure
1142        assert!(!config.stdout);
1143        assert!(config.export_url.is_none());
1144        assert_eq!(config.export_batch_size, 100);
1145        assert_eq!(config.export_interval_secs, 30);
1146    }
1147
1148    #[test]
1149    fn test_tracing_config_defaults() {
1150        let config = TracingConfig::default();
1151        assert!(!config.enabled);
1152        assert_eq!(config.service_name, "mcp-guard");
1153        assert!(config.otlp_endpoint.is_none());
1154        // SECURITY: sample_rate defaults to 0.1 (10%) for production safety
1155        assert_eq!(config.sample_rate, 0.1);
1156        assert!(config.propagate_context);
1157    }
1158
1159    #[test]
1160    fn test_mtls_config_defaults() {
1161        let config = MtlsConfig::default();
1162        assert!(!config.enabled);
1163        assert!(matches!(config.identity_source, MtlsIdentitySource::Cn));
1164        assert!(config.allowed_tools.is_empty());
1165        assert!(config.rate_limit.is_none());
1166    }
1167
1168    // ------------------------------------------------------------------------
1169    // Validation Tests
1170    // ------------------------------------------------------------------------
1171
1172    #[test]
1173    fn test_config_validation_success() {
1174        let config = create_valid_config();
1175        assert!(config.validate().is_ok());
1176    }
1177
1178    #[test]
1179    fn test_config_validation_invalid_port() {
1180        let mut config = create_valid_config();
1181        config.server.port = 0;
1182        assert!(config.validate().is_err());
1183    }
1184
1185    #[test]
1186    fn test_config_validation_rate_limit_zero_rps() {
1187        let mut config = create_valid_config();
1188        config.rate_limit.enabled = true;
1189        config.rate_limit.requests_per_second = 0;
1190        assert!(config.validate().is_err());
1191    }
1192
1193    #[test]
1194    fn test_config_validation_rate_limit_zero_burst() {
1195        let mut config = create_valid_config();
1196        config.rate_limit.enabled = true;
1197        config.rate_limit.burst_size = 0;
1198        assert!(config.validate().is_err());
1199    }
1200
1201    #[test]
1202    fn test_config_validation_stdio_missing_command() {
1203        let mut config = create_valid_config();
1204        config.upstream.transport = TransportType::Stdio;
1205        config.upstream.command = None;
1206        config.upstream.url = None;
1207        assert!(config.validate().is_err());
1208    }
1209
1210    #[test]
1211    fn test_config_validation_http_missing_url() {
1212        let mut config = create_valid_config();
1213        config.upstream.transport = TransportType::Http;
1214        config.upstream.url = None;
1215        assert!(config.validate().is_err());
1216    }
1217
1218    #[test]
1219    fn test_config_validation_sse_missing_url() {
1220        let mut config = create_valid_config();
1221        config.upstream.transport = TransportType::Sse;
1222        config.upstream.url = None;
1223        assert!(config.validate().is_err());
1224    }
1225
1226    #[test]
1227    fn test_config_validation_jwt_invalid_jwks_url() {
1228        let mut config = create_valid_config();
1229        config.auth.jwt = Some(JwtConfig {
1230            mode: JwtMode::Jwks {
1231                jwks_url: "invalid-url".to_string(),
1232                algorithms: default_jwks_algorithms(),
1233                cache_duration_secs: 3600,
1234            },
1235            issuer: "https://issuer.example.com".to_string(),
1236            audience: "mcp-guard".to_string(),
1237            user_id_claim: "sub".to_string(),
1238            scopes_claim: "scope".to_string(),
1239            scope_tool_mapping: HashMap::new(),
1240            leeway_secs: 0,
1241        });
1242        assert!(config.validate().is_err());
1243    }
1244
1245    #[test]
1246    fn test_config_validation_oauth_invalid_redirect_uri() {
1247        let mut config = create_valid_config();
1248        config.auth.oauth = Some(OAuthConfig {
1249            provider: OAuthProvider::GitHub,
1250            client_id: "test".to_string(),
1251            client_secret: None,
1252            authorization_url: None,
1253            token_url: None,
1254            introspection_url: None,
1255            userinfo_url: None,
1256            redirect_uri: "invalid-uri".to_string(),
1257            scopes: vec![],
1258            user_id_claim: "sub".to_string(),
1259            scope_tool_mapping: HashMap::new(),
1260            token_cache_ttl_secs: 300,
1261        });
1262        assert!(config.validate().is_err());
1263    }
1264
1265    #[test]
1266    fn test_config_validation_audit_invalid_export_url() {
1267        let mut config = create_valid_config();
1268        config.audit.export_url = Some("not-a-url".to_string());
1269        assert!(config.validate().is_err());
1270    }
1271
1272    #[test]
1273    fn test_config_validation_audit_batch_size_zero() {
1274        let mut config = create_valid_config();
1275        config.audit.export_url = Some("http://siem.example.com".to_string());
1276        config.audit.export_batch_size = 0;
1277        assert!(config.validate().is_err());
1278    }
1279
1280    #[test]
1281    fn test_config_validation_audit_batch_size_too_large() {
1282        let mut config = create_valid_config();
1283        config.audit.export_url = Some("http://siem.example.com".to_string());
1284        config.audit.export_batch_size = 10001;
1285        assert!(config.validate().is_err());
1286    }
1287
1288    #[test]
1289    fn test_config_validation_audit_interval_zero() {
1290        let mut config = create_valid_config();
1291        config.audit.export_url = Some("http://siem.example.com".to_string());
1292        config.audit.export_interval_secs = 0;
1293        assert!(config.validate().is_err());
1294    }
1295
1296    #[test]
1297    fn test_config_validation_tracing_invalid_sample_rate() {
1298        let mut config = create_valid_config();
1299        config.tracing.enabled = true;
1300        config.tracing.sample_rate = 1.5;
1301        assert!(config.validate().is_err());
1302
1303        config.tracing.sample_rate = -0.1;
1304        assert!(config.validate().is_err());
1305    }
1306
1307    // mTLS tests require Enterprise feature
1308    #[cfg(feature = "enterprise")]
1309    #[test]
1310    fn test_config_validation_mtls_requires_trusted_proxy_ips() {
1311        let mut config = create_valid_config();
1312        // mTLS enabled without trusted_proxy_ips should fail
1313        config.auth.mtls = Some(MtlsConfig {
1314            enabled: true,
1315            identity_source: MtlsIdentitySource::Cn,
1316            allowed_tools: vec![],
1317            rate_limit: None,
1318            trusted_proxy_ips: vec![], // Empty = security risk
1319        });
1320        let result = config.validate();
1321        assert!(result.is_err());
1322        assert!(result
1323            .unwrap_err()
1324            .to_string()
1325            .contains("trusted_proxy_ips"));
1326
1327        // mTLS enabled with trusted_proxy_ips should succeed
1328        config.auth.mtls = Some(MtlsConfig {
1329            enabled: true,
1330            identity_source: MtlsIdentitySource::Cn,
1331            allowed_tools: vec![],
1332            rate_limit: None,
1333            trusted_proxy_ips: vec!["10.0.0.0/8".to_string()],
1334        });
1335        assert!(config.validate().is_ok());
1336
1337        // mTLS disabled without trusted_proxy_ips should succeed (not used)
1338        config.auth.mtls = Some(MtlsConfig {
1339            enabled: false,
1340            identity_source: MtlsIdentitySource::Cn,
1341            allowed_tools: vec![],
1342            rate_limit: None,
1343            trusted_proxy_ips: vec![],
1344        });
1345        assert!(config.validate().is_ok());
1346    }
1347
1348    // Verify mTLS requires Enterprise in free tier
1349    #[cfg(not(feature = "enterprise"))]
1350    #[test]
1351    fn test_config_validation_mtls_requires_enterprise() {
1352        let mut config = create_valid_config();
1353        config.auth.mtls = Some(MtlsConfig {
1354            enabled: true,
1355            identity_source: MtlsIdentitySource::Cn,
1356            allowed_tools: vec![],
1357            rate_limit: None,
1358            trusted_proxy_ips: vec!["10.0.0.0/8".to_string()],
1359        });
1360        let result = config.validate();
1361        assert!(result.is_err());
1362        assert!(result.unwrap_err().to_string().contains("Enterprise"));
1363    }
1364
1365    #[test]
1366    fn test_config_is_multi_server() {
1367        let mut config = create_valid_config();
1368        assert!(!config.is_multi_server());
1369
1370        // Just test the method, not validation
1371        config.upstream.servers.push(ServerRouteConfig {
1372            name: "test".to_string(),
1373            path_prefix: "/test".to_string(),
1374            transport: TransportType::Stdio,
1375            command: Some("/bin/echo".to_string()),
1376            args: vec![],
1377            url: None,
1378            strip_prefix: false,
1379        });
1380        assert!(config.is_multi_server());
1381    }
1382
1383    // ------------------------------------------------------------------------
1384    // ConfigError Tests
1385    // ------------------------------------------------------------------------
1386
1387    #[test]
1388    fn test_config_error_display() {
1389        let err = ConfigError::Parse("invalid TOML".to_string());
1390        assert!(format!("{}", err).contains("invalid TOML"));
1391
1392        let err = ConfigError::Validation("port must be > 0".to_string());
1393        assert!(format!("{}", err).contains("port must be > 0"));
1394    }
1395
1396    // ------------------------------------------------------------------------
1397    // Transport Type Tests
1398    // ------------------------------------------------------------------------
1399
1400    #[test]
1401    fn test_transport_type_serialization() {
1402        let json = serde_json::to_string(&TransportType::Stdio).unwrap();
1403        assert!(json.contains("stdio"));
1404
1405        let json = serde_json::to_string(&TransportType::Http).unwrap();
1406        assert!(json.contains("http"));
1407
1408        let json = serde_json::to_string(&TransportType::Sse).unwrap();
1409        assert!(json.contains("sse"));
1410    }
1411
1412    // ------------------------------------------------------------------------
1413    // OAuth Provider Tests
1414    // ------------------------------------------------------------------------
1415
1416    #[test]
1417    fn test_oauth_provider_serialization() {
1418        let provider = OAuthProvider::GitHub;
1419        let json = serde_json::to_string(&provider).unwrap();
1420        assert!(json.contains("github"));
1421
1422        let provider = OAuthProvider::Google;
1423        let json = serde_json::to_string(&provider).unwrap();
1424        assert!(json.contains("google"));
1425    }
1426}