1use std::collections::HashMap;
4use std::convert::Infallible;
5use std::time::Duration;
6
7use anyhow::{Context, Result};
8use axum::Router;
9use tokio::process::Command;
10use tower::timeout::TimeoutLayer;
11use tower::util::BoxCloneService;
12use tower_mcp::SessionHandle;
13use tower_mcp::auth::{AuthLayer, StaticBearerValidator};
14use tower_mcp::client::StdioClientTransport;
15use tower_mcp::proxy::McpProxy;
16use tower_mcp::{RouterRequest, RouterResponse};
17
18use crate::admin::BackendMeta;
19use crate::alias;
20use crate::cache;
21use crate::coalesce;
22use crate::config::{AuthConfig, ProxyConfig, TransportType};
23use crate::filter::CapabilityFilterService;
24#[cfg(feature = "oauth")]
25use crate::rbac::{RbacConfig, RbacService};
26use crate::validation::{ValidationConfig, ValidationService};
27
28pub struct Proxy {
30 router: Router,
31 session_handle: SessionHandle,
32 inner: McpProxy,
33 config: ProxyConfig,
34 #[cfg(feature = "discovery")]
35 discovery_index: Option<crate::discovery::SharedDiscoveryIndex>,
36}
37
38impl Proxy {
39 pub async fn from_config(config: ProxyConfig) -> Result<Self> {
45 let (mcp_proxy, cb_handles) = build_mcp_proxy(&config).await?;
46 let proxy_for_admin = mcp_proxy.clone();
47 let mut proxy_for_caller = mcp_proxy.clone();
48 let proxy_for_management = mcp_proxy.clone();
49
50 #[cfg(feature = "metrics")]
52 let metrics_handle = if config.observability.metrics.enabled {
53 tracing::info!("Prometheus metrics enabled at /admin/metrics");
54 let builder = metrics_exporter_prometheus::PrometheusBuilder::new();
55 let handle = builder
56 .install_recorder()
57 .context("installing Prometheus metrics recorder")?;
58 Some(handle)
59 } else {
60 None
61 };
62 #[cfg(not(feature = "metrics"))]
63 let metrics_handle = None;
64
65 let (service, cache_handle) = build_middleware_stack(&config, mcp_proxy)?;
66
67 let (router, session_handle) =
68 tower_mcp::transport::http::HttpTransport::from_service(service)
69 .into_router_with_handle();
70
71 let router = apply_auth(&config, router).await?;
73
74 let backend_meta: std::collections::HashMap<String, BackendMeta> = config
76 .backends
77 .iter()
78 .map(|b| {
79 (
80 b.name.clone(),
81 BackendMeta {
82 transport: format!("{:?}", b.transport).to_lowercase(),
83 },
84 )
85 })
86 .collect();
87
88 let admin_state = crate::admin::spawn_health_checker(
90 proxy_for_admin,
91 config.proxy.name.clone(),
92 config.proxy.version.clone(),
93 config.backends.len(),
94 backend_meta,
95 );
96 let router = router.nest(
97 "/admin",
98 crate::admin::admin_router(
99 admin_state.clone(),
100 metrics_handle,
101 session_handle.clone(),
102 cache_handle,
103 proxy_for_management,
104 &config,
105 config.source_path.clone(),
106 cb_handles,
107 ),
108 );
109 tracing::info!("Admin API enabled at /admin/backends");
110
111 #[cfg(feature = "discovery")]
113 let discovery_enabled = config.proxy.tool_discovery
114 || config.proxy.tool_exposure == crate::config::ToolExposure::Search;
115 #[cfg(feature = "discovery")]
116 let (discovery_index, discovery_tools) = if discovery_enabled {
117 let index =
118 crate::discovery::build_index(&mut proxy_for_caller, &config.proxy.separator).await;
119 let tools = crate::discovery::build_discovery_tools(index.clone());
120 (Some(index), Some(tools))
121 } else {
122 (None, None)
123 };
124 #[cfg(not(feature = "discovery"))]
125 let discovery_tools: Option<Vec<tower_mcp::Tool>> = None;
126
127 if let Err(e) = crate::admin_tools::register_admin_tools(
129 &proxy_for_caller,
130 admin_state,
131 session_handle.clone(),
132 &config,
133 discovery_tools,
134 )
135 .await
136 {
137 tracing::warn!("Failed to register admin tools: {e}");
138 } else {
139 tracing::info!("MCP admin tools registered under proxy/ namespace");
140 }
141
142 Ok(Self {
143 router,
144 session_handle,
145 inner: proxy_for_caller,
146 config,
147 #[cfg(feature = "discovery")]
148 discovery_index,
149 })
150 }
151
152 pub fn session_handle(&self) -> &SessionHandle {
154 &self.session_handle
155 }
156
157 pub fn mcp_proxy(&self) -> &McpProxy {
161 &self.inner
162 }
163
164 pub fn enable_hot_reload(&self, config_path: std::path::PathBuf) {
169 tracing::info!("Hot reload enabled, watching config file for changes");
170 crate::reload::spawn_config_watcher(
171 config_path,
172 self.inner.clone(),
173 #[cfg(feature = "discovery")]
174 self.discovery_index
175 .as_ref()
176 .map(|idx| (idx.clone(), self.config.proxy.separator.clone())),
177 );
178 }
179
180 pub fn into_router(self) -> (Router, SessionHandle) {
192 (self.router, self.session_handle)
193 }
194
195 pub async fn serve(self) -> Result<()> {
200 let addr = format!(
201 "{}:{}",
202 self.config.proxy.listen.host, self.config.proxy.listen.port
203 );
204
205 tracing::info!(listen = %addr, "Proxy ready");
206
207 let listener = tokio::net::TcpListener::bind(&addr)
208 .await
209 .with_context(|| format!("binding to {}", addr))?;
210
211 let shutdown_timeout = Duration::from_secs(self.config.proxy.shutdown_timeout_seconds);
212 axum::serve(listener, self.router)
213 .with_graceful_shutdown(shutdown_signal(shutdown_timeout))
214 .await
215 .context("server error")?;
216
217 tracing::info!("Proxy shut down");
218 Ok(())
219 }
220}
221
222pub type CbHandle = tower_resilience::circuitbreaker::CircuitBreakerHandle;
224
225async fn build_mcp_proxy(config: &ProxyConfig) -> Result<(McpProxy, HashMap<String, CbHandle>)> {
228 let mut builder = McpProxy::builder(&config.proxy.name, &config.proxy.version)
229 .separator(&config.proxy.separator);
230 let mut cb_handles: HashMap<String, CbHandle> = HashMap::new();
231
232 if let Some(instructions) = &config.proxy.instructions {
233 builder = builder.instructions(instructions);
234 }
235
236 let outlier_detector = {
239 let max_pct = config
240 .backends
241 .iter()
242 .filter_map(|b| b.outlier_detection.as_ref())
243 .map(|od| od.max_ejection_percent)
244 .max();
245 max_pct.map(crate::outlier::OutlierDetector::new)
246 };
247
248 for backend in &config.backends {
249 tracing::info!(name = %backend.name, transport = ?backend.transport, "Adding backend");
250
251 match backend.transport {
252 TransportType::Stdio => {
253 let command = backend.command.as_deref().unwrap();
254 let args: Vec<&str> = backend.args.iter().map(|s| s.as_str()).collect();
255
256 let mut cmd = Command::new(command);
257 cmd.args(&args);
258
259 for (key, value) in &backend.env {
260 cmd.env(key, value);
261 }
262
263 let transport = StdioClientTransport::spawn_command(&mut cmd)
264 .await
265 .with_context(|| format!("spawning backend '{}'", backend.name))?;
266
267 builder = builder.backend(&backend.name, transport).await;
268 }
269 TransportType::Http => {
270 let url = backend.url.as_deref().unwrap();
271 let mut transport = tower_mcp::client::HttpClientTransport::new(url);
272 if let Some(token) = &backend.bearer_token {
273 transport = transport.bearer_token(token);
274 }
275
276 builder = builder.backend(&backend.name, transport).await;
277 }
278 #[cfg(feature = "websocket")]
279 TransportType::Websocket => {
280 let url = backend.url.as_deref().unwrap();
281 tracing::info!(url = %url, "Connecting to WebSocket backend");
282 let transport = if let Some(token) = &backend.bearer_token {
283 crate::ws_transport::WebSocketClientTransport::connect_with_bearer_token(
284 url, token,
285 )
286 .await
287 .with_context(|| {
288 format!("connecting to WebSocket backend '{}'", backend.name)
289 })?
290 } else {
291 crate::ws_transport::WebSocketClientTransport::connect(url)
292 .await
293 .with_context(|| {
294 format!("connecting to WebSocket backend '{}'", backend.name)
295 })?
296 };
297
298 builder = builder.backend(&backend.name, transport).await;
299 }
300 #[cfg(not(feature = "websocket"))]
301 TransportType::Websocket => {
302 anyhow::bail!(
303 "WebSocket transport requires the 'websocket' feature. \
304 Rebuild with: cargo install mcp-proxy --features websocket"
305 );
306 }
307 }
308
309 if let Some(retry_cfg) = &backend.retry {
313 tracing::info!(
314 backend = %backend.name,
315 max_retries = retry_cfg.max_retries,
316 initial_backoff_ms = retry_cfg.initial_backoff_ms,
317 max_backoff_ms = retry_cfg.max_backoff_ms,
318 "Applying retry policy"
319 );
320 let layer = crate::retry::build_retry_layer(retry_cfg, &backend.name);
321 builder = builder.backend_layer(layer);
322 }
323
324 if let Some(hedge_cfg) = &backend.hedging {
326 let delay = Duration::from_millis(hedge_cfg.delay_ms);
327 let max_attempts = hedge_cfg.max_hedges + 1; tracing::info!(
329 backend = %backend.name,
330 delay_ms = hedge_cfg.delay_ms,
331 max_hedges = hedge_cfg.max_hedges,
332 "Applying request hedging"
333 );
334 let layer = if delay.is_zero() {
335 tower_resilience::hedge::HedgeLayer::builder()
336 .no_delay()
337 .max_hedged_attempts(max_attempts)
338 .name(format!("{}-hedge", backend.name))
339 .build()
340 } else {
341 tower_resilience::hedge::HedgeLayer::builder()
342 .delay(delay)
343 .max_hedged_attempts(max_attempts)
344 .name(format!("{}-hedge", backend.name))
345 .build()
346 };
347 builder = builder.backend_layer(layer);
348 }
349
350 if let Some(cc) = &backend.concurrency {
352 tracing::info!(
353 backend = %backend.name,
354 max = cc.max_concurrent,
355 "Applying concurrency limit"
356 );
357 builder =
358 builder.backend_layer(tower::limit::ConcurrencyLimitLayer::new(cc.max_concurrent));
359 }
360
361 if let Some(rl) = &backend.rate_limit {
363 tracing::info!(
364 backend = %backend.name,
365 requests = rl.requests,
366 period_seconds = rl.period_seconds,
367 "Applying rate limit"
368 );
369 let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
370 .limit_for_period(rl.requests)
371 .refresh_period(Duration::from_secs(rl.period_seconds))
372 .name(format!("{}-ratelimit", backend.name))
373 .build();
374 builder = builder.backend_layer(layer);
375 }
376
377 if let Some(timeout) = &backend.timeout {
379 tracing::info!(
380 backend = %backend.name,
381 seconds = timeout.seconds,
382 "Applying timeout"
383 );
384 builder =
385 builder.backend_layer(TimeoutLayer::new(Duration::from_secs(timeout.seconds)));
386 }
387
388 if let Some(cb) = &backend.circuit_breaker {
390 tracing::info!(
391 backend = %backend.name,
392 failure_rate = cb.failure_rate_threshold,
393 wait_seconds = cb.wait_duration_seconds,
394 "Applying circuit breaker"
395 );
396 let (layer, handle) = tower_resilience::circuitbreaker::CircuitBreakerLayer::builder()
397 .failure_rate_threshold(cb.failure_rate_threshold)
398 .minimum_number_of_calls(cb.minimum_calls)
399 .wait_duration_in_open(Duration::from_secs(cb.wait_duration_seconds))
400 .permitted_calls_in_half_open(cb.permitted_calls_in_half_open)
401 .name(format!("{}-cb", backend.name))
402 .build_with_handle();
403 cb_handles.insert(backend.name.clone(), handle);
404 builder = builder.backend_layer(layer);
405 }
406
407 if let Some(od) = &backend.outlier_detection
409 && let Some(ref detector) = outlier_detector
410 {
411 tracing::info!(
412 backend = %backend.name,
413 consecutive_errors = od.consecutive_errors,
414 base_ejection_seconds = od.base_ejection_seconds,
415 max_ejection_percent = od.max_ejection_percent,
416 "Applying outlier detection"
417 );
418 let layer = crate::outlier::OutlierDetectionLayer::new(
419 backend.name.clone(),
420 od.clone(),
421 detector.clone(),
422 );
423 builder = builder.backend_layer(layer);
424 }
425 }
426
427 let result = builder.build().await?;
428
429 if !result.skipped.is_empty() {
430 for s in &result.skipped {
431 tracing::warn!("Skipped backend: {s}");
432 }
433 }
434
435 Ok((result.proxy, cb_handles))
436}
437
438#[cfg(feature = "oauth")]
445fn oauth_scope_layer(
446 required_scopes: &[String],
447) -> Option<tower_mcp::oauth::ScopeEnforcementLayer> {
448 if required_scopes.is_empty() {
449 return None;
450 }
451 let policy = tower_mcp::oauth::ScopePolicy::new().default_scopes(
452 tower_mcp::oauth::ScopeRequirement::all(required_scopes.iter().cloned()),
453 );
454 Some(tower_mcp::oauth::ScopeEnforcementLayer::new(policy))
455}
456
457fn build_middleware_stack(
459 config: &ProxyConfig,
460 proxy: McpProxy,
461) -> Result<(
462 BoxCloneService<RouterRequest, RouterResponse, Infallible>,
463 Option<cache::CacheHandle>,
464)> {
465 let mut service: BoxCloneService<RouterRequest, RouterResponse, Infallible> =
466 BoxCloneService::new(proxy);
467 let mut cache_handle: Option<cache::CacheHandle> = None;
468
469 let injection_rules: Vec<_> = config
471 .backends
472 .iter()
473 .filter(|b| !b.default_args.is_empty() || !b.inject_args.is_empty())
474 .map(|b| {
475 let namespace = format!("{}{}", b.name, config.proxy.separator);
476 tracing::info!(
477 backend = %b.name,
478 default_args = b.default_args.len(),
479 tool_rules = b.inject_args.len(),
480 "Applying argument injection"
481 );
482 crate::inject::InjectionRules::new(
483 namespace,
484 b.default_args.clone(),
485 b.inject_args.clone(),
486 )
487 })
488 .collect();
489
490 if !injection_rules.is_empty() {
491 service = BoxCloneService::new(crate::inject::InjectArgsService::new(
492 service,
493 injection_rules,
494 ));
495 }
496
497 let param_overrides: Vec<_> = config
499 .backends
500 .iter()
501 .filter(|b| !b.param_overrides.is_empty())
502 .flat_map(|b| {
503 let namespace = format!("{}{}", b.name, config.proxy.separator);
504 tracing::info!(
505 backend = %b.name,
506 overrides = b.param_overrides.len(),
507 "Applying parameter overrides"
508 );
509 b.param_overrides
510 .iter()
511 .map(move |c| crate::param_override::ToolOverride::new(&namespace, c))
512 })
513 .collect();
514
515 if !param_overrides.is_empty() {
516 service = BoxCloneService::new(crate::param_override::ParamOverrideService::new(
517 service,
518 param_overrides,
519 ));
520 }
521
522 let canary_mappings: std::collections::HashMap<String, (String, u32, u32)> = config
524 .backends
525 .iter()
526 .filter_map(|b| {
527 b.canary_of.as_ref().map(|primary_name| {
528 let primary_weight = config
530 .backends
531 .iter()
532 .find(|p| p.name == *primary_name)
533 .map(|p| p.weight)
534 .unwrap_or(100);
535 (
536 primary_name.clone(),
537 (b.name.clone(), primary_weight, b.weight),
538 )
539 })
540 })
541 .collect();
542
543 if !canary_mappings.is_empty() {
544 for (primary, (canary, pw, cw)) in &canary_mappings {
545 tracing::info!(
546 primary = %primary,
547 canary = %canary,
548 primary_weight = pw,
549 canary_weight = cw,
550 "Enabling canary routing"
551 );
552 }
553 service = BoxCloneService::new(crate::canary::CanaryService::new(
554 service,
555 canary_mappings,
556 &config.proxy.separator,
557 ));
558 }
559
560 let mut failover_groups: std::collections::HashMap<String, Vec<(u32, String)>> =
563 std::collections::HashMap::new();
564 for b in &config.backends {
565 if let Some(ref primary) = b.failover_for {
566 failover_groups
567 .entry(primary.clone())
568 .or_default()
569 .push((b.priority, b.name.clone()));
570 }
571 }
572 let failover_mappings: std::collections::HashMap<String, Vec<String>> = failover_groups
574 .into_iter()
575 .map(|(primary, mut backends)| {
576 backends.sort_by_key(|(priority, _)| *priority);
577 let names: Vec<String> = backends.into_iter().map(|(_, name)| name).collect();
578 (primary, names)
579 })
580 .collect();
581
582 if !failover_mappings.is_empty() {
583 for (primary, failovers) in &failover_mappings {
584 tracing::info!(
585 primary = %primary,
586 failovers = ?failovers,
587 "Enabling failover routing"
588 );
589 }
590 service = BoxCloneService::new(crate::failover::FailoverService::new(
591 service,
592 failover_mappings,
593 &config.proxy.separator,
594 ));
595 }
596
597 let mirror_mappings: std::collections::HashMap<String, (String, u32)> = config
599 .backends
600 .iter()
601 .filter_map(|b| {
602 b.mirror_of
603 .as_ref()
604 .map(|source| (source.clone(), (b.name.clone(), b.mirror_percent)))
605 })
606 .collect();
607
608 if !mirror_mappings.is_empty() {
609 for (source, (mirror, pct)) in &mirror_mappings {
610 tracing::info!(
611 source = %source,
612 mirror = %mirror,
613 percent = pct,
614 "Enabling traffic mirroring"
615 );
616 }
617 service = BoxCloneService::new(crate::mirror::MirrorService::new(
618 service,
619 mirror_mappings,
620 &config.proxy.separator,
621 ));
622 }
623
624 let cache_configs: Vec<_> = config
626 .backends
627 .iter()
628 .filter_map(|b| {
629 b.cache
630 .as_ref()
631 .map(|c| (format!("{}{}", b.name, config.proxy.separator), c))
632 })
633 .collect();
634
635 if !cache_configs.is_empty() {
636 for (ns, cfg) in &cache_configs {
637 tracing::info!(
638 backend = %ns.trim_end_matches(&config.proxy.separator),
639 resource_ttl = cfg.resource_ttl_seconds,
640 tool_ttl = cfg.tool_ttl_seconds,
641 max_entries = cfg.max_entries,
642 "Applying response cache"
643 );
644 }
645 let (cache_svc, handle) = cache::CacheService::new(service, cache_configs, &config.cache);
646 service = BoxCloneService::new(cache_svc);
647 cache_handle = Some(handle);
648 }
649
650 if config.performance.coalesce_requests {
652 tracing::info!("Request coalescing enabled");
653 service = BoxCloneService::new(coalesce::CoalesceService::new(service));
654 }
655
656 if config.security.max_argument_size.is_some() {
658 let validation = ValidationConfig {
659 max_argument_size: config.security.max_argument_size,
660 };
661 if let Some(max) = validation.max_argument_size {
662 tracing::info!(max_argument_size = max, "Applying request validation");
663 }
664 service = BoxCloneService::new(ValidationService::new(service, validation));
665 }
666
667 let filters: Vec<_> = config
669 .backends
670 .iter()
671 .filter_map(|b| b.build_filter(&config.proxy.separator).transpose())
672 .collect::<anyhow::Result<Vec<_>>>()?;
673
674 if !filters.is_empty() {
675 for f in &filters {
676 tracing::info!(
677 backend = %f.namespace.trim_end_matches(&config.proxy.separator),
678 tool_filter = ?f.tool_filter,
679 resource_filter = ?f.resource_filter,
680 prompt_filter = ?f.prompt_filter,
681 "Applying capability filter"
682 );
683 }
684 service = BoxCloneService::new(CapabilityFilterService::new(service, filters));
685 }
686
687 if config.proxy.tool_exposure == crate::config::ToolExposure::Search {
689 let prefix = format!("proxy{}", config.proxy.separator);
690 tracing::info!(
691 prefix = %prefix,
692 "Search mode: ListTools will only show proxy/ namespace tools"
693 );
694 service =
695 BoxCloneService::new(crate::filter::SearchModeFilterService::new(service, prefix));
696 }
697
698 let alias_mappings: Vec<_> = config
700 .backends
701 .iter()
702 .flat_map(|b| {
703 let ns = format!("{}{}", b.name, config.proxy.separator);
704 b.aliases
705 .iter()
706 .map(move |a| (ns.clone(), a.from.clone(), a.to.clone()))
707 })
708 .collect();
709
710 if let Some(alias_map) = alias::AliasMap::new(alias_mappings) {
711 let count = alias_map.forward.len();
712 tracing::info!(aliases = count, "Applying tool aliases");
713 service = BoxCloneService::new(alias::AliasService::new(service, alias_map));
714 }
715
716 if !config.composite_tools.is_empty() {
718 let count = config.composite_tools.len();
719 tracing::info!(composite_tools = count, "Applying composite tool fan-out");
720 service = BoxCloneService::new(crate::composite::CompositeService::new(
721 service,
722 config.composite_tools.clone(),
723 ));
724 }
725
726 #[cfg(feature = "oauth")]
728 if matches!(
729 &config.auth,
730 Some(AuthConfig::Bearer {
731 scoped_tokens,
732 ..
733 }) if !scoped_tokens.is_empty()
734 ) {
735 tracing::info!("Enabling bearer token scoping middleware");
736 service = BoxCloneService::new(crate::bearer_scope::BearerScopingService::new(service));
737 }
738
739 #[cfg(feature = "oauth")]
741 {
742 let rbac_config = match &config.auth {
743 Some(
744 AuthConfig::Jwt {
745 roles,
746 role_mapping: Some(mapping),
747 ..
748 }
749 | AuthConfig::OAuth {
750 roles,
751 role_mapping: Some(mapping),
752 ..
753 },
754 ) if !roles.is_empty() => {
755 tracing::info!(
756 roles = roles.len(),
757 claim = %mapping.claim,
758 "Enabling RBAC"
759 );
760 Some(RbacConfig::new(roles, mapping))
761 }
762 _ => None,
763 };
764
765 if let Some(rbac) = rbac_config {
766 service = BoxCloneService::new(RbacService::new(service, rbac));
767 }
768
769 let required_scopes: &[String] = match &config.auth {
776 Some(AuthConfig::OAuth {
777 required_scopes, ..
778 }) => required_scopes,
779 _ => &[],
780 };
781 if let Some(layer) = oauth_scope_layer(required_scopes) {
782 tracing::info!(
783 scopes = ?required_scopes,
784 "Enabling OAuth required_scopes enforcement"
785 );
786 service = BoxCloneService::new(tower::Layer::layer(&layer, service));
787 }
788
789 let forward_namespaces: std::collections::HashSet<String> = config
791 .backends
792 .iter()
793 .filter(|b| b.forward_auth)
794 .map(|b| format!("{}{}", b.name, config.proxy.separator))
795 .collect();
796
797 if !forward_namespaces.is_empty() {
798 tracing::info!(
799 backends = ?forward_namespaces,
800 "Enabling token passthrough for forward_auth backends"
801 );
802 service = BoxCloneService::new(crate::token::TokenPassthroughService::new(
803 service,
804 forward_namespaces,
805 ));
806 }
807 }
808
809 #[cfg(feature = "metrics")]
811 if config.observability.metrics.enabled {
812 service = BoxCloneService::new(crate::metrics::MetricsService::new(service));
813 }
814
815 if config.observability.access_log.enabled {
817 tracing::info!("Access logging enabled (target: mcp::access)");
818 service = BoxCloneService::new(crate::access_log::AccessLogService::new(
819 service,
820 &config.proxy.separator,
821 ));
822 }
823
824 if config.observability.audit {
826 tracing::info!("Audit logging enabled (target: mcp::audit)");
827 let audited = tower::Layer::layer(&tower_mcp::AuditLayer::new(), service);
828 service = BoxCloneService::new(tower_mcp::CatchError::new(audited));
829 }
830
831 if let Some(ref rl) = config.proxy.rate_limit {
833 tracing::info!(
834 requests = rl.requests,
835 period_seconds = rl.period_seconds,
836 "Applying global rate limit"
837 );
838 let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
839 .limit_for_period(rl.requests)
840 .refresh_period(Duration::from_secs(rl.period_seconds))
841 .name("global-ratelimit")
842 .build();
843 let limited = tower::Layer::layer(&layer, service);
844 service = BoxCloneService::new(tower_mcp::CatchError::new(limited));
845 }
846
847 Ok((service, cache_handle))
848}
849
850async fn apply_auth(config: &ProxyConfig, router: Router) -> Result<Router> {
852 let router = if let Some(auth) = &config.auth {
853 match auth {
854 AuthConfig::Bearer {
855 tokens,
856 scoped_tokens,
857 } => {
858 let total = tokens.len() + scoped_tokens.len();
859 if scoped_tokens.is_empty() {
860 tracing::info!(token_count = total, "Enabling bearer token auth");
862 let validator = StaticBearerValidator::new(tokens.iter().cloned());
863 let layer = AuthLayer::new(validator);
864 router.layer(layer)
865 } else {
866 #[cfg(feature = "oauth")]
868 {
869 tracing::info!(
870 token_count = total,
871 scoped = scoped_tokens.len(),
872 "Enabling bearer token auth with per-token scoping"
873 );
874 let layer =
875 crate::bearer_scope::ScopedBearerAuthLayer::new(tokens, scoped_tokens);
876 router.layer(layer)
877 }
878 #[cfg(not(feature = "oauth"))]
879 {
880 anyhow::bail!(
881 "Per-token tool scoping requires the 'oauth' feature. \
882 Rebuild with: cargo install mcp-proxy --features oauth"
883 );
884 }
885 }
886 }
887 #[cfg(feature = "oauth")]
888 AuthConfig::Jwt {
889 issuer,
890 audience,
891 jwks_uri,
892 ..
893 } => {
894 tracing::info!(
895 issuer = %issuer,
896 audience = %audience,
897 jwks_uri = %jwks_uri,
898 "Enabling JWT auth (JWKS)"
899 );
900 let validator = tower_mcp::oauth::JwksValidator::builder(jwks_uri)
901 .expected_audience(audience)
902 .expected_issuer(issuer)
903 .build()
904 .await
905 .context("building JWKS validator")?;
906
907 let addr = format!(
908 "http://{}:{}",
909 config.proxy.listen.host, config.proxy.listen.port
910 );
911 let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
912 .authorization_server(issuer);
913
914 let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
915 router.layer(layer)
916 }
917 #[cfg(not(feature = "oauth"))]
918 AuthConfig::Jwt { .. } => {
919 anyhow::bail!(
920 "JWT auth requires the 'oauth' feature. Rebuild with: cargo install mcp-proxy --features oauth"
921 );
922 }
923 #[cfg(feature = "oauth")]
924 AuthConfig::OAuth {
925 issuer,
926 audience,
927 token_validation,
928 jwks_uri,
929 introspection_endpoint,
930 client_id,
931 client_secret,
932 ..
933 } => {
934 use crate::config::TokenValidationStrategy;
935
936 tracing::info!(
937 issuer = %issuer,
938 audience = %audience,
939 strategy = ?token_validation,
940 "Enabling OAuth 2.1 auth"
941 );
942
943 let discovered = crate::introspection::discover_auth_server(issuer)
945 .await
946 .context("discovering OAuth authorization server")?;
947
948 let effective_jwks_uri = jwks_uri
949 .as_deref()
950 .or(discovered.jwks_uri.as_deref())
951 .ok_or_else(|| {
952 anyhow::anyhow!(
953 "JWKS URI not found via discovery and not configured manually"
954 )
955 })?;
956
957 let effective_introspection = introspection_endpoint
958 .as_deref()
959 .or(discovered.introspection_endpoint.as_deref());
960
961 let addr = format!(
962 "http://{}:{}",
963 config.proxy.listen.host, config.proxy.listen.port
964 );
965 let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
966 .authorization_server(issuer);
967
968 match token_validation {
969 TokenValidationStrategy::Jwt => {
970 let validator =
971 tower_mcp::oauth::JwksValidator::builder(effective_jwks_uri)
972 .expected_audience(audience)
973 .expected_issuer(issuer)
974 .build()
975 .await
976 .context("building JWKS validator")?;
977 let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
978 router.layer(layer)
979 }
980 TokenValidationStrategy::Introspection => {
981 let endpoint = effective_introspection.ok_or_else(|| {
982 anyhow::anyhow!(
983 "introspection endpoint not found via discovery and not configured"
984 )
985 })?;
986 let validator = crate::introspection::IntrospectionValidator::new(
987 endpoint,
988 client_id.as_deref().unwrap(),
989 client_secret.as_deref().unwrap(),
990 )
991 .expected_audience(audience);
992 let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
993 router.layer(layer)
994 }
995 TokenValidationStrategy::Both => {
996 let endpoint = effective_introspection.ok_or_else(|| {
997 anyhow::anyhow!(
998 "introspection endpoint not found via discovery and not configured"
999 )
1000 })?;
1001 let jwt_validator =
1002 tower_mcp::oauth::JwksValidator::builder(effective_jwks_uri)
1003 .expected_audience(audience)
1004 .expected_issuer(issuer)
1005 .build()
1006 .await
1007 .context("building JWKS validator")?;
1008 let introspection_validator =
1009 crate::introspection::IntrospectionValidator::new(
1010 endpoint,
1011 client_id.as_deref().unwrap(),
1012 client_secret.as_deref().unwrap(),
1013 )
1014 .expected_audience(audience);
1015 let fallback = crate::introspection::FallbackValidator::new(
1016 jwt_validator,
1017 introspection_validator,
1018 );
1019 let layer = tower_mcp::oauth::OAuthLayer::new(fallback, metadata);
1020 router.layer(layer)
1021 }
1022 }
1023 }
1024 #[cfg(not(feature = "oauth"))]
1025 AuthConfig::OAuth { .. } => {
1026 anyhow::bail!(
1027 "OAuth auth requires the 'oauth' feature. Rebuild with: cargo install mcp-proxy --features oauth"
1028 );
1029 }
1030 }
1031 } else {
1032 router
1033 };
1034 Ok(router)
1035}
1036
1037pub async fn shutdown_signal(timeout: Duration) {
1039 let ctrl_c = tokio::signal::ctrl_c();
1040 #[cfg(unix)]
1041 {
1042 let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
1043 .expect("SIGTERM handler");
1044 tokio::select! {
1045 _ = ctrl_c => {},
1046 _ = sigterm.recv() => {},
1047 }
1048 }
1049 #[cfg(not(unix))]
1050 {
1051 ctrl_c.await.ok();
1052 }
1053 tracing::info!(
1054 timeout_seconds = timeout.as_secs(),
1055 "Shutdown signal received, draining connections"
1056 );
1057}
1058
1059#[cfg(all(test, feature = "oauth"))]
1060mod scope_enforcement_tests {
1061 use std::collections::HashMap;
1062
1063 use tower::{Layer, Service};
1064 use tower_mcp::oauth::token::TokenClaims;
1065 use tower_mcp::protocol::{CallToolParams, McpRequest, RequestId};
1066 use tower_mcp::router::Extensions;
1067
1068 use super::oauth_scope_layer;
1069 use crate::test_util::MockService;
1070
1071 fn call_with_scope(scope: Option<&str>) -> tower_mcp::RouterRequest {
1073 let mut extensions = Extensions::new();
1074 if let Some(scope) = scope {
1075 extensions.insert(TokenClaims {
1076 sub: Some("user".into()),
1077 iss: None,
1078 aud: None,
1079 exp: None,
1080 scope: Some(scope.to_string()),
1081 client_id: None,
1082 extra: HashMap::new(),
1083 });
1084 }
1085 tower_mcp::RouterRequest {
1086 id: RequestId::Number(1),
1087 inner: McpRequest::CallTool(CallToolParams {
1088 name: "fs/read".into(),
1089 arguments: serde_json::json!({}),
1090 meta: None,
1091 task: None,
1092 }),
1093 extensions,
1094 }
1095 }
1096
1097 #[test]
1098 fn no_required_scopes_yields_no_layer() {
1099 assert!(oauth_scope_layer(&[]).is_none());
1100 }
1101
1102 #[tokio::test]
1103 async fn token_missing_required_scope_is_rejected() {
1104 let required = vec!["mcp:access".to_string()];
1105 let layer = oauth_scope_layer(&required).expect("layer for non-empty scopes");
1106 let mut svc = layer.layer(MockService::with_tools(&["fs/read"]));
1107
1108 let resp = svc
1110 .call(call_with_scope(Some("other:scope")))
1111 .await
1112 .unwrap();
1113 let err = resp.inner.unwrap_err();
1114 assert!(
1115 err.message.to_lowercase().contains("scope"),
1116 "expected insufficient-scope error, got: {}",
1117 err.message
1118 );
1119 }
1120
1121 #[tokio::test]
1122 async fn token_with_all_required_scopes_is_allowed() {
1123 let required = vec!["mcp:access".to_string(), "mcp:read".to_string()];
1124 let layer = oauth_scope_layer(&required).expect("layer for non-empty scopes");
1125 let mut svc = layer.layer(MockService::with_tools(&["fs/read"]));
1126
1127 let resp = svc
1128 .call(call_with_scope(Some("mcp:access mcp:read mcp:extra")))
1129 .await
1130 .unwrap();
1131 assert!(
1132 resp.inner.is_ok(),
1133 "token carrying all required scopes should be allowed"
1134 );
1135 }
1136}