use std::collections::HashMap;
use std::convert::Infallible;
use std::time::Duration;
use anyhow::{Context, Result};
use axum::Router;
use tokio::process::Command;
use tower::timeout::TimeoutLayer;
use tower::util::BoxCloneService;
use tower_mcp::SessionHandle;
use tower_mcp::auth::{AuthLayer, StaticBearerValidator};
use tower_mcp::client::StdioClientTransport;
use tower_mcp::proxy::McpProxy;
use tower_mcp::{RouterRequest, RouterResponse};
use crate::admin::BackendMeta;
use crate::alias;
use crate::cache;
use crate::coalesce;
use crate::config::{AuthConfig, ProxyConfig, TransportType};
use crate::filter::CapabilityFilterService;
#[cfg(feature = "oauth")]
use crate::rbac::{RbacConfig, RbacService};
use crate::validation::{ValidationConfig, ValidationService};
pub struct Proxy {
router: Router,
session_handle: SessionHandle,
inner: McpProxy,
config: ProxyConfig,
#[cfg(feature = "discovery")]
discovery_index: Option<crate::discovery::SharedDiscoveryIndex>,
}
impl Proxy {
pub async fn from_config(config: ProxyConfig) -> Result<Self> {
let (mcp_proxy, cb_handles) = build_mcp_proxy(&config).await?;
let proxy_for_admin = mcp_proxy.clone();
let mut proxy_for_caller = mcp_proxy.clone();
let proxy_for_management = mcp_proxy.clone();
#[cfg(feature = "metrics")]
let metrics_handle = if config.observability.metrics.enabled {
tracing::info!("Prometheus metrics enabled at /admin/metrics");
let builder = metrics_exporter_prometheus::PrometheusBuilder::new();
let handle = builder
.install_recorder()
.context("installing Prometheus metrics recorder")?;
Some(handle)
} else {
None
};
#[cfg(not(feature = "metrics"))]
let metrics_handle = None;
let (service, cache_handle) = build_middleware_stack(&config, mcp_proxy)?;
let (router, session_handle) =
tower_mcp::transport::http::HttpTransport::from_service(service)
.into_router_with_handle();
let router = apply_auth(&config, router).await?;
let backend_meta: std::collections::HashMap<String, BackendMeta> = config
.backends
.iter()
.map(|b| {
(
b.name.clone(),
BackendMeta {
transport: format!("{:?}", b.transport).to_lowercase(),
},
)
})
.collect();
let admin_state = crate::admin::spawn_health_checker(
proxy_for_admin,
config.proxy.name.clone(),
config.proxy.version.clone(),
config.backends.len(),
backend_meta,
);
let router = router.nest(
"/admin",
crate::admin::admin_router(
admin_state.clone(),
metrics_handle,
session_handle.clone(),
cache_handle,
proxy_for_management,
&config,
config.source_path.clone(),
cb_handles,
),
);
tracing::info!("Admin API enabled at /admin/backends");
#[cfg(feature = "discovery")]
let discovery_enabled = config.proxy.tool_discovery
|| config.proxy.tool_exposure == crate::config::ToolExposure::Search;
#[cfg(feature = "discovery")]
let (discovery_index, discovery_tools) = if discovery_enabled {
let index =
crate::discovery::build_index(&mut proxy_for_caller, &config.proxy.separator).await;
let tools = crate::discovery::build_discovery_tools(index.clone());
(Some(index), Some(tools))
} else {
(None, None)
};
#[cfg(not(feature = "discovery"))]
let discovery_tools: Option<Vec<tower_mcp::Tool>> = None;
if let Err(e) = crate::admin_tools::register_admin_tools(
&proxy_for_caller,
admin_state,
session_handle.clone(),
&config,
discovery_tools,
)
.await
{
tracing::warn!("Failed to register admin tools: {e}");
} else {
tracing::info!("MCP admin tools registered under proxy/ namespace");
}
Ok(Self {
router,
session_handle,
inner: proxy_for_caller,
config,
#[cfg(feature = "discovery")]
discovery_index,
})
}
pub fn session_handle(&self) -> &SessionHandle {
&self.session_handle
}
pub fn mcp_proxy(&self) -> &McpProxy {
&self.inner
}
pub fn enable_hot_reload(&self, config_path: std::path::PathBuf) {
tracing::info!("Hot reload enabled, watching config file for changes");
crate::reload::spawn_config_watcher(
config_path,
self.inner.clone(),
#[cfg(feature = "discovery")]
self.discovery_index
.as_ref()
.map(|idx| (idx.clone(), self.config.proxy.separator.clone())),
);
}
pub fn into_router(self) -> (Router, SessionHandle) {
(self.router, self.session_handle)
}
pub async fn serve(self) -> Result<()> {
let addr = format!(
"{}:{}",
self.config.proxy.listen.host, self.config.proxy.listen.port
);
tracing::info!(listen = %addr, "Proxy ready");
let listener = tokio::net::TcpListener::bind(&addr)
.await
.with_context(|| format!("binding to {}", addr))?;
let shutdown_timeout = Duration::from_secs(self.config.proxy.shutdown_timeout_seconds);
axum::serve(listener, self.router)
.with_graceful_shutdown(shutdown_signal(shutdown_timeout))
.await
.context("server error")?;
tracing::info!("Proxy shut down");
Ok(())
}
}
pub type CbHandle = tower_resilience::circuitbreaker::CircuitBreakerHandle;
async fn build_mcp_proxy(config: &ProxyConfig) -> Result<(McpProxy, HashMap<String, CbHandle>)> {
let mut builder = McpProxy::builder(&config.proxy.name, &config.proxy.version)
.separator(&config.proxy.separator);
let mut cb_handles: HashMap<String, CbHandle> = HashMap::new();
if let Some(instructions) = &config.proxy.instructions {
builder = builder.instructions(instructions);
}
let outlier_detector = {
let max_pct = config
.backends
.iter()
.filter_map(|b| b.outlier_detection.as_ref())
.map(|od| od.max_ejection_percent)
.max();
max_pct.map(crate::outlier::OutlierDetector::new)
};
for backend in &config.backends {
tracing::info!(name = %backend.name, transport = ?backend.transport, "Adding backend");
match backend.transport {
TransportType::Stdio => {
let command = backend.command.as_deref().unwrap();
let args: Vec<&str> = backend.args.iter().map(|s| s.as_str()).collect();
let mut cmd = Command::new(command);
cmd.args(&args);
for (key, value) in &backend.env {
cmd.env(key, value);
}
let transport = StdioClientTransport::spawn_command(&mut cmd)
.await
.with_context(|| format!("spawning backend '{}'", backend.name))?;
builder = builder.backend(&backend.name, transport).await;
}
TransportType::Http => {
let url = backend.url.as_deref().unwrap();
let mut transport = tower_mcp::client::HttpClientTransport::new(url);
if let Some(token) = &backend.bearer_token {
transport = transport.bearer_token(token);
}
builder = builder.backend(&backend.name, transport).await;
}
#[cfg(feature = "websocket")]
TransportType::Websocket => {
let url = backend.url.as_deref().unwrap();
tracing::info!(url = %url, "Connecting to WebSocket backend");
let transport = if let Some(token) = &backend.bearer_token {
crate::ws_transport::WebSocketClientTransport::connect_with_bearer_token(
url, token,
)
.await
.with_context(|| {
format!("connecting to WebSocket backend '{}'", backend.name)
})?
} else {
crate::ws_transport::WebSocketClientTransport::connect(url)
.await
.with_context(|| {
format!("connecting to WebSocket backend '{}'", backend.name)
})?
};
builder = builder.backend(&backend.name, transport).await;
}
#[cfg(not(feature = "websocket"))]
TransportType::Websocket => {
anyhow::bail!(
"WebSocket transport requires the 'websocket' feature. \
Rebuild with: cargo install mcp-proxy --features websocket"
);
}
}
if let Some(retry_cfg) = &backend.retry {
tracing::info!(
backend = %backend.name,
max_retries = retry_cfg.max_retries,
initial_backoff_ms = retry_cfg.initial_backoff_ms,
max_backoff_ms = retry_cfg.max_backoff_ms,
"Applying retry policy"
);
let layer = crate::retry::build_retry_layer(retry_cfg, &backend.name);
builder = builder.backend_layer(layer);
}
if let Some(hedge_cfg) = &backend.hedging {
let delay = Duration::from_millis(hedge_cfg.delay_ms);
let max_attempts = hedge_cfg.max_hedges + 1; tracing::info!(
backend = %backend.name,
delay_ms = hedge_cfg.delay_ms,
max_hedges = hedge_cfg.max_hedges,
"Applying request hedging"
);
let layer = if delay.is_zero() {
tower_resilience::hedge::HedgeLayer::builder()
.no_delay()
.max_hedged_attempts(max_attempts)
.name(format!("{}-hedge", backend.name))
.build()
} else {
tower_resilience::hedge::HedgeLayer::builder()
.delay(delay)
.max_hedged_attempts(max_attempts)
.name(format!("{}-hedge", backend.name))
.build()
};
builder = builder.backend_layer(layer);
}
if let Some(cc) = &backend.concurrency {
tracing::info!(
backend = %backend.name,
max = cc.max_concurrent,
"Applying concurrency limit"
);
builder =
builder.backend_layer(tower::limit::ConcurrencyLimitLayer::new(cc.max_concurrent));
}
if let Some(rl) = &backend.rate_limit {
tracing::info!(
backend = %backend.name,
requests = rl.requests,
period_seconds = rl.period_seconds,
"Applying rate limit"
);
let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
.limit_for_period(rl.requests)
.refresh_period(Duration::from_secs(rl.period_seconds))
.name(format!("{}-ratelimit", backend.name))
.build();
builder = builder.backend_layer(layer);
}
if let Some(timeout) = &backend.timeout {
tracing::info!(
backend = %backend.name,
seconds = timeout.seconds,
"Applying timeout"
);
builder =
builder.backend_layer(TimeoutLayer::new(Duration::from_secs(timeout.seconds)));
}
if let Some(cb) = &backend.circuit_breaker {
tracing::info!(
backend = %backend.name,
failure_rate = cb.failure_rate_threshold,
wait_seconds = cb.wait_duration_seconds,
"Applying circuit breaker"
);
let (layer, handle) = tower_resilience::circuitbreaker::CircuitBreakerLayer::builder()
.failure_rate_threshold(cb.failure_rate_threshold)
.minimum_number_of_calls(cb.minimum_calls)
.wait_duration_in_open(Duration::from_secs(cb.wait_duration_seconds))
.permitted_calls_in_half_open(cb.permitted_calls_in_half_open)
.name(format!("{}-cb", backend.name))
.build_with_handle();
cb_handles.insert(backend.name.clone(), handle);
builder = builder.backend_layer(layer);
}
if let Some(od) = &backend.outlier_detection
&& let Some(ref detector) = outlier_detector
{
tracing::info!(
backend = %backend.name,
consecutive_errors = od.consecutive_errors,
base_ejection_seconds = od.base_ejection_seconds,
max_ejection_percent = od.max_ejection_percent,
"Applying outlier detection"
);
let layer = crate::outlier::OutlierDetectionLayer::new(
backend.name.clone(),
od.clone(),
detector.clone(),
);
builder = builder.backend_layer(layer);
}
}
let result = builder.build().await?;
if !result.skipped.is_empty() {
for s in &result.skipped {
tracing::warn!("Skipped backend: {s}");
}
}
Ok((result.proxy, cb_handles))
}
fn build_middleware_stack(
config: &ProxyConfig,
proxy: McpProxy,
) -> Result<(
BoxCloneService<RouterRequest, RouterResponse, Infallible>,
Option<cache::CacheHandle>,
)> {
let mut service: BoxCloneService<RouterRequest, RouterResponse, Infallible> =
BoxCloneService::new(proxy);
let mut cache_handle: Option<cache::CacheHandle> = None;
let injection_rules: Vec<_> = config
.backends
.iter()
.filter(|b| !b.default_args.is_empty() || !b.inject_args.is_empty())
.map(|b| {
let namespace = format!("{}{}", b.name, config.proxy.separator);
tracing::info!(
backend = %b.name,
default_args = b.default_args.len(),
tool_rules = b.inject_args.len(),
"Applying argument injection"
);
crate::inject::InjectionRules::new(
namespace,
b.default_args.clone(),
b.inject_args.clone(),
)
})
.collect();
if !injection_rules.is_empty() {
service = BoxCloneService::new(crate::inject::InjectArgsService::new(
service,
injection_rules,
));
}
let param_overrides: Vec<_> = config
.backends
.iter()
.filter(|b| !b.param_overrides.is_empty())
.flat_map(|b| {
let namespace = format!("{}{}", b.name, config.proxy.separator);
tracing::info!(
backend = %b.name,
overrides = b.param_overrides.len(),
"Applying parameter overrides"
);
b.param_overrides
.iter()
.map(move |c| crate::param_override::ToolOverride::new(&namespace, c))
})
.collect();
if !param_overrides.is_empty() {
service = BoxCloneService::new(crate::param_override::ParamOverrideService::new(
service,
param_overrides,
));
}
let canary_mappings: std::collections::HashMap<String, (String, u32, u32)> = config
.backends
.iter()
.filter_map(|b| {
b.canary_of.as_ref().map(|primary_name| {
let primary_weight = config
.backends
.iter()
.find(|p| p.name == *primary_name)
.map(|p| p.weight)
.unwrap_or(100);
(
primary_name.clone(),
(b.name.clone(), primary_weight, b.weight),
)
})
})
.collect();
if !canary_mappings.is_empty() {
for (primary, (canary, pw, cw)) in &canary_mappings {
tracing::info!(
primary = %primary,
canary = %canary,
primary_weight = pw,
canary_weight = cw,
"Enabling canary routing"
);
}
service = BoxCloneService::new(crate::canary::CanaryService::new(
service,
canary_mappings,
&config.proxy.separator,
));
}
let mut failover_groups: std::collections::HashMap<String, Vec<(u32, String)>> =
std::collections::HashMap::new();
for b in &config.backends {
if let Some(ref primary) = b.failover_for {
failover_groups
.entry(primary.clone())
.or_default()
.push((b.priority, b.name.clone()));
}
}
let failover_mappings: std::collections::HashMap<String, Vec<String>> = failover_groups
.into_iter()
.map(|(primary, mut backends)| {
backends.sort_by_key(|(priority, _)| *priority);
let names: Vec<String> = backends.into_iter().map(|(_, name)| name).collect();
(primary, names)
})
.collect();
if !failover_mappings.is_empty() {
for (primary, failovers) in &failover_mappings {
tracing::info!(
primary = %primary,
failovers = ?failovers,
"Enabling failover routing"
);
}
service = BoxCloneService::new(crate::failover::FailoverService::new(
service,
failover_mappings,
&config.proxy.separator,
));
}
let mirror_mappings: std::collections::HashMap<String, (String, u32)> = config
.backends
.iter()
.filter_map(|b| {
b.mirror_of
.as_ref()
.map(|source| (source.clone(), (b.name.clone(), b.mirror_percent)))
})
.collect();
if !mirror_mappings.is_empty() {
for (source, (mirror, pct)) in &mirror_mappings {
tracing::info!(
source = %source,
mirror = %mirror,
percent = pct,
"Enabling traffic mirroring"
);
}
service = BoxCloneService::new(crate::mirror::MirrorService::new(
service,
mirror_mappings,
&config.proxy.separator,
));
}
let cache_configs: Vec<_> = config
.backends
.iter()
.filter_map(|b| {
b.cache
.as_ref()
.map(|c| (format!("{}{}", b.name, config.proxy.separator), c))
})
.collect();
if !cache_configs.is_empty() {
for (ns, cfg) in &cache_configs {
tracing::info!(
backend = %ns.trim_end_matches(&config.proxy.separator),
resource_ttl = cfg.resource_ttl_seconds,
tool_ttl = cfg.tool_ttl_seconds,
max_entries = cfg.max_entries,
"Applying response cache"
);
}
let (cache_svc, handle) = cache::CacheService::new(service, cache_configs, &config.cache);
service = BoxCloneService::new(cache_svc);
cache_handle = Some(handle);
}
if config.performance.coalesce_requests {
tracing::info!("Request coalescing enabled");
service = BoxCloneService::new(coalesce::CoalesceService::new(service));
}
if config.security.max_argument_size.is_some() {
let validation = ValidationConfig {
max_argument_size: config.security.max_argument_size,
};
if let Some(max) = validation.max_argument_size {
tracing::info!(max_argument_size = max, "Applying request validation");
}
service = BoxCloneService::new(ValidationService::new(service, validation));
}
let filters: Vec<_> = config
.backends
.iter()
.filter_map(|b| b.build_filter(&config.proxy.separator).transpose())
.collect::<anyhow::Result<Vec<_>>>()?;
if !filters.is_empty() {
for f in &filters {
tracing::info!(
backend = %f.namespace.trim_end_matches(&config.proxy.separator),
tool_filter = ?f.tool_filter,
resource_filter = ?f.resource_filter,
prompt_filter = ?f.prompt_filter,
"Applying capability filter"
);
}
service = BoxCloneService::new(CapabilityFilterService::new(service, filters));
}
if config.proxy.tool_exposure == crate::config::ToolExposure::Search {
let prefix = format!("proxy{}", config.proxy.separator);
tracing::info!(
prefix = %prefix,
"Search mode: ListTools will only show proxy/ namespace tools"
);
service =
BoxCloneService::new(crate::filter::SearchModeFilterService::new(service, prefix));
}
let alias_mappings: Vec<_> = config
.backends
.iter()
.flat_map(|b| {
let ns = format!("{}{}", b.name, config.proxy.separator);
b.aliases
.iter()
.map(move |a| (ns.clone(), a.from.clone(), a.to.clone()))
})
.collect();
if let Some(alias_map) = alias::AliasMap::new(alias_mappings) {
let count = alias_map.forward.len();
tracing::info!(aliases = count, "Applying tool aliases");
service = BoxCloneService::new(alias::AliasService::new(service, alias_map));
}
if !config.composite_tools.is_empty() {
let count = config.composite_tools.len();
tracing::info!(composite_tools = count, "Applying composite tool fan-out");
service = BoxCloneService::new(crate::composite::CompositeService::new(
service,
config.composite_tools.clone(),
));
}
#[cfg(feature = "oauth")]
if matches!(
&config.auth,
Some(AuthConfig::Bearer {
scoped_tokens,
..
}) if !scoped_tokens.is_empty()
) {
tracing::info!("Enabling bearer token scoping middleware");
service = BoxCloneService::new(crate::bearer_scope::BearerScopingService::new(service));
}
#[cfg(feature = "oauth")]
{
let rbac_config = match &config.auth {
Some(
AuthConfig::Jwt {
roles,
role_mapping: Some(mapping),
..
}
| AuthConfig::OAuth {
roles,
role_mapping: Some(mapping),
..
},
) if !roles.is_empty() => {
tracing::info!(
roles = roles.len(),
claim = %mapping.claim,
"Enabling RBAC"
);
Some(RbacConfig::new(roles, mapping))
}
_ => None,
};
if let Some(rbac) = rbac_config {
service = BoxCloneService::new(RbacService::new(service, rbac));
}
let forward_namespaces: std::collections::HashSet<String> = config
.backends
.iter()
.filter(|b| b.forward_auth)
.map(|b| format!("{}{}", b.name, config.proxy.separator))
.collect();
if !forward_namespaces.is_empty() {
tracing::info!(
backends = ?forward_namespaces,
"Enabling token passthrough for forward_auth backends"
);
service = BoxCloneService::new(crate::token::TokenPassthroughService::new(
service,
forward_namespaces,
));
}
}
#[cfg(feature = "metrics")]
if config.observability.metrics.enabled {
service = BoxCloneService::new(crate::metrics::MetricsService::new(service));
}
if config.observability.access_log.enabled {
tracing::info!("Access logging enabled (target: mcp::access)");
service = BoxCloneService::new(crate::access_log::AccessLogService::new(
service,
&config.proxy.separator,
));
}
if config.observability.audit {
tracing::info!("Audit logging enabled (target: mcp::audit)");
let audited = tower::Layer::layer(&tower_mcp::AuditLayer::new(), service);
service = BoxCloneService::new(tower_mcp::CatchError::new(audited));
}
if let Some(ref rl) = config.proxy.rate_limit {
tracing::info!(
requests = rl.requests,
period_seconds = rl.period_seconds,
"Applying global rate limit"
);
let layer = tower_resilience::ratelimiter::RateLimiterLayer::builder()
.limit_for_period(rl.requests)
.refresh_period(Duration::from_secs(rl.period_seconds))
.name("global-ratelimit")
.build();
let limited = tower::Layer::layer(&layer, service);
service = BoxCloneService::new(tower_mcp::CatchError::new(limited));
}
Ok((service, cache_handle))
}
async fn apply_auth(config: &ProxyConfig, router: Router) -> Result<Router> {
let router = if let Some(auth) = &config.auth {
match auth {
AuthConfig::Bearer {
tokens,
scoped_tokens,
} => {
let total = tokens.len() + scoped_tokens.len();
if scoped_tokens.is_empty() {
tracing::info!(token_count = total, "Enabling bearer token auth");
let validator = StaticBearerValidator::new(tokens.iter().cloned());
let layer = AuthLayer::new(validator);
router.layer(layer)
} else {
#[cfg(feature = "oauth")]
{
tracing::info!(
token_count = total,
scoped = scoped_tokens.len(),
"Enabling bearer token auth with per-token scoping"
);
let layer =
crate::bearer_scope::ScopedBearerAuthLayer::new(tokens, scoped_tokens);
router.layer(layer)
}
#[cfg(not(feature = "oauth"))]
{
anyhow::bail!(
"Per-token tool scoping requires the 'oauth' feature. \
Rebuild with: cargo install mcp-proxy --features oauth"
);
}
}
}
#[cfg(feature = "oauth")]
AuthConfig::Jwt {
issuer,
audience,
jwks_uri,
..
} => {
tracing::info!(
issuer = %issuer,
audience = %audience,
jwks_uri = %jwks_uri,
"Enabling JWT auth (JWKS)"
);
let validator = tower_mcp::oauth::JwksValidator::builder(jwks_uri)
.expected_audience(audience)
.expected_issuer(issuer)
.build()
.await
.context("building JWKS validator")?;
let addr = format!(
"http://{}:{}",
config.proxy.listen.host, config.proxy.listen.port
);
let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
.authorization_server(issuer);
let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
router.layer(layer)
}
#[cfg(not(feature = "oauth"))]
AuthConfig::Jwt { .. } => {
anyhow::bail!(
"JWT auth requires the 'oauth' feature. Rebuild with: cargo install mcp-proxy --features oauth"
);
}
#[cfg(feature = "oauth")]
AuthConfig::OAuth {
issuer,
audience,
token_validation,
jwks_uri,
introspection_endpoint,
client_id,
client_secret,
..
} => {
use crate::config::TokenValidationStrategy;
tracing::info!(
issuer = %issuer,
audience = %audience,
strategy = ?token_validation,
"Enabling OAuth 2.1 auth"
);
let discovered = crate::introspection::discover_auth_server(issuer)
.await
.context("discovering OAuth authorization server")?;
let effective_jwks_uri = jwks_uri
.as_deref()
.or(discovered.jwks_uri.as_deref())
.ok_or_else(|| {
anyhow::anyhow!(
"JWKS URI not found via discovery and not configured manually"
)
})?;
let effective_introspection = introspection_endpoint
.as_deref()
.or(discovered.introspection_endpoint.as_deref());
let addr = format!(
"http://{}:{}",
config.proxy.listen.host, config.proxy.listen.port
);
let metadata = tower_mcp::oauth::ProtectedResourceMetadata::new(&addr)
.authorization_server(issuer);
match token_validation {
TokenValidationStrategy::Jwt => {
let validator =
tower_mcp::oauth::JwksValidator::builder(effective_jwks_uri)
.expected_audience(audience)
.expected_issuer(issuer)
.build()
.await
.context("building JWKS validator")?;
let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
router.layer(layer)
}
TokenValidationStrategy::Introspection => {
let endpoint = effective_introspection.ok_or_else(|| {
anyhow::anyhow!(
"introspection endpoint not found via discovery and not configured"
)
})?;
let validator = crate::introspection::IntrospectionValidator::new(
endpoint,
client_id.as_deref().unwrap(),
client_secret.as_deref().unwrap(),
)
.expected_audience(audience);
let layer = tower_mcp::oauth::OAuthLayer::new(validator, metadata);
router.layer(layer)
}
TokenValidationStrategy::Both => {
let endpoint = effective_introspection.ok_or_else(|| {
anyhow::anyhow!(
"introspection endpoint not found via discovery and not configured"
)
})?;
let jwt_validator =
tower_mcp::oauth::JwksValidator::builder(effective_jwks_uri)
.expected_audience(audience)
.expected_issuer(issuer)
.build()
.await
.context("building JWKS validator")?;
let introspection_validator =
crate::introspection::IntrospectionValidator::new(
endpoint,
client_id.as_deref().unwrap(),
client_secret.as_deref().unwrap(),
)
.expected_audience(audience);
let fallback = crate::introspection::FallbackValidator::new(
jwt_validator,
introspection_validator,
);
let layer = tower_mcp::oauth::OAuthLayer::new(fallback, metadata);
router.layer(layer)
}
}
}
#[cfg(not(feature = "oauth"))]
AuthConfig::OAuth { .. } => {
anyhow::bail!(
"OAuth auth requires the 'oauth' feature. Rebuild with: cargo install mcp-proxy --features oauth"
);
}
}
} else {
router
};
Ok(router)
}
pub async fn shutdown_signal(timeout: Duration) {
let ctrl_c = tokio::signal::ctrl_c();
#[cfg(unix)]
{
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("SIGTERM handler");
tokio::select! {
_ = ctrl_c => {},
_ = sigterm.recv() => {},
}
}
#[cfg(not(unix))]
{
ctrl_c.await.ok();
}
tracing::info!(
timeout_seconds = timeout.as_secs(),
"Shutdown signal received, draining connections"
);
}