Skip to main content

mcp_proxy/
proxy.rs

1//! Core proxy construction and serving.
2
3use std::collections::HashMap;
4use std::convert::Infallible;
5use std::time::Duration;
6
7use anyhow::{Context, Result};
8use axum::Router;
9use tokio::process::Command;
10use tower::timeout::TimeoutLayer;
11use tower::util::BoxCloneService;
12use tower_mcp::SessionHandle;
13use tower_mcp::auth::{AuthLayer, StaticBearerValidator};
14use tower_mcp::client::StdioClientTransport;
15use tower_mcp::proxy::McpProxy;
16use tower_mcp::{RouterRequest, RouterResponse};
17
18use crate::admin::BackendMeta;
19use crate::alias;
20use crate::cache;
21use crate::coalesce;
22use crate::config::{AuthConfig, ProxyConfig, TransportType};
23use crate::filter::CapabilityFilterService;
24#[cfg(feature = "oauth")]
25use crate::rbac::{RbacConfig, RbacService};
26use crate::validation::{ValidationConfig, ValidationService};
27
28/// A fully constructed MCP proxy ready to serve or embed.
29pub struct Proxy {
30    router: Router,
31    session_handle: SessionHandle,
32    inner: McpProxy,
33    config: ProxyConfig,
34    #[cfg(feature = "discovery")]
35    discovery_index: Option<crate::discovery::SharedDiscoveryIndex>,
36}
37
38impl Proxy {
39    /// Build a proxy from a [`ProxyConfig`].
40    ///
41    /// Connects to all backends, builds the middleware stack, and prepares
42    /// the axum router. Call [`serve()`](Self::serve) to run standalone or
43    /// [`into_router()`](Self::into_router) to embed in an existing app.
44    pub async fn from_config(config: ProxyConfig) -> Result<Self> {
45        let (mcp_proxy, cb_handles) = build_mcp_proxy(&config).await?;
46        let proxy_for_admin = mcp_proxy.clone();
47        let mut proxy_for_caller = mcp_proxy.clone();
48        let proxy_for_management = mcp_proxy.clone();
49
50        // Install Prometheus metrics recorder (must happen before middleware)
51        #[cfg(feature = "metrics")]
52        let metrics_handle = if config.observability.metrics.enabled {
53            tracing::info!("Prometheus metrics enabled at /admin/metrics");
54            let builder = metrics_exporter_prometheus::PrometheusBuilder::new();
55            let handle = builder
56                .install_recorder()
57                .context("installing Prometheus metrics recorder")?;
58            Some(handle)
59        } else {
60            None
61        };
62        #[cfg(not(feature = "metrics"))]
63        let metrics_handle = None;
64
65        let (service, cache_handle) = build_middleware_stack(&config, mcp_proxy)?;
66
67        let (router, session_handle) =
68            tower_mcp::transport::http::HttpTransport::from_service(service)
69                .into_router_with_handle();
70
71        // Inbound authentication (axum-level middleware)
72        let router = apply_auth(&config, router).await?;
73
74        // Collect backend metadata for the health checker
75        let backend_meta: std::collections::HashMap<String, BackendMeta> = config
76            .backends
77            .iter()
78            .map(|b| {
79                (
80                    b.name.clone(),
81                    BackendMeta {
82                        transport: format!("{:?}", b.transport).to_lowercase(),
83                    },
84                )
85            })
86            .collect();
87
88        // Admin API
89        let admin_state = crate::admin::spawn_health_checker(
90            proxy_for_admin,
91            config.proxy.name.clone(),
92            config.proxy.version.clone(),
93            config.backends.len(),
94            backend_meta,
95        );
96        let router = router.nest(
97            "/admin",
98            crate::admin::admin_router(
99                admin_state.clone(),
100                metrics_handle,
101                session_handle.clone(),
102                cache_handle,
103                proxy_for_management,
104                &config,
105                config.source_path.clone(),
106                cb_handles,
107            ),
108        );
109        tracing::info!("Admin API enabled at /admin/backends");
110
111        // Build discovery index if enabled (search mode implies discovery)
112        #[cfg(feature = "discovery")]
113        let discovery_enabled = config.proxy.tool_discovery
114            || config.proxy.tool_exposure == crate::config::ToolExposure::Search;
115        #[cfg(feature = "discovery")]
116        let (discovery_index, discovery_tools) = if discovery_enabled {
117            let index =
118                crate::discovery::build_index(&mut proxy_for_caller, &config.proxy.separator).await;
119            let tools = crate::discovery::build_discovery_tools(index.clone());
120            (Some(index), Some(tools))
121        } else {
122            (None, None)
123        };
124        #[cfg(not(feature = "discovery"))]
125        let discovery_tools: Option<Vec<tower_mcp::Tool>> = None;
126
127        // MCP admin tools (proxy/ namespace)
128        if let Err(e) = crate::admin_tools::register_admin_tools(
129            &proxy_for_caller,
130            admin_state,
131            session_handle.clone(),
132            &config,
133            discovery_tools,
134        )
135        .await
136        {
137            tracing::warn!("Failed to register admin tools: {e}");
138        } else {
139            tracing::info!("MCP admin tools registered under proxy/ namespace");
140        }
141
142        Ok(Self {
143            router,
144            session_handle,
145            inner: proxy_for_caller,
146            config,
147            #[cfg(feature = "discovery")]
148            discovery_index,
149        })
150    }
151
152    /// Get a reference to the session handle for monitoring active sessions.
153    pub fn session_handle(&self) -> &SessionHandle {
154        &self.session_handle
155    }
156
157    /// Get a reference to the underlying [`McpProxy`] for dynamic operations.
158    ///
159    /// Use this to add backends dynamically via [`McpProxy::add_backend()`].
160    pub fn mcp_proxy(&self) -> &McpProxy {
161        &self.inner
162    }
163
164    /// Enable hot reload by watching the given config file path.
165    ///
166    /// New backends added to the config file will be connected dynamically
167    /// without restarting the proxy.
168    pub fn enable_hot_reload(&self, config_path: std::path::PathBuf) {
169        tracing::info!("Hot reload enabled, watching config file for changes");
170        crate::reload::spawn_config_watcher(
171            config_path,
172            self.inner.clone(),
173            #[cfg(feature = "discovery")]
174            self.discovery_index
175                .as_ref()
176                .map(|idx| (idx.clone(), self.config.proxy.separator.clone())),
177        );
178    }
179
180    /// Consume the proxy and return the axum Router and SessionHandle.
181    ///
182    /// Use this to embed the proxy in an existing axum application:
183    ///
184    /// ```rust,ignore
185    /// let (proxy_router, session_handle) = proxy.into_router();
186    ///
187    /// let app = Router::new()
188    ///     .nest("/mcp", proxy_router)
189    ///     .route("/health", get(|| async { "ok" }));
190    /// ```
191    pub fn into_router(self) -> (Router, SessionHandle) {
192        (self.router, self.session_handle)
193    }
194
195    /// Serve the proxy on the configured listen address.
196    ///
197    /// Blocks until a shutdown signal (SIGTERM/SIGINT) is received,
198    /// then drains connections for the configured timeout period.
199    pub async fn serve(self) -> Result<()> {
200        let addr = format!(
201            "{}:{}",
202            self.config.proxy.listen.host, self.config.proxy.listen.port
203        );
204
205        tracing::info!(listen = %addr, "Proxy ready");
206
207        let listener = tokio::net::TcpListener::bind(&addr)
208            .await
209            .with_context(|| format!("binding to {}", addr))?;
210
211        let shutdown_timeout = Duration::from_secs(self.config.proxy.shutdown_timeout_seconds);
212        axum::serve(listener, self.router)
213            .with_graceful_shutdown(shutdown_signal(shutdown_timeout))
214            .await
215            .context("server error")?;
216
217        tracing::info!("Proxy shut down");
218        Ok(())
219    }
220}
221
222/// Circuit breaker handle type alias.
223pub type CbHandle = tower_resilience::circuitbreaker::CircuitBreakerHandle;
224
225/// Build the McpProxy with all backends and per-backend middleware.
226/// Returns the proxy and a map of backend name -> circuit breaker handle.
227async fn build_mcp_proxy(config: &ProxyConfig) -> Result<(McpProxy, HashMap<String, CbHandle>)> {
228    let mut builder = McpProxy::builder(&config.proxy.name, &config.proxy.version)
229        .separator(&config.proxy.separator);
230    let mut cb_handles: HashMap<String, CbHandle> = HashMap::new();
231
232    if let Some(instructions) = &config.proxy.instructions {
233        builder = builder.instructions(instructions);
234    }
235
236    // Create shared outlier detector if any backend has outlier_detection configured.
237    // Use the max of all max_ejection_percent values.
238    let outlier_detector = {
239        let max_pct = config
240            .backends
241            .iter()
242            .filter_map(|b| b.outlier_detection.as_ref())
243            .map(|od| od.max_ejection_percent)
244            .max();
245        max_pct.map(crate::outlier::OutlierDetector::new)
246    };
247
248    for backend in &config.backends {
249        tracing::info!(name = %backend.name, transport = ?backend.transport, "Adding backend");
250
251        match backend.transport {
252            TransportType::Stdio => {
253                let command = backend.command.as_deref().unwrap();
254                let args: Vec<&str> = backend.args.iter().map(|s| s.as_str()).collect();
255
256                let mut cmd = Command::new(command);
257                cmd.args(&args);
258
259                for (key, value) in &backend.env {
260                    cmd.env(key, value);
261                }
262
263                let transport = StdioClientTransport::spawn_command(&mut cmd)
264                    .await
265                    .with_context(|| format!("spawning backend '{}'", backend.name))?;
266
267                builder = builder.backend(&backend.name, transport).await;
268            }
269            TransportType::Http => {
270                let url = backend.url.as_deref().unwrap();
271                let mut transport = tower_mcp::client::HttpClientTransport::new(url);
272                if let Some(token) = &backend.bearer_token {
273                    transport = transport.bearer_token(token);
274                }
275
276                builder = builder.backend(&backend.name, transport).await;
277            }
278            #[cfg(feature = "websocket")]
279            TransportType::Websocket => {
280                let url = backend.url.as_deref().unwrap();
281                tracing::info!(url = %url, "Connecting to WebSocket backend");
282                let transport = if let Some(token) = &backend.bearer_token {
283                    crate::ws_transport::WebSocketClientTransport::connect_with_bearer_token(
284                        url, token,
285                    )
286                    .await
287                    .with_context(|| {
288                        format!("connecting to WebSocket backend '{}'", backend.name)
289                    })?
290                } else {
291                    crate::ws_transport::WebSocketClientTransport::connect(url)
292                        .await
293                        .with_context(|| {
294                            format!("connecting to WebSocket backend '{}'", backend.name)
295                        })?
296                };
297
298                builder = builder.backend(&backend.name, transport).await;
299            }
300            #[cfg(not(feature = "websocket"))]
301            TransportType::Websocket => {
302                anyhow::bail!(
303                    "WebSocket transport requires the 'websocket' feature. \
304                     Rebuild with: cargo install mcp-proxy --features websocket"
305                );
306            }
307        }
308
309        // Per-backend middleware stack (applied in order: inner -> outer)
310
311        // Retry (innermost -- retries happen before other middleware)
312        if let Some(retry_cfg) = &backend.retry {
313            tracing::info!(
314                backend = %backend.name,
315                max_retries = retry_cfg.max_retries,
316                initial_backoff_ms = retry_cfg.initial_backoff_ms,
317                max_backoff_ms = retry_cfg.max_backoff_ms,
318                "Applying retry policy"
319            );
320            let layer = crate::retry::build_retry_layer(retry_cfg, &backend.name);
321            builder = builder.backend_layer(layer);
322        }
323
324        // Hedging (after retry, before concurrency -- hedges are separate requests)
325        if let Some(hedge_cfg) = &backend.hedging {
326            let delay = Duration::from_millis(hedge_cfg.delay_ms);
327            let max_attempts = hedge_cfg.max_hedges + 1; // +1 for the primary request
328            tracing::info!(
329                backend = %backend.name,
330                delay_ms = hedge_cfg.delay_ms,
331                max_hedges = hedge_cfg.max_hedges,
332                "Applying request hedging"
333            );
334            let layer = if delay.is_zero() {
335                tower_resilience::hedge::HedgeLayer::builder()
336                    .no_delay()
337                    .max_hedged_attempts(max_attempts)
338                    .name(format!("{}-hedge", backend.name))
339                    .build()
340            } else {
341                tower_resilience::hedge::HedgeLayer::builder()
342                    .delay(delay)
343                    .max_hedged_attempts(max_attempts)
344                    .name(format!("{}-hedge", backend.name))
345                    .build()
346            };
347            builder = builder.backend_layer(layer);
348        }
349
350        // Concurrency limit
351        if let Some(cc) = &backend.concurrency {
352            tracing::info!(
353                backend = %backend.name,
354                max = cc.max_concurrent,
355                "Applying concurrency limit"
356            );
357            builder =
358                builder.backend_layer(tower::limit::ConcurrencyLimitLayer::new(cc.max_concurrent));
359        }
360
361        // Rate limit
362        if let Some(rl) = &backend.rate_limit {
363            tracing::info!(
364                backend = %backend.name,
365                requests = rl.requests,
366                period_seconds = rl.period_seconds,
367                "Applying rate limit"
368            );
369            let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
370                .limit_for_period(rl.requests)
371                .refresh_period(Duration::from_secs(rl.period_seconds))
372                .name(format!("{}-ratelimit", backend.name))
373                .build();
374            builder = builder.backend_layer(layer);
375        }
376
377        // Timeout
378        if let Some(timeout) = &backend.timeout {
379            tracing::info!(
380                backend = %backend.name,
381                seconds = timeout.seconds,
382                "Applying timeout"
383            );
384            builder =
385                builder.backend_layer(TimeoutLayer::new(Duration::from_secs(timeout.seconds)));
386        }
387
388        // Circuit breaker
389        if let Some(cb) = &backend.circuit_breaker {
390            tracing::info!(
391                backend = %backend.name,
392                failure_rate = cb.failure_rate_threshold,
393                wait_seconds = cb.wait_duration_seconds,
394                "Applying circuit breaker"
395            );
396            let (layer, handle) = tower_resilience::circuitbreaker::CircuitBreakerLayer::builder()
397                .failure_rate_threshold(cb.failure_rate_threshold)
398                .minimum_number_of_calls(cb.minimum_calls)
399                .wait_duration_in_open(Duration::from_secs(cb.wait_duration_seconds))
400                .permitted_calls_in_half_open(cb.permitted_calls_in_half_open)
401                .name(format!("{}-cb", backend.name))
402                .build_with_handle();
403            cb_handles.insert(backend.name.clone(), handle);
404            builder = builder.backend_layer(layer);
405        }
406
407        // Outlier detection (outermost -- observes errors after all other middleware)
408        if let Some(od) = &backend.outlier_detection
409            && let Some(ref detector) = outlier_detector
410        {
411            tracing::info!(
412                backend = %backend.name,
413                consecutive_errors = od.consecutive_errors,
414                base_ejection_seconds = od.base_ejection_seconds,
415                max_ejection_percent = od.max_ejection_percent,
416                "Applying outlier detection"
417            );
418            let layer = crate::outlier::OutlierDetectionLayer::new(
419                backend.name.clone(),
420                od.clone(),
421                detector.clone(),
422            );
423            builder = builder.backend_layer(layer);
424        }
425    }
426
427    let result = builder.build().await?;
428
429    if !result.skipped.is_empty() {
430        for s in &result.skipped {
431            tracing::warn!("Skipped backend: {s}");
432        }
433    }
434
435    Ok((result.proxy, cb_handles))
436}
437
438/// Build the MCP-level middleware stack around the proxy.
439fn build_middleware_stack(
440    config: &ProxyConfig,
441    proxy: McpProxy,
442) -> Result<(
443    BoxCloneService<RouterRequest, RouterResponse, Infallible>,
444    Option<cache::CacheHandle>,
445)> {
446    let mut service: BoxCloneService<RouterRequest, RouterResponse, Infallible> =
447        BoxCloneService::new(proxy);
448    let mut cache_handle: Option<cache::CacheHandle> = None;
449
450    // Argument injection (innermost -- merges default/per-tool args into CallTool requests)
451    let injection_rules: Vec<_> = config
452        .backends
453        .iter()
454        .filter(|b| !b.default_args.is_empty() || !b.inject_args.is_empty())
455        .map(|b| {
456            let namespace = format!("{}{}", b.name, config.proxy.separator);
457            tracing::info!(
458                backend = %b.name,
459                default_args = b.default_args.len(),
460                tool_rules = b.inject_args.len(),
461                "Applying argument injection"
462            );
463            crate::inject::InjectionRules::new(
464                namespace,
465                b.default_args.clone(),
466                b.inject_args.clone(),
467            )
468        })
469        .collect();
470
471    if !injection_rules.is_empty() {
472        service = BoxCloneService::new(crate::inject::InjectArgsService::new(
473            service,
474            injection_rules,
475        ));
476    }
477
478    // Parameter overrides (after inject, before filter -- hides/renames tool params)
479    let param_overrides: Vec<_> = config
480        .backends
481        .iter()
482        .filter(|b| !b.param_overrides.is_empty())
483        .flat_map(|b| {
484            let namespace = format!("{}{}", b.name, config.proxy.separator);
485            tracing::info!(
486                backend = %b.name,
487                overrides = b.param_overrides.len(),
488                "Applying parameter overrides"
489            );
490            b.param_overrides
491                .iter()
492                .map(move |c| crate::param_override::ToolOverride::new(&namespace, c))
493        })
494        .collect();
495
496    if !param_overrides.is_empty() {
497        service = BoxCloneService::new(crate::param_override::ParamOverrideService::new(
498            service,
499            param_overrides,
500        ));
501    }
502
503    // Canary routing (rewrites requests from primary to canary namespace based on weight)
504    let canary_mappings: std::collections::HashMap<String, (String, u32, u32)> = config
505        .backends
506        .iter()
507        .filter_map(|b| {
508            b.canary_of.as_ref().map(|primary_name| {
509                // Find the primary backend's weight
510                let primary_weight = config
511                    .backends
512                    .iter()
513                    .find(|p| p.name == *primary_name)
514                    .map(|p| p.weight)
515                    .unwrap_or(100);
516                (
517                    primary_name.clone(),
518                    (b.name.clone(), primary_weight, b.weight),
519                )
520            })
521        })
522        .collect();
523
524    if !canary_mappings.is_empty() {
525        for (primary, (canary, pw, cw)) in &canary_mappings {
526            tracing::info!(
527                primary = %primary,
528                canary = %canary,
529                primary_weight = pw,
530                canary_weight = cw,
531                "Enabling canary routing"
532            );
533        }
534        service = BoxCloneService::new(crate::canary::CanaryService::new(
535            service,
536            canary_mappings,
537            &config.proxy.separator,
538        ));
539    }
540
541    // Failover routing (deterministic fallback on primary error)
542    // Collect failover backends grouped by primary, sorted by priority (ascending).
543    let mut failover_groups: std::collections::HashMap<String, Vec<(u32, String)>> =
544        std::collections::HashMap::new();
545    for b in &config.backends {
546        if let Some(ref primary) = b.failover_for {
547            failover_groups
548                .entry(primary.clone())
549                .or_default()
550                .push((b.priority, b.name.clone()));
551        }
552    }
553    // Sort each group by priority (lower = preferred)
554    let failover_mappings: std::collections::HashMap<String, Vec<String>> = failover_groups
555        .into_iter()
556        .map(|(primary, mut backends)| {
557            backends.sort_by_key(|(priority, _)| *priority);
558            let names: Vec<String> = backends.into_iter().map(|(_, name)| name).collect();
559            (primary, names)
560        })
561        .collect();
562
563    if !failover_mappings.is_empty() {
564        for (primary, failovers) in &failover_mappings {
565            tracing::info!(
566                primary = %primary,
567                failovers = ?failovers,
568                "Enabling failover routing"
569            );
570        }
571        service = BoxCloneService::new(crate::failover::FailoverService::new(
572            service,
573            failover_mappings,
574            &config.proxy.separator,
575        ));
576    }
577
578    // Traffic mirroring (sends cloned requests through the proxy)
579    let mirror_mappings: std::collections::HashMap<String, (String, u32)> = config
580        .backends
581        .iter()
582        .filter_map(|b| {
583            b.mirror_of
584                .as_ref()
585                .map(|source| (source.clone(), (b.name.clone(), b.mirror_percent)))
586        })
587        .collect();
588
589    if !mirror_mappings.is_empty() {
590        for (source, (mirror, pct)) in &mirror_mappings {
591            tracing::info!(
592                source = %source,
593                mirror = %mirror,
594                percent = pct,
595                "Enabling traffic mirroring"
596            );
597        }
598        service = BoxCloneService::new(crate::mirror::MirrorService::new(
599            service,
600            mirror_mappings,
601            &config.proxy.separator,
602        ));
603    }
604
605    // Response caching
606    let cache_configs: Vec<_> = config
607        .backends
608        .iter()
609        .filter_map(|b| {
610            b.cache
611                .as_ref()
612                .map(|c| (format!("{}{}", b.name, config.proxy.separator), c))
613        })
614        .collect();
615
616    if !cache_configs.is_empty() {
617        for (ns, cfg) in &cache_configs {
618            tracing::info!(
619                backend = %ns.trim_end_matches(&config.proxy.separator),
620                resource_ttl = cfg.resource_ttl_seconds,
621                tool_ttl = cfg.tool_ttl_seconds,
622                max_entries = cfg.max_entries,
623                "Applying response cache"
624            );
625        }
626        let (cache_svc, handle) = cache::CacheService::new(service, cache_configs, &config.cache);
627        service = BoxCloneService::new(cache_svc);
628        cache_handle = Some(handle);
629    }
630
631    // Request coalescing
632    if config.performance.coalesce_requests {
633        tracing::info!("Request coalescing enabled");
634        service = BoxCloneService::new(coalesce::CoalesceService::new(service));
635    }
636
637    // Request validation
638    if config.security.max_argument_size.is_some() {
639        let validation = ValidationConfig {
640            max_argument_size: config.security.max_argument_size,
641        };
642        if let Some(max) = validation.max_argument_size {
643            tracing::info!(max_argument_size = max, "Applying request validation");
644        }
645        service = BoxCloneService::new(ValidationService::new(service, validation));
646    }
647
648    // Static capability filtering
649    let filters: Vec<_> = config
650        .backends
651        .iter()
652        .filter_map(|b| b.build_filter(&config.proxy.separator).transpose())
653        .collect::<anyhow::Result<Vec<_>>>()?;
654
655    if !filters.is_empty() {
656        for f in &filters {
657            tracing::info!(
658                backend = %f.namespace.trim_end_matches(&config.proxy.separator),
659                tool_filter = ?f.tool_filter,
660                resource_filter = ?f.resource_filter,
661                prompt_filter = ?f.prompt_filter,
662                "Applying capability filter"
663            );
664        }
665        service = BoxCloneService::new(CapabilityFilterService::new(service, filters));
666    }
667
668    // Search-mode filtering: hide all tools except proxy/ namespace
669    if config.proxy.tool_exposure == crate::config::ToolExposure::Search {
670        let prefix = format!("proxy{}", config.proxy.separator);
671        tracing::info!(
672            prefix = %prefix,
673            "Search mode: ListTools will only show proxy/ namespace tools"
674        );
675        service =
676            BoxCloneService::new(crate::filter::SearchModeFilterService::new(service, prefix));
677    }
678
679    // Tool aliasing
680    let alias_mappings: Vec<_> = config
681        .backends
682        .iter()
683        .flat_map(|b| {
684            let ns = format!("{}{}", b.name, config.proxy.separator);
685            b.aliases
686                .iter()
687                .map(move |a| (ns.clone(), a.from.clone(), a.to.clone()))
688        })
689        .collect();
690
691    if let Some(alias_map) = alias::AliasMap::new(alias_mappings) {
692        let count = alias_map.forward.len();
693        tracing::info!(aliases = count, "Applying tool aliases");
694        service = BoxCloneService::new(alias::AliasService::new(service, alias_map));
695    }
696
697    // Composite tools (fan-out to multiple backend tools)
698    if !config.composite_tools.is_empty() {
699        let count = config.composite_tools.len();
700        tracing::info!(composite_tools = count, "Applying composite tool fan-out");
701        service = BoxCloneService::new(crate::composite::CompositeService::new(
702            service,
703            config.composite_tools.clone(),
704        ));
705    }
706
707    // Bearer token scoping (per-token allow/deny lists)
708    #[cfg(feature = "oauth")]
709    if matches!(
710        &config.auth,
711        Some(AuthConfig::Bearer {
712            scoped_tokens,
713            ..
714        }) if !scoped_tokens.is_empty()
715    ) {
716        tracing::info!("Enabling bearer token scoping middleware");
717        service = BoxCloneService::new(crate::bearer_scope::BearerScopingService::new(service));
718    }
719
720    // RBAC (JWT auth only)
721    #[cfg(feature = "oauth")]
722    {
723        let rbac_config = match &config.auth {
724            Some(
725                AuthConfig::Jwt {
726                    roles,
727                    role_mapping: Some(mapping),
728                    ..
729                }
730                | AuthConfig::OAuth {
731                    roles,
732                    role_mapping: Some(mapping),
733                    ..
734                },
735            ) if !roles.is_empty() => {
736                tracing::info!(
737                    roles = roles.len(),
738                    claim = %mapping.claim,
739                    "Enabling RBAC"
740                );
741                Some(RbacConfig::new(roles, mapping))
742            }
743            _ => None,
744        };
745
746        if let Some(rbac) = rbac_config {
747            service = BoxCloneService::new(RbacService::new(service, rbac));
748        }
749
750        // Token passthrough (inject ClientToken for forward_auth backends)
751        let forward_namespaces: std::collections::HashSet<String> = config
752            .backends
753            .iter()
754            .filter(|b| b.forward_auth)
755            .map(|b| format!("{}{}", b.name, config.proxy.separator))
756            .collect();
757
758        if !forward_namespaces.is_empty() {
759            tracing::info!(
760                backends = ?forward_namespaces,
761                "Enabling token passthrough for forward_auth backends"
762            );
763            service = BoxCloneService::new(crate::token::TokenPassthroughService::new(
764                service,
765                forward_namespaces,
766            ));
767        }
768    }
769
770    // Metrics
771    #[cfg(feature = "metrics")]
772    if config.observability.metrics.enabled {
773        service = BoxCloneService::new(crate::metrics::MetricsService::new(service));
774    }
775
776    // Structured access logging
777    if config.observability.access_log.enabled {
778        tracing::info!("Access logging enabled (target: mcp::access)");
779        service = BoxCloneService::new(crate::access_log::AccessLogService::new(
780            service,
781            &config.proxy.separator,
782        ));
783    }
784
785    // Audit logging
786    if config.observability.audit {
787        tracing::info!("Audit logging enabled (target: mcp::audit)");
788        let audited = tower::Layer::layer(&tower_mcp::AuditLayer::new(), service);
789        service = BoxCloneService::new(tower_mcp::CatchError::new(audited));
790    }
791
792    // Global rate limit (outermost -- protects entire proxy)
793    if let Some(ref rl) = config.proxy.rate_limit {
794        tracing::info!(
795            requests = rl.requests,
796            period_seconds = rl.period_seconds,
797            "Applying global rate limit"
798        );
799        let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
800            .limit_for_period(rl.requests)
801            .refresh_period(Duration::from_secs(rl.period_seconds))
802            .name("global-ratelimit")
803            .build();
804        let limited = tower::Layer::layer(&layer, service);
805        service = BoxCloneService::new(tower_mcp::CatchError::new(limited));
806    }
807
808    Ok((service, cache_handle))
809}
810
811/// Apply inbound authentication middleware to the router.
812async fn apply_auth(config: &ProxyConfig, router: Router) -> Result<Router> {
813    let router = if let Some(auth) = &config.auth {
814        match auth {
815            AuthConfig::Bearer {
816                tokens,
817                scoped_tokens,
818            } => {
819                let total = tokens.len() + scoped_tokens.len();
820                if scoped_tokens.is_empty() {
821                    // Simple bearer auth: use StaticBearerValidator
822                    tracing::info!(token_count = total, "Enabling bearer token auth");
823                    let validator = StaticBearerValidator::new(tokens.iter().cloned());
824                    let layer = AuthLayer::new(validator);
825                    router.layer(layer)
826                } else {
827                    // Scoped bearer auth: use custom layer that injects TokenClaims
828                    #[cfg(feature = "oauth")]
829                    {
830                        tracing::info!(
831                            token_count = total,
832                            scoped = scoped_tokens.len(),
833                            "Enabling bearer token auth with per-token scoping"
834                        );
835                        let layer =
836                            crate::bearer_scope::ScopedBearerAuthLayer::new(tokens, scoped_tokens);
837                        router.layer(layer)
838                    }
839                    #[cfg(not(feature = "oauth"))]
840                    {
841                        anyhow::bail!(
842                            "Per-token tool scoping requires the 'oauth' feature. \
843                             Rebuild with: cargo install mcp-proxy --features oauth"
844                        );
845                    }
846                }
847            }
848            #[cfg(feature = "oauth")]
849            AuthConfig::Jwt {
850                issuer,
851                audience,
852                jwks_uri,
853                ..
854            } => {
855                tracing::info!(
856                    issuer = %issuer,
857                    audience = %audience,
858                    jwks_uri = %jwks_uri,
859                    "Enabling JWT auth (JWKS)"
860                );
861                let validator = tower_mcp::oauth::JwksValidator::builder(jwks_uri)
862                    .expected_audience(audience)
863                    .expected_issuer(issuer)
864                    .build()
865                    .await
866                    .context("building JWKS validator")?;
867
868                let addr = format!(
869                    "http://{}:{}",
870                    config.proxy.listen.host, config.proxy.listen.port
871                );
872                let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
873                    .authorization_server(issuer);
874
875                let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
876                router.layer(layer)
877            }
878            #[cfg(not(feature = "oauth"))]
879            AuthConfig::Jwt { .. } => {
880                anyhow::bail!(
881                    "JWT auth requires the 'oauth' feature. Rebuild with: cargo install mcp-proxy --features oauth"
882                );
883            }
884            #[cfg(feature = "oauth")]
885            AuthConfig::OAuth {
886                issuer,
887                audience,
888                token_validation,
889                jwks_uri,
890                introspection_endpoint,
891                client_id,
892                client_secret,
893                ..
894            } => {
895                use crate::config::TokenValidationStrategy;
896
897                tracing::info!(
898                    issuer = %issuer,
899                    audience = %audience,
900                    strategy = ?token_validation,
901                    "Enabling OAuth 2.1 auth"
902                );
903
904                // Auto-discover endpoints from issuer if not overridden
905                let discovered = crate::introspection::discover_auth_server(issuer)
906                    .await
907                    .context("discovering OAuth authorization server")?;
908
909                let effective_jwks_uri = jwks_uri
910                    .as_deref()
911                    .or(discovered.jwks_uri.as_deref())
912                    .ok_or_else(|| {
913                        anyhow::anyhow!(
914                            "JWKS URI not found via discovery and not configured manually"
915                        )
916                    })?;
917
918                let effective_introspection = introspection_endpoint
919                    .as_deref()
920                    .or(discovered.introspection_endpoint.as_deref());
921
922                let addr = format!(
923                    "http://{}:{}",
924                    config.proxy.listen.host, config.proxy.listen.port
925                );
926                let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
927                    .authorization_server(issuer);
928
929                match token_validation {
930                    TokenValidationStrategy::Jwt => {
931                        let validator =
932                            tower_mcp::oauth::JwksValidator::builder(effective_jwks_uri)
933                                .expected_audience(audience)
934                                .expected_issuer(issuer)
935                                .build()
936                                .await
937                                .context("building JWKS validator")?;
938                        let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
939                        router.layer(layer)
940                    }
941                    TokenValidationStrategy::Introspection => {
942                        let endpoint = effective_introspection.ok_or_else(|| {
943                            anyhow::anyhow!(
944                                "introspection endpoint not found via discovery and not configured"
945                            )
946                        })?;
947                        let validator = crate::introspection::IntrospectionValidator::new(
948                            endpoint,
949                            client_id.as_deref().unwrap(),
950                            client_secret.as_deref().unwrap(),
951                        )
952                        .expected_audience(audience);
953                        let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
954                        router.layer(layer)
955                    }
956                    TokenValidationStrategy::Both => {
957                        let endpoint = effective_introspection.ok_or_else(|| {
958                            anyhow::anyhow!(
959                                "introspection endpoint not found via discovery and not configured"
960                            )
961                        })?;
962                        let jwt_validator =
963                            tower_mcp::oauth::JwksValidator::builder(effective_jwks_uri)
964                                .expected_audience(audience)
965                                .expected_issuer(issuer)
966                                .build()
967                                .await
968                                .context("building JWKS validator")?;
969                        let introspection_validator =
970                            crate::introspection::IntrospectionValidator::new(
971                                endpoint,
972                                client_id.as_deref().unwrap(),
973                                client_secret.as_deref().unwrap(),
974                            )
975                            .expected_audience(audience);
976                        let fallback = crate::introspection::FallbackValidator::new(
977                            jwt_validator,
978                            introspection_validator,
979                        );
980                        let layer = tower_mcp::oauth::OAuthLayer::new(fallback, metadata);
981                        router.layer(layer)
982                    }
983                }
984            }
985            #[cfg(not(feature = "oauth"))]
986            AuthConfig::OAuth { .. } => {
987                anyhow::bail!(
988                    "OAuth auth requires the 'oauth' feature. Rebuild with: cargo install mcp-proxy --features oauth"
989                );
990            }
991        }
992    } else {
993        router
994    };
995    Ok(router)
996}
997
998/// Wait for SIGTERM or SIGINT, then log and return.
999pub async fn shutdown_signal(timeout: Duration) {
1000    let ctrl_c = tokio::signal::ctrl_c();
1001    #[cfg(unix)]
1002    {
1003        let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
1004            .expect("SIGTERM handler");
1005        tokio::select! {
1006            _ = ctrl_c => {},
1007            _ = sigterm.recv() => {},
1008        }
1009    }
1010    #[cfg(not(unix))]
1011    {
1012        ctrl_c.await.ok();
1013    }
1014    tracing::info!(
1015        timeout_seconds = timeout.as_secs(),
1016        "Shutdown signal received, draining connections"
1017    );
1018}