Skip to main content

mcp_proxy/
proxy.rs

1//! Core proxy construction and serving.
2
3use std::convert::Infallible;
4use std::time::Duration;
5
6use anyhow::{Context, Result};
7use axum::Router;
8use tokio::process::Command;
9use tower::timeout::TimeoutLayer;
10use tower::util::BoxCloneService;
11use tower_mcp::SessionHandle;
12use tower_mcp::auth::{AuthLayer, StaticBearerValidator};
13use tower_mcp::client::StdioClientTransport;
14use tower_mcp::proxy::McpProxy;
15use tower_mcp::{RouterRequest, RouterResponse};
16
17use crate::admin::BackendMeta;
18use crate::alias;
19use crate::cache;
20use crate::coalesce;
21use crate::config::{AuthConfig, ProxyConfig, TransportType};
22use crate::filter::CapabilityFilterService;
23use crate::metrics;
24use crate::rbac::{RbacConfig, RbacService};
25use crate::validation::{ValidationConfig, ValidationService};
26
27/// A fully constructed MCP proxy ready to serve or embed.
28pub struct Proxy {
29    router: Router,
30    session_handle: SessionHandle,
31    inner: McpProxy,
32    config: ProxyConfig,
33}
34
35impl Proxy {
36    /// Build a proxy from a [`ProxyConfig`].
37    ///
38    /// Connects to all backends, builds the middleware stack, and prepares
39    /// the axum router. Call [`serve()`](Self::serve) to run standalone or
40    /// [`into_router()`](Self::into_router) to embed in an existing app.
41    pub async fn from_config(config: ProxyConfig) -> Result<Self> {
42        let mcp_proxy = build_mcp_proxy(&config).await?;
43        let proxy_for_admin = mcp_proxy.clone();
44        let proxy_for_caller = mcp_proxy.clone();
45
46        // Install Prometheus metrics recorder (must happen before middleware)
47        let metrics_handle = if config.observability.metrics.enabled {
48            tracing::info!("Prometheus metrics enabled at /admin/metrics");
49            let builder = metrics_exporter_prometheus::PrometheusBuilder::new();
50            let handle = builder
51                .install_recorder()
52                .context("installing Prometheus metrics recorder")?;
53            Some(handle)
54        } else {
55            None
56        };
57
58        let (service, cache_handle) = build_middleware_stack(&config, mcp_proxy)?;
59
60        let (router, session_handle) =
61            tower_mcp::transport::http::HttpTransport::from_service(service)
62                .into_router_with_handle();
63
64        // Inbound authentication (axum-level middleware)
65        let router = apply_auth(&config, router).await?;
66
67        // Collect backend metadata for the health checker
68        let backend_meta: std::collections::HashMap<String, BackendMeta> = config
69            .backends
70            .iter()
71            .map(|b| {
72                (
73                    b.name.clone(),
74                    BackendMeta {
75                        transport: format!("{:?}", b.transport).to_lowercase(),
76                    },
77                )
78            })
79            .collect();
80
81        // Admin API
82        let admin_state = crate::admin::spawn_health_checker(
83            proxy_for_admin,
84            config.proxy.name.clone(),
85            config.proxy.version.clone(),
86            config.backends.len(),
87            backend_meta,
88        );
89        let router = router.nest(
90            "/admin",
91            crate::admin::admin_router(
92                admin_state.clone(),
93                metrics_handle,
94                session_handle.clone(),
95                cache_handle,
96            ),
97        );
98        tracing::info!("Admin API enabled at /admin/backends");
99
100        // MCP admin tools (proxy/ namespace)
101        if let Err(e) = crate::admin_tools::register_admin_tools(
102            &proxy_for_caller,
103            admin_state,
104            session_handle.clone(),
105            &config,
106        )
107        .await
108        {
109            tracing::warn!("Failed to register admin tools: {e}");
110        } else {
111            tracing::info!("MCP admin tools registered under proxy/ namespace");
112        }
113
114        Ok(Self {
115            router,
116            session_handle,
117            inner: proxy_for_caller,
118            config,
119        })
120    }
121
122    /// Get a reference to the session handle for monitoring active sessions.
123    pub fn session_handle(&self) -> &SessionHandle {
124        &self.session_handle
125    }
126
127    /// Get a reference to the underlying [`McpProxy`] for dynamic operations.
128    ///
129    /// Use this to add backends dynamically via [`McpProxy::add_backend()`].
130    pub fn mcp_proxy(&self) -> &McpProxy {
131        &self.inner
132    }
133
134    /// Enable hot reload by watching the given config file path.
135    ///
136    /// New backends added to the config file will be connected dynamically
137    /// without restarting the proxy.
138    pub fn enable_hot_reload(&self, config_path: std::path::PathBuf) {
139        tracing::info!("Hot reload enabled, watching config file for changes");
140        crate::reload::spawn_config_watcher(config_path, self.inner.clone());
141    }
142
143    /// Consume the proxy and return the axum Router and SessionHandle.
144    ///
145    /// Use this to embed the proxy in an existing axum application:
146    ///
147    /// ```rust,ignore
148    /// let (proxy_router, session_handle) = proxy.into_router();
149    ///
150    /// let app = Router::new()
151    ///     .nest("/mcp", proxy_router)
152    ///     .route("/health", get(|| async { "ok" }));
153    /// ```
154    pub fn into_router(self) -> (Router, SessionHandle) {
155        (self.router, self.session_handle)
156    }
157
158    /// Serve the proxy on the configured listen address.
159    ///
160    /// Blocks until a shutdown signal (SIGTERM/SIGINT) is received,
161    /// then drains connections for the configured timeout period.
162    pub async fn serve(self) -> Result<()> {
163        let addr = format!(
164            "{}:{}",
165            self.config.proxy.listen.host, self.config.proxy.listen.port
166        );
167
168        tracing::info!(listen = %addr, "Proxy ready");
169
170        let listener = tokio::net::TcpListener::bind(&addr)
171            .await
172            .with_context(|| format!("binding to {}", addr))?;
173
174        let shutdown_timeout = Duration::from_secs(self.config.proxy.shutdown_timeout_seconds);
175        axum::serve(listener, self.router)
176            .with_graceful_shutdown(shutdown_signal(shutdown_timeout))
177            .await
178            .context("server error")?;
179
180        tracing::info!("Proxy shut down");
181        Ok(())
182    }
183}
184
185/// Build the McpProxy with all backends and per-backend middleware.
186async fn build_mcp_proxy(config: &ProxyConfig) -> Result<McpProxy> {
187    let mut builder = McpProxy::builder(&config.proxy.name, &config.proxy.version)
188        .separator(&config.proxy.separator);
189
190    if let Some(instructions) = &config.proxy.instructions {
191        builder = builder.instructions(instructions);
192    }
193
194    // Create shared outlier detector if any backend has outlier_detection configured.
195    // Use the max of all max_ejection_percent values.
196    let outlier_detector = {
197        let max_pct = config
198            .backends
199            .iter()
200            .filter_map(|b| b.outlier_detection.as_ref())
201            .map(|od| od.max_ejection_percent)
202            .max();
203        max_pct.map(crate::outlier::OutlierDetector::new)
204    };
205
206    for backend in &config.backends {
207        tracing::info!(name = %backend.name, transport = ?backend.transport, "Adding backend");
208
209        match backend.transport {
210            TransportType::Stdio => {
211                let command = backend.command.as_deref().unwrap();
212                let args: Vec<&str> = backend.args.iter().map(|s| s.as_str()).collect();
213
214                let mut cmd = Command::new(command);
215                cmd.args(&args);
216
217                for (key, value) in &backend.env {
218                    cmd.env(key, value);
219                }
220
221                let transport = StdioClientTransport::spawn_command(&mut cmd)
222                    .await
223                    .with_context(|| format!("spawning backend '{}'", backend.name))?;
224
225                builder = builder.backend(&backend.name, transport).await;
226            }
227            TransportType::Http => {
228                let url = backend.url.as_deref().unwrap();
229                let mut transport = tower_mcp::client::HttpClientTransport::new(url);
230                if let Some(token) = &backend.bearer_token {
231                    transport = transport.bearer_token(token);
232                }
233
234                builder = builder.backend(&backend.name, transport).await;
235            }
236        }
237
238        // Per-backend middleware stack (applied in order: inner -> outer)
239
240        // Retry (innermost -- retries happen before other middleware)
241        if let Some(retry_cfg) = &backend.retry {
242            tracing::info!(
243                backend = %backend.name,
244                max_retries = retry_cfg.max_retries,
245                initial_backoff_ms = retry_cfg.initial_backoff_ms,
246                max_backoff_ms = retry_cfg.max_backoff_ms,
247                "Applying retry policy"
248            );
249            let layer = crate::retry::build_retry_layer(retry_cfg, &backend.name);
250            builder = builder.backend_layer(layer);
251        }
252
253        // Hedging (after retry, before concurrency -- hedges are separate requests)
254        if let Some(hedge_cfg) = &backend.hedging {
255            let delay = Duration::from_millis(hedge_cfg.delay_ms);
256            let max_attempts = hedge_cfg.max_hedges + 1; // +1 for the primary request
257            tracing::info!(
258                backend = %backend.name,
259                delay_ms = hedge_cfg.delay_ms,
260                max_hedges = hedge_cfg.max_hedges,
261                "Applying request hedging"
262            );
263            let layer = if delay.is_zero() {
264                tower_resilience::hedge::HedgeLayer::builder()
265                    .no_delay()
266                    .max_hedged_attempts(max_attempts)
267                    .name(format!("{}-hedge", backend.name))
268                    .build()
269            } else {
270                tower_resilience::hedge::HedgeLayer::builder()
271                    .delay(delay)
272                    .max_hedged_attempts(max_attempts)
273                    .name(format!("{}-hedge", backend.name))
274                    .build()
275            };
276            builder = builder.backend_layer(layer);
277        }
278
279        // Concurrency limit
280        if let Some(cc) = &backend.concurrency {
281            tracing::info!(
282                backend = %backend.name,
283                max = cc.max_concurrent,
284                "Applying concurrency limit"
285            );
286            builder =
287                builder.backend_layer(tower::limit::ConcurrencyLimitLayer::new(cc.max_concurrent));
288        }
289
290        // Rate limit
291        if let Some(rl) = &backend.rate_limit {
292            tracing::info!(
293                backend = %backend.name,
294                requests = rl.requests,
295                period_seconds = rl.period_seconds,
296                "Applying rate limit"
297            );
298            let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
299                .limit_for_period(rl.requests)
300                .refresh_period(Duration::from_secs(rl.period_seconds))
301                .name(format!("{}-ratelimit", backend.name))
302                .build();
303            builder = builder.backend_layer(layer);
304        }
305
306        // Timeout
307        if let Some(timeout) = &backend.timeout {
308            tracing::info!(
309                backend = %backend.name,
310                seconds = timeout.seconds,
311                "Applying timeout"
312            );
313            builder =
314                builder.backend_layer(TimeoutLayer::new(Duration::from_secs(timeout.seconds)));
315        }
316
317        // Circuit breaker
318        if let Some(cb) = &backend.circuit_breaker {
319            tracing::info!(
320                backend = %backend.name,
321                failure_rate = cb.failure_rate_threshold,
322                wait_seconds = cb.wait_duration_seconds,
323                "Applying circuit breaker"
324            );
325            let layer = tower_resilience::circuitbreaker::CircuitBreakerLayer::builder()
326                .failure_rate_threshold(cb.failure_rate_threshold)
327                .minimum_number_of_calls(cb.minimum_calls)
328                .wait_duration_in_open(Duration::from_secs(cb.wait_duration_seconds))
329                .permitted_calls_in_half_open(cb.permitted_calls_in_half_open)
330                .name(format!("{}-cb", backend.name))
331                .build();
332            builder = builder.backend_layer(layer);
333        }
334
335        // Outlier detection (outermost -- observes errors after all other middleware)
336        if let Some(od) = &backend.outlier_detection
337            && let Some(ref detector) = outlier_detector
338        {
339            tracing::info!(
340                backend = %backend.name,
341                consecutive_errors = od.consecutive_errors,
342                base_ejection_seconds = od.base_ejection_seconds,
343                max_ejection_percent = od.max_ejection_percent,
344                "Applying outlier detection"
345            );
346            let layer = crate::outlier::OutlierDetectionLayer::new(
347                backend.name.clone(),
348                od.clone(),
349                detector.clone(),
350            );
351            builder = builder.backend_layer(layer);
352        }
353    }
354
355    let result = builder.build().await?;
356
357    if !result.skipped.is_empty() {
358        for s in &result.skipped {
359            tracing::warn!("Skipped backend: {s}");
360        }
361    }
362
363    Ok(result.proxy)
364}
365
366/// Build the MCP-level middleware stack around the proxy.
367fn build_middleware_stack(
368    config: &ProxyConfig,
369    proxy: McpProxy,
370) -> Result<(
371    BoxCloneService<RouterRequest, RouterResponse, Infallible>,
372    Option<cache::CacheHandle>,
373)> {
374    let mut service: BoxCloneService<RouterRequest, RouterResponse, Infallible> =
375        BoxCloneService::new(proxy);
376    let mut cache_handle: Option<cache::CacheHandle> = None;
377
378    // Argument injection (innermost -- merges default/per-tool args into CallTool requests)
379    let injection_rules: Vec<_> = config
380        .backends
381        .iter()
382        .filter(|b| !b.default_args.is_empty() || !b.inject_args.is_empty())
383        .map(|b| {
384            let namespace = format!("{}{}", b.name, config.proxy.separator);
385            tracing::info!(
386                backend = %b.name,
387                default_args = b.default_args.len(),
388                tool_rules = b.inject_args.len(),
389                "Applying argument injection"
390            );
391            crate::inject::InjectionRules::new(
392                namespace,
393                b.default_args.clone(),
394                b.inject_args.clone(),
395            )
396        })
397        .collect();
398
399    if !injection_rules.is_empty() {
400        service = BoxCloneService::new(crate::inject::InjectArgsService::new(
401            service,
402            injection_rules,
403        ));
404    }
405
406    // Canary routing (rewrites requests from primary to canary namespace based on weight)
407    let canary_mappings: std::collections::HashMap<String, (String, u32, u32)> = config
408        .backends
409        .iter()
410        .filter_map(|b| {
411            b.canary_of.as_ref().map(|primary_name| {
412                // Find the primary backend's weight
413                let primary_weight = config
414                    .backends
415                    .iter()
416                    .find(|p| p.name == *primary_name)
417                    .map(|p| p.weight)
418                    .unwrap_or(100);
419                (
420                    primary_name.clone(),
421                    (b.name.clone(), primary_weight, b.weight),
422                )
423            })
424        })
425        .collect();
426
427    if !canary_mappings.is_empty() {
428        for (primary, (canary, pw, cw)) in &canary_mappings {
429            tracing::info!(
430                primary = %primary,
431                canary = %canary,
432                primary_weight = pw,
433                canary_weight = cw,
434                "Enabling canary routing"
435            );
436        }
437        service = BoxCloneService::new(crate::canary::CanaryService::new(
438            service,
439            canary_mappings,
440            &config.proxy.separator,
441        ));
442    }
443
444    // Traffic mirroring (sends cloned requests through the proxy)
445    let mirror_mappings: std::collections::HashMap<String, (String, u32)> = config
446        .backends
447        .iter()
448        .filter_map(|b| {
449            b.mirror_of
450                .as_ref()
451                .map(|source| (source.clone(), (b.name.clone(), b.mirror_percent)))
452        })
453        .collect();
454
455    if !mirror_mappings.is_empty() {
456        for (source, (mirror, pct)) in &mirror_mappings {
457            tracing::info!(
458                source = %source,
459                mirror = %mirror,
460                percent = pct,
461                "Enabling traffic mirroring"
462            );
463        }
464        service = BoxCloneService::new(crate::mirror::MirrorService::new(
465            service,
466            mirror_mappings,
467            &config.proxy.separator,
468        ));
469    }
470
471    // Response caching
472    let cache_configs: Vec<_> = config
473        .backends
474        .iter()
475        .filter_map(|b| {
476            b.cache
477                .as_ref()
478                .map(|c| (format!("{}{}", b.name, config.proxy.separator), c))
479        })
480        .collect();
481
482    if !cache_configs.is_empty() {
483        for (ns, cfg) in &cache_configs {
484            tracing::info!(
485                backend = %ns.trim_end_matches(&config.proxy.separator),
486                resource_ttl = cfg.resource_ttl_seconds,
487                tool_ttl = cfg.tool_ttl_seconds,
488                max_entries = cfg.max_entries,
489                "Applying response cache"
490            );
491        }
492        let (cache_svc, handle) = cache::CacheService::new(service, cache_configs);
493        service = BoxCloneService::new(cache_svc);
494        cache_handle = Some(handle);
495    }
496
497    // Request coalescing
498    if config.performance.coalesce_requests {
499        tracing::info!("Request coalescing enabled");
500        service = BoxCloneService::new(coalesce::CoalesceService::new(service));
501    }
502
503    // Request validation
504    if config.security.max_argument_size.is_some() {
505        let validation = ValidationConfig {
506            max_argument_size: config.security.max_argument_size,
507        };
508        if let Some(max) = validation.max_argument_size {
509            tracing::info!(max_argument_size = max, "Applying request validation");
510        }
511        service = BoxCloneService::new(ValidationService::new(service, validation));
512    }
513
514    // Static capability filtering
515    let filters: Vec<_> = config
516        .backends
517        .iter()
518        .filter_map(|b| b.build_filter(&config.proxy.separator))
519        .collect();
520
521    if !filters.is_empty() {
522        for f in &filters {
523            tracing::info!(
524                backend = %f.namespace.trim_end_matches(&config.proxy.separator),
525                tool_filter = ?f.tool_filter,
526                resource_filter = ?f.resource_filter,
527                prompt_filter = ?f.prompt_filter,
528                "Applying capability filter"
529            );
530        }
531        service = BoxCloneService::new(CapabilityFilterService::new(service, filters));
532    }
533
534    // Tool aliasing
535    let alias_mappings: Vec<_> = config
536        .backends
537        .iter()
538        .flat_map(|b| {
539            let ns = format!("{}{}", b.name, config.proxy.separator);
540            b.aliases
541                .iter()
542                .map(move |a| (ns.clone(), a.from.clone(), a.to.clone()))
543        })
544        .collect();
545
546    if let Some(alias_map) = alias::AliasMap::new(alias_mappings) {
547        let count = alias_map.forward.len();
548        tracing::info!(aliases = count, "Applying tool aliases");
549        service = BoxCloneService::new(alias::AliasService::new(service, alias_map));
550    }
551
552    // RBAC (JWT auth only)
553    let rbac_config = match &config.auth {
554        Some(AuthConfig::Jwt {
555            roles,
556            role_mapping: Some(mapping),
557            ..
558        }) if !roles.is_empty() => {
559            tracing::info!(
560                roles = roles.len(),
561                claim = %mapping.claim,
562                "Enabling RBAC"
563            );
564            Some(RbacConfig::new(roles, mapping))
565        }
566        _ => None,
567    };
568
569    if let Some(rbac) = rbac_config {
570        service = BoxCloneService::new(RbacService::new(service, rbac));
571    }
572
573    // Token passthrough (inject ClientToken for forward_auth backends)
574    let forward_namespaces: std::collections::HashSet<String> = config
575        .backends
576        .iter()
577        .filter(|b| b.forward_auth)
578        .map(|b| format!("{}{}", b.name, config.proxy.separator))
579        .collect();
580
581    if !forward_namespaces.is_empty() {
582        tracing::info!(
583            backends = ?forward_namespaces,
584            "Enabling token passthrough for forward_auth backends"
585        );
586        service = BoxCloneService::new(crate::token::TokenPassthroughService::new(
587            service,
588            forward_namespaces,
589        ));
590    }
591
592    // Metrics
593    if config.observability.metrics.enabled {
594        service = BoxCloneService::new(metrics::MetricsService::new(service));
595    }
596
597    // Audit logging
598    if config.observability.audit {
599        tracing::info!("Audit logging enabled (target: mcp::audit)");
600        let audited = tower::Layer::layer(&tower_mcp::AuditLayer::new(), service);
601        service = BoxCloneService::new(tower_mcp::CatchError::new(audited));
602    }
603
604    Ok((service, cache_handle))
605}
606
607/// Apply inbound authentication middleware to the router.
608async fn apply_auth(config: &ProxyConfig, router: Router) -> Result<Router> {
609    let router = if let Some(auth) = &config.auth {
610        match auth {
611            AuthConfig::Bearer { tokens } => {
612                tracing::info!(token_count = tokens.len(), "Enabling bearer token auth");
613                let validator = StaticBearerValidator::new(tokens.iter().cloned());
614                let layer = AuthLayer::new(validator);
615                router.layer(layer)
616            }
617            AuthConfig::Jwt {
618                issuer,
619                audience,
620                jwks_uri,
621                ..
622            } => {
623                tracing::info!(
624                    issuer = %issuer,
625                    audience = %audience,
626                    jwks_uri = %jwks_uri,
627                    "Enabling JWT auth (JWKS)"
628                );
629                let validator = tower_mcp::oauth::JwksValidator::builder(jwks_uri)
630                    .expected_audience(audience)
631                    .expected_issuer(issuer)
632                    .build()
633                    .await
634                    .context("building JWKS validator")?;
635
636                let addr = format!(
637                    "http://{}:{}",
638                    config.proxy.listen.host, config.proxy.listen.port
639                );
640                let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
641                    .authorization_server(issuer);
642
643                let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
644                router.layer(layer)
645            }
646        }
647    } else {
648        router
649    };
650    Ok(router)
651}
652
653/// Wait for SIGTERM or SIGINT, then log and return.
654pub async fn shutdown_signal(timeout: Duration) {
655    let ctrl_c = tokio::signal::ctrl_c();
656    #[cfg(unix)]
657    {
658        let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
659            .expect("SIGTERM handler");
660        tokio::select! {
661            _ = ctrl_c => {},
662            _ = sigterm.recv() => {},
663        }
664    }
665    #[cfg(not(unix))]
666    {
667        ctrl_c.await.ok();
668    }
669    tracing::info!(
670        timeout_seconds = timeout.as_secs(),
671        "Shutdown signal received, draining connections"
672    );
673}