1use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32use std::path::PathBuf;
33
34#[derive(Debug, thiserror::Error)]
40pub enum ConfigError {
41 #[error("Failed to read config file: {0}")]
42 Read(#[from] std::io::Error),
43
44 #[error("Failed to parse config: {0}")]
45 Parse(String),
46
47 #[error("Validation error: {0}")]
48 Validation(String),
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct Config {
58 #[serde(default)]
60 pub server: ServerConfig,
61
62 #[serde(default)]
64 pub auth: AuthConfig,
65
66 #[serde(default)]
68 pub rate_limit: RateLimitConfig,
69
70 #[serde(default)]
72 pub audit: AuditConfig,
73
74 #[serde(default)]
76 pub tracing: TracingConfig,
77
78 pub upstream: UpstreamConfig,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ServerConfig {
85 #[serde(default = "default_host")]
87 pub host: String,
88
89 #[serde(default = "default_port")]
91 pub port: u16,
92
93 #[serde(default = "default_max_request_size")]
96 pub max_request_size: usize,
97
98 #[serde(default)]
100 pub cors: CorsConfig,
101
102 #[serde(default)]
104 pub tls: Option<TlsConfig>,
105}
106
107impl Default for ServerConfig {
108 fn default() -> Self {
109 Self {
110 host: default_host(),
111 port: default_port(),
112 max_request_size: default_max_request_size(),
113 cors: CorsConfig::default(),
114 tls: None,
115 }
116 }
117}
118
119fn default_host() -> String {
120 "127.0.0.1".to_string()
121}
122
123fn default_port() -> u16 {
124 3000
125}
126
127fn default_max_request_size() -> usize {
128 1024 * 1024 }
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct CorsConfig {
134 #[serde(default)]
136 pub enabled: bool,
137
138 #[serde(default)]
141 pub allowed_origins: Vec<String>,
142
143 #[serde(default = "default_cors_methods")]
145 pub allowed_methods: Vec<String>,
146
147 #[serde(default = "default_cors_headers")]
149 pub allowed_headers: Vec<String>,
150
151 #[serde(default = "default_cors_max_age")]
153 pub max_age: u64,
154}
155
156impl Default for CorsConfig {
157 fn default() -> Self {
158 Self {
159 enabled: false,
160 allowed_origins: vec![],
161 allowed_methods: default_cors_methods(),
162 allowed_headers: default_cors_headers(),
163 max_age: default_cors_max_age(),
164 }
165 }
166}
167
168fn default_cors_methods() -> Vec<String> {
169 vec!["GET".into(), "POST".into(), "OPTIONS".into()]
170}
171
172fn default_cors_headers() -> Vec<String> {
173 vec!["Authorization".into(), "Content-Type".into()]
174}
175
176fn default_cors_max_age() -> u64 {
177 3600 }
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct TlsConfig {
183 pub cert_path: PathBuf,
185 pub key_path: PathBuf,
187 pub client_ca_path: Option<PathBuf>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct MtlsConfig {
195 #[serde(default)]
197 pub enabled: bool,
198 #[serde(default = "default_mtls_identity_source")]
201 pub identity_source: MtlsIdentitySource,
202 #[serde(default)]
204 pub allowed_tools: Vec<String>,
205 #[serde(default)]
207 pub rate_limit: Option<u32>,
208 #[serde(default)]
213 pub trusted_proxy_ips: Vec<String>,
214}
215
216impl Default for MtlsConfig {
217 fn default() -> Self {
218 Self {
219 enabled: false,
220 identity_source: default_mtls_identity_source(),
221 allowed_tools: vec![],
222 rate_limit: None,
223 trusted_proxy_ips: vec![],
224 }
225 }
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
230#[serde(rename_all = "lowercase")]
231pub enum MtlsIdentitySource {
232 Cn,
234 SanDns,
236 SanEmail,
238}
239
240fn default_mtls_identity_source() -> MtlsIdentitySource {
241 MtlsIdentitySource::Cn
242}
243
244#[derive(Debug, Clone, Default, Serialize, Deserialize)]
250pub struct AuthConfig {
251 #[serde(default)]
253 pub api_keys: Vec<ApiKeyConfig>,
254
255 #[serde(default)]
257 pub jwt: Option<JwtConfig>,
258
259 #[serde(default)]
261 pub oauth: Option<OAuthConfig>,
262
263 #[serde(default)]
265 pub mtls: Option<MtlsConfig>,
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct ApiKeyConfig {
271 pub id: String,
273
274 pub key_hash: String,
276
277 #[serde(default)]
279 pub allowed_tools: Vec<String>,
280
281 #[serde(default)]
283 pub rate_limit: Option<u32>,
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize)]
288#[serde(tag = "mode", rename_all = "lowercase")]
289pub enum JwtMode {
290 Simple {
292 secret: String,
294 },
295 Jwks {
297 jwks_url: String,
299 #[serde(default = "default_jwks_algorithms")]
301 algorithms: Vec<String>,
302 #[serde(default = "default_cache_duration")]
304 cache_duration_secs: u64,
305 },
306}
307
308#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct JwtConfig {
311 #[serde(flatten)]
313 pub mode: JwtMode,
314
315 pub issuer: String,
317
318 pub audience: String,
320
321 #[serde(default = "default_user_id_claim")]
323 pub user_id_claim: String,
324
325 #[serde(default = "default_scopes_claim")]
327 pub scopes_claim: String,
328
329 #[serde(default)]
332 pub scope_tool_mapping: HashMap<String, Vec<String>>,
333
334 #[serde(default)]
336 pub leeway_secs: u64,
337}
338
339fn default_jwks_algorithms() -> Vec<String> {
340 vec!["RS256".to_string(), "ES256".to_string()]
341}
342
343fn default_cache_duration() -> u64 {
344 3600 }
346
347fn default_user_id_claim() -> String {
348 "sub".to_string()
349}
350
351fn default_scopes_claim() -> String {
352 "scope".to_string()
353}
354
355#[derive(Debug, Clone, Serialize, Deserialize)]
357#[serde(rename_all = "lowercase")]
358pub enum OAuthProvider {
359 GitHub,
361 Google,
363 Okta,
365 Custom,
367}
368
369#[derive(Debug, Clone, Serialize, Deserialize)]
371pub struct OAuthConfig {
372 pub provider: OAuthProvider,
374
375 pub client_id: String,
377
378 pub client_secret: Option<String>,
380
381 pub authorization_url: Option<String>,
383
384 pub token_url: Option<String>,
386
387 pub introspection_url: Option<String>,
389
390 pub userinfo_url: Option<String>,
392
393 #[serde(default = "default_redirect_uri")]
395 pub redirect_uri: String,
396
397 #[serde(default = "default_oauth_scopes")]
399 pub scopes: Vec<String>,
400
401 #[serde(default = "default_user_id_claim")]
403 pub user_id_claim: String,
404
405 #[serde(default)]
407 pub scope_tool_mapping: HashMap<String, Vec<String>>,
408
409 #[serde(default = "default_token_cache_ttl")]
415 pub token_cache_ttl_secs: u64,
416}
417
418fn default_token_cache_ttl() -> u64 {
419 300 }
421
422fn default_redirect_uri() -> String {
423 "http://localhost:3000/oauth/callback".to_string()
424}
425
426fn default_oauth_scopes() -> Vec<String> {
427 vec!["openid".to_string(), "profile".to_string()]
428}
429
430#[derive(Debug, Clone, Serialize, Deserialize)]
436pub struct RateLimitConfig {
437 #[serde(default = "default_true")]
439 pub enabled: bool,
440
441 #[serde(default = "default_rps")]
443 pub requests_per_second: u32,
444
445 #[serde(default = "default_burst")]
447 pub burst_size: u32,
448
449 #[serde(default)]
452 pub tool_limits: Vec<ToolRateLimitConfig>,
453}
454
455#[derive(Debug, Clone, Serialize, Deserialize)]
460pub struct ToolRateLimitConfig {
461 pub tool_pattern: String,
463
464 pub requests_per_second: u32,
466
467 #[serde(default = "default_tool_burst")]
469 pub burst_size: u32,
470}
471
472fn default_tool_burst() -> u32 {
473 5 }
475
476impl Default for RateLimitConfig {
477 fn default() -> Self {
478 Self {
479 enabled: true,
480 requests_per_second: default_rps(),
481 burst_size: default_burst(),
482 tool_limits: Vec::new(),
483 }
484 }
485}
486
487fn default_true() -> bool {
488 true
489}
490
491fn default_rps() -> u32 {
492 25 }
494
495fn default_burst() -> u32 {
496 10 }
498
499#[derive(Debug, Clone, Serialize, Deserialize)]
505pub struct AuditConfig {
506 #[serde(default = "default_true")]
508 pub enabled: bool,
509
510 #[serde(default)]
512 pub file: Option<PathBuf>,
513
514 #[serde(default)]
516 pub stdout: bool,
517
518 #[serde(default)]
521 pub export_url: Option<String>,
522
523 #[serde(default = "default_export_batch_size")]
525 pub export_batch_size: usize,
526
527 #[serde(default = "default_export_interval_secs")]
529 pub export_interval_secs: u64,
530
531 #[serde(default)]
533 pub export_headers: HashMap<String, String>,
534
535 #[serde(default)]
538 pub redaction_rules: Vec<RedactionRule>,
539
540 #[serde(default)]
542 pub rotation: Option<LogRotationConfig>,
543}
544
545#[derive(Debug, Clone, Serialize, Deserialize)]
550pub struct RedactionRule {
551 pub name: String,
553
554 pub pattern: String,
557
558 #[serde(default = "default_redaction_replacement")]
560 pub replacement: String,
561}
562
563fn default_redaction_replacement() -> String {
564 "[REDACTED]".to_string()
565}
566
567#[derive(Debug, Clone, Serialize, Deserialize)]
572pub struct LogRotationConfig {
573 #[serde(default)]
575 pub enabled: bool,
576
577 #[serde(default)]
579 pub max_size_bytes: Option<u64>,
580
581 #[serde(default)]
583 pub max_age_secs: Option<u64>,
584
585 #[serde(default = "default_max_backups")]
587 pub max_backups: usize,
588
589 #[serde(default)]
591 pub compress: bool,
592}
593
594fn default_max_backups() -> usize {
595 10
596}
597
598fn default_export_batch_size() -> usize {
599 100
600}
601
602fn default_export_interval_secs() -> u64 {
603 30
604}
605
606impl Default for AuditConfig {
607 fn default() -> Self {
608 Self {
609 enabled: true,
610 file: None,
611 stdout: false,
614 export_url: None,
615 export_batch_size: default_export_batch_size(),
616 export_interval_secs: default_export_interval_secs(),
617 export_headers: HashMap::new(),
618 redaction_rules: Vec::new(),
619 rotation: None,
620 }
621 }
622}
623
624#[derive(Debug, Clone, Serialize, Deserialize)]
630pub struct TracingConfig {
631 #[serde(default)]
633 pub enabled: bool,
634
635 #[serde(default = "default_service_name")]
637 pub service_name: String,
638
639 pub otlp_endpoint: Option<String>,
642
643 #[serde(default = "default_sample_rate")]
645 pub sample_rate: f64,
646
647 #[serde(default = "default_true")]
649 pub propagate_context: bool,
650}
651
652impl Default for TracingConfig {
653 fn default() -> Self {
654 Self {
655 enabled: false,
656 service_name: default_service_name(),
657 otlp_endpoint: None,
658 sample_rate: default_sample_rate(),
659 propagate_context: true,
660 }
661 }
662}
663
664fn default_service_name() -> String {
665 "mcp-guard".to_string()
666}
667
668fn default_sample_rate() -> f64 {
669 0.1
672}
673
674#[derive(Debug, Clone, Serialize, Deserialize)]
680pub struct UpstreamConfig {
681 pub transport: TransportType,
683
684 pub command: Option<String>,
686
687 #[serde(default)]
689 pub args: Vec<String>,
690
691 pub url: Option<String>,
693
694 #[serde(default)]
697 pub servers: Vec<ServerRouteConfig>,
698}
699
700#[derive(Debug, Clone, Serialize, Deserialize)]
702pub struct ServerRouteConfig {
703 pub name: String,
705
706 pub path_prefix: String,
709
710 pub transport: TransportType,
712
713 pub command: Option<String>,
715
716 #[serde(default)]
718 pub args: Vec<String>,
719
720 pub url: Option<String>,
722
723 #[serde(default)]
726 pub strip_prefix: bool,
727}
728
729#[derive(Debug, Clone, Serialize, Deserialize)]
731#[serde(rename_all = "lowercase")]
732pub enum TransportType {
733 Stdio,
734 Http,
735 Sse,
736}
737
738impl Config {
743 pub fn from_file(path: &PathBuf) -> Result<Self, ConfigError> {
745 let content = std::fs::read_to_string(path)?;
746
747 let config: Config = if path
748 .extension()
749 .map(|e| e == "yaml" || e == "yml")
750 .unwrap_or(false)
751 {
752 serde_yaml::from_str(&content).map_err(|e| ConfigError::Parse(e.to_string()))?
753 } else {
754 toml::from_str(&content).map_err(|e| ConfigError::Parse(e.to_string()))?
755 };
756
757 config.validate()?;
758 Ok(config)
759 }
760
761 pub fn validate(&self) -> Result<(), ConfigError> {
763 crate::tier::validate_tier(self)?;
765
766 self.validate_server()?;
768 self.validate_rate_limit()?;
769 self.validate_jwt()?;
770 self.validate_oauth()?;
771 self.validate_audit()?;
772 self.validate_mtls()?;
773 self.validate_tracing()?;
774 self.validate_upstream()
775 }
776
777 fn validate_server(&self) -> Result<(), ConfigError> {
783 if self.server.port == 0 {
784 return Err(ConfigError::Validation(
785 "server.port must be between 1 and 65535".to_string(),
786 ));
787 }
788 Ok(())
789 }
790
791 fn validate_rate_limit(&self) -> Result<(), ConfigError> {
793 if self.rate_limit.enabled {
794 if self.rate_limit.requests_per_second == 0 {
795 return Err(ConfigError::Validation(
796 "rate_limit.requests_per_second must be greater than 0".to_string(),
797 ));
798 }
799 if self.rate_limit.burst_size == 0 {
800 return Err(ConfigError::Validation(
801 "rate_limit.burst_size must be greater than 0".to_string(),
802 ));
803 }
804 }
805 Ok(())
806 }
807
808 fn validate_jwt(&self) -> Result<(), ConfigError> {
810 if let Some(ref jwt_config) = self.auth.jwt {
811 if let JwtMode::Jwks { ref jwks_url, .. } = jwt_config.mode {
812 #[cfg(not(debug_assertions))]
814 if !jwks_url.starts_with("https://") {
815 return Err(ConfigError::Validation(
816 "jwt.jwks_url must use HTTPS in production".to_string(),
817 ));
818 }
819 if !jwks_url.starts_with("http://") && !jwks_url.starts_with("https://") {
821 return Err(ConfigError::Validation(
822 "jwt.jwks_url must be a valid HTTP(S) URL".to_string(),
823 ));
824 }
825 }
826 }
827 Ok(())
828 }
829
830 fn validate_oauth(&self) -> Result<(), ConfigError> {
832 if let Some(ref oauth_config) = self.auth.oauth {
833 if !oauth_config.redirect_uri.starts_with("http://")
835 && !oauth_config.redirect_uri.starts_with("https://")
836 {
837 return Err(ConfigError::Validation(
838 "oauth.redirect_uri must be a valid HTTP(S) URL".to_string(),
839 ));
840 }
841 #[cfg(not(debug_assertions))]
843 if oauth_config.redirect_uri.starts_with("http://") {
844 tracing::warn!(
845 "SECURITY WARNING: oauth.redirect_uri uses HTTP instead of HTTPS. \
846 This is insecure in production and may allow authorization code interception."
847 );
848 }
849 }
850 Ok(())
851 }
852
853 fn validate_audit(&self) -> Result<(), ConfigError> {
855 if let Some(ref export_url) = self.audit.export_url {
856 if !export_url.starts_with("http://") && !export_url.starts_with("https://") {
858 return Err(ConfigError::Validation(
859 "audit.export_url must be a valid HTTP(S) URL".to_string(),
860 ));
861 }
862 if self.audit.export_batch_size == 0 {
864 return Err(ConfigError::Validation(
865 "audit.export_batch_size must be greater than 0".to_string(),
866 ));
867 }
868 if self.audit.export_batch_size > 10000 {
869 return Err(ConfigError::Validation(
870 "audit.export_batch_size must be less than or equal to 10000".to_string(),
871 ));
872 }
873 if self.audit.export_interval_secs == 0 {
875 return Err(ConfigError::Validation(
876 "audit.export_interval_secs must be greater than 0".to_string(),
877 ));
878 }
879 }
880 Ok(())
881 }
882
883 fn validate_mtls(&self) -> Result<(), ConfigError> {
885 if let Some(ref mtls_config) = self.auth.mtls {
886 if mtls_config.enabled && mtls_config.trusted_proxy_ips.is_empty() {
887 return Err(ConfigError::Validation(
889 "auth.mtls.trusted_proxy_ips must be configured when mTLS is enabled. \
890 Without trusted proxy IPs, attackers could spoof client certificate headers."
891 .to_string(),
892 ));
893 }
894 }
895 Ok(())
896 }
897
898 fn validate_tracing(&self) -> Result<(), ConfigError> {
900 if self.tracing.enabled
901 && (self.tracing.sample_rate < 0.0 || self.tracing.sample_rate > 1.0)
902 {
903 return Err(ConfigError::Validation(
904 "tracing.sample_rate must be between 0.0 and 1.0".to_string(),
905 ));
906 }
907 Ok(())
908 }
909
910 fn validate_upstream(&self) -> Result<(), ConfigError> {
912 if !self.upstream.servers.is_empty() {
914 for server in &self.upstream.servers {
915 server.validate()?;
916 }
917 return Ok(());
918 }
919
920 match self.upstream.transport {
922 TransportType::Stdio => {
923 if self.upstream.command.is_none() {
924 return Err(ConfigError::Validation(
925 "stdio transport requires 'command' to be set".to_string(),
926 ));
927 }
928 }
929 TransportType::Http | TransportType::Sse => {
930 if self.upstream.url.is_none() {
931 return Err(ConfigError::Validation(
932 "http/sse transport requires 'url' to be set".to_string(),
933 ));
934 }
935 }
936 }
937
938 Ok(())
939 }
940
941 pub fn is_multi_server(&self) -> bool {
943 !self.upstream.servers.is_empty()
944 }
945
946 pub fn requires_pro_features(&self) -> bool {
953 if self.auth.oauth.is_some() {
955 return true;
956 }
957
958 if let Some(ref jwt_config) = self.auth.jwt {
960 if matches!(jwt_config.mode, JwtMode::Jwks { .. }) {
961 return true;
962 }
963 }
964
965 if self.upstream.servers.is_empty() {
967 match self.upstream.transport {
968 TransportType::Http | TransportType::Sse => return true,
969 TransportType::Stdio => {}
970 }
971 } else {
972 for server in &self.upstream.servers {
974 match server.transport {
975 TransportType::Http | TransportType::Sse => return true,
976 TransportType::Stdio => {}
977 }
978 }
979 }
980
981 false
982 }
983
984 pub fn requires_enterprise_features(&self) -> bool {
993 if let Some(ref mtls_config) = self.auth.mtls {
995 if mtls_config.enabled {
996 return true;
997 }
998 }
999
1000 if !self.upstream.servers.is_empty() {
1002 return true;
1003 }
1004
1005 if self.audit.export_url.is_some() {
1007 return true;
1008 }
1009
1010 if self.tracing.enabled && self.tracing.otlp_endpoint.is_some() {
1012 return true;
1013 }
1014
1015 if !self.rate_limit.tool_limits.is_empty() {
1017 return true;
1018 }
1019
1020 false
1021 }
1022}
1023
1024impl ServerRouteConfig {
1025 pub fn validate(&self) -> Result<(), ConfigError> {
1027 if self.name.is_empty() {
1028 return Err(ConfigError::Validation(
1029 "Server route 'name' cannot be empty".to_string(),
1030 ));
1031 }
1032
1033 if self.path_prefix.is_empty() {
1034 return Err(ConfigError::Validation(format!(
1035 "Server route '{}' path_prefix cannot be empty",
1036 self.name
1037 )));
1038 }
1039
1040 if !self.path_prefix.starts_with('/') {
1041 return Err(ConfigError::Validation(format!(
1042 "Server route '{}' path_prefix must start with '/'",
1043 self.name
1044 )));
1045 }
1046
1047 match self.transport {
1048 TransportType::Stdio => {
1049 if self.command.is_none() {
1050 return Err(ConfigError::Validation(format!(
1051 "Server route '{}' with stdio transport requires 'command' to be set",
1052 self.name
1053 )));
1054 }
1055 }
1056 TransportType::Http | TransportType::Sse => {
1057 if self.url.is_none() {
1058 return Err(ConfigError::Validation(format!(
1059 "Server route '{}' with http/sse transport requires 'url' to be set",
1060 self.name
1061 )));
1062 }
1063 }
1064 }
1065
1066 Ok(())
1067 }
1068}
1069
1070#[cfg(test)]
1075mod tests {
1076 use super::*;
1077
1078 fn create_valid_config() -> Config {
1079 Config {
1081 server: ServerConfig::default(),
1082 auth: AuthConfig::default(),
1083 rate_limit: RateLimitConfig::default(),
1084 audit: AuditConfig::default(),
1085 tracing: TracingConfig::default(),
1086 upstream: UpstreamConfig {
1087 transport: TransportType::Stdio,
1088 command: Some("/bin/echo".to_string()),
1089 args: vec![],
1090 url: None,
1091 servers: vec![],
1092 },
1093 }
1094 }
1095
1096 #[cfg(feature = "pro")]
1098 fn create_valid_config_http() -> Config {
1099 Config {
1100 server: ServerConfig::default(),
1101 auth: AuthConfig::default(),
1102 rate_limit: RateLimitConfig::default(),
1103 audit: AuditConfig::default(),
1104 tracing: TracingConfig::default(),
1105 upstream: UpstreamConfig {
1106 transport: TransportType::Http,
1107 command: None,
1108 args: vec![],
1109 url: Some("http://localhost:8080".to_string()),
1110 servers: vec![],
1111 },
1112 }
1113 }
1114
1115 #[test]
1120 fn test_server_config_defaults() {
1121 let config = ServerConfig::default();
1122 assert_eq!(config.host, "127.0.0.1");
1123 assert_eq!(config.port, 3000);
1124 assert!(config.tls.is_none());
1125 }
1126
1127 #[test]
1128 fn test_rate_limit_config_defaults() {
1129 let config = RateLimitConfig::default();
1130 assert!(config.enabled);
1131 assert_eq!(config.requests_per_second, 25);
1133 assert_eq!(config.burst_size, 10);
1134 }
1135
1136 #[test]
1137 fn test_audit_config_defaults() {
1138 let config = AuditConfig::default();
1139 assert!(config.enabled);
1140 assert!(config.file.is_none());
1141 assert!(!config.stdout);
1143 assert!(config.export_url.is_none());
1144 assert_eq!(config.export_batch_size, 100);
1145 assert_eq!(config.export_interval_secs, 30);
1146 }
1147
1148 #[test]
1149 fn test_tracing_config_defaults() {
1150 let config = TracingConfig::default();
1151 assert!(!config.enabled);
1152 assert_eq!(config.service_name, "mcp-guard");
1153 assert!(config.otlp_endpoint.is_none());
1154 assert_eq!(config.sample_rate, 0.1);
1156 assert!(config.propagate_context);
1157 }
1158
1159 #[test]
1160 fn test_mtls_config_defaults() {
1161 let config = MtlsConfig::default();
1162 assert!(!config.enabled);
1163 assert!(matches!(config.identity_source, MtlsIdentitySource::Cn));
1164 assert!(config.allowed_tools.is_empty());
1165 assert!(config.rate_limit.is_none());
1166 }
1167
1168 #[test]
1173 fn test_config_validation_success() {
1174 let config = create_valid_config();
1175 assert!(config.validate().is_ok());
1176 }
1177
1178 #[test]
1179 fn test_config_validation_invalid_port() {
1180 let mut config = create_valid_config();
1181 config.server.port = 0;
1182 assert!(config.validate().is_err());
1183 }
1184
1185 #[test]
1186 fn test_config_validation_rate_limit_zero_rps() {
1187 let mut config = create_valid_config();
1188 config.rate_limit.enabled = true;
1189 config.rate_limit.requests_per_second = 0;
1190 assert!(config.validate().is_err());
1191 }
1192
1193 #[test]
1194 fn test_config_validation_rate_limit_zero_burst() {
1195 let mut config = create_valid_config();
1196 config.rate_limit.enabled = true;
1197 config.rate_limit.burst_size = 0;
1198 assert!(config.validate().is_err());
1199 }
1200
1201 #[test]
1202 fn test_config_validation_stdio_missing_command() {
1203 let mut config = create_valid_config();
1204 config.upstream.transport = TransportType::Stdio;
1205 config.upstream.command = None;
1206 config.upstream.url = None;
1207 assert!(config.validate().is_err());
1208 }
1209
1210 #[test]
1211 fn test_config_validation_http_missing_url() {
1212 let mut config = create_valid_config();
1213 config.upstream.transport = TransportType::Http;
1214 config.upstream.url = None;
1215 assert!(config.validate().is_err());
1216 }
1217
1218 #[test]
1219 fn test_config_validation_sse_missing_url() {
1220 let mut config = create_valid_config();
1221 config.upstream.transport = TransportType::Sse;
1222 config.upstream.url = None;
1223 assert!(config.validate().is_err());
1224 }
1225
1226 #[test]
1227 fn test_config_validation_jwt_invalid_jwks_url() {
1228 let mut config = create_valid_config();
1229 config.auth.jwt = Some(JwtConfig {
1230 mode: JwtMode::Jwks {
1231 jwks_url: "invalid-url".to_string(),
1232 algorithms: default_jwks_algorithms(),
1233 cache_duration_secs: 3600,
1234 },
1235 issuer: "https://issuer.example.com".to_string(),
1236 audience: "mcp-guard".to_string(),
1237 user_id_claim: "sub".to_string(),
1238 scopes_claim: "scope".to_string(),
1239 scope_tool_mapping: HashMap::new(),
1240 leeway_secs: 0,
1241 });
1242 assert!(config.validate().is_err());
1243 }
1244
1245 #[test]
1246 fn test_config_validation_oauth_invalid_redirect_uri() {
1247 let mut config = create_valid_config();
1248 config.auth.oauth = Some(OAuthConfig {
1249 provider: OAuthProvider::GitHub,
1250 client_id: "test".to_string(),
1251 client_secret: None,
1252 authorization_url: None,
1253 token_url: None,
1254 introspection_url: None,
1255 userinfo_url: None,
1256 redirect_uri: "invalid-uri".to_string(),
1257 scopes: vec![],
1258 user_id_claim: "sub".to_string(),
1259 scope_tool_mapping: HashMap::new(),
1260 token_cache_ttl_secs: 300,
1261 });
1262 assert!(config.validate().is_err());
1263 }
1264
1265 #[test]
1266 fn test_config_validation_audit_invalid_export_url() {
1267 let mut config = create_valid_config();
1268 config.audit.export_url = Some("not-a-url".to_string());
1269 assert!(config.validate().is_err());
1270 }
1271
1272 #[test]
1273 fn test_config_validation_audit_batch_size_zero() {
1274 let mut config = create_valid_config();
1275 config.audit.export_url = Some("http://siem.example.com".to_string());
1276 config.audit.export_batch_size = 0;
1277 assert!(config.validate().is_err());
1278 }
1279
1280 #[test]
1281 fn test_config_validation_audit_batch_size_too_large() {
1282 let mut config = create_valid_config();
1283 config.audit.export_url = Some("http://siem.example.com".to_string());
1284 config.audit.export_batch_size = 10001;
1285 assert!(config.validate().is_err());
1286 }
1287
1288 #[test]
1289 fn test_config_validation_audit_interval_zero() {
1290 let mut config = create_valid_config();
1291 config.audit.export_url = Some("http://siem.example.com".to_string());
1292 config.audit.export_interval_secs = 0;
1293 assert!(config.validate().is_err());
1294 }
1295
1296 #[test]
1297 fn test_config_validation_tracing_invalid_sample_rate() {
1298 let mut config = create_valid_config();
1299 config.tracing.enabled = true;
1300 config.tracing.sample_rate = 1.5;
1301 assert!(config.validate().is_err());
1302
1303 config.tracing.sample_rate = -0.1;
1304 assert!(config.validate().is_err());
1305 }
1306
1307 #[cfg(feature = "enterprise")]
1309 #[test]
1310 fn test_config_validation_mtls_requires_trusted_proxy_ips() {
1311 let mut config = create_valid_config();
1312 config.auth.mtls = Some(MtlsConfig {
1314 enabled: true,
1315 identity_source: MtlsIdentitySource::Cn,
1316 allowed_tools: vec![],
1317 rate_limit: None,
1318 trusted_proxy_ips: vec![], });
1320 let result = config.validate();
1321 assert!(result.is_err());
1322 assert!(result
1323 .unwrap_err()
1324 .to_string()
1325 .contains("trusted_proxy_ips"));
1326
1327 config.auth.mtls = Some(MtlsConfig {
1329 enabled: true,
1330 identity_source: MtlsIdentitySource::Cn,
1331 allowed_tools: vec![],
1332 rate_limit: None,
1333 trusted_proxy_ips: vec!["10.0.0.0/8".to_string()],
1334 });
1335 assert!(config.validate().is_ok());
1336
1337 config.auth.mtls = Some(MtlsConfig {
1339 enabled: false,
1340 identity_source: MtlsIdentitySource::Cn,
1341 allowed_tools: vec![],
1342 rate_limit: None,
1343 trusted_proxy_ips: vec![],
1344 });
1345 assert!(config.validate().is_ok());
1346 }
1347
1348 #[cfg(not(feature = "enterprise"))]
1350 #[test]
1351 fn test_config_validation_mtls_requires_enterprise() {
1352 let mut config = create_valid_config();
1353 config.auth.mtls = Some(MtlsConfig {
1354 enabled: true,
1355 identity_source: MtlsIdentitySource::Cn,
1356 allowed_tools: vec![],
1357 rate_limit: None,
1358 trusted_proxy_ips: vec!["10.0.0.0/8".to_string()],
1359 });
1360 let result = config.validate();
1361 assert!(result.is_err());
1362 assert!(result.unwrap_err().to_string().contains("Enterprise"));
1363 }
1364
1365 #[test]
1366 fn test_config_is_multi_server() {
1367 let mut config = create_valid_config();
1368 assert!(!config.is_multi_server());
1369
1370 config.upstream.servers.push(ServerRouteConfig {
1372 name: "test".to_string(),
1373 path_prefix: "/test".to_string(),
1374 transport: TransportType::Stdio,
1375 command: Some("/bin/echo".to_string()),
1376 args: vec![],
1377 url: None,
1378 strip_prefix: false,
1379 });
1380 assert!(config.is_multi_server());
1381 }
1382
1383 #[test]
1388 fn test_config_error_display() {
1389 let err = ConfigError::Parse("invalid TOML".to_string());
1390 assert!(format!("{}", err).contains("invalid TOML"));
1391
1392 let err = ConfigError::Validation("port must be > 0".to_string());
1393 assert!(format!("{}", err).contains("port must be > 0"));
1394 }
1395
1396 #[test]
1401 fn test_transport_type_serialization() {
1402 let json = serde_json::to_string(&TransportType::Stdio).unwrap();
1403 assert!(json.contains("stdio"));
1404
1405 let json = serde_json::to_string(&TransportType::Http).unwrap();
1406 assert!(json.contains("http"));
1407
1408 let json = serde_json::to_string(&TransportType::Sse).unwrap();
1409 assert!(json.contains("sse"));
1410 }
1411
1412 #[test]
1417 fn test_oauth_provider_serialization() {
1418 let provider = OAuthProvider::GitHub;
1419 let json = serde_json::to_string(&provider).unwrap();
1420 assert!(json.contains("github"));
1421
1422 let provider = OAuthProvider::Google;
1423 let json = serde_json::to_string(&provider).unwrap();
1424 assert!(json.contains("google"));
1425 }
1426}