Skip to main content

mcp_proxy/
proxy.rs

1//! Core proxy construction and serving.
2
3use std::convert::Infallible;
4use std::time::Duration;
5
6use anyhow::{Context, Result};
7use axum::Router;
8use tokio::process::Command;
9use tower::timeout::TimeoutLayer;
10use tower::util::BoxCloneService;
11use tower_mcp::SessionHandle;
12use tower_mcp::auth::{AuthLayer, StaticBearerValidator};
13use tower_mcp::client::StdioClientTransport;
14use tower_mcp::proxy::McpProxy;
15use tower_mcp::{RouterRequest, RouterResponse};
16
17use crate::admin::BackendMeta;
18use crate::alias;
19use crate::cache;
20use crate::coalesce;
21use crate::config::{AuthConfig, ProxyConfig, TransportType};
22use crate::filter::CapabilityFilterService;
23#[cfg(feature = "oauth")]
24use crate::rbac::{RbacConfig, RbacService};
25use crate::validation::{ValidationConfig, ValidationService};
26
27/// A fully constructed MCP proxy ready to serve or embed.
28pub struct Proxy {
29    router: Router,
30    session_handle: SessionHandle,
31    inner: McpProxy,
32    config: ProxyConfig,
33    #[cfg(feature = "discovery")]
34    discovery_index: Option<crate::discovery::SharedDiscoveryIndex>,
35}
36
37impl Proxy {
38    /// Build a proxy from a [`ProxyConfig`].
39    ///
40    /// Connects to all backends, builds the middleware stack, and prepares
41    /// the axum router. Call [`serve()`](Self::serve) to run standalone or
42    /// [`into_router()`](Self::into_router) to embed in an existing app.
43    pub async fn from_config(config: ProxyConfig) -> Result<Self> {
44        let mcp_proxy = build_mcp_proxy(&config).await?;
45        let proxy_for_admin = mcp_proxy.clone();
46        let mut proxy_for_caller = mcp_proxy.clone();
47        let proxy_for_management = mcp_proxy.clone();
48
49        // Install Prometheus metrics recorder (must happen before middleware)
50        #[cfg(feature = "metrics")]
51        let metrics_handle = if config.observability.metrics.enabled {
52            tracing::info!("Prometheus metrics enabled at /admin/metrics");
53            let builder = metrics_exporter_prometheus::PrometheusBuilder::new();
54            let handle = builder
55                .install_recorder()
56                .context("installing Prometheus metrics recorder")?;
57            Some(handle)
58        } else {
59            None
60        };
61        #[cfg(not(feature = "metrics"))]
62        let metrics_handle = None;
63
64        let (service, cache_handle) = build_middleware_stack(&config, mcp_proxy)?;
65
66        let (router, session_handle) =
67            tower_mcp::transport::http::HttpTransport::from_service(service)
68                .into_router_with_handle();
69
70        // Inbound authentication (axum-level middleware)
71        let router = apply_auth(&config, router).await?;
72
73        // Collect backend metadata for the health checker
74        let backend_meta: std::collections::HashMap<String, BackendMeta> = config
75            .backends
76            .iter()
77            .map(|b| {
78                (
79                    b.name.clone(),
80                    BackendMeta {
81                        transport: format!("{:?}", b.transport).to_lowercase(),
82                    },
83                )
84            })
85            .collect();
86
87        // Admin API
88        let admin_state = crate::admin::spawn_health_checker(
89            proxy_for_admin,
90            config.proxy.name.clone(),
91            config.proxy.version.clone(),
92            config.backends.len(),
93            backend_meta,
94        );
95        let router = router.nest(
96            "/admin",
97            crate::admin::admin_router(
98                admin_state.clone(),
99                metrics_handle,
100                session_handle.clone(),
101                cache_handle,
102                proxy_for_management,
103                &config,
104            ),
105        );
106        tracing::info!("Admin API enabled at /admin/backends");
107
108        // Build discovery index if enabled (search mode implies discovery)
109        #[cfg(feature = "discovery")]
110        let discovery_enabled = config.proxy.tool_discovery
111            || config.proxy.tool_exposure == crate::config::ToolExposure::Search;
112        #[cfg(feature = "discovery")]
113        let (discovery_index, discovery_tools) = if discovery_enabled {
114            let index =
115                crate::discovery::build_index(&mut proxy_for_caller, &config.proxy.separator).await;
116            let tools = crate::discovery::build_discovery_tools(index.clone());
117            (Some(index), Some(tools))
118        } else {
119            (None, None)
120        };
121        #[cfg(not(feature = "discovery"))]
122        let discovery_tools: Option<Vec<tower_mcp::Tool>> = None;
123
124        // MCP admin tools (proxy/ namespace)
125        if let Err(e) = crate::admin_tools::register_admin_tools(
126            &proxy_for_caller,
127            admin_state,
128            session_handle.clone(),
129            &config,
130            discovery_tools,
131        )
132        .await
133        {
134            tracing::warn!("Failed to register admin tools: {e}");
135        } else {
136            tracing::info!("MCP admin tools registered under proxy/ namespace");
137        }
138
139        Ok(Self {
140            router,
141            session_handle,
142            inner: proxy_for_caller,
143            config,
144            #[cfg(feature = "discovery")]
145            discovery_index,
146        })
147    }
148
149    /// Get a reference to the session handle for monitoring active sessions.
150    pub fn session_handle(&self) -> &SessionHandle {
151        &self.session_handle
152    }
153
154    /// Get a reference to the underlying [`McpProxy`] for dynamic operations.
155    ///
156    /// Use this to add backends dynamically via [`McpProxy::add_backend()`].
157    pub fn mcp_proxy(&self) -> &McpProxy {
158        &self.inner
159    }
160
161    /// Enable hot reload by watching the given config file path.
162    ///
163    /// New backends added to the config file will be connected dynamically
164    /// without restarting the proxy.
165    pub fn enable_hot_reload(&self, config_path: std::path::PathBuf) {
166        tracing::info!("Hot reload enabled, watching config file for changes");
167        crate::reload::spawn_config_watcher(
168            config_path,
169            self.inner.clone(),
170            #[cfg(feature = "discovery")]
171            self.discovery_index
172                .as_ref()
173                .map(|idx| (idx.clone(), self.config.proxy.separator.clone())),
174        );
175    }
176
177    /// Consume the proxy and return the axum Router and SessionHandle.
178    ///
179    /// Use this to embed the proxy in an existing axum application:
180    ///
181    /// ```rust,ignore
182    /// let (proxy_router, session_handle) = proxy.into_router();
183    ///
184    /// let app = Router::new()
185    ///     .nest("/mcp", proxy_router)
186    ///     .route("/health", get(|| async { "ok" }));
187    /// ```
188    pub fn into_router(self) -> (Router, SessionHandle) {
189        (self.router, self.session_handle)
190    }
191
192    /// Serve the proxy on the configured listen address.
193    ///
194    /// Blocks until a shutdown signal (SIGTERM/SIGINT) is received,
195    /// then drains connections for the configured timeout period.
196    pub async fn serve(self) -> Result<()> {
197        let addr = format!(
198            "{}:{}",
199            self.config.proxy.listen.host, self.config.proxy.listen.port
200        );
201
202        tracing::info!(listen = %addr, "Proxy ready");
203
204        let listener = tokio::net::TcpListener::bind(&addr)
205            .await
206            .with_context(|| format!("binding to {}", addr))?;
207
208        let shutdown_timeout = Duration::from_secs(self.config.proxy.shutdown_timeout_seconds);
209        axum::serve(listener, self.router)
210            .with_graceful_shutdown(shutdown_signal(shutdown_timeout))
211            .await
212            .context("server error")?;
213
214        tracing::info!("Proxy shut down");
215        Ok(())
216    }
217}
218
219/// Build the McpProxy with all backends and per-backend middleware.
220async fn build_mcp_proxy(config: &ProxyConfig) -> Result<McpProxy> {
221    let mut builder = McpProxy::builder(&config.proxy.name, &config.proxy.version)
222        .separator(&config.proxy.separator);
223
224    if let Some(instructions) = &config.proxy.instructions {
225        builder = builder.instructions(instructions);
226    }
227
228    // Create shared outlier detector if any backend has outlier_detection configured.
229    // Use the max of all max_ejection_percent values.
230    let outlier_detector = {
231        let max_pct = config
232            .backends
233            .iter()
234            .filter_map(|b| b.outlier_detection.as_ref())
235            .map(|od| od.max_ejection_percent)
236            .max();
237        max_pct.map(crate::outlier::OutlierDetector::new)
238    };
239
240    for backend in &config.backends {
241        tracing::info!(name = %backend.name, transport = ?backend.transport, "Adding backend");
242
243        match backend.transport {
244            TransportType::Stdio => {
245                let command = backend.command.as_deref().unwrap();
246                let args: Vec<&str> = backend.args.iter().map(|s| s.as_str()).collect();
247
248                let mut cmd = Command::new(command);
249                cmd.args(&args);
250
251                for (key, value) in &backend.env {
252                    cmd.env(key, value);
253                }
254
255                let transport = StdioClientTransport::spawn_command(&mut cmd)
256                    .await
257                    .with_context(|| format!("spawning backend '{}'", backend.name))?;
258
259                builder = builder.backend(&backend.name, transport).await;
260            }
261            TransportType::Http => {
262                let url = backend.url.as_deref().unwrap();
263                let mut transport = tower_mcp::client::HttpClientTransport::new(url);
264                if let Some(token) = &backend.bearer_token {
265                    transport = transport.bearer_token(token);
266                }
267
268                builder = builder.backend(&backend.name, transport).await;
269            }
270            #[cfg(feature = "websocket")]
271            TransportType::Websocket => {
272                let url = backend.url.as_deref().unwrap();
273                tracing::info!(url = %url, "Connecting to WebSocket backend");
274                let transport = if let Some(token) = &backend.bearer_token {
275                    crate::ws_transport::WebSocketClientTransport::connect_with_bearer_token(
276                        url, token,
277                    )
278                    .await
279                    .with_context(|| {
280                        format!("connecting to WebSocket backend '{}'", backend.name)
281                    })?
282                } else {
283                    crate::ws_transport::WebSocketClientTransport::connect(url)
284                        .await
285                        .with_context(|| {
286                            format!("connecting to WebSocket backend '{}'", backend.name)
287                        })?
288                };
289
290                builder = builder.backend(&backend.name, transport).await;
291            }
292            #[cfg(not(feature = "websocket"))]
293            TransportType::Websocket => {
294                anyhow::bail!(
295                    "WebSocket transport requires the 'websocket' feature. \
296                     Rebuild with: cargo install mcp-proxy --features websocket"
297                );
298            }
299        }
300
301        // Per-backend middleware stack (applied in order: inner -> outer)
302
303        // Retry (innermost -- retries happen before other middleware)
304        if let Some(retry_cfg) = &backend.retry {
305            tracing::info!(
306                backend = %backend.name,
307                max_retries = retry_cfg.max_retries,
308                initial_backoff_ms = retry_cfg.initial_backoff_ms,
309                max_backoff_ms = retry_cfg.max_backoff_ms,
310                "Applying retry policy"
311            );
312            let layer = crate::retry::build_retry_layer(retry_cfg, &backend.name);
313            builder = builder.backend_layer(layer);
314        }
315
316        // Hedging (after retry, before concurrency -- hedges are separate requests)
317        if let Some(hedge_cfg) = &backend.hedging {
318            let delay = Duration::from_millis(hedge_cfg.delay_ms);
319            let max_attempts = hedge_cfg.max_hedges + 1; // +1 for the primary request
320            tracing::info!(
321                backend = %backend.name,
322                delay_ms = hedge_cfg.delay_ms,
323                max_hedges = hedge_cfg.max_hedges,
324                "Applying request hedging"
325            );
326            let layer = if delay.is_zero() {
327                tower_resilience::hedge::HedgeLayer::builder()
328                    .no_delay()
329                    .max_hedged_attempts(max_attempts)
330                    .name(format!("{}-hedge", backend.name))
331                    .build()
332            } else {
333                tower_resilience::hedge::HedgeLayer::builder()
334                    .delay(delay)
335                    .max_hedged_attempts(max_attempts)
336                    .name(format!("{}-hedge", backend.name))
337                    .build()
338            };
339            builder = builder.backend_layer(layer);
340        }
341
342        // Concurrency limit
343        if let Some(cc) = &backend.concurrency {
344            tracing::info!(
345                backend = %backend.name,
346                max = cc.max_concurrent,
347                "Applying concurrency limit"
348            );
349            builder =
350                builder.backend_layer(tower::limit::ConcurrencyLimitLayer::new(cc.max_concurrent));
351        }
352
353        // Rate limit
354        if let Some(rl) = &backend.rate_limit {
355            tracing::info!(
356                backend = %backend.name,
357                requests = rl.requests,
358                period_seconds = rl.period_seconds,
359                "Applying rate limit"
360            );
361            let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
362                .limit_for_period(rl.requests)
363                .refresh_period(Duration::from_secs(rl.period_seconds))
364                .name(format!("{}-ratelimit", backend.name))
365                .build();
366            builder = builder.backend_layer(layer);
367        }
368
369        // Timeout
370        if let Some(timeout) = &backend.timeout {
371            tracing::info!(
372                backend = %backend.name,
373                seconds = timeout.seconds,
374                "Applying timeout"
375            );
376            builder =
377                builder.backend_layer(TimeoutLayer::new(Duration::from_secs(timeout.seconds)));
378        }
379
380        // Circuit breaker
381        if let Some(cb) = &backend.circuit_breaker {
382            tracing::info!(
383                backend = %backend.name,
384                failure_rate = cb.failure_rate_threshold,
385                wait_seconds = cb.wait_duration_seconds,
386                "Applying circuit breaker"
387            );
388            let layer = tower_resilience::circuitbreaker::CircuitBreakerLayer::builder()
389                .failure_rate_threshold(cb.failure_rate_threshold)
390                .minimum_number_of_calls(cb.minimum_calls)
391                .wait_duration_in_open(Duration::from_secs(cb.wait_duration_seconds))
392                .permitted_calls_in_half_open(cb.permitted_calls_in_half_open)
393                .name(format!("{}-cb", backend.name))
394                .build();
395            builder = builder.backend_layer(layer);
396        }
397
398        // Outlier detection (outermost -- observes errors after all other middleware)
399        if let Some(od) = &backend.outlier_detection
400            && let Some(ref detector) = outlier_detector
401        {
402            tracing::info!(
403                backend = %backend.name,
404                consecutive_errors = od.consecutive_errors,
405                base_ejection_seconds = od.base_ejection_seconds,
406                max_ejection_percent = od.max_ejection_percent,
407                "Applying outlier detection"
408            );
409            let layer = crate::outlier::OutlierDetectionLayer::new(
410                backend.name.clone(),
411                od.clone(),
412                detector.clone(),
413            );
414            builder = builder.backend_layer(layer);
415        }
416    }
417
418    let result = builder.build().await?;
419
420    if !result.skipped.is_empty() {
421        for s in &result.skipped {
422            tracing::warn!("Skipped backend: {s}");
423        }
424    }
425
426    Ok(result.proxy)
427}
428
429/// Build the MCP-level middleware stack around the proxy.
430fn build_middleware_stack(
431    config: &ProxyConfig,
432    proxy: McpProxy,
433) -> Result<(
434    BoxCloneService<RouterRequest, RouterResponse, Infallible>,
435    Option<cache::CacheHandle>,
436)> {
437    let mut service: BoxCloneService<RouterRequest, RouterResponse, Infallible> =
438        BoxCloneService::new(proxy);
439    let mut cache_handle: Option<cache::CacheHandle> = None;
440
441    // Argument injection (innermost -- merges default/per-tool args into CallTool requests)
442    let injection_rules: Vec<_> = config
443        .backends
444        .iter()
445        .filter(|b| !b.default_args.is_empty() || !b.inject_args.is_empty())
446        .map(|b| {
447            let namespace = format!("{}{}", b.name, config.proxy.separator);
448            tracing::info!(
449                backend = %b.name,
450                default_args = b.default_args.len(),
451                tool_rules = b.inject_args.len(),
452                "Applying argument injection"
453            );
454            crate::inject::InjectionRules::new(
455                namespace,
456                b.default_args.clone(),
457                b.inject_args.clone(),
458            )
459        })
460        .collect();
461
462    if !injection_rules.is_empty() {
463        service = BoxCloneService::new(crate::inject::InjectArgsService::new(
464            service,
465            injection_rules,
466        ));
467    }
468
469    // Parameter overrides (after inject, before filter -- hides/renames tool params)
470    let param_overrides: Vec<_> = config
471        .backends
472        .iter()
473        .filter(|b| !b.param_overrides.is_empty())
474        .flat_map(|b| {
475            let namespace = format!("{}{}", b.name, config.proxy.separator);
476            tracing::info!(
477                backend = %b.name,
478                overrides = b.param_overrides.len(),
479                "Applying parameter overrides"
480            );
481            b.param_overrides
482                .iter()
483                .map(move |c| crate::param_override::ToolOverride::new(&namespace, c))
484        })
485        .collect();
486
487    if !param_overrides.is_empty() {
488        service = BoxCloneService::new(crate::param_override::ParamOverrideService::new(
489            service,
490            param_overrides,
491        ));
492    }
493
494    // Canary routing (rewrites requests from primary to canary namespace based on weight)
495    let canary_mappings: std::collections::HashMap<String, (String, u32, u32)> = config
496        .backends
497        .iter()
498        .filter_map(|b| {
499            b.canary_of.as_ref().map(|primary_name| {
500                // Find the primary backend's weight
501                let primary_weight = config
502                    .backends
503                    .iter()
504                    .find(|p| p.name == *primary_name)
505                    .map(|p| p.weight)
506                    .unwrap_or(100);
507                (
508                    primary_name.clone(),
509                    (b.name.clone(), primary_weight, b.weight),
510                )
511            })
512        })
513        .collect();
514
515    if !canary_mappings.is_empty() {
516        for (primary, (canary, pw, cw)) in &canary_mappings {
517            tracing::info!(
518                primary = %primary,
519                canary = %canary,
520                primary_weight = pw,
521                canary_weight = cw,
522                "Enabling canary routing"
523            );
524        }
525        service = BoxCloneService::new(crate::canary::CanaryService::new(
526            service,
527            canary_mappings,
528            &config.proxy.separator,
529        ));
530    }
531
532    // Failover routing (deterministic fallback on primary error)
533    // Collect failover backends grouped by primary, sorted by priority (ascending).
534    let mut failover_groups: std::collections::HashMap<String, Vec<(u32, String)>> =
535        std::collections::HashMap::new();
536    for b in &config.backends {
537        if let Some(ref primary) = b.failover_for {
538            failover_groups
539                .entry(primary.clone())
540                .or_default()
541                .push((b.priority, b.name.clone()));
542        }
543    }
544    // Sort each group by priority (lower = preferred)
545    let failover_mappings: std::collections::HashMap<String, Vec<String>> = failover_groups
546        .into_iter()
547        .map(|(primary, mut backends)| {
548            backends.sort_by_key(|(priority, _)| *priority);
549            let names: Vec<String> = backends.into_iter().map(|(_, name)| name).collect();
550            (primary, names)
551        })
552        .collect();
553
554    if !failover_mappings.is_empty() {
555        for (primary, failovers) in &failover_mappings {
556            tracing::info!(
557                primary = %primary,
558                failovers = ?failovers,
559                "Enabling failover routing"
560            );
561        }
562        service = BoxCloneService::new(crate::failover::FailoverService::new(
563            service,
564            failover_mappings,
565            &config.proxy.separator,
566        ));
567    }
568
569    // Traffic mirroring (sends cloned requests through the proxy)
570    let mirror_mappings: std::collections::HashMap<String, (String, u32)> = config
571        .backends
572        .iter()
573        .filter_map(|b| {
574            b.mirror_of
575                .as_ref()
576                .map(|source| (source.clone(), (b.name.clone(), b.mirror_percent)))
577        })
578        .collect();
579
580    if !mirror_mappings.is_empty() {
581        for (source, (mirror, pct)) in &mirror_mappings {
582            tracing::info!(
583                source = %source,
584                mirror = %mirror,
585                percent = pct,
586                "Enabling traffic mirroring"
587            );
588        }
589        service = BoxCloneService::new(crate::mirror::MirrorService::new(
590            service,
591            mirror_mappings,
592            &config.proxy.separator,
593        ));
594    }
595
596    // Response caching
597    let cache_configs: Vec<_> = config
598        .backends
599        .iter()
600        .filter_map(|b| {
601            b.cache
602                .as_ref()
603                .map(|c| (format!("{}{}", b.name, config.proxy.separator), c))
604        })
605        .collect();
606
607    if !cache_configs.is_empty() {
608        for (ns, cfg) in &cache_configs {
609            tracing::info!(
610                backend = %ns.trim_end_matches(&config.proxy.separator),
611                resource_ttl = cfg.resource_ttl_seconds,
612                tool_ttl = cfg.tool_ttl_seconds,
613                max_entries = cfg.max_entries,
614                "Applying response cache"
615            );
616        }
617        let (cache_svc, handle) = cache::CacheService::new(service, cache_configs, &config.cache);
618        service = BoxCloneService::new(cache_svc);
619        cache_handle = Some(handle);
620    }
621
622    // Request coalescing
623    if config.performance.coalesce_requests {
624        tracing::info!("Request coalescing enabled");
625        service = BoxCloneService::new(coalesce::CoalesceService::new(service));
626    }
627
628    // Request validation
629    if config.security.max_argument_size.is_some() {
630        let validation = ValidationConfig {
631            max_argument_size: config.security.max_argument_size,
632        };
633        if let Some(max) = validation.max_argument_size {
634            tracing::info!(max_argument_size = max, "Applying request validation");
635        }
636        service = BoxCloneService::new(ValidationService::new(service, validation));
637    }
638
639    // Static capability filtering
640    let filters: Vec<_> = config
641        .backends
642        .iter()
643        .filter_map(|b| b.build_filter(&config.proxy.separator).transpose())
644        .collect::<anyhow::Result<Vec<_>>>()?;
645
646    if !filters.is_empty() {
647        for f in &filters {
648            tracing::info!(
649                backend = %f.namespace.trim_end_matches(&config.proxy.separator),
650                tool_filter = ?f.tool_filter,
651                resource_filter = ?f.resource_filter,
652                prompt_filter = ?f.prompt_filter,
653                "Applying capability filter"
654            );
655        }
656        service = BoxCloneService::new(CapabilityFilterService::new(service, filters));
657    }
658
659    // Search-mode filtering: hide all tools except proxy/ namespace
660    if config.proxy.tool_exposure == crate::config::ToolExposure::Search {
661        let prefix = format!("proxy{}", config.proxy.separator);
662        tracing::info!(
663            prefix = %prefix,
664            "Search mode: ListTools will only show proxy/ namespace tools"
665        );
666        service =
667            BoxCloneService::new(crate::filter::SearchModeFilterService::new(service, prefix));
668    }
669
670    // Tool aliasing
671    let alias_mappings: Vec<_> = config
672        .backends
673        .iter()
674        .flat_map(|b| {
675            let ns = format!("{}{}", b.name, config.proxy.separator);
676            b.aliases
677                .iter()
678                .map(move |a| (ns.clone(), a.from.clone(), a.to.clone()))
679        })
680        .collect();
681
682    if let Some(alias_map) = alias::AliasMap::new(alias_mappings) {
683        let count = alias_map.forward.len();
684        tracing::info!(aliases = count, "Applying tool aliases");
685        service = BoxCloneService::new(alias::AliasService::new(service, alias_map));
686    }
687
688    // Composite tools (fan-out to multiple backend tools)
689    if !config.composite_tools.is_empty() {
690        let count = config.composite_tools.len();
691        tracing::info!(composite_tools = count, "Applying composite tool fan-out");
692        service = BoxCloneService::new(crate::composite::CompositeService::new(
693            service,
694            config.composite_tools.clone(),
695        ));
696    }
697
698    // Bearer token scoping (per-token allow/deny lists)
699    #[cfg(feature = "oauth")]
700    if matches!(
701        &config.auth,
702        Some(AuthConfig::Bearer {
703            scoped_tokens,
704            ..
705        }) if !scoped_tokens.is_empty()
706    ) {
707        tracing::info!("Enabling bearer token scoping middleware");
708        service = BoxCloneService::new(crate::bearer_scope::BearerScopingService::new(service));
709    }
710
711    // RBAC (JWT auth only)
712    #[cfg(feature = "oauth")]
713    {
714        let rbac_config = match &config.auth {
715            Some(
716                AuthConfig::Jwt {
717                    roles,
718                    role_mapping: Some(mapping),
719                    ..
720                }
721                | AuthConfig::OAuth {
722                    roles,
723                    role_mapping: Some(mapping),
724                    ..
725                },
726            ) if !roles.is_empty() => {
727                tracing::info!(
728                    roles = roles.len(),
729                    claim = %mapping.claim,
730                    "Enabling RBAC"
731                );
732                Some(RbacConfig::new(roles, mapping))
733            }
734            _ => None,
735        };
736
737        if let Some(rbac) = rbac_config {
738            service = BoxCloneService::new(RbacService::new(service, rbac));
739        }
740
741        // Token passthrough (inject ClientToken for forward_auth backends)
742        let forward_namespaces: std::collections::HashSet<String> = config
743            .backends
744            .iter()
745            .filter(|b| b.forward_auth)
746            .map(|b| format!("{}{}", b.name, config.proxy.separator))
747            .collect();
748
749        if !forward_namespaces.is_empty() {
750            tracing::info!(
751                backends = ?forward_namespaces,
752                "Enabling token passthrough for forward_auth backends"
753            );
754            service = BoxCloneService::new(crate::token::TokenPassthroughService::new(
755                service,
756                forward_namespaces,
757            ));
758        }
759    }
760
761    // Metrics
762    #[cfg(feature = "metrics")]
763    if config.observability.metrics.enabled {
764        service = BoxCloneService::new(crate::metrics::MetricsService::new(service));
765    }
766
767    // Structured access logging
768    if config.observability.access_log.enabled {
769        tracing::info!("Access logging enabled (target: mcp::access)");
770        service = BoxCloneService::new(crate::access_log::AccessLogService::new(
771            service,
772            &config.proxy.separator,
773        ));
774    }
775
776    // Audit logging
777    if config.observability.audit {
778        tracing::info!("Audit logging enabled (target: mcp::audit)");
779        let audited = tower::Layer::layer(&tower_mcp::AuditLayer::new(), service);
780        service = BoxCloneService::new(tower_mcp::CatchError::new(audited));
781    }
782
783    // Global rate limit (outermost -- protects entire proxy)
784    if let Some(ref rl) = config.proxy.rate_limit {
785        tracing::info!(
786            requests = rl.requests,
787            period_seconds = rl.period_seconds,
788            "Applying global rate limit"
789        );
790        let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
791            .limit_for_period(rl.requests)
792            .refresh_period(Duration::from_secs(rl.period_seconds))
793            .name("global-ratelimit")
794            .build();
795        let limited = tower::Layer::layer(&layer, service);
796        service = BoxCloneService::new(tower_mcp::CatchError::new(limited));
797    }
798
799    Ok((service, cache_handle))
800}
801
802/// Apply inbound authentication middleware to the router.
803async fn apply_auth(config: &ProxyConfig, router: Router) -> Result<Router> {
804    let router = if let Some(auth) = &config.auth {
805        match auth {
806            AuthConfig::Bearer {
807                tokens,
808                scoped_tokens,
809            } => {
810                let total = tokens.len() + scoped_tokens.len();
811                if scoped_tokens.is_empty() {
812                    // Simple bearer auth: use StaticBearerValidator
813                    tracing::info!(token_count = total, "Enabling bearer token auth");
814                    let validator = StaticBearerValidator::new(tokens.iter().cloned());
815                    let layer = AuthLayer::new(validator);
816                    router.layer(layer)
817                } else {
818                    // Scoped bearer auth: use custom layer that injects TokenClaims
819                    #[cfg(feature = "oauth")]
820                    {
821                        tracing::info!(
822                            token_count = total,
823                            scoped = scoped_tokens.len(),
824                            "Enabling bearer token auth with per-token scoping"
825                        );
826                        let layer =
827                            crate::bearer_scope::ScopedBearerAuthLayer::new(tokens, scoped_tokens);
828                        router.layer(layer)
829                    }
830                    #[cfg(not(feature = "oauth"))]
831                    {
832                        anyhow::bail!(
833                            "Per-token tool scoping requires the 'oauth' feature. \
834                             Rebuild with: cargo install mcp-proxy --features oauth"
835                        );
836                    }
837                }
838            }
839            #[cfg(feature = "oauth")]
840            AuthConfig::Jwt {
841                issuer,
842                audience,
843                jwks_uri,
844                ..
845            } => {
846                tracing::info!(
847                    issuer = %issuer,
848                    audience = %audience,
849                    jwks_uri = %jwks_uri,
850                    "Enabling JWT auth (JWKS)"
851                );
852                let validator = tower_mcp::oauth::JwksValidator::builder(jwks_uri)
853                    .expected_audience(audience)
854                    .expected_issuer(issuer)
855                    .build()
856                    .await
857                    .context("building JWKS validator")?;
858
859                let addr = format!(
860                    "http://{}:{}",
861                    config.proxy.listen.host, config.proxy.listen.port
862                );
863                let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
864                    .authorization_server(issuer);
865
866                let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
867                router.layer(layer)
868            }
869            #[cfg(not(feature = "oauth"))]
870            AuthConfig::Jwt { .. } => {
871                anyhow::bail!(
872                    "JWT auth requires the 'oauth' feature. Rebuild with: cargo install mcp-proxy --features oauth"
873                );
874            }
875            #[cfg(feature = "oauth")]
876            AuthConfig::OAuth {
877                issuer,
878                audience,
879                token_validation,
880                jwks_uri,
881                introspection_endpoint,
882                client_id,
883                client_secret,
884                ..
885            } => {
886                use crate::config::TokenValidationStrategy;
887
888                tracing::info!(
889                    issuer = %issuer,
890                    audience = %audience,
891                    strategy = ?token_validation,
892                    "Enabling OAuth 2.1 auth"
893                );
894
895                // Auto-discover endpoints from issuer if not overridden
896                let discovered = crate::introspection::discover_auth_server(issuer)
897                    .await
898                    .context("discovering OAuth authorization server")?;
899
900                let effective_jwks_uri = jwks_uri
901                    .as_deref()
902                    .or(discovered.jwks_uri.as_deref())
903                    .ok_or_else(|| {
904                        anyhow::anyhow!(
905                            "JWKS URI not found via discovery and not configured manually"
906                        )
907                    })?;
908
909                let effective_introspection = introspection_endpoint
910                    .as_deref()
911                    .or(discovered.introspection_endpoint.as_deref());
912
913                let addr = format!(
914                    "http://{}:{}",
915                    config.proxy.listen.host, config.proxy.listen.port
916                );
917                let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
918                    .authorization_server(issuer);
919
920                match token_validation {
921                    TokenValidationStrategy::Jwt => {
922                        let validator =
923                            tower_mcp::oauth::JwksValidator::builder(effective_jwks_uri)
924                                .expected_audience(audience)
925                                .expected_issuer(issuer)
926                                .build()
927                                .await
928                                .context("building JWKS validator")?;
929                        let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
930                        router.layer(layer)
931                    }
932                    TokenValidationStrategy::Introspection => {
933                        let endpoint = effective_introspection.ok_or_else(|| {
934                            anyhow::anyhow!(
935                                "introspection endpoint not found via discovery and not configured"
936                            )
937                        })?;
938                        let validator = crate::introspection::IntrospectionValidator::new(
939                            endpoint,
940                            client_id.as_deref().unwrap(),
941                            client_secret.as_deref().unwrap(),
942                        )
943                        .expected_audience(audience);
944                        let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
945                        router.layer(layer)
946                    }
947                    TokenValidationStrategy::Both => {
948                        let endpoint = effective_introspection.ok_or_else(|| {
949                            anyhow::anyhow!(
950                                "introspection endpoint not found via discovery and not configured"
951                            )
952                        })?;
953                        let jwt_validator =
954                            tower_mcp::oauth::JwksValidator::builder(effective_jwks_uri)
955                                .expected_audience(audience)
956                                .expected_issuer(issuer)
957                                .build()
958                                .await
959                                .context("building JWKS validator")?;
960                        let introspection_validator =
961                            crate::introspection::IntrospectionValidator::new(
962                                endpoint,
963                                client_id.as_deref().unwrap(),
964                                client_secret.as_deref().unwrap(),
965                            )
966                            .expected_audience(audience);
967                        let fallback = crate::introspection::FallbackValidator::new(
968                            jwt_validator,
969                            introspection_validator,
970                        );
971                        let layer = tower_mcp::oauth::OAuthLayer::new(fallback, metadata);
972                        router.layer(layer)
973                    }
974                }
975            }
976            #[cfg(not(feature = "oauth"))]
977            AuthConfig::OAuth { .. } => {
978                anyhow::bail!(
979                    "OAuth auth requires the 'oauth' feature. Rebuild with: cargo install mcp-proxy --features oauth"
980                );
981            }
982        }
983    } else {
984        router
985    };
986    Ok(router)
987}
988
989/// Wait for SIGTERM or SIGINT, then log and return.
990pub async fn shutdown_signal(timeout: Duration) {
991    let ctrl_c = tokio::signal::ctrl_c();
992    #[cfg(unix)]
993    {
994        let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
995            .expect("SIGTERM handler");
996        tokio::select! {
997            _ = ctrl_c => {},
998            _ = sigterm.recv() => {},
999        }
1000    }
1001    #[cfg(not(unix))]
1002    {
1003        ctrl_c.await.ok();
1004    }
1005    tracing::info!(
1006        timeout_seconds = timeout.as_secs(),
1007        "Shutdown signal received, draining connections"
1008    );
1009}