1use std::convert::Infallible;
4use std::time::Duration;
5
6use anyhow::{Context, Result};
7use axum::Router;
8use tokio::process::Command;
9use tower::timeout::TimeoutLayer;
10use tower::util::BoxCloneService;
11use tower_mcp::SessionHandle;
12use tower_mcp::auth::{AuthLayer, StaticBearerValidator};
13use tower_mcp::client::StdioClientTransport;
14use tower_mcp::proxy::McpProxy;
15use tower_mcp::{RouterRequest, RouterResponse};
16
17use crate::admin::BackendMeta;
18use crate::alias;
19use crate::cache;
20use crate::coalesce;
21use crate::config::{AuthConfig, ProxyConfig, TransportType};
22use crate::filter::CapabilityFilterService;
23#[cfg(feature = "oauth")]
24use crate::rbac::{RbacConfig, RbacService};
25use crate::validation::{ValidationConfig, ValidationService};
26
27pub struct Proxy {
29 router: Router,
30 session_handle: SessionHandle,
31 inner: McpProxy,
32 config: ProxyConfig,
33}
34
35impl Proxy {
36 pub async fn from_config(config: ProxyConfig) -> Result<Self> {
42 let mcp_proxy = build_mcp_proxy(&config).await?;
43 let proxy_for_admin = mcp_proxy.clone();
44 let proxy_for_caller = mcp_proxy.clone();
45
46 #[cfg(feature = "metrics")]
48 let metrics_handle = if config.observability.metrics.enabled {
49 tracing::info!("Prometheus metrics enabled at /admin/metrics");
50 let builder = metrics_exporter_prometheus::PrometheusBuilder::new();
51 let handle = builder
52 .install_recorder()
53 .context("installing Prometheus metrics recorder")?;
54 Some(handle)
55 } else {
56 None
57 };
58 #[cfg(not(feature = "metrics"))]
59 let metrics_handle = None;
60
61 let (service, cache_handle) = build_middleware_stack(&config, mcp_proxy)?;
62
63 let (router, session_handle) =
64 tower_mcp::transport::http::HttpTransport::from_service(service)
65 .into_router_with_handle();
66
67 let router = apply_auth(&config, router).await?;
69
70 let backend_meta: std::collections::HashMap<String, BackendMeta> = config
72 .backends
73 .iter()
74 .map(|b| {
75 (
76 b.name.clone(),
77 BackendMeta {
78 transport: format!("{:?}", b.transport).to_lowercase(),
79 },
80 )
81 })
82 .collect();
83
84 let admin_state = crate::admin::spawn_health_checker(
86 proxy_for_admin,
87 config.proxy.name.clone(),
88 config.proxy.version.clone(),
89 config.backends.len(),
90 backend_meta,
91 );
92 let router = router.nest(
93 "/admin",
94 crate::admin::admin_router(
95 admin_state.clone(),
96 metrics_handle,
97 session_handle.clone(),
98 cache_handle,
99 ),
100 );
101 tracing::info!("Admin API enabled at /admin/backends");
102
103 if let Err(e) = crate::admin_tools::register_admin_tools(
105 &proxy_for_caller,
106 admin_state,
107 session_handle.clone(),
108 &config,
109 )
110 .await
111 {
112 tracing::warn!("Failed to register admin tools: {e}");
113 } else {
114 tracing::info!("MCP admin tools registered under proxy/ namespace");
115 }
116
117 Ok(Self {
118 router,
119 session_handle,
120 inner: proxy_for_caller,
121 config,
122 })
123 }
124
125 pub fn session_handle(&self) -> &SessionHandle {
127 &self.session_handle
128 }
129
130 pub fn mcp_proxy(&self) -> &McpProxy {
134 &self.inner
135 }
136
137 pub fn enable_hot_reload(&self, config_path: std::path::PathBuf) {
142 tracing::info!("Hot reload enabled, watching config file for changes");
143 crate::reload::spawn_config_watcher(config_path, self.inner.clone());
144 }
145
146 pub fn into_router(self) -> (Router, SessionHandle) {
158 (self.router, self.session_handle)
159 }
160
161 pub async fn serve(self) -> Result<()> {
166 let addr = format!(
167 "{}:{}",
168 self.config.proxy.listen.host, self.config.proxy.listen.port
169 );
170
171 tracing::info!(listen = %addr, "Proxy ready");
172
173 let listener = tokio::net::TcpListener::bind(&addr)
174 .await
175 .with_context(|| format!("binding to {}", addr))?;
176
177 let shutdown_timeout = Duration::from_secs(self.config.proxy.shutdown_timeout_seconds);
178 axum::serve(listener, self.router)
179 .with_graceful_shutdown(shutdown_signal(shutdown_timeout))
180 .await
181 .context("server error")?;
182
183 tracing::info!("Proxy shut down");
184 Ok(())
185 }
186}
187
188async fn build_mcp_proxy(config: &ProxyConfig) -> Result<McpProxy> {
190 let mut builder = McpProxy::builder(&config.proxy.name, &config.proxy.version)
191 .separator(&config.proxy.separator);
192
193 if let Some(instructions) = &config.proxy.instructions {
194 builder = builder.instructions(instructions);
195 }
196
197 let outlier_detector = {
200 let max_pct = config
201 .backends
202 .iter()
203 .filter_map(|b| b.outlier_detection.as_ref())
204 .map(|od| od.max_ejection_percent)
205 .max();
206 max_pct.map(crate::outlier::OutlierDetector::new)
207 };
208
209 for backend in &config.backends {
210 tracing::info!(name = %backend.name, transport = ?backend.transport, "Adding backend");
211
212 match backend.transport {
213 TransportType::Stdio => {
214 let command = backend.command.as_deref().unwrap();
215 let args: Vec<&str> = backend.args.iter().map(|s| s.as_str()).collect();
216
217 let mut cmd = Command::new(command);
218 cmd.args(&args);
219
220 for (key, value) in &backend.env {
221 cmd.env(key, value);
222 }
223
224 let transport = StdioClientTransport::spawn_command(&mut cmd)
225 .await
226 .with_context(|| format!("spawning backend '{}'", backend.name))?;
227
228 builder = builder.backend(&backend.name, transport).await;
229 }
230 TransportType::Http => {
231 let url = backend.url.as_deref().unwrap();
232 let mut transport = tower_mcp::client::HttpClientTransport::new(url);
233 if let Some(token) = &backend.bearer_token {
234 transport = transport.bearer_token(token);
235 }
236
237 builder = builder.backend(&backend.name, transport).await;
238 }
239 }
240
241 if let Some(retry_cfg) = &backend.retry {
245 tracing::info!(
246 backend = %backend.name,
247 max_retries = retry_cfg.max_retries,
248 initial_backoff_ms = retry_cfg.initial_backoff_ms,
249 max_backoff_ms = retry_cfg.max_backoff_ms,
250 "Applying retry policy"
251 );
252 let layer = crate::retry::build_retry_layer(retry_cfg, &backend.name);
253 builder = builder.backend_layer(layer);
254 }
255
256 if let Some(hedge_cfg) = &backend.hedging {
258 let delay = Duration::from_millis(hedge_cfg.delay_ms);
259 let max_attempts = hedge_cfg.max_hedges + 1; tracing::info!(
261 backend = %backend.name,
262 delay_ms = hedge_cfg.delay_ms,
263 max_hedges = hedge_cfg.max_hedges,
264 "Applying request hedging"
265 );
266 let layer = if delay.is_zero() {
267 tower_resilience::hedge::HedgeLayer::builder()
268 .no_delay()
269 .max_hedged_attempts(max_attempts)
270 .name(format!("{}-hedge", backend.name))
271 .build()
272 } else {
273 tower_resilience::hedge::HedgeLayer::builder()
274 .delay(delay)
275 .max_hedged_attempts(max_attempts)
276 .name(format!("{}-hedge", backend.name))
277 .build()
278 };
279 builder = builder.backend_layer(layer);
280 }
281
282 if let Some(cc) = &backend.concurrency {
284 tracing::info!(
285 backend = %backend.name,
286 max = cc.max_concurrent,
287 "Applying concurrency limit"
288 );
289 builder =
290 builder.backend_layer(tower::limit::ConcurrencyLimitLayer::new(cc.max_concurrent));
291 }
292
293 if let Some(rl) = &backend.rate_limit {
295 tracing::info!(
296 backend = %backend.name,
297 requests = rl.requests,
298 period_seconds = rl.period_seconds,
299 "Applying rate limit"
300 );
301 let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
302 .limit_for_period(rl.requests)
303 .refresh_period(Duration::from_secs(rl.period_seconds))
304 .name(format!("{}-ratelimit", backend.name))
305 .build();
306 builder = builder.backend_layer(layer);
307 }
308
309 if let Some(timeout) = &backend.timeout {
311 tracing::info!(
312 backend = %backend.name,
313 seconds = timeout.seconds,
314 "Applying timeout"
315 );
316 builder =
317 builder.backend_layer(TimeoutLayer::new(Duration::from_secs(timeout.seconds)));
318 }
319
320 if let Some(cb) = &backend.circuit_breaker {
322 tracing::info!(
323 backend = %backend.name,
324 failure_rate = cb.failure_rate_threshold,
325 wait_seconds = cb.wait_duration_seconds,
326 "Applying circuit breaker"
327 );
328 let layer = tower_resilience::circuitbreaker::CircuitBreakerLayer::builder()
329 .failure_rate_threshold(cb.failure_rate_threshold)
330 .minimum_number_of_calls(cb.minimum_calls)
331 .wait_duration_in_open(Duration::from_secs(cb.wait_duration_seconds))
332 .permitted_calls_in_half_open(cb.permitted_calls_in_half_open)
333 .name(format!("{}-cb", backend.name))
334 .build();
335 builder = builder.backend_layer(layer);
336 }
337
338 if let Some(od) = &backend.outlier_detection
340 && let Some(ref detector) = outlier_detector
341 {
342 tracing::info!(
343 backend = %backend.name,
344 consecutive_errors = od.consecutive_errors,
345 base_ejection_seconds = od.base_ejection_seconds,
346 max_ejection_percent = od.max_ejection_percent,
347 "Applying outlier detection"
348 );
349 let layer = crate::outlier::OutlierDetectionLayer::new(
350 backend.name.clone(),
351 od.clone(),
352 detector.clone(),
353 );
354 builder = builder.backend_layer(layer);
355 }
356 }
357
358 let result = builder.build().await?;
359
360 if !result.skipped.is_empty() {
361 for s in &result.skipped {
362 tracing::warn!("Skipped backend: {s}");
363 }
364 }
365
366 Ok(result.proxy)
367}
368
369fn build_middleware_stack(
371 config: &ProxyConfig,
372 proxy: McpProxy,
373) -> Result<(
374 BoxCloneService<RouterRequest, RouterResponse, Infallible>,
375 Option<cache::CacheHandle>,
376)> {
377 let mut service: BoxCloneService<RouterRequest, RouterResponse, Infallible> =
378 BoxCloneService::new(proxy);
379 let mut cache_handle: Option<cache::CacheHandle> = None;
380
381 let injection_rules: Vec<_> = config
383 .backends
384 .iter()
385 .filter(|b| !b.default_args.is_empty() || !b.inject_args.is_empty())
386 .map(|b| {
387 let namespace = format!("{}{}", b.name, config.proxy.separator);
388 tracing::info!(
389 backend = %b.name,
390 default_args = b.default_args.len(),
391 tool_rules = b.inject_args.len(),
392 "Applying argument injection"
393 );
394 crate::inject::InjectionRules::new(
395 namespace,
396 b.default_args.clone(),
397 b.inject_args.clone(),
398 )
399 })
400 .collect();
401
402 if !injection_rules.is_empty() {
403 service = BoxCloneService::new(crate::inject::InjectArgsService::new(
404 service,
405 injection_rules,
406 ));
407 }
408
409 let canary_mappings: std::collections::HashMap<String, (String, u32, u32)> = config
411 .backends
412 .iter()
413 .filter_map(|b| {
414 b.canary_of.as_ref().map(|primary_name| {
415 let primary_weight = config
417 .backends
418 .iter()
419 .find(|p| p.name == *primary_name)
420 .map(|p| p.weight)
421 .unwrap_or(100);
422 (
423 primary_name.clone(),
424 (b.name.clone(), primary_weight, b.weight),
425 )
426 })
427 })
428 .collect();
429
430 if !canary_mappings.is_empty() {
431 for (primary, (canary, pw, cw)) in &canary_mappings {
432 tracing::info!(
433 primary = %primary,
434 canary = %canary,
435 primary_weight = pw,
436 canary_weight = cw,
437 "Enabling canary routing"
438 );
439 }
440 service = BoxCloneService::new(crate::canary::CanaryService::new(
441 service,
442 canary_mappings,
443 &config.proxy.separator,
444 ));
445 }
446
447 let mirror_mappings: std::collections::HashMap<String, (String, u32)> = config
449 .backends
450 .iter()
451 .filter_map(|b| {
452 b.mirror_of
453 .as_ref()
454 .map(|source| (source.clone(), (b.name.clone(), b.mirror_percent)))
455 })
456 .collect();
457
458 if !mirror_mappings.is_empty() {
459 for (source, (mirror, pct)) in &mirror_mappings {
460 tracing::info!(
461 source = %source,
462 mirror = %mirror,
463 percent = pct,
464 "Enabling traffic mirroring"
465 );
466 }
467 service = BoxCloneService::new(crate::mirror::MirrorService::new(
468 service,
469 mirror_mappings,
470 &config.proxy.separator,
471 ));
472 }
473
474 let cache_configs: Vec<_> = config
476 .backends
477 .iter()
478 .filter_map(|b| {
479 b.cache
480 .as_ref()
481 .map(|c| (format!("{}{}", b.name, config.proxy.separator), c))
482 })
483 .collect();
484
485 if !cache_configs.is_empty() {
486 for (ns, cfg) in &cache_configs {
487 tracing::info!(
488 backend = %ns.trim_end_matches(&config.proxy.separator),
489 resource_ttl = cfg.resource_ttl_seconds,
490 tool_ttl = cfg.tool_ttl_seconds,
491 max_entries = cfg.max_entries,
492 "Applying response cache"
493 );
494 }
495 let (cache_svc, handle) = cache::CacheService::new(service, cache_configs);
496 service = BoxCloneService::new(cache_svc);
497 cache_handle = Some(handle);
498 }
499
500 if config.performance.coalesce_requests {
502 tracing::info!("Request coalescing enabled");
503 service = BoxCloneService::new(coalesce::CoalesceService::new(service));
504 }
505
506 if config.security.max_argument_size.is_some() {
508 let validation = ValidationConfig {
509 max_argument_size: config.security.max_argument_size,
510 };
511 if let Some(max) = validation.max_argument_size {
512 tracing::info!(max_argument_size = max, "Applying request validation");
513 }
514 service = BoxCloneService::new(ValidationService::new(service, validation));
515 }
516
517 let filters: Vec<_> = config
519 .backends
520 .iter()
521 .filter_map(|b| b.build_filter(&config.proxy.separator))
522 .collect();
523
524 if !filters.is_empty() {
525 for f in &filters {
526 tracing::info!(
527 backend = %f.namespace.trim_end_matches(&config.proxy.separator),
528 tool_filter = ?f.tool_filter,
529 resource_filter = ?f.resource_filter,
530 prompt_filter = ?f.prompt_filter,
531 "Applying capability filter"
532 );
533 }
534 service = BoxCloneService::new(CapabilityFilterService::new(service, filters));
535 }
536
537 let alias_mappings: Vec<_> = config
539 .backends
540 .iter()
541 .flat_map(|b| {
542 let ns = format!("{}{}", b.name, config.proxy.separator);
543 b.aliases
544 .iter()
545 .map(move |a| (ns.clone(), a.from.clone(), a.to.clone()))
546 })
547 .collect();
548
549 if let Some(alias_map) = alias::AliasMap::new(alias_mappings) {
550 let count = alias_map.forward.len();
551 tracing::info!(aliases = count, "Applying tool aliases");
552 service = BoxCloneService::new(alias::AliasService::new(service, alias_map));
553 }
554
555 #[cfg(feature = "oauth")]
557 {
558 let rbac_config = match &config.auth {
559 Some(AuthConfig::Jwt {
560 roles,
561 role_mapping: Some(mapping),
562 ..
563 }) if !roles.is_empty() => {
564 tracing::info!(
565 roles = roles.len(),
566 claim = %mapping.claim,
567 "Enabling RBAC"
568 );
569 Some(RbacConfig::new(roles, mapping))
570 }
571 _ => None,
572 };
573
574 if let Some(rbac) = rbac_config {
575 service = BoxCloneService::new(RbacService::new(service, rbac));
576 }
577
578 let forward_namespaces: std::collections::HashSet<String> = config
580 .backends
581 .iter()
582 .filter(|b| b.forward_auth)
583 .map(|b| format!("{}{}", b.name, config.proxy.separator))
584 .collect();
585
586 if !forward_namespaces.is_empty() {
587 tracing::info!(
588 backends = ?forward_namespaces,
589 "Enabling token passthrough for forward_auth backends"
590 );
591 service = BoxCloneService::new(crate::token::TokenPassthroughService::new(
592 service,
593 forward_namespaces,
594 ));
595 }
596 }
597
598 #[cfg(feature = "metrics")]
600 if config.observability.metrics.enabled {
601 service = BoxCloneService::new(crate::metrics::MetricsService::new(service));
602 }
603
604 if config.observability.audit {
606 tracing::info!("Audit logging enabled (target: mcp::audit)");
607 let audited = tower::Layer::layer(&tower_mcp::AuditLayer::new(), service);
608 service = BoxCloneService::new(tower_mcp::CatchError::new(audited));
609 }
610
611 Ok((service, cache_handle))
612}
613
614async fn apply_auth(config: &ProxyConfig, router: Router) -> Result<Router> {
616 let router = if let Some(auth) = &config.auth {
617 match auth {
618 AuthConfig::Bearer { tokens } => {
619 tracing::info!(token_count = tokens.len(), "Enabling bearer token auth");
620 let validator = StaticBearerValidator::new(tokens.iter().cloned());
621 let layer = AuthLayer::new(validator);
622 router.layer(layer)
623 }
624 #[cfg(feature = "oauth")]
625 AuthConfig::Jwt {
626 issuer,
627 audience,
628 jwks_uri,
629 ..
630 } => {
631 tracing::info!(
632 issuer = %issuer,
633 audience = %audience,
634 jwks_uri = %jwks_uri,
635 "Enabling JWT auth (JWKS)"
636 );
637 let validator = tower_mcp::oauth::JwksValidator::builder(jwks_uri)
638 .expected_audience(audience)
639 .expected_issuer(issuer)
640 .build()
641 .await
642 .context("building JWKS validator")?;
643
644 let addr = format!(
645 "http://{}:{}",
646 config.proxy.listen.host, config.proxy.listen.port
647 );
648 let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
649 .authorization_server(issuer);
650
651 let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
652 router.layer(layer)
653 }
654 #[cfg(not(feature = "oauth"))]
655 AuthConfig::Jwt { .. } => {
656 anyhow::bail!(
657 "JWT auth requires the 'oauth' feature. Rebuild with: cargo install mcp-proxy --features oauth"
658 );
659 }
660 }
661 } else {
662 router
663 };
664 Ok(router)
665}
666
667pub async fn shutdown_signal(timeout: Duration) {
669 let ctrl_c = tokio::signal::ctrl_c();
670 #[cfg(unix)]
671 {
672 let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
673 .expect("SIGTERM handler");
674 tokio::select! {
675 _ = ctrl_c => {},
676 _ = sigterm.recv() => {},
677 }
678 }
679 #[cfg(not(unix))]
680 {
681 ctrl_c.await.ok();
682 }
683 tracing::info!(
684 timeout_seconds = timeout.as_secs(),
685 "Shutdown signal received, draining connections"
686 );
687}