Skip to main content

mcp_proxy/
proxy.rs

1//! Core proxy construction and serving.
2
3use std::collections::HashMap;
4use std::convert::Infallible;
5use std::time::Duration;
6
7use anyhow::{Context, Result};
8use axum::Router;
9use tokio::process::Command;
10use tower::timeout::TimeoutLayer;
11use tower::util::BoxCloneService;
12use tower_mcp::SessionHandle;
13use tower_mcp::auth::{AuthLayer, StaticBearerValidator};
14use tower_mcp::client::StdioClientTransport;
15use tower_mcp::proxy::McpProxy;
16use tower_mcp::{RouterRequest, RouterResponse};
17
18use crate::admin::BackendMeta;
19use crate::alias;
20use crate::cache;
21use crate::coalesce;
22use crate::config::{AuthConfig, ProxyConfig, TransportType};
23use crate::filter::CapabilityFilterService;
24#[cfg(feature = "oauth")]
25use crate::rbac::{RbacConfig, RbacService};
26use crate::validation::{ValidationConfig, ValidationService};
27
28/// A fully constructed MCP proxy ready to serve or embed.
29pub struct Proxy {
30    router: Router,
31    session_handle: SessionHandle,
32    inner: McpProxy,
33    config: ProxyConfig,
34    #[cfg(feature = "discovery")]
35    discovery_index: Option<crate::discovery::SharedDiscoveryIndex>,
36}
37
38impl Proxy {
39    /// Build a proxy from a [`ProxyConfig`].
40    ///
41    /// Connects to all backends, builds the middleware stack, and prepares
42    /// the axum router. Call [`serve()`](Self::serve) to run standalone or
43    /// [`into_router()`](Self::into_router) to embed in an existing app.
44    pub async fn from_config(config: ProxyConfig) -> Result<Self> {
45        let (mcp_proxy, cb_handles) = build_mcp_proxy(&config).await?;
46        let proxy_for_admin = mcp_proxy.clone();
47        let mut proxy_for_caller = mcp_proxy.clone();
48        let proxy_for_management = mcp_proxy.clone();
49
50        // Install Prometheus metrics recorder (must happen before middleware)
51        #[cfg(feature = "metrics")]
52        let metrics_handle = if config.observability.metrics.enabled {
53            tracing::info!("Prometheus metrics enabled at /admin/metrics");
54            let builder = metrics_exporter_prometheus::PrometheusBuilder::new();
55            let handle = builder
56                .install_recorder()
57                .context("installing Prometheus metrics recorder")?;
58            Some(handle)
59        } else {
60            None
61        };
62        #[cfg(not(feature = "metrics"))]
63        let metrics_handle = None;
64
65        let (service, cache_handle) = build_middleware_stack(&config, mcp_proxy)?;
66
67        let (router, session_handle) =
68            tower_mcp::transport::http::HttpTransport::from_service(service)
69                .into_router_with_handle();
70
71        // Inbound authentication (axum-level middleware)
72        let router = apply_auth(&config, router).await?;
73
74        // Collect backend metadata for the health checker
75        let backend_meta: std::collections::HashMap<String, BackendMeta> = config
76            .backends
77            .iter()
78            .map(|b| {
79                (
80                    b.name.clone(),
81                    BackendMeta {
82                        transport: format!("{:?}", b.transport).to_lowercase(),
83                    },
84                )
85            })
86            .collect();
87
88        // Admin API
89        let admin_state = crate::admin::spawn_health_checker(
90            proxy_for_admin,
91            config.proxy.name.clone(),
92            config.proxy.version.clone(),
93            config.backends.len(),
94            backend_meta,
95        );
96        let router = router.nest(
97            "/admin",
98            crate::admin::admin_router(
99                admin_state.clone(),
100                metrics_handle,
101                session_handle.clone(),
102                cache_handle,
103                proxy_for_management,
104                &config,
105                config.source_path.clone(),
106                cb_handles,
107            ),
108        );
109        tracing::info!("Admin API enabled at /admin/backends");
110
111        // Build discovery index if enabled (search mode implies discovery)
112        #[cfg(feature = "discovery")]
113        let discovery_enabled = config.proxy.tool_discovery
114            || config.proxy.tool_exposure == crate::config::ToolExposure::Search;
115        #[cfg(feature = "discovery")]
116        let (discovery_index, discovery_tools) = if discovery_enabled {
117            let index =
118                crate::discovery::build_index(&mut proxy_for_caller, &config.proxy.separator).await;
119            let tools = crate::discovery::build_discovery_tools(index.clone());
120            (Some(index), Some(tools))
121        } else {
122            (None, None)
123        };
124        #[cfg(not(feature = "discovery"))]
125        let discovery_tools: Option<Vec<tower_mcp::Tool>> = None;
126
127        // MCP admin tools (proxy/ namespace)
128        if let Err(e) = crate::admin_tools::register_admin_tools(
129            &proxy_for_caller,
130            admin_state,
131            session_handle.clone(),
132            &config,
133            discovery_tools,
134        )
135        .await
136        {
137            tracing::warn!("Failed to register admin tools: {e}");
138        } else {
139            tracing::info!("MCP admin tools registered under proxy/ namespace");
140        }
141
142        Ok(Self {
143            router,
144            session_handle,
145            inner: proxy_for_caller,
146            config,
147            #[cfg(feature = "discovery")]
148            discovery_index,
149        })
150    }
151
152    /// Get a reference to the session handle for monitoring active sessions.
153    pub fn session_handle(&self) -> &SessionHandle {
154        &self.session_handle
155    }
156
157    /// Get a reference to the underlying [`McpProxy`] for dynamic operations.
158    ///
159    /// Use this to add backends dynamically via [`McpProxy::add_backend()`].
160    pub fn mcp_proxy(&self) -> &McpProxy {
161        &self.inner
162    }
163
164    /// Enable hot reload by watching the given config file path.
165    ///
166    /// New backends added to the config file will be connected dynamically
167    /// without restarting the proxy.
168    pub fn enable_hot_reload(&self, config_path: std::path::PathBuf) {
169        tracing::info!("Hot reload enabled, watching config file for changes");
170        crate::reload::spawn_config_watcher(
171            config_path,
172            self.inner.clone(),
173            #[cfg(feature = "discovery")]
174            self.discovery_index
175                .as_ref()
176                .map(|idx| (idx.clone(), self.config.proxy.separator.clone())),
177        );
178    }
179
180    /// Consume the proxy and return the axum Router and SessionHandle.
181    ///
182    /// Use this to embed the proxy in an existing axum application:
183    ///
184    /// ```rust,ignore
185    /// let (proxy_router, session_handle) = proxy.into_router();
186    ///
187    /// let app = Router::new()
188    ///     .nest("/mcp", proxy_router)
189    ///     .route("/health", get(|| async { "ok" }));
190    /// ```
191    pub fn into_router(self) -> (Router, SessionHandle) {
192        (self.router, self.session_handle)
193    }
194
195    /// Serve the proxy on the configured listen address.
196    ///
197    /// Blocks until a shutdown signal (SIGTERM/SIGINT) is received,
198    /// then drains connections for the configured timeout period.
199    pub async fn serve(self) -> Result<()> {
200        let addr = format!(
201            "{}:{}",
202            self.config.proxy.listen.host, self.config.proxy.listen.port
203        );
204
205        tracing::info!(listen = %addr, "Proxy ready");
206
207        let listener = tokio::net::TcpListener::bind(&addr)
208            .await
209            .with_context(|| format!("binding to {}", addr))?;
210
211        let shutdown_timeout = Duration::from_secs(self.config.proxy.shutdown_timeout_seconds);
212        axum::serve(listener, self.router)
213            .with_graceful_shutdown(shutdown_signal(shutdown_timeout))
214            .await
215            .context("server error")?;
216
217        tracing::info!("Proxy shut down");
218        Ok(())
219    }
220}
221
222/// Circuit breaker handle type alias.
223pub type CbHandle = tower_resilience::circuitbreaker::CircuitBreakerHandle;
224
225/// Build the McpProxy with all backends and per-backend middleware.
226/// Returns the proxy and a map of backend name -> circuit breaker handle.
227async fn build_mcp_proxy(config: &ProxyConfig) -> Result<(McpProxy, HashMap<String, CbHandle>)> {
228    let mut builder = McpProxy::builder(&config.proxy.name, &config.proxy.version)
229        .separator(&config.proxy.separator);
230    let mut cb_handles: HashMap<String, CbHandle> = HashMap::new();
231
232    if let Some(instructions) = &config.proxy.instructions {
233        builder = builder.instructions(instructions);
234    }
235
236    // Create shared outlier detector if any backend has outlier_detection configured.
237    // Use the max of all max_ejection_percent values.
238    let outlier_detector = {
239        let max_pct = config
240            .backends
241            .iter()
242            .filter_map(|b| b.outlier_detection.as_ref())
243            .map(|od| od.max_ejection_percent)
244            .max();
245        max_pct.map(crate::outlier::OutlierDetector::new)
246    };
247
248    for backend in &config.backends {
249        tracing::info!(name = %backend.name, transport = ?backend.transport, "Adding backend");
250
251        match backend.transport {
252            TransportType::Stdio => {
253                let command = backend.command.as_deref().unwrap();
254                let args: Vec<&str> = backend.args.iter().map(|s| s.as_str()).collect();
255
256                let mut cmd = Command::new(command);
257                cmd.args(&args);
258
259                for (key, value) in &backend.env {
260                    cmd.env(key, value);
261                }
262
263                let transport = StdioClientTransport::spawn_command(&mut cmd)
264                    .await
265                    .with_context(|| format!("spawning backend '{}'", backend.name))?;
266
267                builder = builder.backend(&backend.name, transport).await;
268            }
269            TransportType::Http => {
270                let url = backend.url.as_deref().unwrap();
271                let mut transport = tower_mcp::client::HttpClientTransport::new(url);
272                if let Some(token) = &backend.bearer_token {
273                    transport = transport.bearer_token(token);
274                }
275
276                builder = builder.backend(&backend.name, transport).await;
277            }
278            #[cfg(feature = "websocket")]
279            TransportType::Websocket => {
280                let url = backend.url.as_deref().unwrap();
281                tracing::info!(url = %url, "Connecting to WebSocket backend");
282                let transport = if let Some(token) = &backend.bearer_token {
283                    crate::ws_transport::WebSocketClientTransport::connect_with_bearer_token(
284                        url, token,
285                    )
286                    .await
287                    .with_context(|| {
288                        format!("connecting to WebSocket backend '{}'", backend.name)
289                    })?
290                } else {
291                    crate::ws_transport::WebSocketClientTransport::connect(url)
292                        .await
293                        .with_context(|| {
294                            format!("connecting to WebSocket backend '{}'", backend.name)
295                        })?
296                };
297
298                builder = builder.backend(&backend.name, transport).await;
299            }
300            #[cfg(not(feature = "websocket"))]
301            TransportType::Websocket => {
302                anyhow::bail!(
303                    "WebSocket transport requires the 'websocket' feature. \
304                     Rebuild with: cargo install mcp-proxy --features websocket"
305                );
306            }
307        }
308
309        // Per-backend middleware stack (applied in order: inner -> outer)
310
311        // Retry (innermost -- retries happen before other middleware)
312        if let Some(retry_cfg) = &backend.retry {
313            tracing::info!(
314                backend = %backend.name,
315                max_retries = retry_cfg.max_retries,
316                initial_backoff_ms = retry_cfg.initial_backoff_ms,
317                max_backoff_ms = retry_cfg.max_backoff_ms,
318                "Applying retry policy"
319            );
320            let layer = crate::retry::build_retry_layer(retry_cfg, &backend.name);
321            builder = builder.backend_layer(layer);
322        }
323
324        // Hedging (after retry, before concurrency -- hedges are separate requests)
325        if let Some(hedge_cfg) = &backend.hedging {
326            let delay = Duration::from_millis(hedge_cfg.delay_ms);
327            let max_attempts = hedge_cfg.max_hedges + 1; // +1 for the primary request
328            tracing::info!(
329                backend = %backend.name,
330                delay_ms = hedge_cfg.delay_ms,
331                max_hedges = hedge_cfg.max_hedges,
332                "Applying request hedging"
333            );
334            let layer = if delay.is_zero() {
335                tower_resilience::hedge::HedgeLayer::builder()
336                    .no_delay()
337                    .max_hedged_attempts(max_attempts)
338                    .name(format!("{}-hedge", backend.name))
339                    .build()
340            } else {
341                tower_resilience::hedge::HedgeLayer::builder()
342                    .delay(delay)
343                    .max_hedged_attempts(max_attempts)
344                    .name(format!("{}-hedge", backend.name))
345                    .build()
346            };
347            builder = builder.backend_layer(layer);
348        }
349
350        // Concurrency limit
351        if let Some(cc) = &backend.concurrency {
352            tracing::info!(
353                backend = %backend.name,
354                max = cc.max_concurrent,
355                "Applying concurrency limit"
356            );
357            builder =
358                builder.backend_layer(tower::limit::ConcurrencyLimitLayer::new(cc.max_concurrent));
359        }
360
361        // Rate limit
362        if let Some(rl) = &backend.rate_limit {
363            tracing::info!(
364                backend = %backend.name,
365                requests = rl.requests,
366                period_seconds = rl.period_seconds,
367                "Applying rate limit"
368            );
369            let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
370                .limit_for_period(rl.requests)
371                .refresh_period(Duration::from_secs(rl.period_seconds))
372                .name(format!("{}-ratelimit", backend.name))
373                .build();
374            builder = builder.backend_layer(layer);
375        }
376
377        // Timeout
378        if let Some(timeout) = &backend.timeout {
379            tracing::info!(
380                backend = %backend.name,
381                seconds = timeout.seconds,
382                "Applying timeout"
383            );
384            builder =
385                builder.backend_layer(TimeoutLayer::new(Duration::from_secs(timeout.seconds)));
386        }
387
388        // Circuit breaker
389        if let Some(cb) = &backend.circuit_breaker {
390            tracing::info!(
391                backend = %backend.name,
392                failure_rate = cb.failure_rate_threshold,
393                wait_seconds = cb.wait_duration_seconds,
394                "Applying circuit breaker"
395            );
396            let (layer, handle) = tower_resilience::circuitbreaker::CircuitBreakerLayer::builder()
397                .failure_rate_threshold(cb.failure_rate_threshold)
398                .minimum_number_of_calls(cb.minimum_calls)
399                .wait_duration_in_open(Duration::from_secs(cb.wait_duration_seconds))
400                .permitted_calls_in_half_open(cb.permitted_calls_in_half_open)
401                .name(format!("{}-cb", backend.name))
402                .build_with_handle();
403            cb_handles.insert(backend.name.clone(), handle);
404            builder = builder.backend_layer(layer);
405        }
406
407        // Outlier detection (outermost -- observes errors after all other middleware)
408        if let Some(od) = &backend.outlier_detection
409            && let Some(ref detector) = outlier_detector
410        {
411            tracing::info!(
412                backend = %backend.name,
413                consecutive_errors = od.consecutive_errors,
414                base_ejection_seconds = od.base_ejection_seconds,
415                max_ejection_percent = od.max_ejection_percent,
416                "Applying outlier detection"
417            );
418            let layer = crate::outlier::OutlierDetectionLayer::new(
419                backend.name.clone(),
420                od.clone(),
421                detector.clone(),
422            );
423            builder = builder.backend_layer(layer);
424        }
425    }
426
427    let result = builder.build().await?;
428
429    if !result.skipped.is_empty() {
430        for s in &result.skipped {
431            tracing::warn!("Skipped backend: {s}");
432        }
433    }
434
435    Ok((result.proxy, cb_handles))
436}
437
438/// Build a scope-enforcement layer from configured OAuth `required_scopes`.
439///
440/// Returns `None` when no scopes are required (the layer would be a no-op).
441/// Otherwise returns a [`ScopeEnforcementLayer`](tower_mcp::oauth::ScopeEnforcementLayer)
442/// whose default policy requires *all* of `required_scopes` to be present in the
443/// token (AND semantics) for every request.
444#[cfg(feature = "oauth")]
445fn oauth_scope_layer(
446    required_scopes: &[String],
447) -> Option<tower_mcp::oauth::ScopeEnforcementLayer> {
448    if required_scopes.is_empty() {
449        return None;
450    }
451    let policy = tower_mcp::oauth::ScopePolicy::new().default_scopes(
452        tower_mcp::oauth::ScopeRequirement::all(required_scopes.iter().cloned()),
453    );
454    Some(tower_mcp::oauth::ScopeEnforcementLayer::new(policy))
455}
456
457/// Build the MCP-level middleware stack around the proxy.
458fn build_middleware_stack(
459    config: &ProxyConfig,
460    proxy: McpProxy,
461) -> Result<(
462    BoxCloneService<RouterRequest, RouterResponse, Infallible>,
463    Option<cache::CacheHandle>,
464)> {
465    let mut service: BoxCloneService<RouterRequest, RouterResponse, Infallible> =
466        BoxCloneService::new(proxy);
467    let mut cache_handle: Option<cache::CacheHandle> = None;
468
469    // Argument injection (innermost -- merges default/per-tool args into CallTool requests)
470    let injection_rules: Vec<_> = config
471        .backends
472        .iter()
473        .filter(|b| !b.default_args.is_empty() || !b.inject_args.is_empty())
474        .map(|b| {
475            let namespace = format!("{}{}", b.name, config.proxy.separator);
476            tracing::info!(
477                backend = %b.name,
478                default_args = b.default_args.len(),
479                tool_rules = b.inject_args.len(),
480                "Applying argument injection"
481            );
482            crate::inject::InjectionRules::new(
483                namespace,
484                b.default_args.clone(),
485                b.inject_args.clone(),
486            )
487        })
488        .collect();
489
490    if !injection_rules.is_empty() {
491        service = BoxCloneService::new(crate::inject::InjectArgsService::new(
492            service,
493            injection_rules,
494        ));
495    }
496
497    // Parameter overrides (after inject, before filter -- hides/renames tool params)
498    let param_overrides: Vec<_> = config
499        .backends
500        .iter()
501        .filter(|b| !b.param_overrides.is_empty())
502        .flat_map(|b| {
503            let namespace = format!("{}{}", b.name, config.proxy.separator);
504            tracing::info!(
505                backend = %b.name,
506                overrides = b.param_overrides.len(),
507                "Applying parameter overrides"
508            );
509            b.param_overrides
510                .iter()
511                .map(move |c| crate::param_override::ToolOverride::new(&namespace, c))
512        })
513        .collect();
514
515    if !param_overrides.is_empty() {
516        service = BoxCloneService::new(crate::param_override::ParamOverrideService::new(
517            service,
518            param_overrides,
519        ));
520    }
521
522    // Canary routing (rewrites requests from primary to canary namespace based on weight)
523    let canary_mappings: std::collections::HashMap<String, (String, u32, u32)> = config
524        .backends
525        .iter()
526        .filter_map(|b| {
527            b.canary_of.as_ref().map(|primary_name| {
528                // Find the primary backend's weight
529                let primary_weight = config
530                    .backends
531                    .iter()
532                    .find(|p| p.name == *primary_name)
533                    .map(|p| p.weight)
534                    .unwrap_or(100);
535                (
536                    primary_name.clone(),
537                    (b.name.clone(), primary_weight, b.weight),
538                )
539            })
540        })
541        .collect();
542
543    if !canary_mappings.is_empty() {
544        for (primary, (canary, pw, cw)) in &canary_mappings {
545            tracing::info!(
546                primary = %primary,
547                canary = %canary,
548                primary_weight = pw,
549                canary_weight = cw,
550                "Enabling canary routing"
551            );
552        }
553        service = BoxCloneService::new(crate::canary::CanaryService::new(
554            service,
555            canary_mappings,
556            &config.proxy.separator,
557        ));
558    }
559
560    // Failover routing (deterministic fallback on primary error)
561    // Collect failover backends grouped by primary, sorted by priority (ascending).
562    let mut failover_groups: std::collections::HashMap<String, Vec<(u32, String)>> =
563        std::collections::HashMap::new();
564    for b in &config.backends {
565        if let Some(ref primary) = b.failover_for {
566            failover_groups
567                .entry(primary.clone())
568                .or_default()
569                .push((b.priority, b.name.clone()));
570        }
571    }
572    // Sort each group by priority (lower = preferred)
573    let failover_mappings: std::collections::HashMap<String, Vec<String>> = failover_groups
574        .into_iter()
575        .map(|(primary, mut backends)| {
576            backends.sort_by_key(|(priority, _)| *priority);
577            let names: Vec<String> = backends.into_iter().map(|(_, name)| name).collect();
578            (primary, names)
579        })
580        .collect();
581
582    if !failover_mappings.is_empty() {
583        for (primary, failovers) in &failover_mappings {
584            tracing::info!(
585                primary = %primary,
586                failovers = ?failovers,
587                "Enabling failover routing"
588            );
589        }
590        service = BoxCloneService::new(crate::failover::FailoverService::new(
591            service,
592            failover_mappings,
593            &config.proxy.separator,
594        ));
595    }
596
597    // Traffic mirroring (sends cloned requests through the proxy)
598    let mirror_mappings: std::collections::HashMap<String, (String, u32)> = config
599        .backends
600        .iter()
601        .filter_map(|b| {
602            b.mirror_of
603                .as_ref()
604                .map(|source| (source.clone(), (b.name.clone(), b.mirror_percent)))
605        })
606        .collect();
607
608    if !mirror_mappings.is_empty() {
609        for (source, (mirror, pct)) in &mirror_mappings {
610            tracing::info!(
611                source = %source,
612                mirror = %mirror,
613                percent = pct,
614                "Enabling traffic mirroring"
615            );
616        }
617        service = BoxCloneService::new(crate::mirror::MirrorService::new(
618            service,
619            mirror_mappings,
620            &config.proxy.separator,
621        ));
622    }
623
624    // Response caching
625    let cache_configs: Vec<_> = config
626        .backends
627        .iter()
628        .filter_map(|b| {
629            b.cache
630                .as_ref()
631                .map(|c| (format!("{}{}", b.name, config.proxy.separator), c))
632        })
633        .collect();
634
635    if !cache_configs.is_empty() {
636        for (ns, cfg) in &cache_configs {
637            tracing::info!(
638                backend = %ns.trim_end_matches(&config.proxy.separator),
639                resource_ttl = cfg.resource_ttl_seconds,
640                tool_ttl = cfg.tool_ttl_seconds,
641                max_entries = cfg.max_entries,
642                "Applying response cache"
643            );
644        }
645        let (cache_svc, handle) = cache::CacheService::new(service, cache_configs, &config.cache);
646        service = BoxCloneService::new(cache_svc);
647        cache_handle = Some(handle);
648    }
649
650    // Request coalescing
651    if config.performance.coalesce_requests {
652        tracing::info!("Request coalescing enabled");
653        service = BoxCloneService::new(coalesce::CoalesceService::new(service));
654    }
655
656    // Request validation
657    if config.security.max_argument_size.is_some() {
658        let validation = ValidationConfig {
659            max_argument_size: config.security.max_argument_size,
660        };
661        if let Some(max) = validation.max_argument_size {
662            tracing::info!(max_argument_size = max, "Applying request validation");
663        }
664        service = BoxCloneService::new(ValidationService::new(service, validation));
665    }
666
667    // Static capability filtering
668    let filters: Vec<_> = config
669        .backends
670        .iter()
671        .filter_map(|b| b.build_filter(&config.proxy.separator).transpose())
672        .collect::<anyhow::Result<Vec<_>>>()?;
673
674    if !filters.is_empty() {
675        for f in &filters {
676            tracing::info!(
677                backend = %f.namespace.trim_end_matches(&config.proxy.separator),
678                tool_filter = ?f.tool_filter,
679                resource_filter = ?f.resource_filter,
680                prompt_filter = ?f.prompt_filter,
681                "Applying capability filter"
682            );
683        }
684        service = BoxCloneService::new(CapabilityFilterService::new(service, filters));
685    }
686
687    // Search-mode filtering: hide all tools except proxy/ namespace
688    if config.proxy.tool_exposure == crate::config::ToolExposure::Search {
689        let prefix = format!("proxy{}", config.proxy.separator);
690        tracing::info!(
691            prefix = %prefix,
692            "Search mode: ListTools will only show proxy/ namespace tools"
693        );
694        service =
695            BoxCloneService::new(crate::filter::SearchModeFilterService::new(service, prefix));
696    }
697
698    // Tool aliasing
699    let alias_mappings: Vec<_> = config
700        .backends
701        .iter()
702        .flat_map(|b| {
703            let ns = format!("{}{}", b.name, config.proxy.separator);
704            b.aliases
705                .iter()
706                .map(move |a| (ns.clone(), a.from.clone(), a.to.clone()))
707        })
708        .collect();
709
710    if let Some(alias_map) = alias::AliasMap::new(alias_mappings) {
711        let count = alias_map.forward.len();
712        tracing::info!(aliases = count, "Applying tool aliases");
713        service = BoxCloneService::new(alias::AliasService::new(service, alias_map));
714    }
715
716    // Composite tools (fan-out to multiple backend tools)
717    if !config.composite_tools.is_empty() {
718        let count = config.composite_tools.len();
719        tracing::info!(composite_tools = count, "Applying composite tool fan-out");
720        service = BoxCloneService::new(crate::composite::CompositeService::new(
721            service,
722            config.composite_tools.clone(),
723        ));
724    }
725
726    // Bearer token scoping (per-token allow/deny lists)
727    #[cfg(feature = "oauth")]
728    if matches!(
729        &config.auth,
730        Some(AuthConfig::Bearer {
731            scoped_tokens,
732            ..
733        }) if !scoped_tokens.is_empty()
734    ) {
735        tracing::info!("Enabling bearer token scoping middleware");
736        service = BoxCloneService::new(crate::bearer_scope::BearerScopingService::new(service));
737    }
738
739    // RBAC (JWT auth only)
740    #[cfg(feature = "oauth")]
741    {
742        let rbac_config = match &config.auth {
743            Some(
744                AuthConfig::Jwt {
745                    roles,
746                    role_mapping: Some(mapping),
747                    ..
748                }
749                | AuthConfig::OAuth {
750                    roles,
751                    role_mapping: Some(mapping),
752                    ..
753                },
754            ) if !roles.is_empty() => {
755                tracing::info!(
756                    roles = roles.len(),
757                    claim = %mapping.claim,
758                    "Enabling RBAC"
759                );
760                Some(RbacConfig::new(roles, mapping))
761            }
762            _ => None,
763        };
764
765        if let Some(rbac) = rbac_config {
766            service = BoxCloneService::new(RbacService::new(service, rbac));
767        }
768
769        // OAuth `required_scopes` enforcement: a coarse global gate that rejects
770        // any token missing one of the configured scopes (AND semantics). Runs
771        // outside RBAC so a token lacking the required scopes is denied for every
772        // operation, including `tools/list`. Reads TokenClaims injected by the
773        // OAuth auth layer; requests without claims pass through (already rejected
774        // upstream by the HTTP auth layer when auth is enabled).
775        let required_scopes: &[String] = match &config.auth {
776            Some(AuthConfig::OAuth {
777                required_scopes, ..
778            }) => required_scopes,
779            _ => &[],
780        };
781        if let Some(layer) = oauth_scope_layer(required_scopes) {
782            tracing::info!(
783                scopes = ?required_scopes,
784                "Enabling OAuth required_scopes enforcement"
785            );
786            service = BoxCloneService::new(tower::Layer::layer(&layer, service));
787        }
788
789        // Token passthrough (inject ClientToken for forward_auth backends)
790        let forward_namespaces: std::collections::HashSet<String> = config
791            .backends
792            .iter()
793            .filter(|b| b.forward_auth)
794            .map(|b| format!("{}{}", b.name, config.proxy.separator))
795            .collect();
796
797        if !forward_namespaces.is_empty() {
798            tracing::info!(
799                backends = ?forward_namespaces,
800                "Enabling token passthrough for forward_auth backends"
801            );
802            service = BoxCloneService::new(crate::token::TokenPassthroughService::new(
803                service,
804                forward_namespaces,
805            ));
806        }
807    }
808
809    // Metrics
810    #[cfg(feature = "metrics")]
811    if config.observability.metrics.enabled {
812        service = BoxCloneService::new(crate::metrics::MetricsService::new(service));
813    }
814
815    // Structured access logging
816    if config.observability.access_log.enabled {
817        tracing::info!("Access logging enabled (target: mcp::access)");
818        service = BoxCloneService::new(crate::access_log::AccessLogService::new(
819            service,
820            &config.proxy.separator,
821        ));
822    }
823
824    // Audit logging
825    if config.observability.audit {
826        tracing::info!("Audit logging enabled (target: mcp::audit)");
827        let audited = tower::Layer::layer(&tower_mcp::AuditLayer::new(), service);
828        service = BoxCloneService::new(tower_mcp::CatchError::new(audited));
829    }
830
831    // Global rate limit (outermost -- protects entire proxy)
832    if let Some(ref rl) = config.proxy.rate_limit {
833        tracing::info!(
834            requests = rl.requests,
835            period_seconds = rl.period_seconds,
836            "Applying global rate limit"
837        );
838        let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
839            .limit_for_period(rl.requests)
840            .refresh_period(Duration::from_secs(rl.period_seconds))
841            .name("global-ratelimit")
842            .build();
843        let limited = tower::Layer::layer(&layer, service);
844        service = BoxCloneService::new(tower_mcp::CatchError::new(limited));
845    }
846
847    Ok((service, cache_handle))
848}
849
850/// Apply inbound authentication middleware to the router.
851async fn apply_auth(config: &ProxyConfig, router: Router) -> Result<Router> {
852    let router = if let Some(auth) = &config.auth {
853        match auth {
854            AuthConfig::Bearer {
855                tokens,
856                scoped_tokens,
857            } => {
858                let total = tokens.len() + scoped_tokens.len();
859                if scoped_tokens.is_empty() {
860                    // Simple bearer auth: use StaticBearerValidator
861                    tracing::info!(token_count = total, "Enabling bearer token auth");
862                    let validator = StaticBearerValidator::new(tokens.iter().cloned());
863                    let layer = AuthLayer::new(validator);
864                    router.layer(layer)
865                } else {
866                    // Scoped bearer auth: use custom layer that injects TokenClaims
867                    #[cfg(feature = "oauth")]
868                    {
869                        tracing::info!(
870                            token_count = total,
871                            scoped = scoped_tokens.len(),
872                            "Enabling bearer token auth with per-token scoping"
873                        );
874                        let layer =
875                            crate::bearer_scope::ScopedBearerAuthLayer::new(tokens, scoped_tokens);
876                        router.layer(layer)
877                    }
878                    #[cfg(not(feature = "oauth"))]
879                    {
880                        anyhow::bail!(
881                            "Per-token tool scoping requires the 'oauth' feature. \
882                             Rebuild with: cargo install mcp-proxy --features oauth"
883                        );
884                    }
885                }
886            }
887            #[cfg(feature = "oauth")]
888            AuthConfig::Jwt {
889                issuer,
890                audience,
891                jwks_uri,
892                ..
893            } => {
894                tracing::info!(
895                    issuer = %issuer,
896                    audience = %audience,
897                    jwks_uri = %jwks_uri,
898                    "Enabling JWT auth (JWKS)"
899                );
900                let validator = tower_mcp::oauth::JwksValidator::builder(jwks_uri)
901                    .expected_audience(audience)
902                    .expected_issuer(issuer)
903                    .build()
904                    .await
905                    .context("building JWKS validator")?;
906
907                let addr = format!(
908                    "http://{}:{}",
909                    config.proxy.listen.host, config.proxy.listen.port
910                );
911                let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
912                    .authorization_server(issuer);
913
914                let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
915                router.layer(layer)
916            }
917            #[cfg(not(feature = "oauth"))]
918            AuthConfig::Jwt { .. } => {
919                anyhow::bail!(
920                    "JWT auth requires the 'oauth' feature. Rebuild with: cargo install mcp-proxy --features oauth"
921                );
922            }
923            #[cfg(feature = "oauth")]
924            AuthConfig::OAuth {
925                issuer,
926                audience,
927                token_validation,
928                jwks_uri,
929                introspection_endpoint,
930                client_id,
931                client_secret,
932                ..
933            } => {
934                use crate::config::TokenValidationStrategy;
935
936                tracing::info!(
937                    issuer = %issuer,
938                    audience = %audience,
939                    strategy = ?token_validation,
940                    "Enabling OAuth 2.1 auth"
941                );
942
943                // Auto-discover endpoints from issuer if not overridden
944                let discovered = crate::introspection::discover_auth_server(issuer)
945                    .await
946                    .context("discovering OAuth authorization server")?;
947
948                let effective_jwks_uri = jwks_uri
949                    .as_deref()
950                    .or(discovered.jwks_uri.as_deref())
951                    .ok_or_else(|| {
952                        anyhow::anyhow!(
953                            "JWKS URI not found via discovery and not configured manually"
954                        )
955                    })?;
956
957                let effective_introspection = introspection_endpoint
958                    .as_deref()
959                    .or(discovered.introspection_endpoint.as_deref());
960
961                let addr = format!(
962                    "http://{}:{}",
963                    config.proxy.listen.host, config.proxy.listen.port
964                );
965                let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
966                    .authorization_server(issuer);
967
968                match token_validation {
969                    TokenValidationStrategy::Jwt => {
970                        let validator =
971                            tower_mcp::oauth::JwksValidator::builder(effective_jwks_uri)
972                                .expected_audience(audience)
973                                .expected_issuer(issuer)
974                                .build()
975                                .await
976                                .context("building JWKS validator")?;
977                        let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
978                        router.layer(layer)
979                    }
980                    TokenValidationStrategy::Introspection => {
981                        let endpoint = effective_introspection.ok_or_else(|| {
982                            anyhow::anyhow!(
983                                "introspection endpoint not found via discovery and not configured"
984                            )
985                        })?;
986                        let validator = crate::introspection::IntrospectionValidator::new(
987                            endpoint,
988                            client_id.as_deref().unwrap(),
989                            client_secret.as_deref().unwrap(),
990                        )
991                        .expected_audience(audience);
992                        let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
993                        router.layer(layer)
994                    }
995                    TokenValidationStrategy::Both => {
996                        let endpoint = effective_introspection.ok_or_else(|| {
997                            anyhow::anyhow!(
998                                "introspection endpoint not found via discovery and not configured"
999                            )
1000                        })?;
1001                        let jwt_validator =
1002                            tower_mcp::oauth::JwksValidator::builder(effective_jwks_uri)
1003                                .expected_audience(audience)
1004                                .expected_issuer(issuer)
1005                                .build()
1006                                .await
1007                                .context("building JWKS validator")?;
1008                        let introspection_validator =
1009                            crate::introspection::IntrospectionValidator::new(
1010                                endpoint,
1011                                client_id.as_deref().unwrap(),
1012                                client_secret.as_deref().unwrap(),
1013                            )
1014                            .expected_audience(audience);
1015                        let fallback = crate::introspection::FallbackValidator::new(
1016                            jwt_validator,
1017                            introspection_validator,
1018                        );
1019                        let layer = tower_mcp::oauth::OAuthLayer::new(fallback, metadata);
1020                        router.layer(layer)
1021                    }
1022                }
1023            }
1024            #[cfg(not(feature = "oauth"))]
1025            AuthConfig::OAuth { .. } => {
1026                anyhow::bail!(
1027                    "OAuth auth requires the 'oauth' feature. Rebuild with: cargo install mcp-proxy --features oauth"
1028                );
1029            }
1030        }
1031    } else {
1032        router
1033    };
1034    Ok(router)
1035}
1036
1037/// Wait for SIGTERM or SIGINT, then log and return.
1038pub async fn shutdown_signal(timeout: Duration) {
1039    let ctrl_c = tokio::signal::ctrl_c();
1040    #[cfg(unix)]
1041    {
1042        let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
1043            .expect("SIGTERM handler");
1044        tokio::select! {
1045            _ = ctrl_c => {},
1046            _ = sigterm.recv() => {},
1047        }
1048    }
1049    #[cfg(not(unix))]
1050    {
1051        ctrl_c.await.ok();
1052    }
1053    tracing::info!(
1054        timeout_seconds = timeout.as_secs(),
1055        "Shutdown signal received, draining connections"
1056    );
1057}
1058
1059#[cfg(all(test, feature = "oauth"))]
1060mod scope_enforcement_tests {
1061    use std::collections::HashMap;
1062
1063    use tower::{Layer, Service};
1064    use tower_mcp::oauth::token::TokenClaims;
1065    use tower_mcp::protocol::{CallToolParams, McpRequest, RequestId};
1066    use tower_mcp::router::Extensions;
1067
1068    use super::oauth_scope_layer;
1069    use crate::test_util::MockService;
1070
1071    /// Build a `tools/call` request, optionally carrying a token with `scope`.
1072    fn call_with_scope(scope: Option<&str>) -> tower_mcp::RouterRequest {
1073        let mut extensions = Extensions::new();
1074        if let Some(scope) = scope {
1075            extensions.insert(TokenClaims {
1076                sub: Some("user".into()),
1077                iss: None,
1078                aud: None,
1079                exp: None,
1080                scope: Some(scope.to_string()),
1081                client_id: None,
1082                extra: HashMap::new(),
1083            });
1084        }
1085        tower_mcp::RouterRequest {
1086            id: RequestId::Number(1),
1087            inner: McpRequest::CallTool(CallToolParams {
1088                name: "fs/read".into(),
1089                arguments: serde_json::json!({}),
1090                meta: None,
1091                task: None,
1092            }),
1093            extensions,
1094        }
1095    }
1096
1097    #[test]
1098    fn no_required_scopes_yields_no_layer() {
1099        assert!(oauth_scope_layer(&[]).is_none());
1100    }
1101
1102    #[tokio::test]
1103    async fn token_missing_required_scope_is_rejected() {
1104        let required = vec!["mcp:access".to_string()];
1105        let layer = oauth_scope_layer(&required).expect("layer for non-empty scopes");
1106        let mut svc = layer.layer(MockService::with_tools(&["fs/read"]));
1107
1108        // Token carries a different scope -> missing the required one.
1109        let resp = svc
1110            .call(call_with_scope(Some("other:scope")))
1111            .await
1112            .unwrap();
1113        let err = resp.inner.unwrap_err();
1114        assert!(
1115            err.message.to_lowercase().contains("scope"),
1116            "expected insufficient-scope error, got: {}",
1117            err.message
1118        );
1119    }
1120
1121    #[tokio::test]
1122    async fn token_with_all_required_scopes_is_allowed() {
1123        let required = vec!["mcp:access".to_string(), "mcp:read".to_string()];
1124        let layer = oauth_scope_layer(&required).expect("layer for non-empty scopes");
1125        let mut svc = layer.layer(MockService::with_tools(&["fs/read"]));
1126
1127        let resp = svc
1128            .call(call_with_scope(Some("mcp:access mcp:read mcp:extra")))
1129            .await
1130            .unwrap();
1131        assert!(
1132            resp.inner.is_ok(),
1133            "token carrying all required scopes should be allowed"
1134        );
1135    }
1136}