1use std::convert::Infallible;
4use std::time::Duration;
5
6use anyhow::{Context, Result};
7use axum::Router;
8use tokio::process::Command;
9use tower::timeout::TimeoutLayer;
10use tower::util::BoxCloneService;
11use tower_mcp::SessionHandle;
12use tower_mcp::auth::{AuthLayer, StaticBearerValidator};
13use tower_mcp::client::StdioClientTransport;
14use tower_mcp::proxy::McpProxy;
15use tower_mcp::{RouterRequest, RouterResponse};
16
17use crate::admin::BackendMeta;
18use crate::alias;
19use crate::cache;
20use crate::coalesce;
21use crate::config::{AuthConfig, ProxyConfig, TransportType};
22use crate::filter::CapabilityFilterService;
23#[cfg(feature = "oauth")]
24use crate::rbac::{RbacConfig, RbacService};
25use crate::validation::{ValidationConfig, ValidationService};
26
27pub struct Proxy {
29 router: Router,
30 session_handle: SessionHandle,
31 inner: McpProxy,
32 config: ProxyConfig,
33 #[cfg(feature = "discovery")]
34 discovery_index: Option<crate::discovery::SharedDiscoveryIndex>,
35}
36
37impl Proxy {
38 pub async fn from_config(config: ProxyConfig) -> Result<Self> {
44 let mcp_proxy = build_mcp_proxy(&config).await?;
45 let proxy_for_admin = mcp_proxy.clone();
46 let mut proxy_for_caller = mcp_proxy.clone();
47 let proxy_for_management = mcp_proxy.clone();
48
49 #[cfg(feature = "metrics")]
51 let metrics_handle = if config.observability.metrics.enabled {
52 tracing::info!("Prometheus metrics enabled at /admin/metrics");
53 let builder = metrics_exporter_prometheus::PrometheusBuilder::new();
54 let handle = builder
55 .install_recorder()
56 .context("installing Prometheus metrics recorder")?;
57 Some(handle)
58 } else {
59 None
60 };
61 #[cfg(not(feature = "metrics"))]
62 let metrics_handle = None;
63
64 let (service, cache_handle) = build_middleware_stack(&config, mcp_proxy)?;
65
66 let (router, session_handle) =
67 tower_mcp::transport::http::HttpTransport::from_service(service)
68 .into_router_with_handle();
69
70 let router = apply_auth(&config, router).await?;
72
73 let backend_meta: std::collections::HashMap<String, BackendMeta> = config
75 .backends
76 .iter()
77 .map(|b| {
78 (
79 b.name.clone(),
80 BackendMeta {
81 transport: format!("{:?}", b.transport).to_lowercase(),
82 },
83 )
84 })
85 .collect();
86
87 let admin_state = crate::admin::spawn_health_checker(
89 proxy_for_admin,
90 config.proxy.name.clone(),
91 config.proxy.version.clone(),
92 config.backends.len(),
93 backend_meta,
94 );
95 let router = router.nest(
96 "/admin",
97 crate::admin::admin_router(
98 admin_state.clone(),
99 metrics_handle,
100 session_handle.clone(),
101 cache_handle,
102 proxy_for_management,
103 &config,
104 ),
105 );
106 tracing::info!("Admin API enabled at /admin/backends");
107
108 #[cfg(feature = "discovery")]
110 let discovery_enabled = config.proxy.tool_discovery
111 || config.proxy.tool_exposure == crate::config::ToolExposure::Search;
112 #[cfg(feature = "discovery")]
113 let (discovery_index, discovery_tools) = if discovery_enabled {
114 let index =
115 crate::discovery::build_index(&mut proxy_for_caller, &config.proxy.separator).await;
116 let tools = crate::discovery::build_discovery_tools(index.clone());
117 (Some(index), Some(tools))
118 } else {
119 (None, None)
120 };
121 #[cfg(not(feature = "discovery"))]
122 let discovery_tools: Option<Vec<tower_mcp::Tool>> = None;
123
124 if let Err(e) = crate::admin_tools::register_admin_tools(
126 &proxy_for_caller,
127 admin_state,
128 session_handle.clone(),
129 &config,
130 discovery_tools,
131 )
132 .await
133 {
134 tracing::warn!("Failed to register admin tools: {e}");
135 } else {
136 tracing::info!("MCP admin tools registered under proxy/ namespace");
137 }
138
139 Ok(Self {
140 router,
141 session_handle,
142 inner: proxy_for_caller,
143 config,
144 #[cfg(feature = "discovery")]
145 discovery_index,
146 })
147 }
148
149 pub fn session_handle(&self) -> &SessionHandle {
151 &self.session_handle
152 }
153
154 pub fn mcp_proxy(&self) -> &McpProxy {
158 &self.inner
159 }
160
161 pub fn enable_hot_reload(&self, config_path: std::path::PathBuf) {
166 tracing::info!("Hot reload enabled, watching config file for changes");
167 crate::reload::spawn_config_watcher(
168 config_path,
169 self.inner.clone(),
170 #[cfg(feature = "discovery")]
171 self.discovery_index
172 .as_ref()
173 .map(|idx| (idx.clone(), self.config.proxy.separator.clone())),
174 );
175 }
176
177 pub fn into_router(self) -> (Router, SessionHandle) {
189 (self.router, self.session_handle)
190 }
191
192 pub async fn serve(self) -> Result<()> {
197 let addr = format!(
198 "{}:{}",
199 self.config.proxy.listen.host, self.config.proxy.listen.port
200 );
201
202 tracing::info!(listen = %addr, "Proxy ready");
203
204 let listener = tokio::net::TcpListener::bind(&addr)
205 .await
206 .with_context(|| format!("binding to {}", addr))?;
207
208 let shutdown_timeout = Duration::from_secs(self.config.proxy.shutdown_timeout_seconds);
209 axum::serve(listener, self.router)
210 .with_graceful_shutdown(shutdown_signal(shutdown_timeout))
211 .await
212 .context("server error")?;
213
214 tracing::info!("Proxy shut down");
215 Ok(())
216 }
217}
218
219async fn build_mcp_proxy(config: &ProxyConfig) -> Result<McpProxy> {
221 let mut builder = McpProxy::builder(&config.proxy.name, &config.proxy.version)
222 .separator(&config.proxy.separator);
223
224 if let Some(instructions) = &config.proxy.instructions {
225 builder = builder.instructions(instructions);
226 }
227
228 let outlier_detector = {
231 let max_pct = config
232 .backends
233 .iter()
234 .filter_map(|b| b.outlier_detection.as_ref())
235 .map(|od| od.max_ejection_percent)
236 .max();
237 max_pct.map(crate::outlier::OutlierDetector::new)
238 };
239
240 for backend in &config.backends {
241 tracing::info!(name = %backend.name, transport = ?backend.transport, "Adding backend");
242
243 match backend.transport {
244 TransportType::Stdio => {
245 let command = backend.command.as_deref().unwrap();
246 let args: Vec<&str> = backend.args.iter().map(|s| s.as_str()).collect();
247
248 let mut cmd = Command::new(command);
249 cmd.args(&args);
250
251 for (key, value) in &backend.env {
252 cmd.env(key, value);
253 }
254
255 let transport = StdioClientTransport::spawn_command(&mut cmd)
256 .await
257 .with_context(|| format!("spawning backend '{}'", backend.name))?;
258
259 builder = builder.backend(&backend.name, transport).await;
260 }
261 TransportType::Http => {
262 let url = backend.url.as_deref().unwrap();
263 let mut transport = tower_mcp::client::HttpClientTransport::new(url);
264 if let Some(token) = &backend.bearer_token {
265 transport = transport.bearer_token(token);
266 }
267
268 builder = builder.backend(&backend.name, transport).await;
269 }
270 #[cfg(feature = "websocket")]
271 TransportType::Websocket => {
272 let url = backend.url.as_deref().unwrap();
273 tracing::info!(url = %url, "Connecting to WebSocket backend");
274 let transport = if let Some(token) = &backend.bearer_token {
275 crate::ws_transport::WebSocketClientTransport::connect_with_bearer_token(
276 url, token,
277 )
278 .await
279 .with_context(|| {
280 format!("connecting to WebSocket backend '{}'", backend.name)
281 })?
282 } else {
283 crate::ws_transport::WebSocketClientTransport::connect(url)
284 .await
285 .with_context(|| {
286 format!("connecting to WebSocket backend '{}'", backend.name)
287 })?
288 };
289
290 builder = builder.backend(&backend.name, transport).await;
291 }
292 #[cfg(not(feature = "websocket"))]
293 TransportType::Websocket => {
294 anyhow::bail!(
295 "WebSocket transport requires the 'websocket' feature. \
296 Rebuild with: cargo install mcp-proxy --features websocket"
297 );
298 }
299 }
300
301 if let Some(retry_cfg) = &backend.retry {
305 tracing::info!(
306 backend = %backend.name,
307 max_retries = retry_cfg.max_retries,
308 initial_backoff_ms = retry_cfg.initial_backoff_ms,
309 max_backoff_ms = retry_cfg.max_backoff_ms,
310 "Applying retry policy"
311 );
312 let layer = crate::retry::build_retry_layer(retry_cfg, &backend.name);
313 builder = builder.backend_layer(layer);
314 }
315
316 if let Some(hedge_cfg) = &backend.hedging {
318 let delay = Duration::from_millis(hedge_cfg.delay_ms);
319 let max_attempts = hedge_cfg.max_hedges + 1; tracing::info!(
321 backend = %backend.name,
322 delay_ms = hedge_cfg.delay_ms,
323 max_hedges = hedge_cfg.max_hedges,
324 "Applying request hedging"
325 );
326 let layer = if delay.is_zero() {
327 tower_resilience::hedge::HedgeLayer::builder()
328 .no_delay()
329 .max_hedged_attempts(max_attempts)
330 .name(format!("{}-hedge", backend.name))
331 .build()
332 } else {
333 tower_resilience::hedge::HedgeLayer::builder()
334 .delay(delay)
335 .max_hedged_attempts(max_attempts)
336 .name(format!("{}-hedge", backend.name))
337 .build()
338 };
339 builder = builder.backend_layer(layer);
340 }
341
342 if let Some(cc) = &backend.concurrency {
344 tracing::info!(
345 backend = %backend.name,
346 max = cc.max_concurrent,
347 "Applying concurrency limit"
348 );
349 builder =
350 builder.backend_layer(tower::limit::ConcurrencyLimitLayer::new(cc.max_concurrent));
351 }
352
353 if let Some(rl) = &backend.rate_limit {
355 tracing::info!(
356 backend = %backend.name,
357 requests = rl.requests,
358 period_seconds = rl.period_seconds,
359 "Applying rate limit"
360 );
361 let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
362 .limit_for_period(rl.requests)
363 .refresh_period(Duration::from_secs(rl.period_seconds))
364 .name(format!("{}-ratelimit", backend.name))
365 .build();
366 builder = builder.backend_layer(layer);
367 }
368
369 if let Some(timeout) = &backend.timeout {
371 tracing::info!(
372 backend = %backend.name,
373 seconds = timeout.seconds,
374 "Applying timeout"
375 );
376 builder =
377 builder.backend_layer(TimeoutLayer::new(Duration::from_secs(timeout.seconds)));
378 }
379
380 if let Some(cb) = &backend.circuit_breaker {
382 tracing::info!(
383 backend = %backend.name,
384 failure_rate = cb.failure_rate_threshold,
385 wait_seconds = cb.wait_duration_seconds,
386 "Applying circuit breaker"
387 );
388 let layer = tower_resilience::circuitbreaker::CircuitBreakerLayer::builder()
389 .failure_rate_threshold(cb.failure_rate_threshold)
390 .minimum_number_of_calls(cb.minimum_calls)
391 .wait_duration_in_open(Duration::from_secs(cb.wait_duration_seconds))
392 .permitted_calls_in_half_open(cb.permitted_calls_in_half_open)
393 .name(format!("{}-cb", backend.name))
394 .build();
395 builder = builder.backend_layer(layer);
396 }
397
398 if let Some(od) = &backend.outlier_detection
400 && let Some(ref detector) = outlier_detector
401 {
402 tracing::info!(
403 backend = %backend.name,
404 consecutive_errors = od.consecutive_errors,
405 base_ejection_seconds = od.base_ejection_seconds,
406 max_ejection_percent = od.max_ejection_percent,
407 "Applying outlier detection"
408 );
409 let layer = crate::outlier::OutlierDetectionLayer::new(
410 backend.name.clone(),
411 od.clone(),
412 detector.clone(),
413 );
414 builder = builder.backend_layer(layer);
415 }
416 }
417
418 let result = builder.build().await?;
419
420 if !result.skipped.is_empty() {
421 for s in &result.skipped {
422 tracing::warn!("Skipped backend: {s}");
423 }
424 }
425
426 Ok(result.proxy)
427}
428
429fn build_middleware_stack(
431 config: &ProxyConfig,
432 proxy: McpProxy,
433) -> Result<(
434 BoxCloneService<RouterRequest, RouterResponse, Infallible>,
435 Option<cache::CacheHandle>,
436)> {
437 let mut service: BoxCloneService<RouterRequest, RouterResponse, Infallible> =
438 BoxCloneService::new(proxy);
439 let mut cache_handle: Option<cache::CacheHandle> = None;
440
441 let injection_rules: Vec<_> = config
443 .backends
444 .iter()
445 .filter(|b| !b.default_args.is_empty() || !b.inject_args.is_empty())
446 .map(|b| {
447 let namespace = format!("{}{}", b.name, config.proxy.separator);
448 tracing::info!(
449 backend = %b.name,
450 default_args = b.default_args.len(),
451 tool_rules = b.inject_args.len(),
452 "Applying argument injection"
453 );
454 crate::inject::InjectionRules::new(
455 namespace,
456 b.default_args.clone(),
457 b.inject_args.clone(),
458 )
459 })
460 .collect();
461
462 if !injection_rules.is_empty() {
463 service = BoxCloneService::new(crate::inject::InjectArgsService::new(
464 service,
465 injection_rules,
466 ));
467 }
468
469 let param_overrides: Vec<_> = config
471 .backends
472 .iter()
473 .filter(|b| !b.param_overrides.is_empty())
474 .flat_map(|b| {
475 let namespace = format!("{}{}", b.name, config.proxy.separator);
476 tracing::info!(
477 backend = %b.name,
478 overrides = b.param_overrides.len(),
479 "Applying parameter overrides"
480 );
481 b.param_overrides
482 .iter()
483 .map(move |c| crate::param_override::ToolOverride::new(&namespace, c))
484 })
485 .collect();
486
487 if !param_overrides.is_empty() {
488 service = BoxCloneService::new(crate::param_override::ParamOverrideService::new(
489 service,
490 param_overrides,
491 ));
492 }
493
494 let canary_mappings: std::collections::HashMap<String, (String, u32, u32)> = config
496 .backends
497 .iter()
498 .filter_map(|b| {
499 b.canary_of.as_ref().map(|primary_name| {
500 let primary_weight = config
502 .backends
503 .iter()
504 .find(|p| p.name == *primary_name)
505 .map(|p| p.weight)
506 .unwrap_or(100);
507 (
508 primary_name.clone(),
509 (b.name.clone(), primary_weight, b.weight),
510 )
511 })
512 })
513 .collect();
514
515 if !canary_mappings.is_empty() {
516 for (primary, (canary, pw, cw)) in &canary_mappings {
517 tracing::info!(
518 primary = %primary,
519 canary = %canary,
520 primary_weight = pw,
521 canary_weight = cw,
522 "Enabling canary routing"
523 );
524 }
525 service = BoxCloneService::new(crate::canary::CanaryService::new(
526 service,
527 canary_mappings,
528 &config.proxy.separator,
529 ));
530 }
531
532 let mut failover_groups: std::collections::HashMap<String, Vec<(u32, String)>> =
535 std::collections::HashMap::new();
536 for b in &config.backends {
537 if let Some(ref primary) = b.failover_for {
538 failover_groups
539 .entry(primary.clone())
540 .or_default()
541 .push((b.priority, b.name.clone()));
542 }
543 }
544 let failover_mappings: std::collections::HashMap<String, Vec<String>> = failover_groups
546 .into_iter()
547 .map(|(primary, mut backends)| {
548 backends.sort_by_key(|(priority, _)| *priority);
549 let names: Vec<String> = backends.into_iter().map(|(_, name)| name).collect();
550 (primary, names)
551 })
552 .collect();
553
554 if !failover_mappings.is_empty() {
555 for (primary, failovers) in &failover_mappings {
556 tracing::info!(
557 primary = %primary,
558 failovers = ?failovers,
559 "Enabling failover routing"
560 );
561 }
562 service = BoxCloneService::new(crate::failover::FailoverService::new(
563 service,
564 failover_mappings,
565 &config.proxy.separator,
566 ));
567 }
568
569 let mirror_mappings: std::collections::HashMap<String, (String, u32)> = config
571 .backends
572 .iter()
573 .filter_map(|b| {
574 b.mirror_of
575 .as_ref()
576 .map(|source| (source.clone(), (b.name.clone(), b.mirror_percent)))
577 })
578 .collect();
579
580 if !mirror_mappings.is_empty() {
581 for (source, (mirror, pct)) in &mirror_mappings {
582 tracing::info!(
583 source = %source,
584 mirror = %mirror,
585 percent = pct,
586 "Enabling traffic mirroring"
587 );
588 }
589 service = BoxCloneService::new(crate::mirror::MirrorService::new(
590 service,
591 mirror_mappings,
592 &config.proxy.separator,
593 ));
594 }
595
596 let cache_configs: Vec<_> = config
598 .backends
599 .iter()
600 .filter_map(|b| {
601 b.cache
602 .as_ref()
603 .map(|c| (format!("{}{}", b.name, config.proxy.separator), c))
604 })
605 .collect();
606
607 if !cache_configs.is_empty() {
608 for (ns, cfg) in &cache_configs {
609 tracing::info!(
610 backend = %ns.trim_end_matches(&config.proxy.separator),
611 resource_ttl = cfg.resource_ttl_seconds,
612 tool_ttl = cfg.tool_ttl_seconds,
613 max_entries = cfg.max_entries,
614 "Applying response cache"
615 );
616 }
617 let (cache_svc, handle) = cache::CacheService::new(service, cache_configs, &config.cache);
618 service = BoxCloneService::new(cache_svc);
619 cache_handle = Some(handle);
620 }
621
622 if config.performance.coalesce_requests {
624 tracing::info!("Request coalescing enabled");
625 service = BoxCloneService::new(coalesce::CoalesceService::new(service));
626 }
627
628 if config.security.max_argument_size.is_some() {
630 let validation = ValidationConfig {
631 max_argument_size: config.security.max_argument_size,
632 };
633 if let Some(max) = validation.max_argument_size {
634 tracing::info!(max_argument_size = max, "Applying request validation");
635 }
636 service = BoxCloneService::new(ValidationService::new(service, validation));
637 }
638
639 let filters: Vec<_> = config
641 .backends
642 .iter()
643 .filter_map(|b| b.build_filter(&config.proxy.separator).transpose())
644 .collect::<anyhow::Result<Vec<_>>>()?;
645
646 if !filters.is_empty() {
647 for f in &filters {
648 tracing::info!(
649 backend = %f.namespace.trim_end_matches(&config.proxy.separator),
650 tool_filter = ?f.tool_filter,
651 resource_filter = ?f.resource_filter,
652 prompt_filter = ?f.prompt_filter,
653 "Applying capability filter"
654 );
655 }
656 service = BoxCloneService::new(CapabilityFilterService::new(service, filters));
657 }
658
659 if config.proxy.tool_exposure == crate::config::ToolExposure::Search {
661 let prefix = format!("proxy{}", config.proxy.separator);
662 tracing::info!(
663 prefix = %prefix,
664 "Search mode: ListTools will only show proxy/ namespace tools"
665 );
666 service =
667 BoxCloneService::new(crate::filter::SearchModeFilterService::new(service, prefix));
668 }
669
670 let alias_mappings: Vec<_> = config
672 .backends
673 .iter()
674 .flat_map(|b| {
675 let ns = format!("{}{}", b.name, config.proxy.separator);
676 b.aliases
677 .iter()
678 .map(move |a| (ns.clone(), a.from.clone(), a.to.clone()))
679 })
680 .collect();
681
682 if let Some(alias_map) = alias::AliasMap::new(alias_mappings) {
683 let count = alias_map.forward.len();
684 tracing::info!(aliases = count, "Applying tool aliases");
685 service = BoxCloneService::new(alias::AliasService::new(service, alias_map));
686 }
687
688 if !config.composite_tools.is_empty() {
690 let count = config.composite_tools.len();
691 tracing::info!(composite_tools = count, "Applying composite tool fan-out");
692 service = BoxCloneService::new(crate::composite::CompositeService::new(
693 service,
694 config.composite_tools.clone(),
695 ));
696 }
697
698 #[cfg(feature = "oauth")]
700 if matches!(
701 &config.auth,
702 Some(AuthConfig::Bearer {
703 scoped_tokens,
704 ..
705 }) if !scoped_tokens.is_empty()
706 ) {
707 tracing::info!("Enabling bearer token scoping middleware");
708 service = BoxCloneService::new(crate::bearer_scope::BearerScopingService::new(service));
709 }
710
711 #[cfg(feature = "oauth")]
713 {
714 let rbac_config = match &config.auth {
715 Some(
716 AuthConfig::Jwt {
717 roles,
718 role_mapping: Some(mapping),
719 ..
720 }
721 | AuthConfig::OAuth {
722 roles,
723 role_mapping: Some(mapping),
724 ..
725 },
726 ) if !roles.is_empty() => {
727 tracing::info!(
728 roles = roles.len(),
729 claim = %mapping.claim,
730 "Enabling RBAC"
731 );
732 Some(RbacConfig::new(roles, mapping))
733 }
734 _ => None,
735 };
736
737 if let Some(rbac) = rbac_config {
738 service = BoxCloneService::new(RbacService::new(service, rbac));
739 }
740
741 let forward_namespaces: std::collections::HashSet<String> = config
743 .backends
744 .iter()
745 .filter(|b| b.forward_auth)
746 .map(|b| format!("{}{}", b.name, config.proxy.separator))
747 .collect();
748
749 if !forward_namespaces.is_empty() {
750 tracing::info!(
751 backends = ?forward_namespaces,
752 "Enabling token passthrough for forward_auth backends"
753 );
754 service = BoxCloneService::new(crate::token::TokenPassthroughService::new(
755 service,
756 forward_namespaces,
757 ));
758 }
759 }
760
761 #[cfg(feature = "metrics")]
763 if config.observability.metrics.enabled {
764 service = BoxCloneService::new(crate::metrics::MetricsService::new(service));
765 }
766
767 if config.observability.access_log.enabled {
769 tracing::info!("Access logging enabled (target: mcp::access)");
770 service = BoxCloneService::new(crate::access_log::AccessLogService::new(
771 service,
772 &config.proxy.separator,
773 ));
774 }
775
776 if config.observability.audit {
778 tracing::info!("Audit logging enabled (target: mcp::audit)");
779 let audited = tower::Layer::layer(&tower_mcp::AuditLayer::new(), service);
780 service = BoxCloneService::new(tower_mcp::CatchError::new(audited));
781 }
782
783 if let Some(ref rl) = config.proxy.rate_limit {
785 tracing::info!(
786 requests = rl.requests,
787 period_seconds = rl.period_seconds,
788 "Applying global rate limit"
789 );
790 let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
791 .limit_for_period(rl.requests)
792 .refresh_period(Duration::from_secs(rl.period_seconds))
793 .name("global-ratelimit")
794 .build();
795 let limited = tower::Layer::layer(&layer, service);
796 service = BoxCloneService::new(tower_mcp::CatchError::new(limited));
797 }
798
799 Ok((service, cache_handle))
800}
801
802async fn apply_auth(config: &ProxyConfig, router: Router) -> Result<Router> {
804 let router = if let Some(auth) = &config.auth {
805 match auth {
806 AuthConfig::Bearer {
807 tokens,
808 scoped_tokens,
809 } => {
810 let total = tokens.len() + scoped_tokens.len();
811 if scoped_tokens.is_empty() {
812 tracing::info!(token_count = total, "Enabling bearer token auth");
814 let validator = StaticBearerValidator::new(tokens.iter().cloned());
815 let layer = AuthLayer::new(validator);
816 router.layer(layer)
817 } else {
818 #[cfg(feature = "oauth")]
820 {
821 tracing::info!(
822 token_count = total,
823 scoped = scoped_tokens.len(),
824 "Enabling bearer token auth with per-token scoping"
825 );
826 let layer =
827 crate::bearer_scope::ScopedBearerAuthLayer::new(tokens, scoped_tokens);
828 router.layer(layer)
829 }
830 #[cfg(not(feature = "oauth"))]
831 {
832 anyhow::bail!(
833 "Per-token tool scoping requires the 'oauth' feature. \
834 Rebuild with: cargo install mcp-proxy --features oauth"
835 );
836 }
837 }
838 }
839 #[cfg(feature = "oauth")]
840 AuthConfig::Jwt {
841 issuer,
842 audience,
843 jwks_uri,
844 ..
845 } => {
846 tracing::info!(
847 issuer = %issuer,
848 audience = %audience,
849 jwks_uri = %jwks_uri,
850 "Enabling JWT auth (JWKS)"
851 );
852 let validator = tower_mcp::oauth::JwksValidator::builder(jwks_uri)
853 .expected_audience(audience)
854 .expected_issuer(issuer)
855 .build()
856 .await
857 .context("building JWKS validator")?;
858
859 let addr = format!(
860 "http://{}:{}",
861 config.proxy.listen.host, config.proxy.listen.port
862 );
863 let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
864 .authorization_server(issuer);
865
866 let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
867 router.layer(layer)
868 }
869 #[cfg(not(feature = "oauth"))]
870 AuthConfig::Jwt { .. } => {
871 anyhow::bail!(
872 "JWT auth requires the 'oauth' feature. Rebuild with: cargo install mcp-proxy --features oauth"
873 );
874 }
875 #[cfg(feature = "oauth")]
876 AuthConfig::OAuth {
877 issuer,
878 audience,
879 token_validation,
880 jwks_uri,
881 introspection_endpoint,
882 client_id,
883 client_secret,
884 ..
885 } => {
886 use crate::config::TokenValidationStrategy;
887
888 tracing::info!(
889 issuer = %issuer,
890 audience = %audience,
891 strategy = ?token_validation,
892 "Enabling OAuth 2.1 auth"
893 );
894
895 let discovered = crate::introspection::discover_auth_server(issuer)
897 .await
898 .context("discovering OAuth authorization server")?;
899
900 let effective_jwks_uri = jwks_uri
901 .as_deref()
902 .or(discovered.jwks_uri.as_deref())
903 .ok_or_else(|| {
904 anyhow::anyhow!(
905 "JWKS URI not found via discovery and not configured manually"
906 )
907 })?;
908
909 let effective_introspection = introspection_endpoint
910 .as_deref()
911 .or(discovered.introspection_endpoint.as_deref());
912
913 let addr = format!(
914 "http://{}:{}",
915 config.proxy.listen.host, config.proxy.listen.port
916 );
917 let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
918 .authorization_server(issuer);
919
920 match token_validation {
921 TokenValidationStrategy::Jwt => {
922 let validator =
923 tower_mcp::oauth::JwksValidator::builder(effective_jwks_uri)
924 .expected_audience(audience)
925 .expected_issuer(issuer)
926 .build()
927 .await
928 .context("building JWKS validator")?;
929 let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
930 router.layer(layer)
931 }
932 TokenValidationStrategy::Introspection => {
933 let endpoint = effective_introspection.ok_or_else(|| {
934 anyhow::anyhow!(
935 "introspection endpoint not found via discovery and not configured"
936 )
937 })?;
938 let validator = crate::introspection::IntrospectionValidator::new(
939 endpoint,
940 client_id.as_deref().unwrap(),
941 client_secret.as_deref().unwrap(),
942 )
943 .expected_audience(audience);
944 let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
945 router.layer(layer)
946 }
947 TokenValidationStrategy::Both => {
948 let endpoint = effective_introspection.ok_or_else(|| {
949 anyhow::anyhow!(
950 "introspection endpoint not found via discovery and not configured"
951 )
952 })?;
953 let jwt_validator =
954 tower_mcp::oauth::JwksValidator::builder(effective_jwks_uri)
955 .expected_audience(audience)
956 .expected_issuer(issuer)
957 .build()
958 .await
959 .context("building JWKS validator")?;
960 let introspection_validator =
961 crate::introspection::IntrospectionValidator::new(
962 endpoint,
963 client_id.as_deref().unwrap(),
964 client_secret.as_deref().unwrap(),
965 )
966 .expected_audience(audience);
967 let fallback = crate::introspection::FallbackValidator::new(
968 jwt_validator,
969 introspection_validator,
970 );
971 let layer = tower_mcp::oauth::OAuthLayer::new(fallback, metadata);
972 router.layer(layer)
973 }
974 }
975 }
976 #[cfg(not(feature = "oauth"))]
977 AuthConfig::OAuth { .. } => {
978 anyhow::bail!(
979 "OAuth auth requires the 'oauth' feature. Rebuild with: cargo install mcp-proxy --features oauth"
980 );
981 }
982 }
983 } else {
984 router
985 };
986 Ok(router)
987}
988
989pub async fn shutdown_signal(timeout: Duration) {
991 let ctrl_c = tokio::signal::ctrl_c();
992 #[cfg(unix)]
993 {
994 let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
995 .expect("SIGTERM handler");
996 tokio::select! {
997 _ = ctrl_c => {},
998 _ = sigterm.recv() => {},
999 }
1000 }
1001 #[cfg(not(unix))]
1002 {
1003 ctrl_c.await.ok();
1004 }
1005 tracing::info!(
1006 timeout_seconds = timeout.as_secs(),
1007 "Shutdown signal received, draining connections"
1008 );
1009}