1use std::convert::Infallible;
4use std::time::Duration;
5
6use anyhow::{Context, Result};
7use axum::Router;
8use tokio::process::Command;
9use tower::timeout::TimeoutLayer;
10use tower::util::BoxCloneService;
11use tower_mcp::SessionHandle;
12use tower_mcp::auth::{AuthLayer, StaticBearerValidator};
13use tower_mcp::client::StdioClientTransport;
14use tower_mcp::proxy::McpProxy;
15use tower_mcp::{RouterRequest, RouterResponse};
16
17use crate::admin::BackendMeta;
18use crate::alias;
19use crate::cache;
20use crate::coalesce;
21use crate::config::{AuthConfig, ProxyConfig, TransportType};
22use crate::filter::CapabilityFilterService;
23use crate::metrics;
24use crate::rbac::{RbacConfig, RbacService};
25use crate::validation::{ValidationConfig, ValidationService};
26
27pub struct Proxy {
29 router: Router,
30 session_handle: SessionHandle,
31 inner: McpProxy,
32 config: ProxyConfig,
33}
34
35impl Proxy {
36 pub async fn from_config(config: ProxyConfig) -> Result<Self> {
42 let mcp_proxy = build_mcp_proxy(&config).await?;
43 let proxy_for_admin = mcp_proxy.clone();
44 let proxy_for_caller = mcp_proxy.clone();
45
46 let metrics_handle = if config.observability.metrics.enabled {
48 tracing::info!("Prometheus metrics enabled at /admin/metrics");
49 let builder = metrics_exporter_prometheus::PrometheusBuilder::new();
50 let handle = builder
51 .install_recorder()
52 .context("installing Prometheus metrics recorder")?;
53 Some(handle)
54 } else {
55 None
56 };
57
58 let (service, cache_handle) = build_middleware_stack(&config, mcp_proxy)?;
59
60 let (router, session_handle) =
61 tower_mcp::transport::http::HttpTransport::from_service(service)
62 .into_router_with_handle();
63
64 let router = apply_auth(&config, router).await?;
66
67 let backend_meta: std::collections::HashMap<String, BackendMeta> = config
69 .backends
70 .iter()
71 .map(|b| {
72 (
73 b.name.clone(),
74 BackendMeta {
75 transport: format!("{:?}", b.transport).to_lowercase(),
76 },
77 )
78 })
79 .collect();
80
81 let admin_state = crate::admin::spawn_health_checker(
83 proxy_for_admin,
84 config.proxy.name.clone(),
85 config.proxy.version.clone(),
86 config.backends.len(),
87 backend_meta,
88 );
89 let router = router.nest(
90 "/admin",
91 crate::admin::admin_router(
92 admin_state.clone(),
93 metrics_handle,
94 session_handle.clone(),
95 cache_handle,
96 ),
97 );
98 tracing::info!("Admin API enabled at /admin/backends");
99
100 if let Err(e) = crate::admin_tools::register_admin_tools(
102 &proxy_for_caller,
103 admin_state,
104 session_handle.clone(),
105 &config,
106 )
107 .await
108 {
109 tracing::warn!("Failed to register admin tools: {e}");
110 } else {
111 tracing::info!("MCP admin tools registered under proxy/ namespace");
112 }
113
114 Ok(Self {
115 router,
116 session_handle,
117 inner: proxy_for_caller,
118 config,
119 })
120 }
121
122 pub fn session_handle(&self) -> &SessionHandle {
124 &self.session_handle
125 }
126
127 pub fn mcp_proxy(&self) -> &McpProxy {
131 &self.inner
132 }
133
134 pub fn enable_hot_reload(&self, config_path: std::path::PathBuf) {
139 tracing::info!("Hot reload enabled, watching config file for changes");
140 crate::reload::spawn_config_watcher(config_path, self.inner.clone());
141 }
142
143 pub fn into_router(self) -> (Router, SessionHandle) {
155 (self.router, self.session_handle)
156 }
157
158 pub async fn serve(self) -> Result<()> {
163 let addr = format!(
164 "{}:{}",
165 self.config.proxy.listen.host, self.config.proxy.listen.port
166 );
167
168 tracing::info!(listen = %addr, "Proxy ready");
169
170 let listener = tokio::net::TcpListener::bind(&addr)
171 .await
172 .with_context(|| format!("binding to {}", addr))?;
173
174 let shutdown_timeout = Duration::from_secs(self.config.proxy.shutdown_timeout_seconds);
175 axum::serve(listener, self.router)
176 .with_graceful_shutdown(shutdown_signal(shutdown_timeout))
177 .await
178 .context("server error")?;
179
180 tracing::info!("Proxy shut down");
181 Ok(())
182 }
183}
184
185async fn build_mcp_proxy(config: &ProxyConfig) -> Result<McpProxy> {
187 let mut builder = McpProxy::builder(&config.proxy.name, &config.proxy.version)
188 .separator(&config.proxy.separator);
189
190 if let Some(instructions) = &config.proxy.instructions {
191 builder = builder.instructions(instructions);
192 }
193
194 let outlier_detector = {
197 let max_pct = config
198 .backends
199 .iter()
200 .filter_map(|b| b.outlier_detection.as_ref())
201 .map(|od| od.max_ejection_percent)
202 .max();
203 max_pct.map(crate::outlier::OutlierDetector::new)
204 };
205
206 for backend in &config.backends {
207 tracing::info!(name = %backend.name, transport = ?backend.transport, "Adding backend");
208
209 match backend.transport {
210 TransportType::Stdio => {
211 let command = backend.command.as_deref().unwrap();
212 let args: Vec<&str> = backend.args.iter().map(|s| s.as_str()).collect();
213
214 let mut cmd = Command::new(command);
215 cmd.args(&args);
216
217 for (key, value) in &backend.env {
218 cmd.env(key, value);
219 }
220
221 let transport = StdioClientTransport::spawn_command(&mut cmd)
222 .await
223 .with_context(|| format!("spawning backend '{}'", backend.name))?;
224
225 builder = builder.backend(&backend.name, transport).await;
226 }
227 TransportType::Http => {
228 let url = backend.url.as_deref().unwrap();
229 let mut transport = tower_mcp::client::HttpClientTransport::new(url);
230 if let Some(token) = &backend.bearer_token {
231 transport = transport.bearer_token(token);
232 }
233
234 builder = builder.backend(&backend.name, transport).await;
235 }
236 }
237
238 if let Some(retry_cfg) = &backend.retry {
242 tracing::info!(
243 backend = %backend.name,
244 max_retries = retry_cfg.max_retries,
245 initial_backoff_ms = retry_cfg.initial_backoff_ms,
246 max_backoff_ms = retry_cfg.max_backoff_ms,
247 "Applying retry policy"
248 );
249 let layer = crate::retry::build_retry_layer(retry_cfg, &backend.name);
250 builder = builder.backend_layer(layer);
251 }
252
253 if let Some(hedge_cfg) = &backend.hedging {
255 let delay = Duration::from_millis(hedge_cfg.delay_ms);
256 let max_attempts = hedge_cfg.max_hedges + 1; tracing::info!(
258 backend = %backend.name,
259 delay_ms = hedge_cfg.delay_ms,
260 max_hedges = hedge_cfg.max_hedges,
261 "Applying request hedging"
262 );
263 let layer = if delay.is_zero() {
264 tower_resilience::hedge::HedgeLayer::builder()
265 .no_delay()
266 .max_hedged_attempts(max_attempts)
267 .name(format!("{}-hedge", backend.name))
268 .build()
269 } else {
270 tower_resilience::hedge::HedgeLayer::builder()
271 .delay(delay)
272 .max_hedged_attempts(max_attempts)
273 .name(format!("{}-hedge", backend.name))
274 .build()
275 };
276 builder = builder.backend_layer(layer);
277 }
278
279 if let Some(cc) = &backend.concurrency {
281 tracing::info!(
282 backend = %backend.name,
283 max = cc.max_concurrent,
284 "Applying concurrency limit"
285 );
286 builder =
287 builder.backend_layer(tower::limit::ConcurrencyLimitLayer::new(cc.max_concurrent));
288 }
289
290 if let Some(rl) = &backend.rate_limit {
292 tracing::info!(
293 backend = %backend.name,
294 requests = rl.requests,
295 period_seconds = rl.period_seconds,
296 "Applying rate limit"
297 );
298 let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
299 .limit_for_period(rl.requests)
300 .refresh_period(Duration::from_secs(rl.period_seconds))
301 .name(format!("{}-ratelimit", backend.name))
302 .build();
303 builder = builder.backend_layer(layer);
304 }
305
306 if let Some(timeout) = &backend.timeout {
308 tracing::info!(
309 backend = %backend.name,
310 seconds = timeout.seconds,
311 "Applying timeout"
312 );
313 builder =
314 builder.backend_layer(TimeoutLayer::new(Duration::from_secs(timeout.seconds)));
315 }
316
317 if let Some(cb) = &backend.circuit_breaker {
319 tracing::info!(
320 backend = %backend.name,
321 failure_rate = cb.failure_rate_threshold,
322 wait_seconds = cb.wait_duration_seconds,
323 "Applying circuit breaker"
324 );
325 let layer = tower_resilience::circuitbreaker::CircuitBreakerLayer::builder()
326 .failure_rate_threshold(cb.failure_rate_threshold)
327 .minimum_number_of_calls(cb.minimum_calls)
328 .wait_duration_in_open(Duration::from_secs(cb.wait_duration_seconds))
329 .permitted_calls_in_half_open(cb.permitted_calls_in_half_open)
330 .name(format!("{}-cb", backend.name))
331 .build();
332 builder = builder.backend_layer(layer);
333 }
334
335 if let Some(od) = &backend.outlier_detection
337 && let Some(ref detector) = outlier_detector
338 {
339 tracing::info!(
340 backend = %backend.name,
341 consecutive_errors = od.consecutive_errors,
342 base_ejection_seconds = od.base_ejection_seconds,
343 max_ejection_percent = od.max_ejection_percent,
344 "Applying outlier detection"
345 );
346 let layer = crate::outlier::OutlierDetectionLayer::new(
347 backend.name.clone(),
348 od.clone(),
349 detector.clone(),
350 );
351 builder = builder.backend_layer(layer);
352 }
353 }
354
355 let result = builder.build().await?;
356
357 if !result.skipped.is_empty() {
358 for s in &result.skipped {
359 tracing::warn!("Skipped backend: {s}");
360 }
361 }
362
363 Ok(result.proxy)
364}
365
366fn build_middleware_stack(
368 config: &ProxyConfig,
369 proxy: McpProxy,
370) -> Result<(
371 BoxCloneService<RouterRequest, RouterResponse, Infallible>,
372 Option<cache::CacheHandle>,
373)> {
374 let mut service: BoxCloneService<RouterRequest, RouterResponse, Infallible> =
375 BoxCloneService::new(proxy);
376 let mut cache_handle: Option<cache::CacheHandle> = None;
377
378 let injection_rules: Vec<_> = config
380 .backends
381 .iter()
382 .filter(|b| !b.default_args.is_empty() || !b.inject_args.is_empty())
383 .map(|b| {
384 let namespace = format!("{}{}", b.name, config.proxy.separator);
385 tracing::info!(
386 backend = %b.name,
387 default_args = b.default_args.len(),
388 tool_rules = b.inject_args.len(),
389 "Applying argument injection"
390 );
391 crate::inject::InjectionRules::new(
392 namespace,
393 b.default_args.clone(),
394 b.inject_args.clone(),
395 )
396 })
397 .collect();
398
399 if !injection_rules.is_empty() {
400 service = BoxCloneService::new(crate::inject::InjectArgsService::new(
401 service,
402 injection_rules,
403 ));
404 }
405
406 let canary_mappings: std::collections::HashMap<String, (String, u32, u32)> = config
408 .backends
409 .iter()
410 .filter_map(|b| {
411 b.canary_of.as_ref().map(|primary_name| {
412 let primary_weight = config
414 .backends
415 .iter()
416 .find(|p| p.name == *primary_name)
417 .map(|p| p.weight)
418 .unwrap_or(100);
419 (
420 primary_name.clone(),
421 (b.name.clone(), primary_weight, b.weight),
422 )
423 })
424 })
425 .collect();
426
427 if !canary_mappings.is_empty() {
428 for (primary, (canary, pw, cw)) in &canary_mappings {
429 tracing::info!(
430 primary = %primary,
431 canary = %canary,
432 primary_weight = pw,
433 canary_weight = cw,
434 "Enabling canary routing"
435 );
436 }
437 service = BoxCloneService::new(crate::canary::CanaryService::new(
438 service,
439 canary_mappings,
440 &config.proxy.separator,
441 ));
442 }
443
444 let mirror_mappings: std::collections::HashMap<String, (String, u32)> = config
446 .backends
447 .iter()
448 .filter_map(|b| {
449 b.mirror_of
450 .as_ref()
451 .map(|source| (source.clone(), (b.name.clone(), b.mirror_percent)))
452 })
453 .collect();
454
455 if !mirror_mappings.is_empty() {
456 for (source, (mirror, pct)) in &mirror_mappings {
457 tracing::info!(
458 source = %source,
459 mirror = %mirror,
460 percent = pct,
461 "Enabling traffic mirroring"
462 );
463 }
464 service = BoxCloneService::new(crate::mirror::MirrorService::new(
465 service,
466 mirror_mappings,
467 &config.proxy.separator,
468 ));
469 }
470
471 let cache_configs: Vec<_> = config
473 .backends
474 .iter()
475 .filter_map(|b| {
476 b.cache
477 .as_ref()
478 .map(|c| (format!("{}{}", b.name, config.proxy.separator), c))
479 })
480 .collect();
481
482 if !cache_configs.is_empty() {
483 for (ns, cfg) in &cache_configs {
484 tracing::info!(
485 backend = %ns.trim_end_matches(&config.proxy.separator),
486 resource_ttl = cfg.resource_ttl_seconds,
487 tool_ttl = cfg.tool_ttl_seconds,
488 max_entries = cfg.max_entries,
489 "Applying response cache"
490 );
491 }
492 let (cache_svc, handle) = cache::CacheService::new(service, cache_configs);
493 service = BoxCloneService::new(cache_svc);
494 cache_handle = Some(handle);
495 }
496
497 if config.performance.coalesce_requests {
499 tracing::info!("Request coalescing enabled");
500 service = BoxCloneService::new(coalesce::CoalesceService::new(service));
501 }
502
503 if config.security.max_argument_size.is_some() {
505 let validation = ValidationConfig {
506 max_argument_size: config.security.max_argument_size,
507 };
508 if let Some(max) = validation.max_argument_size {
509 tracing::info!(max_argument_size = max, "Applying request validation");
510 }
511 service = BoxCloneService::new(ValidationService::new(service, validation));
512 }
513
514 let filters: Vec<_> = config
516 .backends
517 .iter()
518 .filter_map(|b| b.build_filter(&config.proxy.separator))
519 .collect();
520
521 if !filters.is_empty() {
522 for f in &filters {
523 tracing::info!(
524 backend = %f.namespace.trim_end_matches(&config.proxy.separator),
525 tool_filter = ?f.tool_filter,
526 resource_filter = ?f.resource_filter,
527 prompt_filter = ?f.prompt_filter,
528 "Applying capability filter"
529 );
530 }
531 service = BoxCloneService::new(CapabilityFilterService::new(service, filters));
532 }
533
534 let alias_mappings: Vec<_> = config
536 .backends
537 .iter()
538 .flat_map(|b| {
539 let ns = format!("{}{}", b.name, config.proxy.separator);
540 b.aliases
541 .iter()
542 .map(move |a| (ns.clone(), a.from.clone(), a.to.clone()))
543 })
544 .collect();
545
546 if let Some(alias_map) = alias::AliasMap::new(alias_mappings) {
547 let count = alias_map.forward.len();
548 tracing::info!(aliases = count, "Applying tool aliases");
549 service = BoxCloneService::new(alias::AliasService::new(service, alias_map));
550 }
551
552 let rbac_config = match &config.auth {
554 Some(AuthConfig::Jwt {
555 roles,
556 role_mapping: Some(mapping),
557 ..
558 }) if !roles.is_empty() => {
559 tracing::info!(
560 roles = roles.len(),
561 claim = %mapping.claim,
562 "Enabling RBAC"
563 );
564 Some(RbacConfig::new(roles, mapping))
565 }
566 _ => None,
567 };
568
569 if let Some(rbac) = rbac_config {
570 service = BoxCloneService::new(RbacService::new(service, rbac));
571 }
572
573 let forward_namespaces: std::collections::HashSet<String> = config
575 .backends
576 .iter()
577 .filter(|b| b.forward_auth)
578 .map(|b| format!("{}{}", b.name, config.proxy.separator))
579 .collect();
580
581 if !forward_namespaces.is_empty() {
582 tracing::info!(
583 backends = ?forward_namespaces,
584 "Enabling token passthrough for forward_auth backends"
585 );
586 service = BoxCloneService::new(crate::token::TokenPassthroughService::new(
587 service,
588 forward_namespaces,
589 ));
590 }
591
592 if config.observability.metrics.enabled {
594 service = BoxCloneService::new(metrics::MetricsService::new(service));
595 }
596
597 if config.observability.audit {
599 tracing::info!("Audit logging enabled (target: mcp::audit)");
600 let audited = tower::Layer::layer(&tower_mcp::AuditLayer::new(), service);
601 service = BoxCloneService::new(tower_mcp::CatchError::new(audited));
602 }
603
604 Ok((service, cache_handle))
605}
606
607async fn apply_auth(config: &ProxyConfig, router: Router) -> Result<Router> {
609 let router = if let Some(auth) = &config.auth {
610 match auth {
611 AuthConfig::Bearer { tokens } => {
612 tracing::info!(token_count = tokens.len(), "Enabling bearer token auth");
613 let validator = StaticBearerValidator::new(tokens.iter().cloned());
614 let layer = AuthLayer::new(validator);
615 router.layer(layer)
616 }
617 AuthConfig::Jwt {
618 issuer,
619 audience,
620 jwks_uri,
621 ..
622 } => {
623 tracing::info!(
624 issuer = %issuer,
625 audience = %audience,
626 jwks_uri = %jwks_uri,
627 "Enabling JWT auth (JWKS)"
628 );
629 let validator = tower_mcp::oauth::JwksValidator::builder(jwks_uri)
630 .expected_audience(audience)
631 .expected_issuer(issuer)
632 .build()
633 .await
634 .context("building JWKS validator")?;
635
636 let addr = format!(
637 "http://{}:{}",
638 config.proxy.listen.host, config.proxy.listen.port
639 );
640 let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
641 .authorization_server(issuer);
642
643 let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
644 router.layer(layer)
645 }
646 }
647 } else {
648 router
649 };
650 Ok(router)
651}
652
653pub async fn shutdown_signal(timeout: Duration) {
655 let ctrl_c = tokio::signal::ctrl_c();
656 #[cfg(unix)]
657 {
658 let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
659 .expect("SIGTERM handler");
660 tokio::select! {
661 _ = ctrl_c => {},
662 _ = sigterm.recv() => {},
663 }
664 }
665 #[cfg(not(unix))]
666 {
667 ctrl_c.await.ok();
668 }
669 tracing::info!(
670 timeout_seconds = timeout.as_secs(),
671 "Shutdown signal received, draining connections"
672 );
673}