1use std::collections::HashMap;
4use std::convert::Infallible;
5use std::time::Duration;
6
7use anyhow::{Context, Result};
8use axum::Router;
9use tokio::process::Command;
10use tower::timeout::TimeoutLayer;
11use tower::util::BoxCloneService;
12use tower_mcp::SessionHandle;
13use tower_mcp::auth::{AuthLayer, StaticBearerValidator};
14use tower_mcp::client::StdioClientTransport;
15use tower_mcp::proxy::McpProxy;
16use tower_mcp::{RouterRequest, RouterResponse};
17
18use crate::admin::BackendMeta;
19use crate::alias;
20use crate::cache;
21use crate::coalesce;
22use crate::config::{AuthConfig, ProxyConfig, TransportType};
23use crate::filter::CapabilityFilterService;
24#[cfg(feature = "oauth")]
25use crate::rbac::{RbacConfig, RbacService};
26use crate::validation::{ValidationConfig, ValidationService};
27
28pub struct Proxy {
30 router: Router,
31 session_handle: SessionHandle,
32 inner: McpProxy,
33 config: ProxyConfig,
34 #[cfg(feature = "discovery")]
35 discovery_index: Option<crate::discovery::SharedDiscoveryIndex>,
36}
37
38impl Proxy {
39 pub async fn from_config(config: ProxyConfig) -> Result<Self> {
45 let (mcp_proxy, cb_handles) = build_mcp_proxy(&config).await?;
46 let proxy_for_admin = mcp_proxy.clone();
47 let mut proxy_for_caller = mcp_proxy.clone();
48 let proxy_for_management = mcp_proxy.clone();
49
50 #[cfg(feature = "metrics")]
52 let metrics_handle = if config.observability.metrics.enabled {
53 tracing::info!("Prometheus metrics enabled at /admin/metrics");
54 let builder = metrics_exporter_prometheus::PrometheusBuilder::new();
55 let handle = builder
56 .install_recorder()
57 .context("installing Prometheus metrics recorder")?;
58 Some(handle)
59 } else {
60 None
61 };
62 #[cfg(not(feature = "metrics"))]
63 let metrics_handle = None;
64
65 let (service, cache_handle) = build_middleware_stack(&config, mcp_proxy)?;
66
67 let (router, session_handle) =
68 tower_mcp::transport::http::HttpTransport::from_service(service)
69 .into_router_with_handle();
70
71 let router = apply_auth(&config, router).await?;
73
74 let backend_meta: std::collections::HashMap<String, BackendMeta> = config
76 .backends
77 .iter()
78 .map(|b| {
79 (
80 b.name.clone(),
81 BackendMeta {
82 transport: format!("{:?}", b.transport).to_lowercase(),
83 },
84 )
85 })
86 .collect();
87
88 let admin_state = crate::admin::spawn_health_checker(
90 proxy_for_admin,
91 config.proxy.name.clone(),
92 config.proxy.version.clone(),
93 config.backends.len(),
94 backend_meta,
95 );
96 let router = router.nest(
97 "/admin",
98 crate::admin::admin_router(
99 admin_state.clone(),
100 metrics_handle,
101 session_handle.clone(),
102 cache_handle,
103 proxy_for_management,
104 &config,
105 config.source_path.clone(),
106 cb_handles,
107 ),
108 );
109 tracing::info!("Admin API enabled at /admin/backends");
110
111 #[cfg(feature = "discovery")]
113 let discovery_enabled = config.proxy.tool_discovery
114 || config.proxy.tool_exposure == crate::config::ToolExposure::Search;
115 #[cfg(feature = "discovery")]
116 let (discovery_index, discovery_tools) = if discovery_enabled {
117 let index =
118 crate::discovery::build_index(&mut proxy_for_caller, &config.proxy.separator).await;
119 let tools = crate::discovery::build_discovery_tools(index.clone());
120 (Some(index), Some(tools))
121 } else {
122 (None, None)
123 };
124 #[cfg(not(feature = "discovery"))]
125 let discovery_tools: Option<Vec<tower_mcp::Tool>> = None;
126
127 if let Err(e) = crate::admin_tools::register_admin_tools(
129 &proxy_for_caller,
130 admin_state,
131 session_handle.clone(),
132 &config,
133 discovery_tools,
134 )
135 .await
136 {
137 tracing::warn!("Failed to register admin tools: {e}");
138 } else {
139 tracing::info!("MCP admin tools registered under proxy/ namespace");
140 }
141
142 Ok(Self {
143 router,
144 session_handle,
145 inner: proxy_for_caller,
146 config,
147 #[cfg(feature = "discovery")]
148 discovery_index,
149 })
150 }
151
152 pub fn session_handle(&self) -> &SessionHandle {
154 &self.session_handle
155 }
156
157 pub fn mcp_proxy(&self) -> &McpProxy {
161 &self.inner
162 }
163
164 pub fn enable_hot_reload(&self, config_path: std::path::PathBuf) {
169 tracing::info!("Hot reload enabled, watching config file for changes");
170 crate::reload::spawn_config_watcher(
171 config_path,
172 self.inner.clone(),
173 #[cfg(feature = "discovery")]
174 self.discovery_index
175 .as_ref()
176 .map(|idx| (idx.clone(), self.config.proxy.separator.clone())),
177 );
178 }
179
180 pub fn into_router(self) -> (Router, SessionHandle) {
192 (self.router, self.session_handle)
193 }
194
195 pub async fn serve(self) -> Result<()> {
200 let addr = format!(
201 "{}:{}",
202 self.config.proxy.listen.host, self.config.proxy.listen.port
203 );
204
205 tracing::info!(listen = %addr, "Proxy ready");
206
207 let listener = tokio::net::TcpListener::bind(&addr)
208 .await
209 .with_context(|| format!("binding to {}", addr))?;
210
211 let shutdown_timeout = Duration::from_secs(self.config.proxy.shutdown_timeout_seconds);
212 axum::serve(listener, self.router)
213 .with_graceful_shutdown(shutdown_signal(shutdown_timeout))
214 .await
215 .context("server error")?;
216
217 tracing::info!("Proxy shut down");
218 Ok(())
219 }
220}
221
222pub type CbHandle = tower_resilience::circuitbreaker::CircuitBreakerHandle;
224
225async fn build_mcp_proxy(config: &ProxyConfig) -> Result<(McpProxy, HashMap<String, CbHandle>)> {
228 let mut builder = McpProxy::builder(&config.proxy.name, &config.proxy.version)
229 .separator(&config.proxy.separator);
230 let mut cb_handles: HashMap<String, CbHandle> = HashMap::new();
231
232 if let Some(instructions) = &config.proxy.instructions {
233 builder = builder.instructions(instructions);
234 }
235
236 let outlier_detector = {
239 let max_pct = config
240 .backends
241 .iter()
242 .filter_map(|b| b.outlier_detection.as_ref())
243 .map(|od| od.max_ejection_percent)
244 .max();
245 max_pct.map(crate::outlier::OutlierDetector::new)
246 };
247
248 for backend in &config.backends {
249 tracing::info!(name = %backend.name, transport = ?backend.transport, "Adding backend");
250
251 match backend.transport {
252 TransportType::Stdio => {
253 let command = backend.command.as_deref().unwrap();
254 let args: Vec<&str> = backend.args.iter().map(|s| s.as_str()).collect();
255
256 let mut cmd = Command::new(command);
257 cmd.args(&args);
258
259 for (key, value) in &backend.env {
260 cmd.env(key, value);
261 }
262
263 let transport = StdioClientTransport::spawn_command(&mut cmd)
264 .await
265 .with_context(|| format!("spawning backend '{}'", backend.name))?;
266
267 builder = builder.backend(&backend.name, transport).await;
268 }
269 TransportType::Http => {
270 let url = backend.url.as_deref().unwrap();
271 let mut transport = tower_mcp::client::HttpClientTransport::new(url);
272 if let Some(token) = &backend.bearer_token {
273 transport = transport.bearer_token(token);
274 }
275
276 builder = builder.backend(&backend.name, transport).await;
277 }
278 #[cfg(feature = "websocket")]
279 TransportType::Websocket => {
280 let url = backend.url.as_deref().unwrap();
281 tracing::info!(url = %url, "Connecting to WebSocket backend");
282 let transport = if let Some(token) = &backend.bearer_token {
283 crate::ws_transport::WebSocketClientTransport::connect_with_bearer_token(
284 url, token,
285 )
286 .await
287 .with_context(|| {
288 format!("connecting to WebSocket backend '{}'", backend.name)
289 })?
290 } else {
291 crate::ws_transport::WebSocketClientTransport::connect(url)
292 .await
293 .with_context(|| {
294 format!("connecting to WebSocket backend '{}'", backend.name)
295 })?
296 };
297
298 builder = builder.backend(&backend.name, transport).await;
299 }
300 #[cfg(not(feature = "websocket"))]
301 TransportType::Websocket => {
302 anyhow::bail!(
303 "WebSocket transport requires the 'websocket' feature. \
304 Rebuild with: cargo install mcp-proxy --features websocket"
305 );
306 }
307 }
308
309 if let Some(retry_cfg) = &backend.retry {
313 tracing::info!(
314 backend = %backend.name,
315 max_retries = retry_cfg.max_retries,
316 initial_backoff_ms = retry_cfg.initial_backoff_ms,
317 max_backoff_ms = retry_cfg.max_backoff_ms,
318 "Applying retry policy"
319 );
320 let layer = crate::retry::build_retry_layer(retry_cfg, &backend.name);
321 builder = builder.backend_layer(layer);
322 }
323
324 if let Some(hedge_cfg) = &backend.hedging {
326 let delay = Duration::from_millis(hedge_cfg.delay_ms);
327 let max_attempts = hedge_cfg.max_hedges + 1; tracing::info!(
329 backend = %backend.name,
330 delay_ms = hedge_cfg.delay_ms,
331 max_hedges = hedge_cfg.max_hedges,
332 "Applying request hedging"
333 );
334 let layer = if delay.is_zero() {
335 tower_resilience::hedge::HedgeLayer::builder()
336 .no_delay()
337 .max_hedged_attempts(max_attempts)
338 .name(format!("{}-hedge", backend.name))
339 .build()
340 } else {
341 tower_resilience::hedge::HedgeLayer::builder()
342 .delay(delay)
343 .max_hedged_attempts(max_attempts)
344 .name(format!("{}-hedge", backend.name))
345 .build()
346 };
347 builder = builder.backend_layer(layer);
348 }
349
350 if let Some(cc) = &backend.concurrency {
352 tracing::info!(
353 backend = %backend.name,
354 max = cc.max_concurrent,
355 "Applying concurrency limit"
356 );
357 builder =
358 builder.backend_layer(tower::limit::ConcurrencyLimitLayer::new(cc.max_concurrent));
359 }
360
361 if let Some(rl) = &backend.rate_limit {
363 tracing::info!(
364 backend = %backend.name,
365 requests = rl.requests,
366 period_seconds = rl.period_seconds,
367 "Applying rate limit"
368 );
369 let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
370 .limit_for_period(rl.requests)
371 .refresh_period(Duration::from_secs(rl.period_seconds))
372 .name(format!("{}-ratelimit", backend.name))
373 .build();
374 builder = builder.backend_layer(layer);
375 }
376
377 if let Some(timeout) = &backend.timeout {
379 tracing::info!(
380 backend = %backend.name,
381 seconds = timeout.seconds,
382 "Applying timeout"
383 );
384 builder =
385 builder.backend_layer(TimeoutLayer::new(Duration::from_secs(timeout.seconds)));
386 }
387
388 if let Some(cb) = &backend.circuit_breaker {
390 tracing::info!(
391 backend = %backend.name,
392 failure_rate = cb.failure_rate_threshold,
393 wait_seconds = cb.wait_duration_seconds,
394 "Applying circuit breaker"
395 );
396 let (layer, handle) = tower_resilience::circuitbreaker::CircuitBreakerLayer::builder()
397 .failure_rate_threshold(cb.failure_rate_threshold)
398 .minimum_number_of_calls(cb.minimum_calls)
399 .wait_duration_in_open(Duration::from_secs(cb.wait_duration_seconds))
400 .permitted_calls_in_half_open(cb.permitted_calls_in_half_open)
401 .name(format!("{}-cb", backend.name))
402 .build_with_handle();
403 cb_handles.insert(backend.name.clone(), handle);
404 builder = builder.backend_layer(layer);
405 }
406
407 if let Some(od) = &backend.outlier_detection
409 && let Some(ref detector) = outlier_detector
410 {
411 tracing::info!(
412 backend = %backend.name,
413 consecutive_errors = od.consecutive_errors,
414 base_ejection_seconds = od.base_ejection_seconds,
415 max_ejection_percent = od.max_ejection_percent,
416 "Applying outlier detection"
417 );
418 let layer = crate::outlier::OutlierDetectionLayer::new(
419 backend.name.clone(),
420 od.clone(),
421 detector.clone(),
422 );
423 builder = builder.backend_layer(layer);
424 }
425 }
426
427 let result = builder.build().await?;
428
429 if !result.skipped.is_empty() {
430 for s in &result.skipped {
431 tracing::warn!("Skipped backend: {s}");
432 }
433 }
434
435 Ok((result.proxy, cb_handles))
436}
437
438fn build_middleware_stack(
440 config: &ProxyConfig,
441 proxy: McpProxy,
442) -> Result<(
443 BoxCloneService<RouterRequest, RouterResponse, Infallible>,
444 Option<cache::CacheHandle>,
445)> {
446 let mut service: BoxCloneService<RouterRequest, RouterResponse, Infallible> =
447 BoxCloneService::new(proxy);
448 let mut cache_handle: Option<cache::CacheHandle> = None;
449
450 let injection_rules: Vec<_> = config
452 .backends
453 .iter()
454 .filter(|b| !b.default_args.is_empty() || !b.inject_args.is_empty())
455 .map(|b| {
456 let namespace = format!("{}{}", b.name, config.proxy.separator);
457 tracing::info!(
458 backend = %b.name,
459 default_args = b.default_args.len(),
460 tool_rules = b.inject_args.len(),
461 "Applying argument injection"
462 );
463 crate::inject::InjectionRules::new(
464 namespace,
465 b.default_args.clone(),
466 b.inject_args.clone(),
467 )
468 })
469 .collect();
470
471 if !injection_rules.is_empty() {
472 service = BoxCloneService::new(crate::inject::InjectArgsService::new(
473 service,
474 injection_rules,
475 ));
476 }
477
478 let param_overrides: Vec<_> = config
480 .backends
481 .iter()
482 .filter(|b| !b.param_overrides.is_empty())
483 .flat_map(|b| {
484 let namespace = format!("{}{}", b.name, config.proxy.separator);
485 tracing::info!(
486 backend = %b.name,
487 overrides = b.param_overrides.len(),
488 "Applying parameter overrides"
489 );
490 b.param_overrides
491 .iter()
492 .map(move |c| crate::param_override::ToolOverride::new(&namespace, c))
493 })
494 .collect();
495
496 if !param_overrides.is_empty() {
497 service = BoxCloneService::new(crate::param_override::ParamOverrideService::new(
498 service,
499 param_overrides,
500 ));
501 }
502
503 let canary_mappings: std::collections::HashMap<String, (String, u32, u32)> = config
505 .backends
506 .iter()
507 .filter_map(|b| {
508 b.canary_of.as_ref().map(|primary_name| {
509 let primary_weight = config
511 .backends
512 .iter()
513 .find(|p| p.name == *primary_name)
514 .map(|p| p.weight)
515 .unwrap_or(100);
516 (
517 primary_name.clone(),
518 (b.name.clone(), primary_weight, b.weight),
519 )
520 })
521 })
522 .collect();
523
524 if !canary_mappings.is_empty() {
525 for (primary, (canary, pw, cw)) in &canary_mappings {
526 tracing::info!(
527 primary = %primary,
528 canary = %canary,
529 primary_weight = pw,
530 canary_weight = cw,
531 "Enabling canary routing"
532 );
533 }
534 service = BoxCloneService::new(crate::canary::CanaryService::new(
535 service,
536 canary_mappings,
537 &config.proxy.separator,
538 ));
539 }
540
541 let mut failover_groups: std::collections::HashMap<String, Vec<(u32, String)>> =
544 std::collections::HashMap::new();
545 for b in &config.backends {
546 if let Some(ref primary) = b.failover_for {
547 failover_groups
548 .entry(primary.clone())
549 .or_default()
550 .push((b.priority, b.name.clone()));
551 }
552 }
553 let failover_mappings: std::collections::HashMap<String, Vec<String>> = failover_groups
555 .into_iter()
556 .map(|(primary, mut backends)| {
557 backends.sort_by_key(|(priority, _)| *priority);
558 let names: Vec<String> = backends.into_iter().map(|(_, name)| name).collect();
559 (primary, names)
560 })
561 .collect();
562
563 if !failover_mappings.is_empty() {
564 for (primary, failovers) in &failover_mappings {
565 tracing::info!(
566 primary = %primary,
567 failovers = ?failovers,
568 "Enabling failover routing"
569 );
570 }
571 service = BoxCloneService::new(crate::failover::FailoverService::new(
572 service,
573 failover_mappings,
574 &config.proxy.separator,
575 ));
576 }
577
578 let mirror_mappings: std::collections::HashMap<String, (String, u32)> = config
580 .backends
581 .iter()
582 .filter_map(|b| {
583 b.mirror_of
584 .as_ref()
585 .map(|source| (source.clone(), (b.name.clone(), b.mirror_percent)))
586 })
587 .collect();
588
589 if !mirror_mappings.is_empty() {
590 for (source, (mirror, pct)) in &mirror_mappings {
591 tracing::info!(
592 source = %source,
593 mirror = %mirror,
594 percent = pct,
595 "Enabling traffic mirroring"
596 );
597 }
598 service = BoxCloneService::new(crate::mirror::MirrorService::new(
599 service,
600 mirror_mappings,
601 &config.proxy.separator,
602 ));
603 }
604
605 let cache_configs: Vec<_> = config
607 .backends
608 .iter()
609 .filter_map(|b| {
610 b.cache
611 .as_ref()
612 .map(|c| (format!("{}{}", b.name, config.proxy.separator), c))
613 })
614 .collect();
615
616 if !cache_configs.is_empty() {
617 for (ns, cfg) in &cache_configs {
618 tracing::info!(
619 backend = %ns.trim_end_matches(&config.proxy.separator),
620 resource_ttl = cfg.resource_ttl_seconds,
621 tool_ttl = cfg.tool_ttl_seconds,
622 max_entries = cfg.max_entries,
623 "Applying response cache"
624 );
625 }
626 let (cache_svc, handle) = cache::CacheService::new(service, cache_configs, &config.cache);
627 service = BoxCloneService::new(cache_svc);
628 cache_handle = Some(handle);
629 }
630
631 if config.performance.coalesce_requests {
633 tracing::info!("Request coalescing enabled");
634 service = BoxCloneService::new(coalesce::CoalesceService::new(service));
635 }
636
637 if config.security.max_argument_size.is_some() {
639 let validation = ValidationConfig {
640 max_argument_size: config.security.max_argument_size,
641 };
642 if let Some(max) = validation.max_argument_size {
643 tracing::info!(max_argument_size = max, "Applying request validation");
644 }
645 service = BoxCloneService::new(ValidationService::new(service, validation));
646 }
647
648 let filters: Vec<_> = config
650 .backends
651 .iter()
652 .filter_map(|b| b.build_filter(&config.proxy.separator).transpose())
653 .collect::<anyhow::Result<Vec<_>>>()?;
654
655 if !filters.is_empty() {
656 for f in &filters {
657 tracing::info!(
658 backend = %f.namespace.trim_end_matches(&config.proxy.separator),
659 tool_filter = ?f.tool_filter,
660 resource_filter = ?f.resource_filter,
661 prompt_filter = ?f.prompt_filter,
662 "Applying capability filter"
663 );
664 }
665 service = BoxCloneService::new(CapabilityFilterService::new(service, filters));
666 }
667
668 if config.proxy.tool_exposure == crate::config::ToolExposure::Search {
670 let prefix = format!("proxy{}", config.proxy.separator);
671 tracing::info!(
672 prefix = %prefix,
673 "Search mode: ListTools will only show proxy/ namespace tools"
674 );
675 service =
676 BoxCloneService::new(crate::filter::SearchModeFilterService::new(service, prefix));
677 }
678
679 let alias_mappings: Vec<_> = config
681 .backends
682 .iter()
683 .flat_map(|b| {
684 let ns = format!("{}{}", b.name, config.proxy.separator);
685 b.aliases
686 .iter()
687 .map(move |a| (ns.clone(), a.from.clone(), a.to.clone()))
688 })
689 .collect();
690
691 if let Some(alias_map) = alias::AliasMap::new(alias_mappings) {
692 let count = alias_map.forward.len();
693 tracing::info!(aliases = count, "Applying tool aliases");
694 service = BoxCloneService::new(alias::AliasService::new(service, alias_map));
695 }
696
697 if !config.composite_tools.is_empty() {
699 let count = config.composite_tools.len();
700 tracing::info!(composite_tools = count, "Applying composite tool fan-out");
701 service = BoxCloneService::new(crate::composite::CompositeService::new(
702 service,
703 config.composite_tools.clone(),
704 ));
705 }
706
707 #[cfg(feature = "oauth")]
709 if matches!(
710 &config.auth,
711 Some(AuthConfig::Bearer {
712 scoped_tokens,
713 ..
714 }) if !scoped_tokens.is_empty()
715 ) {
716 tracing::info!("Enabling bearer token scoping middleware");
717 service = BoxCloneService::new(crate::bearer_scope::BearerScopingService::new(service));
718 }
719
720 #[cfg(feature = "oauth")]
722 {
723 let rbac_config = match &config.auth {
724 Some(
725 AuthConfig::Jwt {
726 roles,
727 role_mapping: Some(mapping),
728 ..
729 }
730 | AuthConfig::OAuth {
731 roles,
732 role_mapping: Some(mapping),
733 ..
734 },
735 ) if !roles.is_empty() => {
736 tracing::info!(
737 roles = roles.len(),
738 claim = %mapping.claim,
739 "Enabling RBAC"
740 );
741 Some(RbacConfig::new(roles, mapping))
742 }
743 _ => None,
744 };
745
746 if let Some(rbac) = rbac_config {
747 service = BoxCloneService::new(RbacService::new(service, rbac));
748 }
749
750 let forward_namespaces: std::collections::HashSet<String> = config
752 .backends
753 .iter()
754 .filter(|b| b.forward_auth)
755 .map(|b| format!("{}{}", b.name, config.proxy.separator))
756 .collect();
757
758 if !forward_namespaces.is_empty() {
759 tracing::info!(
760 backends = ?forward_namespaces,
761 "Enabling token passthrough for forward_auth backends"
762 );
763 service = BoxCloneService::new(crate::token::TokenPassthroughService::new(
764 service,
765 forward_namespaces,
766 ));
767 }
768 }
769
770 #[cfg(feature = "metrics")]
772 if config.observability.metrics.enabled {
773 service = BoxCloneService::new(crate::metrics::MetricsService::new(service));
774 }
775
776 if config.observability.access_log.enabled {
778 tracing::info!("Access logging enabled (target: mcp::access)");
779 service = BoxCloneService::new(crate::access_log::AccessLogService::new(
780 service,
781 &config.proxy.separator,
782 ));
783 }
784
785 if config.observability.audit {
787 tracing::info!("Audit logging enabled (target: mcp::audit)");
788 let audited = tower::Layer::layer(&tower_mcp::AuditLayer::new(), service);
789 service = BoxCloneService::new(tower_mcp::CatchError::new(audited));
790 }
791
792 if let Some(ref rl) = config.proxy.rate_limit {
794 tracing::info!(
795 requests = rl.requests,
796 period_seconds = rl.period_seconds,
797 "Applying global rate limit"
798 );
799 let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
800 .limit_for_period(rl.requests)
801 .refresh_period(Duration::from_secs(rl.period_seconds))
802 .name("global-ratelimit")
803 .build();
804 let limited = tower::Layer::layer(&layer, service);
805 service = BoxCloneService::new(tower_mcp::CatchError::new(limited));
806 }
807
808 Ok((service, cache_handle))
809}
810
811async fn apply_auth(config: &ProxyConfig, router: Router) -> Result<Router> {
813 let router = if let Some(auth) = &config.auth {
814 match auth {
815 AuthConfig::Bearer {
816 tokens,
817 scoped_tokens,
818 } => {
819 let total = tokens.len() + scoped_tokens.len();
820 if scoped_tokens.is_empty() {
821 tracing::info!(token_count = total, "Enabling bearer token auth");
823 let validator = StaticBearerValidator::new(tokens.iter().cloned());
824 let layer = AuthLayer::new(validator);
825 router.layer(layer)
826 } else {
827 #[cfg(feature = "oauth")]
829 {
830 tracing::info!(
831 token_count = total,
832 scoped = scoped_tokens.len(),
833 "Enabling bearer token auth with per-token scoping"
834 );
835 let layer =
836 crate::bearer_scope::ScopedBearerAuthLayer::new(tokens, scoped_tokens);
837 router.layer(layer)
838 }
839 #[cfg(not(feature = "oauth"))]
840 {
841 anyhow::bail!(
842 "Per-token tool scoping requires the 'oauth' feature. \
843 Rebuild with: cargo install mcp-proxy --features oauth"
844 );
845 }
846 }
847 }
848 #[cfg(feature = "oauth")]
849 AuthConfig::Jwt {
850 issuer,
851 audience,
852 jwks_uri,
853 ..
854 } => {
855 tracing::info!(
856 issuer = %issuer,
857 audience = %audience,
858 jwks_uri = %jwks_uri,
859 "Enabling JWT auth (JWKS)"
860 );
861 let validator = tower_mcp::oauth::JwksValidator::builder(jwks_uri)
862 .expected_audience(audience)
863 .expected_issuer(issuer)
864 .build()
865 .await
866 .context("building JWKS validator")?;
867
868 let addr = format!(
869 "http://{}:{}",
870 config.proxy.listen.host, config.proxy.listen.port
871 );
872 let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
873 .authorization_server(issuer);
874
875 let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
876 router.layer(layer)
877 }
878 #[cfg(not(feature = "oauth"))]
879 AuthConfig::Jwt { .. } => {
880 anyhow::bail!(
881 "JWT auth requires the 'oauth' feature. Rebuild with: cargo install mcp-proxy --features oauth"
882 );
883 }
884 #[cfg(feature = "oauth")]
885 AuthConfig::OAuth {
886 issuer,
887 audience,
888 token_validation,
889 jwks_uri,
890 introspection_endpoint,
891 client_id,
892 client_secret,
893 ..
894 } => {
895 use crate::config::TokenValidationStrategy;
896
897 tracing::info!(
898 issuer = %issuer,
899 audience = %audience,
900 strategy = ?token_validation,
901 "Enabling OAuth 2.1 auth"
902 );
903
904 let discovered = crate::introspection::discover_auth_server(issuer)
906 .await
907 .context("discovering OAuth authorization server")?;
908
909 let effective_jwks_uri = jwks_uri
910 .as_deref()
911 .or(discovered.jwks_uri.as_deref())
912 .ok_or_else(|| {
913 anyhow::anyhow!(
914 "JWKS URI not found via discovery and not configured manually"
915 )
916 })?;
917
918 let effective_introspection = introspection_endpoint
919 .as_deref()
920 .or(discovered.introspection_endpoint.as_deref());
921
922 let addr = format!(
923 "http://{}:{}",
924 config.proxy.listen.host, config.proxy.listen.port
925 );
926 let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
927 .authorization_server(issuer);
928
929 match token_validation {
930 TokenValidationStrategy::Jwt => {
931 let validator =
932 tower_mcp::oauth::JwksValidator::builder(effective_jwks_uri)
933 .expected_audience(audience)
934 .expected_issuer(issuer)
935 .build()
936 .await
937 .context("building JWKS validator")?;
938 let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
939 router.layer(layer)
940 }
941 TokenValidationStrategy::Introspection => {
942 let endpoint = effective_introspection.ok_or_else(|| {
943 anyhow::anyhow!(
944 "introspection endpoint not found via discovery and not configured"
945 )
946 })?;
947 let validator = crate::introspection::IntrospectionValidator::new(
948 endpoint,
949 client_id.as_deref().unwrap(),
950 client_secret.as_deref().unwrap(),
951 )
952 .expected_audience(audience);
953 let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
954 router.layer(layer)
955 }
956 TokenValidationStrategy::Both => {
957 let endpoint = effective_introspection.ok_or_else(|| {
958 anyhow::anyhow!(
959 "introspection endpoint not found via discovery and not configured"
960 )
961 })?;
962 let jwt_validator =
963 tower_mcp::oauth::JwksValidator::builder(effective_jwks_uri)
964 .expected_audience(audience)
965 .expected_issuer(issuer)
966 .build()
967 .await
968 .context("building JWKS validator")?;
969 let introspection_validator =
970 crate::introspection::IntrospectionValidator::new(
971 endpoint,
972 client_id.as_deref().unwrap(),
973 client_secret.as_deref().unwrap(),
974 )
975 .expected_audience(audience);
976 let fallback = crate::introspection::FallbackValidator::new(
977 jwt_validator,
978 introspection_validator,
979 );
980 let layer = tower_mcp::oauth::OAuthLayer::new(fallback, metadata);
981 router.layer(layer)
982 }
983 }
984 }
985 #[cfg(not(feature = "oauth"))]
986 AuthConfig::OAuth { .. } => {
987 anyhow::bail!(
988 "OAuth auth requires the 'oauth' feature. Rebuild with: cargo install mcp-proxy --features oauth"
989 );
990 }
991 }
992 } else {
993 router
994 };
995 Ok(router)
996}
997
998pub async fn shutdown_signal(timeout: Duration) {
1000 let ctrl_c = tokio::signal::ctrl_c();
1001 #[cfg(unix)]
1002 {
1003 let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
1004 .expect("SIGTERM handler");
1005 tokio::select! {
1006 _ = ctrl_c => {},
1007 _ = sigterm.recv() => {},
1008 }
1009 }
1010 #[cfg(not(unix))]
1011 {
1012 ctrl_c.await.ok();
1013 }
1014 tracing::info!(
1015 timeout_seconds = timeout.as_secs(),
1016 "Shutdown signal received, draining connections"
1017 );
1018}