Skip to main content

mcp_proxy/
proxy.rs

1//! Core proxy construction and serving.
2
3use std::convert::Infallible;
4use std::time::Duration;
5
6use anyhow::{Context, Result};
7use axum::Router;
8use tokio::process::Command;
9use tower::timeout::TimeoutLayer;
10use tower::util::BoxCloneService;
11use tower_mcp::SessionHandle;
12use tower_mcp::auth::{AuthLayer, StaticBearerValidator};
13use tower_mcp::client::StdioClientTransport;
14use tower_mcp::proxy::McpProxy;
15use tower_mcp::{RouterRequest, RouterResponse};
16
17use crate::admin::BackendMeta;
18use crate::alias;
19use crate::cache;
20use crate::coalesce;
21use crate::config::{AuthConfig, ProxyConfig, TransportType};
22use crate::filter::CapabilityFilterService;
23#[cfg(feature = "oauth")]
24use crate::rbac::{RbacConfig, RbacService};
25use crate::validation::{ValidationConfig, ValidationService};
26
27/// A fully constructed MCP proxy ready to serve or embed.
28pub struct Proxy {
29    router: Router,
30    session_handle: SessionHandle,
31    inner: McpProxy,
32    config: ProxyConfig,
33}
34
35impl Proxy {
36    /// Build a proxy from a [`ProxyConfig`].
37    ///
38    /// Connects to all backends, builds the middleware stack, and prepares
39    /// the axum router. Call [`serve()`](Self::serve) to run standalone or
40    /// [`into_router()`](Self::into_router) to embed in an existing app.
41    pub async fn from_config(config: ProxyConfig) -> Result<Self> {
42        let mcp_proxy = build_mcp_proxy(&config).await?;
43        let proxy_for_admin = mcp_proxy.clone();
44        let proxy_for_caller = mcp_proxy.clone();
45
46        // Install Prometheus metrics recorder (must happen before middleware)
47        #[cfg(feature = "metrics")]
48        let metrics_handle = if config.observability.metrics.enabled {
49            tracing::info!("Prometheus metrics enabled at /admin/metrics");
50            let builder = metrics_exporter_prometheus::PrometheusBuilder::new();
51            let handle = builder
52                .install_recorder()
53                .context("installing Prometheus metrics recorder")?;
54            Some(handle)
55        } else {
56            None
57        };
58        #[cfg(not(feature = "metrics"))]
59        let metrics_handle = None;
60
61        let (service, cache_handle) = build_middleware_stack(&config, mcp_proxy)?;
62
63        let (router, session_handle) =
64            tower_mcp::transport::http::HttpTransport::from_service(service)
65                .into_router_with_handle();
66
67        // Inbound authentication (axum-level middleware)
68        let router = apply_auth(&config, router).await?;
69
70        // Collect backend metadata for the health checker
71        let backend_meta: std::collections::HashMap<String, BackendMeta> = config
72            .backends
73            .iter()
74            .map(|b| {
75                (
76                    b.name.clone(),
77                    BackendMeta {
78                        transport: format!("{:?}", b.transport).to_lowercase(),
79                    },
80                )
81            })
82            .collect();
83
84        // Admin API
85        let admin_state = crate::admin::spawn_health_checker(
86            proxy_for_admin,
87            config.proxy.name.clone(),
88            config.proxy.version.clone(),
89            config.backends.len(),
90            backend_meta,
91        );
92        let router = router.nest(
93            "/admin",
94            crate::admin::admin_router(
95                admin_state.clone(),
96                metrics_handle,
97                session_handle.clone(),
98                cache_handle,
99            ),
100        );
101        tracing::info!("Admin API enabled at /admin/backends");
102
103        // MCP admin tools (proxy/ namespace)
104        if let Err(e) = crate::admin_tools::register_admin_tools(
105            &proxy_for_caller,
106            admin_state,
107            session_handle.clone(),
108            &config,
109        )
110        .await
111        {
112            tracing::warn!("Failed to register admin tools: {e}");
113        } else {
114            tracing::info!("MCP admin tools registered under proxy/ namespace");
115        }
116
117        Ok(Self {
118            router,
119            session_handle,
120            inner: proxy_for_caller,
121            config,
122        })
123    }
124
125    /// Get a reference to the session handle for monitoring active sessions.
126    pub fn session_handle(&self) -> &SessionHandle {
127        &self.session_handle
128    }
129
130    /// Get a reference to the underlying [`McpProxy`] for dynamic operations.
131    ///
132    /// Use this to add backends dynamically via [`McpProxy::add_backend()`].
133    pub fn mcp_proxy(&self) -> &McpProxy {
134        &self.inner
135    }
136
137    /// Enable hot reload by watching the given config file path.
138    ///
139    /// New backends added to the config file will be connected dynamically
140    /// without restarting the proxy.
141    pub fn enable_hot_reload(&self, config_path: std::path::PathBuf) {
142        tracing::info!("Hot reload enabled, watching config file for changes");
143        crate::reload::spawn_config_watcher(config_path, self.inner.clone());
144    }
145
146    /// Consume the proxy and return the axum Router and SessionHandle.
147    ///
148    /// Use this to embed the proxy in an existing axum application:
149    ///
150    /// ```rust,ignore
151    /// let (proxy_router, session_handle) = proxy.into_router();
152    ///
153    /// let app = Router::new()
154    ///     .nest("/mcp", proxy_router)
155    ///     .route("/health", get(|| async { "ok" }));
156    /// ```
157    pub fn into_router(self) -> (Router, SessionHandle) {
158        (self.router, self.session_handle)
159    }
160
161    /// Serve the proxy on the configured listen address.
162    ///
163    /// Blocks until a shutdown signal (SIGTERM/SIGINT) is received,
164    /// then drains connections for the configured timeout period.
165    pub async fn serve(self) -> Result<()> {
166        let addr = format!(
167            "{}:{}",
168            self.config.proxy.listen.host, self.config.proxy.listen.port
169        );
170
171        tracing::info!(listen = %addr, "Proxy ready");
172
173        let listener = tokio::net::TcpListener::bind(&addr)
174            .await
175            .with_context(|| format!("binding to {}", addr))?;
176
177        let shutdown_timeout = Duration::from_secs(self.config.proxy.shutdown_timeout_seconds);
178        axum::serve(listener, self.router)
179            .with_graceful_shutdown(shutdown_signal(shutdown_timeout))
180            .await
181            .context("server error")?;
182
183        tracing::info!("Proxy shut down");
184        Ok(())
185    }
186}
187
188/// Build the McpProxy with all backends and per-backend middleware.
189async fn build_mcp_proxy(config: &ProxyConfig) -> Result<McpProxy> {
190    let mut builder = McpProxy::builder(&config.proxy.name, &config.proxy.version)
191        .separator(&config.proxy.separator);
192
193    if let Some(instructions) = &config.proxy.instructions {
194        builder = builder.instructions(instructions);
195    }
196
197    // Create shared outlier detector if any backend has outlier_detection configured.
198    // Use the max of all max_ejection_percent values.
199    let outlier_detector = {
200        let max_pct = config
201            .backends
202            .iter()
203            .filter_map(|b| b.outlier_detection.as_ref())
204            .map(|od| od.max_ejection_percent)
205            .max();
206        max_pct.map(crate::outlier::OutlierDetector::new)
207    };
208
209    for backend in &config.backends {
210        tracing::info!(name = %backend.name, transport = ?backend.transport, "Adding backend");
211
212        match backend.transport {
213            TransportType::Stdio => {
214                let command = backend.command.as_deref().unwrap();
215                let args: Vec<&str> = backend.args.iter().map(|s| s.as_str()).collect();
216
217                let mut cmd = Command::new(command);
218                cmd.args(&args);
219
220                for (key, value) in &backend.env {
221                    cmd.env(key, value);
222                }
223
224                let transport = StdioClientTransport::spawn_command(&mut cmd)
225                    .await
226                    .with_context(|| format!("spawning backend '{}'", backend.name))?;
227
228                builder = builder.backend(&backend.name, transport).await;
229            }
230            TransportType::Http => {
231                let url = backend.url.as_deref().unwrap();
232                let mut transport = tower_mcp::client::HttpClientTransport::new(url);
233                if let Some(token) = &backend.bearer_token {
234                    transport = transport.bearer_token(token);
235                }
236
237                builder = builder.backend(&backend.name, transport).await;
238            }
239        }
240
241        // Per-backend middleware stack (applied in order: inner -> outer)
242
243        // Retry (innermost -- retries happen before other middleware)
244        if let Some(retry_cfg) = &backend.retry {
245            tracing::info!(
246                backend = %backend.name,
247                max_retries = retry_cfg.max_retries,
248                initial_backoff_ms = retry_cfg.initial_backoff_ms,
249                max_backoff_ms = retry_cfg.max_backoff_ms,
250                "Applying retry policy"
251            );
252            let layer = crate::retry::build_retry_layer(retry_cfg, &backend.name);
253            builder = builder.backend_layer(layer);
254        }
255
256        // Hedging (after retry, before concurrency -- hedges are separate requests)
257        if let Some(hedge_cfg) = &backend.hedging {
258            let delay = Duration::from_millis(hedge_cfg.delay_ms);
259            let max_attempts = hedge_cfg.max_hedges + 1; // +1 for the primary request
260            tracing::info!(
261                backend = %backend.name,
262                delay_ms = hedge_cfg.delay_ms,
263                max_hedges = hedge_cfg.max_hedges,
264                "Applying request hedging"
265            );
266            let layer = if delay.is_zero() {
267                tower_resilience::hedge::HedgeLayer::builder()
268                    .no_delay()
269                    .max_hedged_attempts(max_attempts)
270                    .name(format!("{}-hedge", backend.name))
271                    .build()
272            } else {
273                tower_resilience::hedge::HedgeLayer::builder()
274                    .delay(delay)
275                    .max_hedged_attempts(max_attempts)
276                    .name(format!("{}-hedge", backend.name))
277                    .build()
278            };
279            builder = builder.backend_layer(layer);
280        }
281
282        // Concurrency limit
283        if let Some(cc) = &backend.concurrency {
284            tracing::info!(
285                backend = %backend.name,
286                max = cc.max_concurrent,
287                "Applying concurrency limit"
288            );
289            builder =
290                builder.backend_layer(tower::limit::ConcurrencyLimitLayer::new(cc.max_concurrent));
291        }
292
293        // Rate limit
294        if let Some(rl) = &backend.rate_limit {
295            tracing::info!(
296                backend = %backend.name,
297                requests = rl.requests,
298                period_seconds = rl.period_seconds,
299                "Applying rate limit"
300            );
301            let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
302                .limit_for_period(rl.requests)
303                .refresh_period(Duration::from_secs(rl.period_seconds))
304                .name(format!("{}-ratelimit", backend.name))
305                .build();
306            builder = builder.backend_layer(layer);
307        }
308
309        // Timeout
310        if let Some(timeout) = &backend.timeout {
311            tracing::info!(
312                backend = %backend.name,
313                seconds = timeout.seconds,
314                "Applying timeout"
315            );
316            builder =
317                builder.backend_layer(TimeoutLayer::new(Duration::from_secs(timeout.seconds)));
318        }
319
320        // Circuit breaker
321        if let Some(cb) = &backend.circuit_breaker {
322            tracing::info!(
323                backend = %backend.name,
324                failure_rate = cb.failure_rate_threshold,
325                wait_seconds = cb.wait_duration_seconds,
326                "Applying circuit breaker"
327            );
328            let layer = tower_resilience::circuitbreaker::CircuitBreakerLayer::builder()
329                .failure_rate_threshold(cb.failure_rate_threshold)
330                .minimum_number_of_calls(cb.minimum_calls)
331                .wait_duration_in_open(Duration::from_secs(cb.wait_duration_seconds))
332                .permitted_calls_in_half_open(cb.permitted_calls_in_half_open)
333                .name(format!("{}-cb", backend.name))
334                .build();
335            builder = builder.backend_layer(layer);
336        }
337
338        // Outlier detection (outermost -- observes errors after all other middleware)
339        if let Some(od) = &backend.outlier_detection
340            && let Some(ref detector) = outlier_detector
341        {
342            tracing::info!(
343                backend = %backend.name,
344                consecutive_errors = od.consecutive_errors,
345                base_ejection_seconds = od.base_ejection_seconds,
346                max_ejection_percent = od.max_ejection_percent,
347                "Applying outlier detection"
348            );
349            let layer = crate::outlier::OutlierDetectionLayer::new(
350                backend.name.clone(),
351                od.clone(),
352                detector.clone(),
353            );
354            builder = builder.backend_layer(layer);
355        }
356    }
357
358    let result = builder.build().await?;
359
360    if !result.skipped.is_empty() {
361        for s in &result.skipped {
362            tracing::warn!("Skipped backend: {s}");
363        }
364    }
365
366    Ok(result.proxy)
367}
368
369/// Build the MCP-level middleware stack around the proxy.
370fn build_middleware_stack(
371    config: &ProxyConfig,
372    proxy: McpProxy,
373) -> Result<(
374    BoxCloneService<RouterRequest, RouterResponse, Infallible>,
375    Option<cache::CacheHandle>,
376)> {
377    let mut service: BoxCloneService<RouterRequest, RouterResponse, Infallible> =
378        BoxCloneService::new(proxy);
379    let mut cache_handle: Option<cache::CacheHandle> = None;
380
381    // Argument injection (innermost -- merges default/per-tool args into CallTool requests)
382    let injection_rules: Vec<_> = config
383        .backends
384        .iter()
385        .filter(|b| !b.default_args.is_empty() || !b.inject_args.is_empty())
386        .map(|b| {
387            let namespace = format!("{}{}", b.name, config.proxy.separator);
388            tracing::info!(
389                backend = %b.name,
390                default_args = b.default_args.len(),
391                tool_rules = b.inject_args.len(),
392                "Applying argument injection"
393            );
394            crate::inject::InjectionRules::new(
395                namespace,
396                b.default_args.clone(),
397                b.inject_args.clone(),
398            )
399        })
400        .collect();
401
402    if !injection_rules.is_empty() {
403        service = BoxCloneService::new(crate::inject::InjectArgsService::new(
404            service,
405            injection_rules,
406        ));
407    }
408
409    // Canary routing (rewrites requests from primary to canary namespace based on weight)
410    let canary_mappings: std::collections::HashMap<String, (String, u32, u32)> = config
411        .backends
412        .iter()
413        .filter_map(|b| {
414            b.canary_of.as_ref().map(|primary_name| {
415                // Find the primary backend's weight
416                let primary_weight = config
417                    .backends
418                    .iter()
419                    .find(|p| p.name == *primary_name)
420                    .map(|p| p.weight)
421                    .unwrap_or(100);
422                (
423                    primary_name.clone(),
424                    (b.name.clone(), primary_weight, b.weight),
425                )
426            })
427        })
428        .collect();
429
430    if !canary_mappings.is_empty() {
431        for (primary, (canary, pw, cw)) in &canary_mappings {
432            tracing::info!(
433                primary = %primary,
434                canary = %canary,
435                primary_weight = pw,
436                canary_weight = cw,
437                "Enabling canary routing"
438            );
439        }
440        service = BoxCloneService::new(crate::canary::CanaryService::new(
441            service,
442            canary_mappings,
443            &config.proxy.separator,
444        ));
445    }
446
447    // Traffic mirroring (sends cloned requests through the proxy)
448    let mirror_mappings: std::collections::HashMap<String, (String, u32)> = config
449        .backends
450        .iter()
451        .filter_map(|b| {
452            b.mirror_of
453                .as_ref()
454                .map(|source| (source.clone(), (b.name.clone(), b.mirror_percent)))
455        })
456        .collect();
457
458    if !mirror_mappings.is_empty() {
459        for (source, (mirror, pct)) in &mirror_mappings {
460            tracing::info!(
461                source = %source,
462                mirror = %mirror,
463                percent = pct,
464                "Enabling traffic mirroring"
465            );
466        }
467        service = BoxCloneService::new(crate::mirror::MirrorService::new(
468            service,
469            mirror_mappings,
470            &config.proxy.separator,
471        ));
472    }
473
474    // Response caching
475    let cache_configs: Vec<_> = config
476        .backends
477        .iter()
478        .filter_map(|b| {
479            b.cache
480                .as_ref()
481                .map(|c| (format!("{}{}", b.name, config.proxy.separator), c))
482        })
483        .collect();
484
485    if !cache_configs.is_empty() {
486        for (ns, cfg) in &cache_configs {
487            tracing::info!(
488                backend = %ns.trim_end_matches(&config.proxy.separator),
489                resource_ttl = cfg.resource_ttl_seconds,
490                tool_ttl = cfg.tool_ttl_seconds,
491                max_entries = cfg.max_entries,
492                "Applying response cache"
493            );
494        }
495        let (cache_svc, handle) = cache::CacheService::new(service, cache_configs);
496        service = BoxCloneService::new(cache_svc);
497        cache_handle = Some(handle);
498    }
499
500    // Request coalescing
501    if config.performance.coalesce_requests {
502        tracing::info!("Request coalescing enabled");
503        service = BoxCloneService::new(coalesce::CoalesceService::new(service));
504    }
505
506    // Request validation
507    if config.security.max_argument_size.is_some() {
508        let validation = ValidationConfig {
509            max_argument_size: config.security.max_argument_size,
510        };
511        if let Some(max) = validation.max_argument_size {
512            tracing::info!(max_argument_size = max, "Applying request validation");
513        }
514        service = BoxCloneService::new(ValidationService::new(service, validation));
515    }
516
517    // Static capability filtering
518    let filters: Vec<_> = config
519        .backends
520        .iter()
521        .filter_map(|b| b.build_filter(&config.proxy.separator))
522        .collect();
523
524    if !filters.is_empty() {
525        for f in &filters {
526            tracing::info!(
527                backend = %f.namespace.trim_end_matches(&config.proxy.separator),
528                tool_filter = ?f.tool_filter,
529                resource_filter = ?f.resource_filter,
530                prompt_filter = ?f.prompt_filter,
531                "Applying capability filter"
532            );
533        }
534        service = BoxCloneService::new(CapabilityFilterService::new(service, filters));
535    }
536
537    // Tool aliasing
538    let alias_mappings: Vec<_> = config
539        .backends
540        .iter()
541        .flat_map(|b| {
542            let ns = format!("{}{}", b.name, config.proxy.separator);
543            b.aliases
544                .iter()
545                .map(move |a| (ns.clone(), a.from.clone(), a.to.clone()))
546        })
547        .collect();
548
549    if let Some(alias_map) = alias::AliasMap::new(alias_mappings) {
550        let count = alias_map.forward.len();
551        tracing::info!(aliases = count, "Applying tool aliases");
552        service = BoxCloneService::new(alias::AliasService::new(service, alias_map));
553    }
554
555    // RBAC (JWT auth only)
556    #[cfg(feature = "oauth")]
557    {
558        let rbac_config = match &config.auth {
559            Some(AuthConfig::Jwt {
560                roles,
561                role_mapping: Some(mapping),
562                ..
563            }) if !roles.is_empty() => {
564                tracing::info!(
565                    roles = roles.len(),
566                    claim = %mapping.claim,
567                    "Enabling RBAC"
568                );
569                Some(RbacConfig::new(roles, mapping))
570            }
571            _ => None,
572        };
573
574        if let Some(rbac) = rbac_config {
575            service = BoxCloneService::new(RbacService::new(service, rbac));
576        }
577
578        // Token passthrough (inject ClientToken for forward_auth backends)
579        let forward_namespaces: std::collections::HashSet<String> = config
580            .backends
581            .iter()
582            .filter(|b| b.forward_auth)
583            .map(|b| format!("{}{}", b.name, config.proxy.separator))
584            .collect();
585
586        if !forward_namespaces.is_empty() {
587            tracing::info!(
588                backends = ?forward_namespaces,
589                "Enabling token passthrough for forward_auth backends"
590            );
591            service = BoxCloneService::new(crate::token::TokenPassthroughService::new(
592                service,
593                forward_namespaces,
594            ));
595        }
596    }
597
598    // Metrics
599    #[cfg(feature = "metrics")]
600    if config.observability.metrics.enabled {
601        service = BoxCloneService::new(crate::metrics::MetricsService::new(service));
602    }
603
604    // Audit logging
605    if config.observability.audit {
606        tracing::info!("Audit logging enabled (target: mcp::audit)");
607        let audited = tower::Layer::layer(&tower_mcp::AuditLayer::new(), service);
608        service = BoxCloneService::new(tower_mcp::CatchError::new(audited));
609    }
610
611    Ok((service, cache_handle))
612}
613
614/// Apply inbound authentication middleware to the router.
615async fn apply_auth(config: &ProxyConfig, router: Router) -> Result<Router> {
616    let router = if let Some(auth) = &config.auth {
617        match auth {
618            AuthConfig::Bearer { tokens } => {
619                tracing::info!(token_count = tokens.len(), "Enabling bearer token auth");
620                let validator = StaticBearerValidator::new(tokens.iter().cloned());
621                let layer = AuthLayer::new(validator);
622                router.layer(layer)
623            }
624            #[cfg(feature = "oauth")]
625            AuthConfig::Jwt {
626                issuer,
627                audience,
628                jwks_uri,
629                ..
630            } => {
631                tracing::info!(
632                    issuer = %issuer,
633                    audience = %audience,
634                    jwks_uri = %jwks_uri,
635                    "Enabling JWT auth (JWKS)"
636                );
637                let validator = tower_mcp::oauth::JwksValidator::builder(jwks_uri)
638                    .expected_audience(audience)
639                    .expected_issuer(issuer)
640                    .build()
641                    .await
642                    .context("building JWKS validator")?;
643
644                let addr = format!(
645                    "http://{}:{}",
646                    config.proxy.listen.host, config.proxy.listen.port
647                );
648                let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
649                    .authorization_server(issuer);
650
651                let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
652                router.layer(layer)
653            }
654            #[cfg(not(feature = "oauth"))]
655            AuthConfig::Jwt { .. } => {
656                anyhow::bail!(
657                    "JWT auth requires the 'oauth' feature. Rebuild with: cargo install mcp-proxy --features oauth"
658                );
659            }
660        }
661    } else {
662        router
663    };
664    Ok(router)
665}
666
667/// Wait for SIGTERM or SIGINT, then log and return.
668pub async fn shutdown_signal(timeout: Duration) {
669    let ctrl_c = tokio::signal::ctrl_c();
670    #[cfg(unix)]
671    {
672        let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
673            .expect("SIGTERM handler");
674        tokio::select! {
675            _ = ctrl_c => {},
676            _ = sigterm.recv() => {},
677        }
678    }
679    #[cfg(not(unix))]
680    {
681        ctrl_c.await.ok();
682    }
683    tracing::info!(
684        timeout_seconds = timeout.as_secs(),
685        "Shutdown signal received, draining connections"
686    );
687}