1use std::collections::HashMap;
4use std::collections::HashSet;
5use std::path::Path;
6
7use anyhow::{Context, Result};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Deserialize, Serialize)]
12pub struct ProxyConfig {
13 pub proxy: ProxySettings,
15 #[serde(default)]
17 pub backends: Vec<BackendConfig>,
18 pub auth: Option<AuthConfig>,
20 #[serde(default)]
22 pub performance: PerformanceConfig,
23 #[serde(default)]
25 pub security: SecurityConfig,
26 #[serde(default)]
28 pub observability: ObservabilityConfig,
29}
30
31#[derive(Debug, Deserialize, Serialize)]
33pub struct ProxySettings {
34 pub name: String,
36 #[serde(default = "default_version")]
38 pub version: String,
39 #[serde(default = "default_separator")]
41 pub separator: String,
42 pub listen: ListenConfig,
44 pub instructions: Option<String>,
46 #[serde(default = "default_shutdown_timeout")]
48 pub shutdown_timeout_seconds: u64,
49 #[serde(default)]
51 pub hot_reload: bool,
52}
53
54#[derive(Debug, Deserialize, Serialize)]
56pub struct ListenConfig {
57 #[serde(default = "default_host")]
59 pub host: String,
60 #[serde(default = "default_port")]
62 pub port: u16,
63}
64
65#[derive(Debug, Deserialize, Serialize)]
67pub struct BackendConfig {
68 pub name: String,
70 pub transport: TransportType,
72 pub command: Option<String>,
74 #[serde(default)]
76 pub args: Vec<String>,
77 pub url: Option<String>,
79 #[serde(default)]
81 pub env: HashMap<String, String>,
82 pub timeout: Option<TimeoutConfig>,
84 pub circuit_breaker: Option<CircuitBreakerConfig>,
86 pub rate_limit: Option<RateLimitConfig>,
88 pub concurrency: Option<ConcurrencyConfig>,
90 pub retry: Option<RetryConfig>,
92 pub outlier_detection: Option<OutlierDetectionConfig>,
94 pub hedging: Option<HedgingConfig>,
96 pub mirror_of: Option<String>,
99 #[serde(default = "default_mirror_percent")]
101 pub mirror_percent: u32,
102 pub cache: Option<BackendCacheConfig>,
104 pub bearer_token: Option<String>,
107 #[serde(default)]
110 pub forward_auth: bool,
111 #[serde(default)]
113 pub aliases: Vec<AliasConfig>,
114 #[serde(default)]
117 pub default_args: serde_json::Map<String, serde_json::Value>,
118 #[serde(default)]
120 pub inject_args: Vec<InjectArgsConfig>,
121 #[serde(default)]
123 pub expose_tools: Vec<String>,
124 #[serde(default)]
126 pub hide_tools: Vec<String>,
127 #[serde(default)]
129 pub expose_resources: Vec<String>,
130 #[serde(default)]
132 pub hide_resources: Vec<String>,
133 #[serde(default)]
135 pub expose_prompts: Vec<String>,
136 #[serde(default)]
138 pub hide_prompts: Vec<String>,
139 pub canary_of: Option<String>,
143 #[serde(default = "default_weight")]
146 pub weight: u32,
147}
148
149#[derive(Debug, Deserialize, Serialize)]
151#[serde(rename_all = "lowercase")]
152pub enum TransportType {
153 Stdio,
155 Http,
157}
158
159#[derive(Debug, Deserialize, Serialize)]
161pub struct TimeoutConfig {
162 pub seconds: u64,
164}
165
166#[derive(Debug, Deserialize, Serialize)]
168pub struct CircuitBreakerConfig {
169 #[serde(default = "default_failure_rate")]
171 pub failure_rate_threshold: f64,
172 #[serde(default = "default_min_calls")]
174 pub minimum_calls: usize,
175 #[serde(default = "default_wait_duration")]
177 pub wait_duration_seconds: u64,
178 #[serde(default = "default_half_open_calls")]
180 pub permitted_calls_in_half_open: usize,
181}
182
183#[derive(Debug, Deserialize, Serialize)]
185pub struct RateLimitConfig {
186 pub requests: usize,
188 #[serde(default = "default_rate_period")]
190 pub period_seconds: u64,
191}
192
193#[derive(Debug, Deserialize, Serialize)]
195pub struct ConcurrencyConfig {
196 pub max_concurrent: usize,
198}
199
200#[derive(Debug, Clone, Deserialize, Serialize)]
202pub struct RetryConfig {
203 #[serde(default = "default_max_retries")]
205 pub max_retries: u32,
206 #[serde(default = "default_initial_backoff_ms")]
208 pub initial_backoff_ms: u64,
209 #[serde(default = "default_max_backoff_ms")]
211 pub max_backoff_ms: u64,
212 pub budget_percent: Option<f64>,
217 #[serde(default = "default_min_retries_per_sec")]
220 pub min_retries_per_sec: u32,
221}
222
223#[derive(Debug, Clone, Deserialize, Serialize)]
227pub struct OutlierDetectionConfig {
228 #[serde(default = "default_consecutive_errors")]
230 pub consecutive_errors: u32,
231 #[serde(default = "default_interval_seconds")]
233 pub interval_seconds: u64,
234 #[serde(default = "default_base_ejection_seconds")]
236 pub base_ejection_seconds: u64,
237 #[serde(default = "default_max_ejection_percent")]
239 pub max_ejection_percent: u32,
240}
241
242#[derive(Debug, Clone, Deserialize, Serialize)]
244pub struct InjectArgsConfig {
245 pub tool: String,
247 pub args: serde_json::Map<String, serde_json::Value>,
250 #[serde(default)]
252 pub overwrite: bool,
253}
254
255#[derive(Debug, Clone, Deserialize, Serialize)]
261pub struct HedgingConfig {
262 #[serde(default = "default_hedge_delay_ms")]
265 pub delay_ms: u64,
266 #[serde(default = "default_max_hedges")]
268 pub max_hedges: usize,
269}
270
271#[derive(Debug, Deserialize, Serialize)]
273#[serde(tag = "type", rename_all = "lowercase")]
274pub enum AuthConfig {
275 Bearer {
277 tokens: Vec<String>,
279 },
280 Jwt {
282 issuer: String,
284 audience: String,
286 jwks_uri: String,
288 #[serde(default)]
290 roles: Vec<RoleConfig>,
291 role_mapping: Option<RoleMappingConfig>,
293 },
294}
295
296#[derive(Debug, Deserialize, Serialize)]
298pub struct RoleConfig {
299 pub name: String,
301 #[serde(default)]
303 pub allow_tools: Vec<String>,
304 #[serde(default)]
306 pub deny_tools: Vec<String>,
307}
308
309#[derive(Debug, Deserialize, Serialize)]
311pub struct RoleMappingConfig {
312 pub claim: String,
314 pub mapping: HashMap<String, String>,
316}
317
318#[derive(Debug, Deserialize, Serialize)]
320pub struct AliasConfig {
321 pub from: String,
323 pub to: String,
325}
326
327#[derive(Debug, Deserialize, Serialize)]
329pub struct BackendCacheConfig {
330 #[serde(default)]
332 pub resource_ttl_seconds: u64,
333 #[serde(default)]
335 pub tool_ttl_seconds: u64,
336 #[serde(default = "default_max_cache_entries")]
338 pub max_entries: u64,
339}
340
341#[derive(Debug, Default, Deserialize, Serialize)]
343pub struct PerformanceConfig {
344 #[serde(default)]
346 pub coalesce_requests: bool,
347}
348
349#[derive(Debug, Default, Deserialize, Serialize)]
351pub struct SecurityConfig {
352 pub max_argument_size: Option<usize>,
354}
355
356#[derive(Debug, Default, Deserialize, Serialize)]
358pub struct ObservabilityConfig {
359 #[serde(default)]
361 pub audit: bool,
362 #[serde(default = "default_log_level")]
364 pub log_level: String,
365 #[serde(default)]
367 pub json_logs: bool,
368 #[serde(default)]
370 pub metrics: MetricsConfig,
371 #[serde(default)]
373 pub tracing: TracingConfig,
374}
375
376#[derive(Debug, Default, Deserialize, Serialize)]
378pub struct MetricsConfig {
379 #[serde(default)]
381 pub enabled: bool,
382}
383
384#[derive(Debug, Default, Deserialize, Serialize)]
386pub struct TracingConfig {
387 #[serde(default)]
389 pub enabled: bool,
390 #[serde(default = "default_otlp_endpoint")]
392 pub endpoint: String,
393 #[serde(default = "default_service_name")]
395 pub service_name: String,
396}
397
398fn default_version() -> String {
401 "0.1.0".to_string()
402}
403
404fn default_separator() -> String {
405 "/".to_string()
406}
407
408fn default_host() -> String {
409 "127.0.0.1".to_string()
410}
411
412fn default_port() -> u16 {
413 8080
414}
415
416fn default_log_level() -> String {
417 "info".to_string()
418}
419
420fn default_failure_rate() -> f64 {
421 0.5
422}
423
424fn default_min_calls() -> usize {
425 5
426}
427
428fn default_wait_duration() -> u64 {
429 30
430}
431
432fn default_half_open_calls() -> usize {
433 3
434}
435
436fn default_rate_period() -> u64 {
437 1
438}
439
440fn default_max_retries() -> u32 {
441 3
442}
443
444fn default_initial_backoff_ms() -> u64 {
445 100
446}
447
448fn default_max_backoff_ms() -> u64 {
449 5000
450}
451
452fn default_min_retries_per_sec() -> u32 {
453 10
454}
455
456fn default_consecutive_errors() -> u32 {
457 5
458}
459
460fn default_interval_seconds() -> u64 {
461 10
462}
463
464fn default_base_ejection_seconds() -> u64 {
465 30
466}
467
468fn default_max_ejection_percent() -> u32 {
469 50
470}
471
472fn default_hedge_delay_ms() -> u64 {
473 200
474}
475
476fn default_max_hedges() -> usize {
477 1
478}
479
480fn default_mirror_percent() -> u32 {
481 100
482}
483
484fn default_weight() -> u32 {
485 100
486}
487
488fn default_max_cache_entries() -> u64 {
489 1000
490}
491
492fn default_shutdown_timeout() -> u64 {
493 30
494}
495
496fn default_otlp_endpoint() -> String {
497 "http://localhost:4317".to_string()
498}
499
500fn default_service_name() -> String {
501 "mcp-proxy".to_string()
502}
503
504#[derive(Debug, Clone)]
506pub struct BackendFilter {
507 pub namespace: String,
509 pub tool_filter: NameFilter,
511 pub resource_filter: NameFilter,
513 pub prompt_filter: NameFilter,
515}
516
517#[derive(Debug, Clone)]
519pub enum NameFilter {
520 PassAll,
522 AllowList(HashSet<String>),
524 DenyList(HashSet<String>),
526}
527
528impl NameFilter {
529 pub fn allows(&self, name: &str) -> bool {
548 match self {
549 Self::PassAll => true,
550 Self::AllowList(set) => set.contains(name),
551 Self::DenyList(set) => !set.contains(name),
552 }
553 }
554}
555
556impl BackendConfig {
557 pub fn build_filter(&self, separator: &str) -> Option<BackendFilter> {
564 if self.canary_of.is_some() {
567 return Some(BackendFilter {
568 namespace: format!("{}{}", self.name, separator),
569 tool_filter: NameFilter::AllowList(HashSet::new()),
570 resource_filter: NameFilter::AllowList(HashSet::new()),
571 prompt_filter: NameFilter::AllowList(HashSet::new()),
572 });
573 }
574
575 let tool_filter = if !self.expose_tools.is_empty() {
576 NameFilter::AllowList(self.expose_tools.iter().cloned().collect())
577 } else if !self.hide_tools.is_empty() {
578 NameFilter::DenyList(self.hide_tools.iter().cloned().collect())
579 } else {
580 NameFilter::PassAll
581 };
582
583 let resource_filter = if !self.expose_resources.is_empty() {
584 NameFilter::AllowList(self.expose_resources.iter().cloned().collect())
585 } else if !self.hide_resources.is_empty() {
586 NameFilter::DenyList(self.hide_resources.iter().cloned().collect())
587 } else {
588 NameFilter::PassAll
589 };
590
591 let prompt_filter = if !self.expose_prompts.is_empty() {
592 NameFilter::AllowList(self.expose_prompts.iter().cloned().collect())
593 } else if !self.hide_prompts.is_empty() {
594 NameFilter::DenyList(self.hide_prompts.iter().cloned().collect())
595 } else {
596 NameFilter::PassAll
597 };
598
599 if matches!(tool_filter, NameFilter::PassAll)
601 && matches!(resource_filter, NameFilter::PassAll)
602 && matches!(prompt_filter, NameFilter::PassAll)
603 {
604 return None;
605 }
606
607 Some(BackendFilter {
608 namespace: format!("{}{}", self.name, separator),
609 tool_filter,
610 resource_filter,
611 prompt_filter,
612 })
613 }
614}
615
616impl ProxyConfig {
617 pub fn load(path: &Path) -> Result<Self> {
619 let content =
620 std::fs::read_to_string(path).with_context(|| format!("reading {}", path.display()))?;
621 let config: Self =
622 toml::from_str(&content).with_context(|| format!("parsing {}", path.display()))?;
623 config.validate()?;
624 Ok(config)
625 }
626
627 pub fn parse(toml: &str) -> Result<Self> {
649 let config: Self = toml::from_str(toml).context("parsing config")?;
650 config.validate()?;
651 Ok(config)
652 }
653
654 fn validate(&self) -> Result<()> {
655 if self.backends.is_empty() {
656 anyhow::bail!("at least one backend is required");
657 }
658 for backend in &self.backends {
659 match backend.transport {
660 TransportType::Stdio => {
661 if backend.command.is_none() {
662 anyhow::bail!(
663 "backend '{}': stdio transport requires 'command'",
664 backend.name
665 );
666 }
667 }
668 TransportType::Http => {
669 if backend.url.is_none() {
670 anyhow::bail!("backend '{}': http transport requires 'url'", backend.name);
671 }
672 }
673 }
674
675 if let Some(cb) = &backend.circuit_breaker
676 && (cb.failure_rate_threshold <= 0.0 || cb.failure_rate_threshold > 1.0)
677 {
678 anyhow::bail!(
679 "backend '{}': circuit_breaker.failure_rate_threshold must be in (0.0, 1.0]",
680 backend.name
681 );
682 }
683
684 if let Some(rl) = &backend.rate_limit
685 && rl.requests == 0
686 {
687 anyhow::bail!(
688 "backend '{}': rate_limit.requests must be > 0",
689 backend.name
690 );
691 }
692
693 if let Some(cc) = &backend.concurrency
694 && cc.max_concurrent == 0
695 {
696 anyhow::bail!(
697 "backend '{}': concurrency.max_concurrent must be > 0",
698 backend.name
699 );
700 }
701
702 if !backend.expose_tools.is_empty() && !backend.hide_tools.is_empty() {
703 anyhow::bail!(
704 "backend '{}': cannot specify both expose_tools and hide_tools",
705 backend.name
706 );
707 }
708 if !backend.expose_resources.is_empty() && !backend.hide_resources.is_empty() {
709 anyhow::bail!(
710 "backend '{}': cannot specify both expose_resources and hide_resources",
711 backend.name
712 );
713 }
714 if !backend.expose_prompts.is_empty() && !backend.hide_prompts.is_empty() {
715 anyhow::bail!(
716 "backend '{}': cannot specify both expose_prompts and hide_prompts",
717 backend.name
718 );
719 }
720 }
721
722 let backend_names: HashSet<&str> = self.backends.iter().map(|b| b.name.as_str()).collect();
724 for backend in &self.backends {
725 if let Some(ref source) = backend.mirror_of {
726 if !backend_names.contains(source.as_str()) {
727 anyhow::bail!(
728 "backend '{}': mirror_of references unknown backend '{}'",
729 backend.name,
730 source
731 );
732 }
733 if source == &backend.name {
734 anyhow::bail!(
735 "backend '{}': mirror_of cannot reference itself",
736 backend.name
737 );
738 }
739 }
740 }
741
742 for backend in &self.backends {
744 if let Some(ref primary) = backend.canary_of {
745 if !backend_names.contains(primary.as_str()) {
746 anyhow::bail!(
747 "backend '{}': canary_of references unknown backend '{}'",
748 backend.name,
749 primary
750 );
751 }
752 if primary == &backend.name {
753 anyhow::bail!(
754 "backend '{}': canary_of cannot reference itself",
755 backend.name
756 );
757 }
758 if backend.weight == 0 {
759 anyhow::bail!("backend '{}': weight must be > 0", backend.name);
760 }
761 }
762 }
763
764 Ok(())
765 }
766
767 pub fn resolve_env_vars(&mut self) {
770 for backend in &mut self.backends {
771 for value in backend.env.values_mut() {
772 if let Some(var_name) = value.strip_prefix("${").and_then(|s| s.strip_suffix('}'))
773 && let Ok(env_val) = std::env::var(var_name)
774 {
775 *value = env_val;
776 }
777 }
778 if let Some(ref mut token) = backend.bearer_token
779 && let Some(var_name) = token.strip_prefix("${").and_then(|s| s.strip_suffix('}'))
780 && let Ok(env_val) = std::env::var(var_name)
781 {
782 *token = env_val;
783 }
784 }
785 }
786}
787
788#[cfg(test)]
789mod tests {
790 use super::*;
791
792 fn minimal_config() -> &'static str {
793 r#"
794 [proxy]
795 name = "test"
796 [proxy.listen]
797
798 [[backends]]
799 name = "echo"
800 transport = "stdio"
801 command = "echo"
802 "#
803 }
804
805 #[test]
806 fn test_parse_minimal_config() {
807 let config = ProxyConfig::parse(minimal_config()).unwrap();
808 assert_eq!(config.proxy.name, "test");
809 assert_eq!(config.proxy.version, "0.1.0"); assert_eq!(config.proxy.separator, "/"); assert_eq!(config.proxy.listen.host, "127.0.0.1"); assert_eq!(config.proxy.listen.port, 8080); assert_eq!(config.proxy.shutdown_timeout_seconds, 30); assert!(!config.proxy.hot_reload); assert_eq!(config.backends.len(), 1);
816 assert_eq!(config.backends[0].name, "echo");
817 assert!(config.auth.is_none());
818 assert!(!config.observability.audit);
819 assert!(!config.observability.metrics.enabled);
820 }
821
822 #[test]
823 fn test_parse_full_config() {
824 let toml = r#"
825 [proxy]
826 name = "full-gw"
827 version = "2.0.0"
828 separator = "."
829 shutdown_timeout_seconds = 60
830 hot_reload = true
831 instructions = "A test proxy"
832 [proxy.listen]
833 host = "0.0.0.0"
834 port = 9090
835
836 [[backends]]
837 name = "files"
838 transport = "stdio"
839 command = "file-server"
840 args = ["--root", "/tmp"]
841 expose_tools = ["read_file"]
842
843 [backends.env]
844 LOG_LEVEL = "debug"
845
846 [backends.timeout]
847 seconds = 30
848
849 [backends.concurrency]
850 max_concurrent = 5
851
852 [backends.rate_limit]
853 requests = 100
854 period_seconds = 10
855
856 [backends.circuit_breaker]
857 failure_rate_threshold = 0.5
858 minimum_calls = 10
859 wait_duration_seconds = 60
860 permitted_calls_in_half_open = 2
861
862 [backends.cache]
863 resource_ttl_seconds = 300
864 tool_ttl_seconds = 60
865 max_entries = 500
866
867 [[backends.aliases]]
868 from = "read_file"
869 to = "read"
870
871 [[backends]]
872 name = "remote"
873 transport = "http"
874 url = "http://localhost:3000"
875
876 [observability]
877 audit = true
878 log_level = "debug"
879 json_logs = true
880
881 [observability.metrics]
882 enabled = true
883
884 [observability.tracing]
885 enabled = true
886 endpoint = "http://jaeger:4317"
887 service_name = "test-gw"
888
889 [performance]
890 coalesce_requests = true
891
892 [security]
893 max_argument_size = 1048576
894 "#;
895
896 let config = ProxyConfig::parse(toml).unwrap();
897 assert_eq!(config.proxy.name, "full-gw");
898 assert_eq!(config.proxy.version, "2.0.0");
899 assert_eq!(config.proxy.separator, ".");
900 assert_eq!(config.proxy.shutdown_timeout_seconds, 60);
901 assert!(config.proxy.hot_reload);
902 assert_eq!(config.proxy.instructions.as_deref(), Some("A test proxy"));
903 assert_eq!(config.proxy.listen.host, "0.0.0.0");
904 assert_eq!(config.proxy.listen.port, 9090);
905
906 assert_eq!(config.backends.len(), 2);
907
908 let files = &config.backends[0];
909 assert_eq!(files.command.as_deref(), Some("file-server"));
910 assert_eq!(files.args, vec!["--root", "/tmp"]);
911 assert_eq!(files.expose_tools, vec!["read_file"]);
912 assert_eq!(files.env.get("LOG_LEVEL").unwrap(), "debug");
913 assert_eq!(files.timeout.as_ref().unwrap().seconds, 30);
914 assert_eq!(files.concurrency.as_ref().unwrap().max_concurrent, 5);
915 assert_eq!(files.rate_limit.as_ref().unwrap().requests, 100);
916 assert_eq!(files.cache.as_ref().unwrap().resource_ttl_seconds, 300);
917 assert_eq!(files.cache.as_ref().unwrap().tool_ttl_seconds, 60);
918 assert_eq!(files.cache.as_ref().unwrap().max_entries, 500);
919 assert_eq!(files.aliases.len(), 1);
920 assert_eq!(files.aliases[0].from, "read_file");
921 assert_eq!(files.aliases[0].to, "read");
922
923 let cb = files.circuit_breaker.as_ref().unwrap();
924 assert_eq!(cb.failure_rate_threshold, 0.5);
925 assert_eq!(cb.minimum_calls, 10);
926 assert_eq!(cb.wait_duration_seconds, 60);
927 assert_eq!(cb.permitted_calls_in_half_open, 2);
928
929 let remote = &config.backends[1];
930 assert_eq!(remote.url.as_deref(), Some("http://localhost:3000"));
931
932 assert!(config.observability.audit);
933 assert_eq!(config.observability.log_level, "debug");
934 assert!(config.observability.json_logs);
935 assert!(config.observability.metrics.enabled);
936 assert!(config.observability.tracing.enabled);
937 assert_eq!(config.observability.tracing.endpoint, "http://jaeger:4317");
938
939 assert!(config.performance.coalesce_requests);
940 assert_eq!(config.security.max_argument_size, Some(1048576));
941 }
942
943 #[test]
944 fn test_parse_bearer_auth() {
945 let toml = r#"
946 [proxy]
947 name = "auth-gw"
948 [proxy.listen]
949
950 [[backends]]
951 name = "echo"
952 transport = "stdio"
953 command = "echo"
954
955 [auth]
956 type = "bearer"
957 tokens = ["token-1", "token-2"]
958 "#;
959
960 let config = ProxyConfig::parse(toml).unwrap();
961 match &config.auth {
962 Some(AuthConfig::Bearer { tokens }) => {
963 assert_eq!(tokens, &["token-1", "token-2"]);
964 }
965 other => panic!("expected Bearer auth, got: {:?}", other),
966 }
967 }
968
969 #[test]
970 fn test_parse_jwt_auth_with_rbac() {
971 let toml = r#"
972 [proxy]
973 name = "jwt-gw"
974 [proxy.listen]
975
976 [[backends]]
977 name = "echo"
978 transport = "stdio"
979 command = "echo"
980
981 [auth]
982 type = "jwt"
983 issuer = "https://auth.example.com"
984 audience = "mcp-proxy"
985 jwks_uri = "https://auth.example.com/.well-known/jwks.json"
986
987 [[auth.roles]]
988 name = "reader"
989 allow_tools = ["echo/read"]
990
991 [[auth.roles]]
992 name = "admin"
993
994 [auth.role_mapping]
995 claim = "scope"
996 mapping = { "mcp:read" = "reader", "mcp:admin" = "admin" }
997 "#;
998
999 let config = ProxyConfig::parse(toml).unwrap();
1000 match &config.auth {
1001 Some(AuthConfig::Jwt {
1002 issuer,
1003 audience,
1004 jwks_uri,
1005 roles,
1006 role_mapping,
1007 }) => {
1008 assert_eq!(issuer, "https://auth.example.com");
1009 assert_eq!(audience, "mcp-proxy");
1010 assert_eq!(jwks_uri, "https://auth.example.com/.well-known/jwks.json");
1011 assert_eq!(roles.len(), 2);
1012 assert_eq!(roles[0].name, "reader");
1013 assert_eq!(roles[0].allow_tools, vec!["echo/read"]);
1014 let mapping = role_mapping.as_ref().unwrap();
1015 assert_eq!(mapping.claim, "scope");
1016 assert_eq!(mapping.mapping.get("mcp:read").unwrap(), "reader");
1017 }
1018 other => panic!("expected Jwt auth, got: {:?}", other),
1019 }
1020 }
1021
1022 #[test]
1027 fn test_reject_no_backends() {
1028 let toml = r#"
1029 [proxy]
1030 name = "empty"
1031 [proxy.listen]
1032 "#;
1033
1034 let err = ProxyConfig::parse(toml).unwrap_err();
1035 assert!(
1036 format!("{err}").contains("at least one backend"),
1037 "unexpected error: {err}"
1038 );
1039 }
1040
1041 #[test]
1042 fn test_reject_stdio_without_command() {
1043 let toml = r#"
1044 [proxy]
1045 name = "bad"
1046 [proxy.listen]
1047
1048 [[backends]]
1049 name = "broken"
1050 transport = "stdio"
1051 "#;
1052
1053 let err = ProxyConfig::parse(toml).unwrap_err();
1054 assert!(
1055 format!("{err}").contains("stdio transport requires 'command'"),
1056 "unexpected error: {err}"
1057 );
1058 }
1059
1060 #[test]
1061 fn test_reject_http_without_url() {
1062 let toml = r#"
1063 [proxy]
1064 name = "bad"
1065 [proxy.listen]
1066
1067 [[backends]]
1068 name = "broken"
1069 transport = "http"
1070 "#;
1071
1072 let err = ProxyConfig::parse(toml).unwrap_err();
1073 assert!(
1074 format!("{err}").contains("http transport requires 'url'"),
1075 "unexpected error: {err}"
1076 );
1077 }
1078
1079 #[test]
1080 fn test_reject_invalid_circuit_breaker_threshold() {
1081 let toml = r#"
1082 [proxy]
1083 name = "bad"
1084 [proxy.listen]
1085
1086 [[backends]]
1087 name = "svc"
1088 transport = "stdio"
1089 command = "echo"
1090
1091 [backends.circuit_breaker]
1092 failure_rate_threshold = 1.5
1093 "#;
1094
1095 let err = ProxyConfig::parse(toml).unwrap_err();
1096 assert!(
1097 format!("{err}").contains("failure_rate_threshold must be in (0.0, 1.0]"),
1098 "unexpected error: {err}"
1099 );
1100 }
1101
1102 #[test]
1103 fn test_reject_zero_rate_limit() {
1104 let toml = r#"
1105 [proxy]
1106 name = "bad"
1107 [proxy.listen]
1108
1109 [[backends]]
1110 name = "svc"
1111 transport = "stdio"
1112 command = "echo"
1113
1114 [backends.rate_limit]
1115 requests = 0
1116 "#;
1117
1118 let err = ProxyConfig::parse(toml).unwrap_err();
1119 assert!(
1120 format!("{err}").contains("rate_limit.requests must be > 0"),
1121 "unexpected error: {err}"
1122 );
1123 }
1124
1125 #[test]
1126 fn test_reject_zero_concurrency() {
1127 let toml = r#"
1128 [proxy]
1129 name = "bad"
1130 [proxy.listen]
1131
1132 [[backends]]
1133 name = "svc"
1134 transport = "stdio"
1135 command = "echo"
1136
1137 [backends.concurrency]
1138 max_concurrent = 0
1139 "#;
1140
1141 let err = ProxyConfig::parse(toml).unwrap_err();
1142 assert!(
1143 format!("{err}").contains("concurrency.max_concurrent must be > 0"),
1144 "unexpected error: {err}"
1145 );
1146 }
1147
1148 #[test]
1149 fn test_reject_expose_and_hide_tools() {
1150 let toml = r#"
1151 [proxy]
1152 name = "bad"
1153 [proxy.listen]
1154
1155 [[backends]]
1156 name = "svc"
1157 transport = "stdio"
1158 command = "echo"
1159 expose_tools = ["read"]
1160 hide_tools = ["write"]
1161 "#;
1162
1163 let err = ProxyConfig::parse(toml).unwrap_err();
1164 assert!(
1165 format!("{err}").contains("cannot specify both expose_tools and hide_tools"),
1166 "unexpected error: {err}"
1167 );
1168 }
1169
1170 #[test]
1171 fn test_reject_expose_and_hide_resources() {
1172 let toml = r#"
1173 [proxy]
1174 name = "bad"
1175 [proxy.listen]
1176
1177 [[backends]]
1178 name = "svc"
1179 transport = "stdio"
1180 command = "echo"
1181 expose_resources = ["file:///a"]
1182 hide_resources = ["file:///b"]
1183 "#;
1184
1185 let err = ProxyConfig::parse(toml).unwrap_err();
1186 assert!(
1187 format!("{err}").contains("cannot specify both expose_resources and hide_resources"),
1188 "unexpected error: {err}"
1189 );
1190 }
1191
1192 #[test]
1193 fn test_reject_expose_and_hide_prompts() {
1194 let toml = r#"
1195 [proxy]
1196 name = "bad"
1197 [proxy.listen]
1198
1199 [[backends]]
1200 name = "svc"
1201 transport = "stdio"
1202 command = "echo"
1203 expose_prompts = ["help"]
1204 hide_prompts = ["admin"]
1205 "#;
1206
1207 let err = ProxyConfig::parse(toml).unwrap_err();
1208 assert!(
1209 format!("{err}").contains("cannot specify both expose_prompts and hide_prompts"),
1210 "unexpected error: {err}"
1211 );
1212 }
1213
1214 #[test]
1219 fn test_resolve_env_vars() {
1220 unsafe { std::env::set_var("MCP_GW_TEST_TOKEN", "secret-123") };
1222
1223 let toml = r#"
1224 [proxy]
1225 name = "env-test"
1226 [proxy.listen]
1227
1228 [[backends]]
1229 name = "svc"
1230 transport = "stdio"
1231 command = "echo"
1232
1233 [backends.env]
1234 API_TOKEN = "${MCP_GW_TEST_TOKEN}"
1235 STATIC_VAL = "unchanged"
1236 "#;
1237
1238 let mut config = ProxyConfig::parse(toml).unwrap();
1239 config.resolve_env_vars();
1240
1241 assert_eq!(
1242 config.backends[0].env.get("API_TOKEN").unwrap(),
1243 "secret-123"
1244 );
1245 assert_eq!(
1246 config.backends[0].env.get("STATIC_VAL").unwrap(),
1247 "unchanged"
1248 );
1249
1250 unsafe { std::env::remove_var("MCP_GW_TEST_TOKEN") };
1252 }
1253
1254 #[test]
1255 fn test_parse_bearer_token_and_forward_auth() {
1256 let toml = r#"
1257 [proxy]
1258 name = "token-gw"
1259 [proxy.listen]
1260
1261 [[backends]]
1262 name = "github"
1263 transport = "http"
1264 url = "http://localhost:3000"
1265 bearer_token = "ghp_abc123"
1266 forward_auth = true
1267
1268 [[backends]]
1269 name = "db"
1270 transport = "http"
1271 url = "http://localhost:5432"
1272 "#;
1273
1274 let config = ProxyConfig::parse(toml).unwrap();
1275 assert_eq!(
1276 config.backends[0].bearer_token.as_deref(),
1277 Some("ghp_abc123")
1278 );
1279 assert!(config.backends[0].forward_auth);
1280 assert!(config.backends[1].bearer_token.is_none());
1281 assert!(!config.backends[1].forward_auth);
1282 }
1283
1284 #[test]
1285 fn test_resolve_bearer_token_env_var() {
1286 unsafe { std::env::set_var("MCP_GW_TEST_BEARER", "resolved-token") };
1287
1288 let toml = r#"
1289 [proxy]
1290 name = "env-token"
1291 [proxy.listen]
1292
1293 [[backends]]
1294 name = "api"
1295 transport = "http"
1296 url = "http://localhost:3000"
1297 bearer_token = "${MCP_GW_TEST_BEARER}"
1298 "#;
1299
1300 let mut config = ProxyConfig::parse(toml).unwrap();
1301 config.resolve_env_vars();
1302
1303 assert_eq!(
1304 config.backends[0].bearer_token.as_deref(),
1305 Some("resolved-token")
1306 );
1307
1308 unsafe { std::env::remove_var("MCP_GW_TEST_BEARER") };
1309 }
1310
1311 #[test]
1312 fn test_parse_outlier_detection() {
1313 let toml = r#"
1314 [proxy]
1315 name = "od-gw"
1316 [proxy.listen]
1317
1318 [[backends]]
1319 name = "flaky"
1320 transport = "http"
1321 url = "http://localhost:8080"
1322
1323 [backends.outlier_detection]
1324 consecutive_errors = 3
1325 interval_seconds = 5
1326 base_ejection_seconds = 60
1327 max_ejection_percent = 25
1328 "#;
1329
1330 let config = ProxyConfig::parse(toml).unwrap();
1331 let od = config.backends[0]
1332 .outlier_detection
1333 .as_ref()
1334 .expect("should have outlier_detection");
1335 assert_eq!(od.consecutive_errors, 3);
1336 assert_eq!(od.interval_seconds, 5);
1337 assert_eq!(od.base_ejection_seconds, 60);
1338 assert_eq!(od.max_ejection_percent, 25);
1339 }
1340
1341 #[test]
1342 fn test_parse_outlier_detection_defaults() {
1343 let toml = r#"
1344 [proxy]
1345 name = "od-gw"
1346 [proxy.listen]
1347
1348 [[backends]]
1349 name = "flaky"
1350 transport = "http"
1351 url = "http://localhost:8080"
1352
1353 [backends.outlier_detection]
1354 "#;
1355
1356 let config = ProxyConfig::parse(toml).unwrap();
1357 let od = config.backends[0]
1358 .outlier_detection
1359 .as_ref()
1360 .expect("should have outlier_detection");
1361 assert_eq!(od.consecutive_errors, 5);
1362 assert_eq!(od.interval_seconds, 10);
1363 assert_eq!(od.base_ejection_seconds, 30);
1364 assert_eq!(od.max_ejection_percent, 50);
1365 }
1366
1367 #[test]
1368 fn test_parse_mirror_config() {
1369 let toml = r#"
1370 [proxy]
1371 name = "mirror-gw"
1372 [proxy.listen]
1373
1374 [[backends]]
1375 name = "api"
1376 transport = "http"
1377 url = "http://localhost:8080"
1378
1379 [[backends]]
1380 name = "api-v2"
1381 transport = "http"
1382 url = "http://localhost:8081"
1383 mirror_of = "api"
1384 mirror_percent = 10
1385 "#;
1386
1387 let config = ProxyConfig::parse(toml).unwrap();
1388 assert!(config.backends[0].mirror_of.is_none());
1389 assert_eq!(config.backends[1].mirror_of.as_deref(), Some("api"));
1390 assert_eq!(config.backends[1].mirror_percent, 10);
1391 }
1392
1393 #[test]
1394 fn test_mirror_percent_defaults_to_100() {
1395 let toml = r#"
1396 [proxy]
1397 name = "mirror-gw"
1398 [proxy.listen]
1399
1400 [[backends]]
1401 name = "api"
1402 transport = "http"
1403 url = "http://localhost:8080"
1404
1405 [[backends]]
1406 name = "api-v2"
1407 transport = "http"
1408 url = "http://localhost:8081"
1409 mirror_of = "api"
1410 "#;
1411
1412 let config = ProxyConfig::parse(toml).unwrap();
1413 assert_eq!(config.backends[1].mirror_percent, 100);
1414 }
1415
1416 #[test]
1417 fn test_reject_mirror_unknown_backend() {
1418 let toml = r#"
1419 [proxy]
1420 name = "bad"
1421 [proxy.listen]
1422
1423 [[backends]]
1424 name = "api-v2"
1425 transport = "http"
1426 url = "http://localhost:8081"
1427 mirror_of = "nonexistent"
1428 "#;
1429
1430 let err = ProxyConfig::parse(toml).unwrap_err();
1431 assert!(
1432 format!("{err}").contains("mirror_of references unknown backend"),
1433 "unexpected error: {err}"
1434 );
1435 }
1436
1437 #[test]
1438 fn test_reject_mirror_self() {
1439 let toml = r#"
1440 [proxy]
1441 name = "bad"
1442 [proxy.listen]
1443
1444 [[backends]]
1445 name = "api"
1446 transport = "http"
1447 url = "http://localhost:8080"
1448 mirror_of = "api"
1449 "#;
1450
1451 let err = ProxyConfig::parse(toml).unwrap_err();
1452 assert!(
1453 format!("{err}").contains("mirror_of cannot reference itself"),
1454 "unexpected error: {err}"
1455 );
1456 }
1457
1458 #[test]
1459 fn test_parse_hedging_config() {
1460 let toml = r#"
1461 [proxy]
1462 name = "hedge-gw"
1463 [proxy.listen]
1464
1465 [[backends]]
1466 name = "api"
1467 transport = "http"
1468 url = "http://localhost:8080"
1469
1470 [backends.hedging]
1471 delay_ms = 150
1472 max_hedges = 2
1473 "#;
1474
1475 let config = ProxyConfig::parse(toml).unwrap();
1476 let hedge = config.backends[0]
1477 .hedging
1478 .as_ref()
1479 .expect("should have hedging");
1480 assert_eq!(hedge.delay_ms, 150);
1481 assert_eq!(hedge.max_hedges, 2);
1482 }
1483
1484 #[test]
1485 fn test_parse_hedging_defaults() {
1486 let toml = r#"
1487 [proxy]
1488 name = "hedge-gw"
1489 [proxy.listen]
1490
1491 [[backends]]
1492 name = "api"
1493 transport = "http"
1494 url = "http://localhost:8080"
1495
1496 [backends.hedging]
1497 "#;
1498
1499 let config = ProxyConfig::parse(toml).unwrap();
1500 let hedge = config.backends[0]
1501 .hedging
1502 .as_ref()
1503 .expect("should have hedging");
1504 assert_eq!(hedge.delay_ms, 200);
1505 assert_eq!(hedge.max_hedges, 1);
1506 }
1507
1508 #[test]
1513 fn test_build_filter_allowlist() {
1514 let toml = r#"
1515 [proxy]
1516 name = "filter"
1517 [proxy.listen]
1518
1519 [[backends]]
1520 name = "svc"
1521 transport = "stdio"
1522 command = "echo"
1523 expose_tools = ["read", "list"]
1524 "#;
1525
1526 let config = ProxyConfig::parse(toml).unwrap();
1527 let filter = config.backends[0]
1528 .build_filter(&config.proxy.separator)
1529 .expect("should have filter");
1530 assert_eq!(filter.namespace, "svc/");
1531 assert!(filter.tool_filter.allows("read"));
1532 assert!(filter.tool_filter.allows("list"));
1533 assert!(!filter.tool_filter.allows("delete"));
1534 }
1535
1536 #[test]
1537 fn test_build_filter_denylist() {
1538 let toml = r#"
1539 [proxy]
1540 name = "filter"
1541 [proxy.listen]
1542
1543 [[backends]]
1544 name = "svc"
1545 transport = "stdio"
1546 command = "echo"
1547 hide_tools = ["delete", "write"]
1548 "#;
1549
1550 let config = ProxyConfig::parse(toml).unwrap();
1551 let filter = config.backends[0]
1552 .build_filter(&config.proxy.separator)
1553 .expect("should have filter");
1554 assert!(filter.tool_filter.allows("read"));
1555 assert!(!filter.tool_filter.allows("delete"));
1556 assert!(!filter.tool_filter.allows("write"));
1557 }
1558
1559 #[test]
1560 fn test_parse_inject_args() {
1561 let toml = r#"
1562 [proxy]
1563 name = "inject-gw"
1564 [proxy.listen]
1565
1566 [[backends]]
1567 name = "db"
1568 transport = "http"
1569 url = "http://localhost:8080"
1570
1571 [backends.default_args]
1572 timeout = 30
1573
1574 [[backends.inject_args]]
1575 tool = "query"
1576 args = { read_only = true, max_rows = 1000 }
1577
1578 [[backends.inject_args]]
1579 tool = "dangerous_op"
1580 args = { dry_run = true }
1581 overwrite = true
1582 "#;
1583
1584 let config = ProxyConfig::parse(toml).unwrap();
1585 let backend = &config.backends[0];
1586
1587 assert_eq!(backend.default_args.len(), 1);
1588 assert_eq!(backend.default_args["timeout"], 30);
1589
1590 assert_eq!(backend.inject_args.len(), 2);
1591 assert_eq!(backend.inject_args[0].tool, "query");
1592 assert_eq!(backend.inject_args[0].args["read_only"], true);
1593 assert_eq!(backend.inject_args[0].args["max_rows"], 1000);
1594 assert!(!backend.inject_args[0].overwrite);
1595
1596 assert_eq!(backend.inject_args[1].tool, "dangerous_op");
1597 assert_eq!(backend.inject_args[1].args["dry_run"], true);
1598 assert!(backend.inject_args[1].overwrite);
1599 }
1600
1601 #[test]
1602 fn test_parse_inject_args_defaults_to_empty() {
1603 let config = ProxyConfig::parse(minimal_config()).unwrap();
1604 assert!(config.backends[0].default_args.is_empty());
1605 assert!(config.backends[0].inject_args.is_empty());
1606 }
1607
1608 #[test]
1609 fn test_build_filter_none_when_no_filtering() {
1610 let config = ProxyConfig::parse(minimal_config()).unwrap();
1611 assert!(
1612 config.backends[0]
1613 .build_filter(&config.proxy.separator)
1614 .is_none()
1615 );
1616 }
1617}