Skip to main content

mcp_proxy/
config.rs

1//! Proxy configuration types and TOML parsing.
2
3use std::collections::HashMap;
4use std::collections::HashSet;
5use std::path::Path;
6
7use anyhow::{Context, Result};
8use serde::{Deserialize, Serialize};
9
10/// Top-level proxy configuration, typically loaded from a TOML file.
11#[derive(Debug, Deserialize, Serialize)]
12pub struct ProxyConfig {
13    /// Core proxy settings (name, version, listen address).
14    pub proxy: ProxySettings,
15    /// Backend MCP servers to proxy.
16    #[serde(default)]
17    pub backends: Vec<BackendConfig>,
18    /// Inbound authentication configuration.
19    pub auth: Option<AuthConfig>,
20    /// Performance tuning options.
21    #[serde(default)]
22    pub performance: PerformanceConfig,
23    /// Security policies.
24    #[serde(default)]
25    pub security: SecurityConfig,
26    /// Logging, metrics, and tracing configuration.
27    #[serde(default)]
28    pub observability: ObservabilityConfig,
29}
30
31/// Core proxy identity and server settings.
32#[derive(Debug, Deserialize, Serialize)]
33pub struct ProxySettings {
34    /// Proxy name, used in MCP server info.
35    pub name: String,
36    /// Proxy version, used in MCP server info (default: "0.1.0").
37    #[serde(default = "default_version")]
38    pub version: String,
39    /// Namespace separator between backend name and tool/resource name (default: "/").
40    #[serde(default = "default_separator")]
41    pub separator: String,
42    /// HTTP listen address.
43    pub listen: ListenConfig,
44    /// Optional instructions text sent to MCP clients.
45    pub instructions: Option<String>,
46    /// Graceful shutdown timeout in seconds (default: 30)
47    #[serde(default = "default_shutdown_timeout")]
48    pub shutdown_timeout_seconds: u64,
49    /// Enable hot reload: watch config file for new backends
50    #[serde(default)]
51    pub hot_reload: bool,
52}
53
54/// HTTP server listen address.
55#[derive(Debug, Deserialize, Serialize)]
56pub struct ListenConfig {
57    /// Bind host (default: "127.0.0.1").
58    #[serde(default = "default_host")]
59    pub host: String,
60    /// Bind port (default: 8080).
61    #[serde(default = "default_port")]
62    pub port: u16,
63}
64
65/// Configuration for a single backend MCP server.
66#[derive(Debug, Deserialize, Serialize)]
67pub struct BackendConfig {
68    /// Unique backend name, used as the namespace prefix for its tools/resources.
69    pub name: String,
70    /// Transport protocol to use when connecting to this backend.
71    pub transport: TransportType,
72    /// Command for stdio backends
73    pub command: Option<String>,
74    /// Arguments for stdio backends
75    #[serde(default)]
76    pub args: Vec<String>,
77    /// URL for HTTP backends
78    pub url: Option<String>,
79    /// Environment variables for subprocess backends
80    #[serde(default)]
81    pub env: HashMap<String, String>,
82    /// Per-backend timeout
83    pub timeout: Option<TimeoutConfig>,
84    /// Per-backend circuit breaker
85    pub circuit_breaker: Option<CircuitBreakerConfig>,
86    /// Per-backend rate limit
87    pub rate_limit: Option<RateLimitConfig>,
88    /// Per-backend concurrency limit
89    pub concurrency: Option<ConcurrencyConfig>,
90    /// Per-backend retry policy
91    pub retry: Option<RetryConfig>,
92    /// Per-backend outlier detection (passive health checks)
93    pub outlier_detection: Option<OutlierDetectionConfig>,
94    /// Per-backend request hedging (parallel redundant requests)
95    pub hedging: Option<HedgingConfig>,
96    /// Mirror traffic from another backend (fire-and-forget).
97    /// Set to the name of the source backend to mirror.
98    pub mirror_of: Option<String>,
99    /// Percentage of requests to mirror (1-100, default: 100).
100    #[serde(default = "default_mirror_percent")]
101    pub mirror_percent: u32,
102    /// Per-backend cache policy
103    pub cache: Option<BackendCacheConfig>,
104    /// Static bearer token for authenticating to this backend (HTTP only).
105    /// Supports `${ENV_VAR}` syntax for env var resolution.
106    pub bearer_token: Option<String>,
107    /// Forward the client's inbound auth token to this backend.
108    /// Only works with HTTP backends when the proxy has auth enabled.
109    #[serde(default)]
110    pub forward_auth: bool,
111    /// Tool aliases: rename tools exposed by this backend
112    #[serde(default)]
113    pub aliases: Vec<AliasConfig>,
114    /// Default arguments injected into all tool calls for this backend.
115    /// Merged into tool call arguments (does not overwrite existing keys).
116    #[serde(default)]
117    pub default_args: serde_json::Map<String, serde_json::Value>,
118    /// Per-tool argument injection rules.
119    #[serde(default)]
120    pub inject_args: Vec<InjectArgsConfig>,
121    /// Capability filtering: only expose these tools (allowlist)
122    #[serde(default)]
123    pub expose_tools: Vec<String>,
124    /// Capability filtering: hide these tools (denylist)
125    #[serde(default)]
126    pub hide_tools: Vec<String>,
127    /// Capability filtering: only expose these resources (allowlist, by URI)
128    #[serde(default)]
129    pub expose_resources: Vec<String>,
130    /// Capability filtering: hide these resources (denylist, by URI)
131    #[serde(default)]
132    pub hide_resources: Vec<String>,
133    /// Capability filtering: only expose these prompts (allowlist)
134    #[serde(default)]
135    pub expose_prompts: Vec<String>,
136    /// Capability filtering: hide these prompts (denylist)
137    #[serde(default)]
138    pub hide_prompts: Vec<String>,
139    /// Canary routing: name of the primary backend this is a canary for.
140    /// When set, this backend's tools are hidden and requests targeting
141    /// the primary are probabilistically routed here based on weight.
142    pub canary_of: Option<String>,
143    /// Routing weight for canary deployments (default: 100).
144    /// Higher values receive proportionally more traffic.
145    #[serde(default = "default_weight")]
146    pub weight: u32,
147}
148
149/// Backend transport protocol.
150#[derive(Debug, Deserialize, Serialize)]
151#[serde(rename_all = "lowercase")]
152pub enum TransportType {
153    /// Subprocess communicating via stdin/stdout.
154    Stdio,
155    /// HTTP+SSE remote server.
156    Http,
157}
158
159/// Per-backend request timeout.
160#[derive(Debug, Deserialize, Serialize)]
161pub struct TimeoutConfig {
162    /// Timeout duration in seconds.
163    pub seconds: u64,
164}
165
166/// Per-backend circuit breaker configuration.
167#[derive(Debug, Deserialize, Serialize)]
168pub struct CircuitBreakerConfig {
169    /// Failure rate threshold (0.0-1.0) to trip open (default: 0.5)
170    #[serde(default = "default_failure_rate")]
171    pub failure_rate_threshold: f64,
172    /// Minimum number of calls before evaluating failure rate (default: 5)
173    #[serde(default = "default_min_calls")]
174    pub minimum_calls: usize,
175    /// Seconds to wait in open state before half-open (default: 30)
176    #[serde(default = "default_wait_duration")]
177    pub wait_duration_seconds: u64,
178    /// Number of permitted calls in half-open state (default: 3)
179    #[serde(default = "default_half_open_calls")]
180    pub permitted_calls_in_half_open: usize,
181}
182
183/// Per-backend rate limiting configuration.
184#[derive(Debug, Deserialize, Serialize)]
185pub struct RateLimitConfig {
186    /// Maximum requests per period
187    pub requests: usize,
188    /// Period in seconds (default: 1)
189    #[serde(default = "default_rate_period")]
190    pub period_seconds: u64,
191}
192
193/// Per-backend concurrency limit configuration.
194#[derive(Debug, Deserialize, Serialize)]
195pub struct ConcurrencyConfig {
196    /// Maximum concurrent requests.
197    pub max_concurrent: usize,
198}
199
200/// Per-backend retry policy with exponential backoff.
201#[derive(Debug, Clone, Deserialize, Serialize)]
202pub struct RetryConfig {
203    /// Maximum number of retry attempts (default: 3)
204    #[serde(default = "default_max_retries")]
205    pub max_retries: u32,
206    /// Initial backoff in milliseconds (default: 100)
207    #[serde(default = "default_initial_backoff_ms")]
208    pub initial_backoff_ms: u64,
209    /// Maximum backoff in milliseconds (default: 5000)
210    #[serde(default = "default_max_backoff_ms")]
211    pub max_backoff_ms: u64,
212    /// Maximum percentage of requests that can be retries (default: none / unlimited).
213    /// When set, prevents retry storms by capping retries as a fraction of total
214    /// request volume. Envoy uses 20% as a default. Evaluated over a 10-second
215    /// rolling window.
216    pub budget_percent: Option<f64>,
217    /// Minimum retries per second allowed regardless of budget (default: 10).
218    /// Ensures low-traffic backends can still retry.
219    #[serde(default = "default_min_retries_per_sec")]
220    pub min_retries_per_sec: u32,
221}
222
223/// Passive health check / outlier detection configuration.
224///
225/// Tracks consecutive errors on live traffic and ejects unhealthy backends.
226#[derive(Debug, Clone, Deserialize, Serialize)]
227pub struct OutlierDetectionConfig {
228    /// Number of consecutive errors before ejecting (default: 5)
229    #[serde(default = "default_consecutive_errors")]
230    pub consecutive_errors: u32,
231    /// Evaluation interval in seconds (default: 10)
232    #[serde(default = "default_interval_seconds")]
233    pub interval_seconds: u64,
234    /// How long to eject in seconds (default: 30)
235    #[serde(default = "default_base_ejection_seconds")]
236    pub base_ejection_seconds: u64,
237    /// Maximum percentage of backends that can be ejected (default: 50)
238    #[serde(default = "default_max_ejection_percent")]
239    pub max_ejection_percent: u32,
240}
241
242/// Per-tool argument injection configuration.
243#[derive(Debug, Clone, Deserialize, Serialize)]
244pub struct InjectArgsConfig {
245    /// Tool name (backend-local, without namespace prefix).
246    pub tool: String,
247    /// Arguments to inject. Merged into the tool call arguments.
248    /// Does not overwrite existing keys unless `overwrite` is true.
249    pub args: serde_json::Map<String, serde_json::Value>,
250    /// Whether injected args should overwrite existing values (default: false).
251    #[serde(default)]
252    pub overwrite: bool,
253}
254
255/// Request hedging configuration.
256///
257/// Sends parallel redundant requests to reduce tail latency. If the primary
258/// request hasn't completed after `delay_ms`, a hedge request is fired.
259/// The first successful response wins.
260#[derive(Debug, Clone, Deserialize, Serialize)]
261pub struct HedgingConfig {
262    /// Delay in milliseconds before sending a hedge request (default: 200).
263    /// Set to 0 for parallel mode (all requests fire immediately).
264    #[serde(default = "default_hedge_delay_ms")]
265    pub delay_ms: u64,
266    /// Maximum number of additional hedge requests (default: 1)
267    #[serde(default = "default_max_hedges")]
268    pub max_hedges: usize,
269}
270
271/// Inbound authentication configuration.
272#[derive(Debug, Deserialize, Serialize)]
273#[serde(tag = "type", rename_all = "lowercase")]
274pub enum AuthConfig {
275    /// Static bearer token authentication.
276    Bearer {
277        /// Accepted bearer tokens.
278        tokens: Vec<String>,
279    },
280    /// JWT authentication via JWKS endpoint.
281    Jwt {
282        /// Expected token issuer (`iss` claim).
283        issuer: String,
284        /// Expected token audience (`aud` claim).
285        audience: String,
286        /// URL to fetch the JSON Web Key Set for token verification.
287        jwks_uri: String,
288        /// RBAC role definitions
289        #[serde(default)]
290        roles: Vec<RoleConfig>,
291        /// Map JWT claims to roles
292        role_mapping: Option<RoleMappingConfig>,
293    },
294}
295
296/// RBAC role definition.
297#[derive(Debug, Deserialize, Serialize)]
298pub struct RoleConfig {
299    /// Role name, referenced by `RoleMappingConfig`.
300    pub name: String,
301    /// Tools this role can access (namespaced, e.g. "files/read_file")
302    #[serde(default)]
303    pub allow_tools: Vec<String>,
304    /// Tools this role cannot access
305    #[serde(default)]
306    pub deny_tools: Vec<String>,
307}
308
309/// Maps JWT claim values to RBAC role names.
310#[derive(Debug, Deserialize, Serialize)]
311pub struct RoleMappingConfig {
312    /// JWT claim to read for role resolution (e.g. "scope", "role", "groups")
313    pub claim: String,
314    /// Map claim values to role names
315    pub mapping: HashMap<String, String>,
316}
317
318/// Tool alias: exposes a backend tool under a different name.
319#[derive(Debug, Deserialize, Serialize)]
320pub struct AliasConfig {
321    /// Original tool name (backend-local, without namespace prefix)
322    pub from: String,
323    /// New tool name to expose (will be namespaced as backend/to)
324    pub to: String,
325}
326
327/// Per-backend response cache configuration.
328#[derive(Debug, Deserialize, Serialize)]
329pub struct BackendCacheConfig {
330    /// TTL for cached resource reads in seconds (0 = disabled)
331    #[serde(default)]
332    pub resource_ttl_seconds: u64,
333    /// TTL for cached tool call results in seconds (0 = disabled)
334    #[serde(default)]
335    pub tool_ttl_seconds: u64,
336    /// Maximum number of cached entries per backend (default: 1000)
337    #[serde(default = "default_max_cache_entries")]
338    pub max_entries: u64,
339}
340
341/// Performance tuning options.
342#[derive(Debug, Default, Deserialize, Serialize)]
343pub struct PerformanceConfig {
344    /// Deduplicate identical concurrent tool calls and resource reads
345    #[serde(default)]
346    pub coalesce_requests: bool,
347}
348
349/// Security policies.
350#[derive(Debug, Default, Deserialize, Serialize)]
351pub struct SecurityConfig {
352    /// Maximum size of tool call arguments in bytes (default: unlimited)
353    pub max_argument_size: Option<usize>,
354}
355
356/// Logging, metrics, and distributed tracing configuration.
357#[derive(Debug, Default, Deserialize, Serialize)]
358pub struct ObservabilityConfig {
359    /// Enable audit logging of all MCP requests (default: false).
360    #[serde(default)]
361    pub audit: bool,
362    /// Log level filter (default: "info").
363    #[serde(default = "default_log_level")]
364    pub log_level: String,
365    /// Emit structured JSON logs (default: false).
366    #[serde(default)]
367    pub json_logs: bool,
368    /// Prometheus metrics configuration.
369    #[serde(default)]
370    pub metrics: MetricsConfig,
371    /// OpenTelemetry distributed tracing configuration.
372    #[serde(default)]
373    pub tracing: TracingConfig,
374}
375
376/// Prometheus metrics configuration.
377#[derive(Debug, Default, Deserialize, Serialize)]
378pub struct MetricsConfig {
379    /// Enable Prometheus metrics at `/admin/metrics` (default: false).
380    #[serde(default)]
381    pub enabled: bool,
382}
383
384/// OpenTelemetry distributed tracing configuration.
385#[derive(Debug, Default, Deserialize, Serialize)]
386pub struct TracingConfig {
387    /// Enable OTLP trace export (default: false).
388    #[serde(default)]
389    pub enabled: bool,
390    /// OTLP endpoint (default: http://localhost:4317)
391    #[serde(default = "default_otlp_endpoint")]
392    pub endpoint: String,
393    /// Service name for traces (default: "mcp-proxy")
394    #[serde(default = "default_service_name")]
395    pub service_name: String,
396}
397
398// Defaults
399
400fn default_version() -> String {
401    "0.1.0".to_string()
402}
403
404fn default_separator() -> String {
405    "/".to_string()
406}
407
408fn default_host() -> String {
409    "127.0.0.1".to_string()
410}
411
412fn default_port() -> u16 {
413    8080
414}
415
416fn default_log_level() -> String {
417    "info".to_string()
418}
419
420fn default_failure_rate() -> f64 {
421    0.5
422}
423
424fn default_min_calls() -> usize {
425    5
426}
427
428fn default_wait_duration() -> u64 {
429    30
430}
431
432fn default_half_open_calls() -> usize {
433    3
434}
435
436fn default_rate_period() -> u64 {
437    1
438}
439
440fn default_max_retries() -> u32 {
441    3
442}
443
444fn default_initial_backoff_ms() -> u64 {
445    100
446}
447
448fn default_max_backoff_ms() -> u64 {
449    5000
450}
451
452fn default_min_retries_per_sec() -> u32 {
453    10
454}
455
456fn default_consecutive_errors() -> u32 {
457    5
458}
459
460fn default_interval_seconds() -> u64 {
461    10
462}
463
464fn default_base_ejection_seconds() -> u64 {
465    30
466}
467
468fn default_max_ejection_percent() -> u32 {
469    50
470}
471
472fn default_hedge_delay_ms() -> u64 {
473    200
474}
475
476fn default_max_hedges() -> usize {
477    1
478}
479
480fn default_mirror_percent() -> u32 {
481    100
482}
483
484fn default_weight() -> u32 {
485    100
486}
487
488fn default_max_cache_entries() -> u64 {
489    1000
490}
491
492fn default_shutdown_timeout() -> u64 {
493    30
494}
495
496fn default_otlp_endpoint() -> String {
497    "http://localhost:4317".to_string()
498}
499
500fn default_service_name() -> String {
501    "mcp-proxy".to_string()
502}
503
504/// Resolved filter rules for a backend's capabilities.
505#[derive(Debug, Clone)]
506pub struct BackendFilter {
507    /// Namespace prefix (e.g. "db/") this filter applies to.
508    pub namespace: String,
509    /// Filter for tool names.
510    pub tool_filter: NameFilter,
511    /// Filter for resource URIs.
512    pub resource_filter: NameFilter,
513    /// Filter for prompt names.
514    pub prompt_filter: NameFilter,
515}
516
517/// A name-based allow/deny filter.
518#[derive(Debug, Clone)]
519pub enum NameFilter {
520    /// No filtering -- everything passes.
521    PassAll,
522    /// Only items in this set are allowed.
523    AllowList(HashSet<String>),
524    /// Items in this set are denied.
525    DenyList(HashSet<String>),
526}
527
528impl NameFilter {
529    /// Check if a capability name is allowed by this filter.
530    ///
531    /// # Examples
532    ///
533    /// ```
534    /// use std::collections::HashSet;
535    /// use mcp_proxy::config::NameFilter;
536    ///
537    /// let filter = NameFilter::DenyList(["delete".to_string()].into());
538    /// assert!(filter.allows("read"));
539    /// assert!(!filter.allows("delete"));
540    ///
541    /// let filter = NameFilter::AllowList(["read".to_string()].into());
542    /// assert!(filter.allows("read"));
543    /// assert!(!filter.allows("write"));
544    ///
545    /// assert!(NameFilter::PassAll.allows("anything"));
546    /// ```
547    pub fn allows(&self, name: &str) -> bool {
548        match self {
549            Self::PassAll => true,
550            Self::AllowList(set) => set.contains(name),
551            Self::DenyList(set) => !set.contains(name),
552        }
553    }
554}
555
556impl BackendConfig {
557    /// Build a [`BackendFilter`] from this backend's expose/hide lists.
558    /// Returns `None` if no filtering is configured.
559    ///
560    /// Canary backends automatically hide all capabilities so their tools
561    /// don't appear in `ListTools` responses (traffic reaches them via the
562    /// canary routing middleware, not direct tool calls).
563    pub fn build_filter(&self, separator: &str) -> Option<BackendFilter> {
564        // Canary backends hide all capabilities -- tools are accessed via
565        // the canary routing middleware rewriting the primary namespace.
566        if self.canary_of.is_some() {
567            return Some(BackendFilter {
568                namespace: format!("{}{}", self.name, separator),
569                tool_filter: NameFilter::AllowList(HashSet::new()),
570                resource_filter: NameFilter::AllowList(HashSet::new()),
571                prompt_filter: NameFilter::AllowList(HashSet::new()),
572            });
573        }
574
575        let tool_filter = if !self.expose_tools.is_empty() {
576            NameFilter::AllowList(self.expose_tools.iter().cloned().collect())
577        } else if !self.hide_tools.is_empty() {
578            NameFilter::DenyList(self.hide_tools.iter().cloned().collect())
579        } else {
580            NameFilter::PassAll
581        };
582
583        let resource_filter = if !self.expose_resources.is_empty() {
584            NameFilter::AllowList(self.expose_resources.iter().cloned().collect())
585        } else if !self.hide_resources.is_empty() {
586            NameFilter::DenyList(self.hide_resources.iter().cloned().collect())
587        } else {
588            NameFilter::PassAll
589        };
590
591        let prompt_filter = if !self.expose_prompts.is_empty() {
592            NameFilter::AllowList(self.expose_prompts.iter().cloned().collect())
593        } else if !self.hide_prompts.is_empty() {
594            NameFilter::DenyList(self.hide_prompts.iter().cloned().collect())
595        } else {
596            NameFilter::PassAll
597        };
598
599        // Only create a filter if at least one dimension has filtering
600        if matches!(tool_filter, NameFilter::PassAll)
601            && matches!(resource_filter, NameFilter::PassAll)
602            && matches!(prompt_filter, NameFilter::PassAll)
603        {
604            return None;
605        }
606
607        Some(BackendFilter {
608            namespace: format!("{}{}", self.name, separator),
609            tool_filter,
610            resource_filter,
611            prompt_filter,
612        })
613    }
614}
615
616impl ProxyConfig {
617    /// Load and validate a config from a file path.
618    pub fn load(path: &Path) -> Result<Self> {
619        let content =
620            std::fs::read_to_string(path).with_context(|| format!("reading {}", path.display()))?;
621        let config: Self =
622            toml::from_str(&content).with_context(|| format!("parsing {}", path.display()))?;
623        config.validate()?;
624        Ok(config)
625    }
626
627    /// Parse and validate a config from a TOML string.
628    ///
629    /// # Examples
630    ///
631    /// ```
632    /// use mcp_proxy::ProxyConfig;
633    ///
634    /// let config = ProxyConfig::parse(r#"
635    ///     [proxy]
636    ///     name = "my-proxy"
637    ///     [proxy.listen]
638    ///
639    ///     [[backends]]
640    ///     name = "echo"
641    ///     transport = "stdio"
642    ///     command = "echo"
643    /// "#).unwrap();
644    ///
645    /// assert_eq!(config.proxy.name, "my-proxy");
646    /// assert_eq!(config.backends.len(), 1);
647    /// ```
648    pub fn parse(toml: &str) -> Result<Self> {
649        let config: Self = toml::from_str(toml).context("parsing config")?;
650        config.validate()?;
651        Ok(config)
652    }
653
654    fn validate(&self) -> Result<()> {
655        if self.backends.is_empty() {
656            anyhow::bail!("at least one backend is required");
657        }
658        for backend in &self.backends {
659            match backend.transport {
660                TransportType::Stdio => {
661                    if backend.command.is_none() {
662                        anyhow::bail!(
663                            "backend '{}': stdio transport requires 'command'",
664                            backend.name
665                        );
666                    }
667                }
668                TransportType::Http => {
669                    if backend.url.is_none() {
670                        anyhow::bail!("backend '{}': http transport requires 'url'", backend.name);
671                    }
672                }
673            }
674
675            if let Some(cb) = &backend.circuit_breaker
676                && (cb.failure_rate_threshold <= 0.0 || cb.failure_rate_threshold > 1.0)
677            {
678                anyhow::bail!(
679                    "backend '{}': circuit_breaker.failure_rate_threshold must be in (0.0, 1.0]",
680                    backend.name
681                );
682            }
683
684            if let Some(rl) = &backend.rate_limit
685                && rl.requests == 0
686            {
687                anyhow::bail!(
688                    "backend '{}': rate_limit.requests must be > 0",
689                    backend.name
690                );
691            }
692
693            if let Some(cc) = &backend.concurrency
694                && cc.max_concurrent == 0
695            {
696                anyhow::bail!(
697                    "backend '{}': concurrency.max_concurrent must be > 0",
698                    backend.name
699                );
700            }
701
702            if !backend.expose_tools.is_empty() && !backend.hide_tools.is_empty() {
703                anyhow::bail!(
704                    "backend '{}': cannot specify both expose_tools and hide_tools",
705                    backend.name
706                );
707            }
708            if !backend.expose_resources.is_empty() && !backend.hide_resources.is_empty() {
709                anyhow::bail!(
710                    "backend '{}': cannot specify both expose_resources and hide_resources",
711                    backend.name
712                );
713            }
714            if !backend.expose_prompts.is_empty() && !backend.hide_prompts.is_empty() {
715                anyhow::bail!(
716                    "backend '{}': cannot specify both expose_prompts and hide_prompts",
717                    backend.name
718                );
719            }
720        }
721
722        // Validate mirror_of references
723        let backend_names: HashSet<&str> = self.backends.iter().map(|b| b.name.as_str()).collect();
724        for backend in &self.backends {
725            if let Some(ref source) = backend.mirror_of {
726                if !backend_names.contains(source.as_str()) {
727                    anyhow::bail!(
728                        "backend '{}': mirror_of references unknown backend '{}'",
729                        backend.name,
730                        source
731                    );
732                }
733                if source == &backend.name {
734                    anyhow::bail!(
735                        "backend '{}': mirror_of cannot reference itself",
736                        backend.name
737                    );
738                }
739            }
740        }
741
742        // Validate canary_of references
743        for backend in &self.backends {
744            if let Some(ref primary) = backend.canary_of {
745                if !backend_names.contains(primary.as_str()) {
746                    anyhow::bail!(
747                        "backend '{}': canary_of references unknown backend '{}'",
748                        backend.name,
749                        primary
750                    );
751                }
752                if primary == &backend.name {
753                    anyhow::bail!(
754                        "backend '{}': canary_of cannot reference itself",
755                        backend.name
756                    );
757                }
758                if backend.weight == 0 {
759                    anyhow::bail!("backend '{}': weight must be > 0", backend.name);
760                }
761            }
762        }
763
764        Ok(())
765    }
766
767    /// Resolve environment variable references in config values.
768    /// Replaces `${VAR_NAME}` with the value of the environment variable.
769    pub fn resolve_env_vars(&mut self) {
770        for backend in &mut self.backends {
771            for value in backend.env.values_mut() {
772                if let Some(var_name) = value.strip_prefix("${").and_then(|s| s.strip_suffix('}'))
773                    && let Ok(env_val) = std::env::var(var_name)
774                {
775                    *value = env_val;
776                }
777            }
778            if let Some(ref mut token) = backend.bearer_token
779                && let Some(var_name) = token.strip_prefix("${").and_then(|s| s.strip_suffix('}'))
780                && let Ok(env_val) = std::env::var(var_name)
781            {
782                *token = env_val;
783            }
784        }
785    }
786}
787
788#[cfg(test)]
789mod tests {
790    use super::*;
791
792    fn minimal_config() -> &'static str {
793        r#"
794        [proxy]
795        name = "test"
796        [proxy.listen]
797
798        [[backends]]
799        name = "echo"
800        transport = "stdio"
801        command = "echo"
802        "#
803    }
804
805    #[test]
806    fn test_parse_minimal_config() {
807        let config = ProxyConfig::parse(minimal_config()).unwrap();
808        assert_eq!(config.proxy.name, "test");
809        assert_eq!(config.proxy.version, "0.1.0"); // default
810        assert_eq!(config.proxy.separator, "/"); // default
811        assert_eq!(config.proxy.listen.host, "127.0.0.1"); // default
812        assert_eq!(config.proxy.listen.port, 8080); // default
813        assert_eq!(config.proxy.shutdown_timeout_seconds, 30); // default
814        assert!(!config.proxy.hot_reload); // default false
815        assert_eq!(config.backends.len(), 1);
816        assert_eq!(config.backends[0].name, "echo");
817        assert!(config.auth.is_none());
818        assert!(!config.observability.audit);
819        assert!(!config.observability.metrics.enabled);
820    }
821
822    #[test]
823    fn test_parse_full_config() {
824        let toml = r#"
825        [proxy]
826        name = "full-gw"
827        version = "2.0.0"
828        separator = "."
829        shutdown_timeout_seconds = 60
830        hot_reload = true
831        instructions = "A test proxy"
832        [proxy.listen]
833        host = "0.0.0.0"
834        port = 9090
835
836        [[backends]]
837        name = "files"
838        transport = "stdio"
839        command = "file-server"
840        args = ["--root", "/tmp"]
841        expose_tools = ["read_file"]
842
843        [backends.env]
844        LOG_LEVEL = "debug"
845
846        [backends.timeout]
847        seconds = 30
848
849        [backends.concurrency]
850        max_concurrent = 5
851
852        [backends.rate_limit]
853        requests = 100
854        period_seconds = 10
855
856        [backends.circuit_breaker]
857        failure_rate_threshold = 0.5
858        minimum_calls = 10
859        wait_duration_seconds = 60
860        permitted_calls_in_half_open = 2
861
862        [backends.cache]
863        resource_ttl_seconds = 300
864        tool_ttl_seconds = 60
865        max_entries = 500
866
867        [[backends.aliases]]
868        from = "read_file"
869        to = "read"
870
871        [[backends]]
872        name = "remote"
873        transport = "http"
874        url = "http://localhost:3000"
875
876        [observability]
877        audit = true
878        log_level = "debug"
879        json_logs = true
880
881        [observability.metrics]
882        enabled = true
883
884        [observability.tracing]
885        enabled = true
886        endpoint = "http://jaeger:4317"
887        service_name = "test-gw"
888
889        [performance]
890        coalesce_requests = true
891
892        [security]
893        max_argument_size = 1048576
894        "#;
895
896        let config = ProxyConfig::parse(toml).unwrap();
897        assert_eq!(config.proxy.name, "full-gw");
898        assert_eq!(config.proxy.version, "2.0.0");
899        assert_eq!(config.proxy.separator, ".");
900        assert_eq!(config.proxy.shutdown_timeout_seconds, 60);
901        assert!(config.proxy.hot_reload);
902        assert_eq!(config.proxy.instructions.as_deref(), Some("A test proxy"));
903        assert_eq!(config.proxy.listen.host, "0.0.0.0");
904        assert_eq!(config.proxy.listen.port, 9090);
905
906        assert_eq!(config.backends.len(), 2);
907
908        let files = &config.backends[0];
909        assert_eq!(files.command.as_deref(), Some("file-server"));
910        assert_eq!(files.args, vec!["--root", "/tmp"]);
911        assert_eq!(files.expose_tools, vec!["read_file"]);
912        assert_eq!(files.env.get("LOG_LEVEL").unwrap(), "debug");
913        assert_eq!(files.timeout.as_ref().unwrap().seconds, 30);
914        assert_eq!(files.concurrency.as_ref().unwrap().max_concurrent, 5);
915        assert_eq!(files.rate_limit.as_ref().unwrap().requests, 100);
916        assert_eq!(files.cache.as_ref().unwrap().resource_ttl_seconds, 300);
917        assert_eq!(files.cache.as_ref().unwrap().tool_ttl_seconds, 60);
918        assert_eq!(files.cache.as_ref().unwrap().max_entries, 500);
919        assert_eq!(files.aliases.len(), 1);
920        assert_eq!(files.aliases[0].from, "read_file");
921        assert_eq!(files.aliases[0].to, "read");
922
923        let cb = files.circuit_breaker.as_ref().unwrap();
924        assert_eq!(cb.failure_rate_threshold, 0.5);
925        assert_eq!(cb.minimum_calls, 10);
926        assert_eq!(cb.wait_duration_seconds, 60);
927        assert_eq!(cb.permitted_calls_in_half_open, 2);
928
929        let remote = &config.backends[1];
930        assert_eq!(remote.url.as_deref(), Some("http://localhost:3000"));
931
932        assert!(config.observability.audit);
933        assert_eq!(config.observability.log_level, "debug");
934        assert!(config.observability.json_logs);
935        assert!(config.observability.metrics.enabled);
936        assert!(config.observability.tracing.enabled);
937        assert_eq!(config.observability.tracing.endpoint, "http://jaeger:4317");
938
939        assert!(config.performance.coalesce_requests);
940        assert_eq!(config.security.max_argument_size, Some(1048576));
941    }
942
943    #[test]
944    fn test_parse_bearer_auth() {
945        let toml = r#"
946        [proxy]
947        name = "auth-gw"
948        [proxy.listen]
949
950        [[backends]]
951        name = "echo"
952        transport = "stdio"
953        command = "echo"
954
955        [auth]
956        type = "bearer"
957        tokens = ["token-1", "token-2"]
958        "#;
959
960        let config = ProxyConfig::parse(toml).unwrap();
961        match &config.auth {
962            Some(AuthConfig::Bearer { tokens }) => {
963                assert_eq!(tokens, &["token-1", "token-2"]);
964            }
965            other => panic!("expected Bearer auth, got: {:?}", other),
966        }
967    }
968
969    #[test]
970    fn test_parse_jwt_auth_with_rbac() {
971        let toml = r#"
972        [proxy]
973        name = "jwt-gw"
974        [proxy.listen]
975
976        [[backends]]
977        name = "echo"
978        transport = "stdio"
979        command = "echo"
980
981        [auth]
982        type = "jwt"
983        issuer = "https://auth.example.com"
984        audience = "mcp-proxy"
985        jwks_uri = "https://auth.example.com/.well-known/jwks.json"
986
987        [[auth.roles]]
988        name = "reader"
989        allow_tools = ["echo/read"]
990
991        [[auth.roles]]
992        name = "admin"
993
994        [auth.role_mapping]
995        claim = "scope"
996        mapping = { "mcp:read" = "reader", "mcp:admin" = "admin" }
997        "#;
998
999        let config = ProxyConfig::parse(toml).unwrap();
1000        match &config.auth {
1001            Some(AuthConfig::Jwt {
1002                issuer,
1003                audience,
1004                jwks_uri,
1005                roles,
1006                role_mapping,
1007            }) => {
1008                assert_eq!(issuer, "https://auth.example.com");
1009                assert_eq!(audience, "mcp-proxy");
1010                assert_eq!(jwks_uri, "https://auth.example.com/.well-known/jwks.json");
1011                assert_eq!(roles.len(), 2);
1012                assert_eq!(roles[0].name, "reader");
1013                assert_eq!(roles[0].allow_tools, vec!["echo/read"]);
1014                let mapping = role_mapping.as_ref().unwrap();
1015                assert_eq!(mapping.claim, "scope");
1016                assert_eq!(mapping.mapping.get("mcp:read").unwrap(), "reader");
1017            }
1018            other => panic!("expected Jwt auth, got: {:?}", other),
1019        }
1020    }
1021
1022    // ========================================================================
1023    // Validation errors
1024    // ========================================================================
1025
1026    #[test]
1027    fn test_reject_no_backends() {
1028        let toml = r#"
1029        [proxy]
1030        name = "empty"
1031        [proxy.listen]
1032        "#;
1033
1034        let err = ProxyConfig::parse(toml).unwrap_err();
1035        assert!(
1036            format!("{err}").contains("at least one backend"),
1037            "unexpected error: {err}"
1038        );
1039    }
1040
1041    #[test]
1042    fn test_reject_stdio_without_command() {
1043        let toml = r#"
1044        [proxy]
1045        name = "bad"
1046        [proxy.listen]
1047
1048        [[backends]]
1049        name = "broken"
1050        transport = "stdio"
1051        "#;
1052
1053        let err = ProxyConfig::parse(toml).unwrap_err();
1054        assert!(
1055            format!("{err}").contains("stdio transport requires 'command'"),
1056            "unexpected error: {err}"
1057        );
1058    }
1059
1060    #[test]
1061    fn test_reject_http_without_url() {
1062        let toml = r#"
1063        [proxy]
1064        name = "bad"
1065        [proxy.listen]
1066
1067        [[backends]]
1068        name = "broken"
1069        transport = "http"
1070        "#;
1071
1072        let err = ProxyConfig::parse(toml).unwrap_err();
1073        assert!(
1074            format!("{err}").contains("http transport requires 'url'"),
1075            "unexpected error: {err}"
1076        );
1077    }
1078
1079    #[test]
1080    fn test_reject_invalid_circuit_breaker_threshold() {
1081        let toml = r#"
1082        [proxy]
1083        name = "bad"
1084        [proxy.listen]
1085
1086        [[backends]]
1087        name = "svc"
1088        transport = "stdio"
1089        command = "echo"
1090
1091        [backends.circuit_breaker]
1092        failure_rate_threshold = 1.5
1093        "#;
1094
1095        let err = ProxyConfig::parse(toml).unwrap_err();
1096        assert!(
1097            format!("{err}").contains("failure_rate_threshold must be in (0.0, 1.0]"),
1098            "unexpected error: {err}"
1099        );
1100    }
1101
1102    #[test]
1103    fn test_reject_zero_rate_limit() {
1104        let toml = r#"
1105        [proxy]
1106        name = "bad"
1107        [proxy.listen]
1108
1109        [[backends]]
1110        name = "svc"
1111        transport = "stdio"
1112        command = "echo"
1113
1114        [backends.rate_limit]
1115        requests = 0
1116        "#;
1117
1118        let err = ProxyConfig::parse(toml).unwrap_err();
1119        assert!(
1120            format!("{err}").contains("rate_limit.requests must be > 0"),
1121            "unexpected error: {err}"
1122        );
1123    }
1124
1125    #[test]
1126    fn test_reject_zero_concurrency() {
1127        let toml = r#"
1128        [proxy]
1129        name = "bad"
1130        [proxy.listen]
1131
1132        [[backends]]
1133        name = "svc"
1134        transport = "stdio"
1135        command = "echo"
1136
1137        [backends.concurrency]
1138        max_concurrent = 0
1139        "#;
1140
1141        let err = ProxyConfig::parse(toml).unwrap_err();
1142        assert!(
1143            format!("{err}").contains("concurrency.max_concurrent must be > 0"),
1144            "unexpected error: {err}"
1145        );
1146    }
1147
1148    #[test]
1149    fn test_reject_expose_and_hide_tools() {
1150        let toml = r#"
1151        [proxy]
1152        name = "bad"
1153        [proxy.listen]
1154
1155        [[backends]]
1156        name = "svc"
1157        transport = "stdio"
1158        command = "echo"
1159        expose_tools = ["read"]
1160        hide_tools = ["write"]
1161        "#;
1162
1163        let err = ProxyConfig::parse(toml).unwrap_err();
1164        assert!(
1165            format!("{err}").contains("cannot specify both expose_tools and hide_tools"),
1166            "unexpected error: {err}"
1167        );
1168    }
1169
1170    #[test]
1171    fn test_reject_expose_and_hide_resources() {
1172        let toml = r#"
1173        [proxy]
1174        name = "bad"
1175        [proxy.listen]
1176
1177        [[backends]]
1178        name = "svc"
1179        transport = "stdio"
1180        command = "echo"
1181        expose_resources = ["file:///a"]
1182        hide_resources = ["file:///b"]
1183        "#;
1184
1185        let err = ProxyConfig::parse(toml).unwrap_err();
1186        assert!(
1187            format!("{err}").contains("cannot specify both expose_resources and hide_resources"),
1188            "unexpected error: {err}"
1189        );
1190    }
1191
1192    #[test]
1193    fn test_reject_expose_and_hide_prompts() {
1194        let toml = r#"
1195        [proxy]
1196        name = "bad"
1197        [proxy.listen]
1198
1199        [[backends]]
1200        name = "svc"
1201        transport = "stdio"
1202        command = "echo"
1203        expose_prompts = ["help"]
1204        hide_prompts = ["admin"]
1205        "#;
1206
1207        let err = ProxyConfig::parse(toml).unwrap_err();
1208        assert!(
1209            format!("{err}").contains("cannot specify both expose_prompts and hide_prompts"),
1210            "unexpected error: {err}"
1211        );
1212    }
1213
1214    // ========================================================================
1215    // Env var resolution
1216    // ========================================================================
1217
1218    #[test]
1219    fn test_resolve_env_vars() {
1220        // SAFETY: test runs single-threaded, no other threads reading this var
1221        unsafe { std::env::set_var("MCP_GW_TEST_TOKEN", "secret-123") };
1222
1223        let toml = r#"
1224        [proxy]
1225        name = "env-test"
1226        [proxy.listen]
1227
1228        [[backends]]
1229        name = "svc"
1230        transport = "stdio"
1231        command = "echo"
1232
1233        [backends.env]
1234        API_TOKEN = "${MCP_GW_TEST_TOKEN}"
1235        STATIC_VAL = "unchanged"
1236        "#;
1237
1238        let mut config = ProxyConfig::parse(toml).unwrap();
1239        config.resolve_env_vars();
1240
1241        assert_eq!(
1242            config.backends[0].env.get("API_TOKEN").unwrap(),
1243            "secret-123"
1244        );
1245        assert_eq!(
1246            config.backends[0].env.get("STATIC_VAL").unwrap(),
1247            "unchanged"
1248        );
1249
1250        // SAFETY: same as above
1251        unsafe { std::env::remove_var("MCP_GW_TEST_TOKEN") };
1252    }
1253
1254    #[test]
1255    fn test_parse_bearer_token_and_forward_auth() {
1256        let toml = r#"
1257        [proxy]
1258        name = "token-gw"
1259        [proxy.listen]
1260
1261        [[backends]]
1262        name = "github"
1263        transport = "http"
1264        url = "http://localhost:3000"
1265        bearer_token = "ghp_abc123"
1266        forward_auth = true
1267
1268        [[backends]]
1269        name = "db"
1270        transport = "http"
1271        url = "http://localhost:5432"
1272        "#;
1273
1274        let config = ProxyConfig::parse(toml).unwrap();
1275        assert_eq!(
1276            config.backends[0].bearer_token.as_deref(),
1277            Some("ghp_abc123")
1278        );
1279        assert!(config.backends[0].forward_auth);
1280        assert!(config.backends[1].bearer_token.is_none());
1281        assert!(!config.backends[1].forward_auth);
1282    }
1283
1284    #[test]
1285    fn test_resolve_bearer_token_env_var() {
1286        unsafe { std::env::set_var("MCP_GW_TEST_BEARER", "resolved-token") };
1287
1288        let toml = r#"
1289        [proxy]
1290        name = "env-token"
1291        [proxy.listen]
1292
1293        [[backends]]
1294        name = "api"
1295        transport = "http"
1296        url = "http://localhost:3000"
1297        bearer_token = "${MCP_GW_TEST_BEARER}"
1298        "#;
1299
1300        let mut config = ProxyConfig::parse(toml).unwrap();
1301        config.resolve_env_vars();
1302
1303        assert_eq!(
1304            config.backends[0].bearer_token.as_deref(),
1305            Some("resolved-token")
1306        );
1307
1308        unsafe { std::env::remove_var("MCP_GW_TEST_BEARER") };
1309    }
1310
1311    #[test]
1312    fn test_parse_outlier_detection() {
1313        let toml = r#"
1314        [proxy]
1315        name = "od-gw"
1316        [proxy.listen]
1317
1318        [[backends]]
1319        name = "flaky"
1320        transport = "http"
1321        url = "http://localhost:8080"
1322
1323        [backends.outlier_detection]
1324        consecutive_errors = 3
1325        interval_seconds = 5
1326        base_ejection_seconds = 60
1327        max_ejection_percent = 25
1328        "#;
1329
1330        let config = ProxyConfig::parse(toml).unwrap();
1331        let od = config.backends[0]
1332            .outlier_detection
1333            .as_ref()
1334            .expect("should have outlier_detection");
1335        assert_eq!(od.consecutive_errors, 3);
1336        assert_eq!(od.interval_seconds, 5);
1337        assert_eq!(od.base_ejection_seconds, 60);
1338        assert_eq!(od.max_ejection_percent, 25);
1339    }
1340
1341    #[test]
1342    fn test_parse_outlier_detection_defaults() {
1343        let toml = r#"
1344        [proxy]
1345        name = "od-gw"
1346        [proxy.listen]
1347
1348        [[backends]]
1349        name = "flaky"
1350        transport = "http"
1351        url = "http://localhost:8080"
1352
1353        [backends.outlier_detection]
1354        "#;
1355
1356        let config = ProxyConfig::parse(toml).unwrap();
1357        let od = config.backends[0]
1358            .outlier_detection
1359            .as_ref()
1360            .expect("should have outlier_detection");
1361        assert_eq!(od.consecutive_errors, 5);
1362        assert_eq!(od.interval_seconds, 10);
1363        assert_eq!(od.base_ejection_seconds, 30);
1364        assert_eq!(od.max_ejection_percent, 50);
1365    }
1366
1367    #[test]
1368    fn test_parse_mirror_config() {
1369        let toml = r#"
1370        [proxy]
1371        name = "mirror-gw"
1372        [proxy.listen]
1373
1374        [[backends]]
1375        name = "api"
1376        transport = "http"
1377        url = "http://localhost:8080"
1378
1379        [[backends]]
1380        name = "api-v2"
1381        transport = "http"
1382        url = "http://localhost:8081"
1383        mirror_of = "api"
1384        mirror_percent = 10
1385        "#;
1386
1387        let config = ProxyConfig::parse(toml).unwrap();
1388        assert!(config.backends[0].mirror_of.is_none());
1389        assert_eq!(config.backends[1].mirror_of.as_deref(), Some("api"));
1390        assert_eq!(config.backends[1].mirror_percent, 10);
1391    }
1392
1393    #[test]
1394    fn test_mirror_percent_defaults_to_100() {
1395        let toml = r#"
1396        [proxy]
1397        name = "mirror-gw"
1398        [proxy.listen]
1399
1400        [[backends]]
1401        name = "api"
1402        transport = "http"
1403        url = "http://localhost:8080"
1404
1405        [[backends]]
1406        name = "api-v2"
1407        transport = "http"
1408        url = "http://localhost:8081"
1409        mirror_of = "api"
1410        "#;
1411
1412        let config = ProxyConfig::parse(toml).unwrap();
1413        assert_eq!(config.backends[1].mirror_percent, 100);
1414    }
1415
1416    #[test]
1417    fn test_reject_mirror_unknown_backend() {
1418        let toml = r#"
1419        [proxy]
1420        name = "bad"
1421        [proxy.listen]
1422
1423        [[backends]]
1424        name = "api-v2"
1425        transport = "http"
1426        url = "http://localhost:8081"
1427        mirror_of = "nonexistent"
1428        "#;
1429
1430        let err = ProxyConfig::parse(toml).unwrap_err();
1431        assert!(
1432            format!("{err}").contains("mirror_of references unknown backend"),
1433            "unexpected error: {err}"
1434        );
1435    }
1436
1437    #[test]
1438    fn test_reject_mirror_self() {
1439        let toml = r#"
1440        [proxy]
1441        name = "bad"
1442        [proxy.listen]
1443
1444        [[backends]]
1445        name = "api"
1446        transport = "http"
1447        url = "http://localhost:8080"
1448        mirror_of = "api"
1449        "#;
1450
1451        let err = ProxyConfig::parse(toml).unwrap_err();
1452        assert!(
1453            format!("{err}").contains("mirror_of cannot reference itself"),
1454            "unexpected error: {err}"
1455        );
1456    }
1457
1458    #[test]
1459    fn test_parse_hedging_config() {
1460        let toml = r#"
1461        [proxy]
1462        name = "hedge-gw"
1463        [proxy.listen]
1464
1465        [[backends]]
1466        name = "api"
1467        transport = "http"
1468        url = "http://localhost:8080"
1469
1470        [backends.hedging]
1471        delay_ms = 150
1472        max_hedges = 2
1473        "#;
1474
1475        let config = ProxyConfig::parse(toml).unwrap();
1476        let hedge = config.backends[0]
1477            .hedging
1478            .as_ref()
1479            .expect("should have hedging");
1480        assert_eq!(hedge.delay_ms, 150);
1481        assert_eq!(hedge.max_hedges, 2);
1482    }
1483
1484    #[test]
1485    fn test_parse_hedging_defaults() {
1486        let toml = r#"
1487        [proxy]
1488        name = "hedge-gw"
1489        [proxy.listen]
1490
1491        [[backends]]
1492        name = "api"
1493        transport = "http"
1494        url = "http://localhost:8080"
1495
1496        [backends.hedging]
1497        "#;
1498
1499        let config = ProxyConfig::parse(toml).unwrap();
1500        let hedge = config.backends[0]
1501            .hedging
1502            .as_ref()
1503            .expect("should have hedging");
1504        assert_eq!(hedge.delay_ms, 200);
1505        assert_eq!(hedge.max_hedges, 1);
1506    }
1507
1508    // ========================================================================
1509    // Capability filter building
1510    // ========================================================================
1511
1512    #[test]
1513    fn test_build_filter_allowlist() {
1514        let toml = r#"
1515        [proxy]
1516        name = "filter"
1517        [proxy.listen]
1518
1519        [[backends]]
1520        name = "svc"
1521        transport = "stdio"
1522        command = "echo"
1523        expose_tools = ["read", "list"]
1524        "#;
1525
1526        let config = ProxyConfig::parse(toml).unwrap();
1527        let filter = config.backends[0]
1528            .build_filter(&config.proxy.separator)
1529            .expect("should have filter");
1530        assert_eq!(filter.namespace, "svc/");
1531        assert!(filter.tool_filter.allows("read"));
1532        assert!(filter.tool_filter.allows("list"));
1533        assert!(!filter.tool_filter.allows("delete"));
1534    }
1535
1536    #[test]
1537    fn test_build_filter_denylist() {
1538        let toml = r#"
1539        [proxy]
1540        name = "filter"
1541        [proxy.listen]
1542
1543        [[backends]]
1544        name = "svc"
1545        transport = "stdio"
1546        command = "echo"
1547        hide_tools = ["delete", "write"]
1548        "#;
1549
1550        let config = ProxyConfig::parse(toml).unwrap();
1551        let filter = config.backends[0]
1552            .build_filter(&config.proxy.separator)
1553            .expect("should have filter");
1554        assert!(filter.tool_filter.allows("read"));
1555        assert!(!filter.tool_filter.allows("delete"));
1556        assert!(!filter.tool_filter.allows("write"));
1557    }
1558
1559    #[test]
1560    fn test_parse_inject_args() {
1561        let toml = r#"
1562        [proxy]
1563        name = "inject-gw"
1564        [proxy.listen]
1565
1566        [[backends]]
1567        name = "db"
1568        transport = "http"
1569        url = "http://localhost:8080"
1570
1571        [backends.default_args]
1572        timeout = 30
1573
1574        [[backends.inject_args]]
1575        tool = "query"
1576        args = { read_only = true, max_rows = 1000 }
1577
1578        [[backends.inject_args]]
1579        tool = "dangerous_op"
1580        args = { dry_run = true }
1581        overwrite = true
1582        "#;
1583
1584        let config = ProxyConfig::parse(toml).unwrap();
1585        let backend = &config.backends[0];
1586
1587        assert_eq!(backend.default_args.len(), 1);
1588        assert_eq!(backend.default_args["timeout"], 30);
1589
1590        assert_eq!(backend.inject_args.len(), 2);
1591        assert_eq!(backend.inject_args[0].tool, "query");
1592        assert_eq!(backend.inject_args[0].args["read_only"], true);
1593        assert_eq!(backend.inject_args[0].args["max_rows"], 1000);
1594        assert!(!backend.inject_args[0].overwrite);
1595
1596        assert_eq!(backend.inject_args[1].tool, "dangerous_op");
1597        assert_eq!(backend.inject_args[1].args["dry_run"], true);
1598        assert!(backend.inject_args[1].overwrite);
1599    }
1600
1601    #[test]
1602    fn test_parse_inject_args_defaults_to_empty() {
1603        let config = ProxyConfig::parse(minimal_config()).unwrap();
1604        assert!(config.backends[0].default_args.is_empty());
1605        assert!(config.backends[0].inject_args.is_empty());
1606    }
1607
1608    #[test]
1609    fn test_build_filter_none_when_no_filtering() {
1610        let config = ProxyConfig::parse(minimal_config()).unwrap();
1611        assert!(
1612            config.backends[0]
1613                .build_filter(&config.proxy.separator)
1614                .is_none()
1615        );
1616    }
1617}