use std::{
future::Future,
net::SocketAddr,
path::{Path, PathBuf},
pin::Pin,
sync::Arc,
time::Duration,
};
use arc_swap::ArcSwap;
use axum::{body::Body, extract::Request, middleware::Next, response::IntoResponse};
use rmcp::{
ServerHandler,
transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
},
};
use rustls::RootCertStore;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use crate::{
auth::{
AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
build_rate_limiter, extract_mtls_identity,
},
error::McpxError,
mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
};
#[allow(
clippy::needless_pass_by_value,
reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
)]
fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
McpxError::Startup(format!("{e:#}"))
}
#[allow(
clippy::needless_pass_by_value,
reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
)]
fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
McpxError::Startup(format!("{op}: {e}"))
}
pub type ReadinessCheck =
Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
#[allow(
missing_debug_implementations,
reason = "contains callback/trait objects that don't impl Debug"
)]
#[allow(
clippy::struct_excessive_bools,
reason = "server configuration naturally has many boolean feature flags"
)]
#[non_exhaustive]
pub struct McpServerConfig {
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in 1.0"
)]
pub bind_addr: String,
#[deprecated(
since = "0.13.0",
note = "set via McpServerConfig::new(); direct field access will become pub(crate) in 1.0"
)]
pub name: String,
#[deprecated(
since = "0.13.0",
note = "set via McpServerConfig::new(); direct field access will become pub(crate) in 1.0"
)]
pub version: String,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in 1.0"
)]
pub tls_cert_path: Option<PathBuf>,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in 1.0"
)]
pub tls_key_path: Option<PathBuf>,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in 1.0"
)]
pub auth: Option<AuthConfig>,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in 1.0"
)]
pub rbac: Option<Arc<RbacPolicy>>,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in 1.0"
)]
pub allowed_origins: Vec<String>,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in 1.0"
)]
pub tool_rate_limit: Option<u32>,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in 1.0"
)]
pub readiness_check: Option<ReadinessCheck>,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in 1.0"
)]
pub max_request_body: usize,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in 1.0"
)]
pub request_timeout: Duration,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in 1.0"
)]
pub shutdown_timeout: Duration,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in 1.0"
)]
pub session_idle_timeout: Duration,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in 1.0"
)]
pub sse_keep_alive: Duration,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in 1.0"
)]
pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in 1.0"
)]
pub extra_router: Option<axum::Router>,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in 1.0"
)]
pub public_url: Option<String>,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in 1.0"
)]
pub log_request_headers: bool,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in 1.0"
)]
pub compression_enabled: bool,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in 1.0"
)]
pub compression_min_size: u16,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in 1.0"
)]
pub max_concurrent_requests: Option<usize>,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in 1.0"
)]
pub admin_enabled: bool,
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in 1.0"
)]
pub admin_role: String,
#[cfg(feature = "metrics")]
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in 1.0"
)]
pub metrics_enabled: bool,
#[cfg(feature = "metrics")]
#[deprecated(
since = "0.13.0",
note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in 1.0"
)]
pub metrics_bind: String,
}
#[allow(
missing_debug_implementations,
reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
)]
pub struct Validated<T>(T);
impl<T> std::fmt::Debug for Validated<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Validated").finish_non_exhaustive()
}
}
impl<T> Validated<T> {
#[must_use]
pub fn as_inner(&self) -> &T {
&self.0
}
#[must_use]
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> std::ops::Deref for Validated<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
#[allow(
deprecated,
reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
)]
impl McpServerConfig {
#[must_use]
pub fn new(
bind_addr: impl Into<String>,
name: impl Into<String>,
version: impl Into<String>,
) -> Self {
Self {
bind_addr: bind_addr.into(),
name: name.into(),
version: version.into(),
tls_cert_path: None,
tls_key_path: None,
auth: None,
rbac: None,
allowed_origins: Vec::new(),
tool_rate_limit: None,
readiness_check: None,
max_request_body: 1024 * 1024,
request_timeout: Duration::from_mins(2),
shutdown_timeout: Duration::from_secs(30),
session_idle_timeout: Duration::from_mins(20),
sse_keep_alive: Duration::from_secs(15),
on_reload_ready: None,
extra_router: None,
public_url: None,
log_request_headers: false,
compression_enabled: false,
compression_min_size: 1024,
max_concurrent_requests: None,
admin_enabled: false,
admin_role: "admin".to_owned(),
#[cfg(feature = "metrics")]
metrics_enabled: false,
#[cfg(feature = "metrics")]
metrics_bind: "127.0.0.1:9090".into(),
}
}
#[must_use]
pub fn with_auth(mut self, auth: AuthConfig) -> Self {
self.auth = Some(auth);
self
}
#[must_use]
pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
self.bind_addr = addr.into();
self
}
#[must_use]
pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
self.rbac = Some(rbac);
self
}
#[must_use]
pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
self.tls_cert_path = Some(cert_path.into());
self.tls_key_path = Some(key_path.into());
self
}
#[must_use]
pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
self.public_url = Some(url.into());
self
}
#[must_use]
pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.allowed_origins = origins.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn with_extra_router(mut self, router: axum::Router) -> Self {
self.extra_router = Some(router);
self
}
#[must_use]
pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
self.readiness_check = Some(check);
self
}
#[must_use]
pub fn with_max_request_body(mut self, bytes: usize) -> Self {
self.max_request_body = bytes;
self
}
#[must_use]
pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = timeout;
self
}
#[must_use]
pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
self.shutdown_timeout = timeout;
self
}
#[must_use]
pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
self.session_idle_timeout = timeout;
self
}
#[must_use]
pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
self.sse_keep_alive = interval;
self
}
#[must_use]
pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
self.max_concurrent_requests = Some(limit);
self
}
#[must_use]
pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
self.tool_rate_limit = Some(per_minute);
self
}
#[must_use]
pub fn with_reload_callback<F>(mut self, callback: F) -> Self
where
F: FnOnce(ReloadHandle) + Send + 'static,
{
self.on_reload_ready = Some(Box::new(callback));
self
}
#[must_use]
pub fn enable_compression(mut self, min_size: u16) -> Self {
self.compression_enabled = true;
self.compression_min_size = min_size;
self
}
#[must_use]
pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
self.admin_enabled = true;
self.admin_role = role.into();
self
}
#[must_use]
pub fn enable_request_header_logging(mut self) -> Self {
self.log_request_headers = true;
self
}
#[cfg(feature = "metrics")]
#[must_use]
pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
self.metrics_enabled = true;
self.metrics_bind = bind.into();
self
}
pub fn validate(self) -> Result<Validated<Self>, McpxError> {
self.check()?;
Ok(Validated(self))
}
fn check(&self) -> Result<(), McpxError> {
if self.admin_enabled {
let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
if !auth_enabled {
return Err(McpxError::Config(
"admin_enabled=true requires auth to be configured and enabled".into(),
));
}
}
match (&self.tls_cert_path, &self.tls_key_path) {
(Some(_), None) => {
return Err(McpxError::Config(
"tls_cert_path is set but tls_key_path is missing".into(),
));
}
(None, Some(_)) => {
return Err(McpxError::Config(
"tls_key_path is set but tls_cert_path is missing".into(),
));
}
_ => {}
}
if self.bind_addr.parse::<SocketAddr>().is_err() {
return Err(McpxError::Config(format!(
"bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
self.bind_addr
)));
}
if let Some(ref url) = self.public_url
&& !(url.starts_with("http://") || url.starts_with("https://"))
{
return Err(McpxError::Config(format!(
"public_url {url:?} must start with http:// or https://"
)));
}
for origin in &self.allowed_origins {
if !(origin.starts_with("http://") || origin.starts_with("https://")) {
return Err(McpxError::Config(format!(
"allowed_origins entry {origin:?} must start with http:// or https://"
)));
}
}
if self.max_request_body == 0 {
return Err(McpxError::Config(
"max_request_body must be greater than zero".into(),
));
}
#[cfg(feature = "oauth")]
if let Some(auth_cfg) = &self.auth
&& let Some(oauth_cfg) = &auth_cfg.oauth
{
oauth_cfg.validate()?;
}
Ok(())
}
}
#[allow(
missing_debug_implementations,
reason = "contains Arc<AuthState> with non-Debug fields"
)]
pub struct ReloadHandle {
auth: Option<Arc<AuthState>>,
rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
crl_set: Option<Arc<CrlSet>>,
}
impl ReloadHandle {
pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
if let Some(ref auth) = self.auth {
auth.reload_keys(keys);
}
}
pub fn reload_rbac(&self, policy: RbacPolicy) {
if let Some(ref rbac) = self.rbac {
rbac.store(Arc::new(policy));
tracing::info!("RBAC policy reloaded");
}
}
pub async fn refresh_crls(&self) -> Result<(), McpxError> {
let Some(ref crl_set) = self.crl_set else {
return Err(McpxError::Config(
"CRL refresh requested but mTLS CRL support is not configured".into(),
));
};
crl_set.force_refresh().await
}
}
#[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
struct AppRunParams {
tls_paths: Option<(PathBuf, PathBuf)>,
mtls_config: Option<MtlsConfig>,
shutdown_timeout: Duration,
auth_state: Option<Arc<AuthState>>,
rbac_swap: Arc<ArcSwap<RbacPolicy>>,
on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
ct: CancellationToken,
scheme: &'static str,
name: String,
}
#[allow(
clippy::cognitive_complexity,
reason = "router assembly is intrinsically sequential; splitting harms readability"
)]
#[allow(
deprecated,
reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
)]
fn build_app_router<H, F>(
mut config: McpServerConfig,
handler_factory: F,
) -> anyhow::Result<(axum::Router, AppRunParams)>
where
H: ServerHandler + 'static,
F: Fn() -> H + Send + Sync + Clone + 'static,
{
let ct = CancellationToken::new();
let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
let mcp_service = StreamableHttpService::new(
move || Ok(handler_factory()),
{
let mut mgr = LocalSessionManager::default();
mgr.session_config.keep_alive = Some(config.session_idle_timeout);
mgr.into()
},
StreamableHttpServerConfig::default()
.with_allowed_hosts(allowed_hosts)
.with_sse_keep_alive(Some(config.sse_keep_alive))
.with_cancellation_token(ct.child_token()),
);
let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
let auth_state: Option<Arc<AuthState>> = match config.auth {
Some(ref auth_config) if auth_config.enabled => {
let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
let pre_auth_limiter = auth_config
.rate_limit
.as_ref()
.map(crate::auth::build_pre_auth_limiter);
#[cfg(feature = "oauth")]
let jwks_cache = auth_config
.oauth
.as_ref()
.map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
.transpose()
.map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
Some(Arc::new(AuthState {
api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
rate_limiter,
pre_auth_limiter,
#[cfg(feature = "oauth")]
jwks_cache,
seen_identities: std::sync::Mutex::new(std::collections::HashSet::new()),
counters: crate::auth::AuthCounters::default(),
}))
}
_ => None,
};
let rbac_swap = Arc::new(ArcSwap::new(
config
.rbac
.clone()
.unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
));
if config.admin_enabled {
let Some(ref auth_state_ref) = auth_state else {
return Err(anyhow::anyhow!(
"admin_enabled=true requires auth to be configured and enabled"
));
};
let admin_state = crate::admin::AdminState {
started_at: std::time::Instant::now(),
name: config.name.clone(),
version: config.version.clone(),
auth: Some(Arc::clone(auth_state_ref)),
rbac: Arc::clone(&rbac_swap),
};
let admin_cfg = crate::admin::AdminConfig {
role: config.admin_role.clone(),
};
mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
}
{
let tool_limiter: Option<Arc<ToolRateLimiter>> =
config.tool_rate_limit.map(build_tool_rate_limiter);
if rbac_swap.load().is_enabled() {
tracing::info!("RBAC enforcement enabled on /mcp");
}
if let Some(limit) = config.tool_rate_limit {
tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
}
let rbac_for_mw = Arc::clone(&rbac_swap);
mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
let p = rbac_for_mw.load_full();
let tl = tool_limiter.clone();
rbac_middleware(p, tl, req, next)
}));
}
if let Some(ref auth_config) = config.auth
&& auth_config.enabled
{
let Some(ref state) = auth_state else {
return Err(anyhow::anyhow!("auth state missing despite enabled config"));
};
let methods: Vec<&str> = [
auth_config.mtls.is_some().then_some("mTLS"),
(!auth_config.api_keys.is_empty()).then_some("bearer"),
#[cfg(feature = "oauth")]
auth_config.oauth.is_some().then_some("oauth-jwt"),
]
.into_iter()
.flatten()
.collect();
tracing::info!(
methods = %methods.join(", "),
api_keys = auth_config.api_keys.len(),
"auth enabled on /mcp"
);
let state_for_mw = Arc::clone(state);
mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
let s = Arc::clone(&state_for_mw);
auth_middleware(s, req, next)
}));
}
mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
axum::http::StatusCode::REQUEST_TIMEOUT,
config.request_timeout,
));
mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
config.max_request_body,
));
let mut effective_origins = config.allowed_origins.clone();
if effective_origins.is_empty()
&& let Some(ref url) = config.public_url
{
if let Some(scheme_end) = url.find("://") {
let after_scheme = &url[scheme_end + 3..];
let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
let origin = format!("{}{}", &url[..scheme_end + 3], &after_scheme[..host_end]);
tracing::info!(
%origin,
"auto-derived allowed origin from public_url"
);
effective_origins.push(origin);
}
}
let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
let cors_origins = Arc::clone(&allowed_origins);
let log_request_headers = config.log_request_headers;
let readyz_route = if let Some(check) = config.readiness_check.take() {
axum::routing::get(move || readyz(Arc::clone(&check)))
} else {
axum::routing::get(healthz)
};
#[allow(unused_mut)] let mut router = axum::Router::new()
.route("/healthz", axum::routing::get(healthz))
.route("/readyz", readyz_route)
.route(
"/version",
axum::routing::get({
let payload_bytes: Arc<[u8]> =
serialize_version_payload(&config.name, &config.version);
move || {
let p = Arc::clone(&payload_bytes);
async move {
(
[(axum::http::header::CONTENT_TYPE, "application/json")],
p.to_vec(),
)
}
}
}),
)
.merge(mcp_router);
if let Some(extra) = config.extra_router.take() {
router = router.merge(extra);
}
let server_url = if let Some(ref url) = config.public_url {
url.trim_end_matches('/').to_owned()
} else {
let prm_scheme = if config.tls_cert_path.is_some() {
"https"
} else {
"http"
};
format!("{prm_scheme}://{}", config.bind_addr)
};
let resource_url = format!("{server_url}/mcp");
#[cfg(feature = "oauth")]
let prm_metadata = if let Some(ref auth_config) = config.auth
&& let Some(ref oauth_config) = auth_config.oauth
{
crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
} else {
serde_json::json!({ "resource": resource_url })
};
#[cfg(not(feature = "oauth"))]
let prm_metadata = serde_json::json!({ "resource": resource_url });
router = router.route(
"/.well-known/oauth-protected-resource",
axum::routing::get(move || {
let m = prm_metadata.clone();
async move { axum::Json(m) }
}),
);
#[cfg(feature = "oauth")]
if let Some(ref auth_config) = config.auth
&& let Some(ref oauth_config) = auth_config.oauth
&& oauth_config.proxy.is_some()
{
router = install_oauth_proxy_routes(router, &server_url, oauth_config)?;
}
let is_tls = config.tls_cert_path.is_some();
router = router.layer(axum::middleware::from_fn(move |req, next| {
security_headers_middleware(is_tls, req, next)
}));
if !cors_origins.is_empty() {
let cors = tower_http::cors::CorsLayer::new()
.allow_origin(
cors_origins
.iter()
.filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
.collect::<Vec<_>>(),
)
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::OPTIONS,
])
.allow_headers([
axum::http::header::CONTENT_TYPE,
axum::http::header::AUTHORIZATION,
]);
router = router.layer(cors);
}
if config.compression_enabled {
use tower_http::compression::Predicate as _;
let predicate = tower_http::compression::DefaultPredicate::new().and(
tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
);
router = router.layer(
tower_http::compression::CompressionLayer::new()
.gzip(true)
.br(true)
.compress_when(predicate),
);
tracing::info!(
min_size = config.compression_min_size,
"response compression enabled (gzip, br)"
);
}
if let Some(max) = config.max_concurrent_requests {
let overload_handler = tower::ServiceBuilder::new()
.layer(axum::error_handling::HandleErrorLayer::new(
|_err: tower::BoxError| async {
(
axum::http::StatusCode::SERVICE_UNAVAILABLE,
axum::Json(serde_json::json!({
"error": "overloaded",
"error_description": "server is at capacity, retry later"
})),
)
},
))
.layer(tower::load_shed::LoadShedLayer::new())
.layer(tower::limit::ConcurrencyLimitLayer::new(max));
router = router.layer(overload_handler);
tracing::info!(max, "global concurrency limit enabled");
}
router = router.fallback(|| async {
(
axum::http::StatusCode::NOT_FOUND,
axum::Json(serde_json::json!({
"error": "not_found",
"error_description": "The requested endpoint does not exist"
})),
)
});
#[cfg(feature = "metrics")]
if config.metrics_enabled {
let metrics = Arc::new(
crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
);
let m = Arc::clone(&metrics);
router = router.layer(axum::middleware::from_fn(
move |req: Request<Body>, next: Next| {
let m = Arc::clone(&m);
metrics_middleware(m, req, next)
},
));
let metrics_bind = config.metrics_bind.clone();
tokio::spawn(async move {
if let Err(e) = crate::metrics::serve_metrics(metrics_bind, metrics).await {
tracing::error!("metrics listener failed: {e}");
}
});
}
router = router.layer(axum::middleware::from_fn(move |req, next| {
let origins = Arc::clone(&allowed_origins);
origin_check_middleware(origins, log_request_headers, req, next)
}));
let scheme = if config.tls_cert_path.is_some() {
"https"
} else {
"http"
};
let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
(Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
_ => None,
};
let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
Ok((
router,
AppRunParams {
tls_paths,
mtls_config,
shutdown_timeout: config.shutdown_timeout,
auth_state,
rbac_swap,
on_reload_ready: config.on_reload_ready.take(),
ct,
scheme,
name: config.name.clone(),
},
))
}
pub async fn serve<H, F>(
config: Validated<McpServerConfig>,
handler_factory: F,
) -> Result<(), McpxError>
where
H: ServerHandler + 'static,
F: Fn() -> H + Send + Sync + Clone + 'static,
{
let config = config.into_inner();
#[allow(
deprecated,
reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
)]
let bind_addr = config.bind_addr.clone();
let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
let listener = TcpListener::bind(&bind_addr)
.await
.map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
log_listening(¶ms.name, params.scheme, &bind_addr);
run_server(
router,
listener,
params.tls_paths,
params.mtls_config,
params.shutdown_timeout,
params.auth_state,
params.rbac_swap,
params.on_reload_ready,
params.ct,
)
.await
.map_err(anyhow_to_startup)
}
pub async fn serve_with_listener<H, F>(
listener: TcpListener,
config: Validated<McpServerConfig>,
handler_factory: F,
ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
shutdown: Option<CancellationToken>,
) -> Result<(), McpxError>
where
H: ServerHandler + 'static,
F: Fn() -> H + Send + Sync + Clone + 'static,
{
let config = config.into_inner();
let local_addr = listener
.local_addr()
.map_err(|e| io_to_startup("listener.local_addr", e))?;
let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
log_listening(¶ms.name, params.scheme, &local_addr.to_string());
if let Some(external) = shutdown {
let internal = params.ct.clone();
tokio::spawn(async move {
external.cancelled().await;
internal.cancel();
});
}
if let Some(tx) = ready_tx {
let _ = tx.send(local_addr);
}
run_server(
router,
listener,
params.tls_paths,
params.mtls_config,
params.shutdown_timeout,
params.auth_state,
params.rbac_swap,
params.on_reload_ready,
params.ct,
)
.await
.map_err(anyhow_to_startup)
}
#[allow(
clippy::cognitive_complexity,
reason = "tracing::info! macro expansions inflate the score; logic is trivial"
)]
fn log_listening(name: &str, scheme: &str, addr: &str) {
tracing::info!("{name} listening on {addr}");
tracing::info!(" MCP endpoint: {scheme}://{addr}/mcp");
tracing::info!(" Health check: {scheme}://{addr}/healthz");
tracing::info!(" Readiness: {scheme}://{addr}/readyz");
}
#[allow(
clippy::too_many_arguments,
clippy::cognitive_complexity,
reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
)]
async fn run_server(
router: axum::Router,
listener: TcpListener,
tls_paths: Option<(PathBuf, PathBuf)>,
mtls_config: Option<MtlsConfig>,
shutdown_timeout: Duration,
auth_state: Option<Arc<AuthState>>,
rbac_swap: Arc<ArcSwap<RbacPolicy>>,
mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
ct: CancellationToken,
) -> anyhow::Result<()> {
let shutdown_trigger = CancellationToken::new();
{
let trigger = shutdown_trigger.clone();
let parent = ct.clone();
tokio::spawn(async move {
tokio::select! {
() = shutdown_signal() => {}
() = parent.cancelled() => {}
}
trigger.cancel();
});
}
let graceful = {
let trigger = shutdown_trigger.clone();
let ct = ct.clone();
async move {
trigger.cancelled().await;
tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
ct.cancel();
}
};
let force_exit_timer = {
let trigger = shutdown_trigger.clone();
async move {
trigger.cancelled().await;
tokio::time::sleep(shutdown_timeout).await;
}
};
if let Some((cert_path, key_path)) = tls_paths {
let crl_set = if let Some(mtls) = mtls_config.as_ref()
&& mtls.crl_enabled
{
let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
let (crl_set, discover_rx) =
mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
.await
.map_err(|error| anyhow::anyhow!(error.to_string()))?;
tokio::spawn(mtls_revocation::run_crl_refresher(
Arc::clone(&crl_set),
discover_rx,
ct.clone(),
));
Some(crl_set)
} else {
None
};
if let Some(cb) = on_reload_ready.take() {
cb(ReloadHandle {
auth: auth_state.clone(),
rbac: Some(Arc::clone(&rbac_swap)),
crl_set: crl_set.clone(),
});
}
let tls_listener = TlsListener::new(
listener,
&cert_path,
&key_path,
mtls_config.as_ref(),
crl_set,
)?;
let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
tokio::select! {
result = axum::serve(tls_listener, make_svc)
.with_graceful_shutdown(graceful) => { result?; }
() = force_exit_timer => {
tracing::warn!("shutdown timeout exceeded, forcing exit");
}
}
} else {
if let Some(cb) = on_reload_ready.take() {
cb(ReloadHandle {
auth: auth_state,
rbac: Some(rbac_swap),
crl_set: None,
});
}
let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
tokio::select! {
result = axum::serve(listener, make_svc)
.with_graceful_shutdown(graceful) => { result?; }
() = force_exit_timer => {
tracing::warn!("shutdown timeout exceeded, forcing exit");
}
}
}
Ok(())
}
#[cfg(feature = "oauth")]
fn install_oauth_proxy_routes(
router: axum::Router,
server_url: &str,
oauth_config: &crate::oauth::OAuthConfig,
) -> Result<axum::Router, McpxError> {
let Some(ref proxy) = oauth_config.proxy else {
return Ok(router);
};
let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
let router = router.route(
"/.well-known/oauth-authorization-server",
axum::routing::get(move || {
let m = asm.clone();
async move { axum::Json(m) }
}),
);
let proxy_authorize = proxy.clone();
let router = router.route(
"/authorize",
axum::routing::get(
move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
let p = proxy_authorize.clone();
async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
},
),
);
let proxy_token = proxy.clone();
let token_http = http.clone();
let router = router.route(
"/token",
axum::routing::post(move |body: String| {
let p = proxy_token.clone();
let h = token_http.clone();
async move { crate::oauth::handle_token(&h, &p, &body).await }
}),
);
let proxy_register = proxy.clone();
let router = router.route(
"/register",
axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
let p = proxy_register;
async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
}),
);
let router = if proxy.expose_admin_endpoints && proxy.introspection_url.is_some() {
let proxy_introspect = proxy.clone();
let introspect_http = http.clone();
router.route(
"/introspect",
axum::routing::post(move |body: String| {
let p = proxy_introspect.clone();
let h = introspect_http.clone();
async move { crate::oauth::handle_introspect(&h, &p, &body).await }
}),
)
} else {
router
};
let router = if proxy.expose_admin_endpoints && proxy.revocation_url.is_some() {
let proxy_revoke = proxy.clone();
let revoke_http = http;
router.route(
"/revoke",
axum::routing::post(move |body: String| {
let p = proxy_revoke.clone();
let h = revoke_http.clone();
async move { crate::oauth::handle_revoke(&h, &p, &body).await }
}),
)
} else {
router
};
tracing::info!(
introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
"OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
);
Ok(router)
}
fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
let mut hosts = vec![
"localhost".to_owned(),
"127.0.0.1".to_owned(),
"::1".to_owned(),
];
if let Some(url) = public_url
&& let Ok(uri) = url.parse::<axum::http::Uri>()
&& let Some(authority) = uri.authority()
{
let host = authority.host().to_owned();
if !hosts.iter().any(|h| h == &host) {
hosts.push(host);
}
let authority = authority.as_str().to_owned();
if !hosts.iter().any(|h| h == &authority) {
hosts.push(authority);
}
}
if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
&& let Some(authority) = uri.authority()
{
let host = authority.host().to_owned();
if !hosts.iter().any(|h| h == &host) {
hosts.push(host);
}
let authority = authority.as_str().to_owned();
if !hosts.iter().any(|h| h == &authority) {
hosts.push(authority);
}
}
hosts
}
impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
for TlsConnInfo
{
fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
let addr = *target.remote_addr();
let identity = target.io().identity().cloned();
TlsConnInfo::new(addr, identity)
}
}
struct TlsListener {
inner: TcpListener,
acceptor: tokio_rustls::TlsAcceptor,
mtls_default_role: String,
}
impl TlsListener {
fn new(
inner: TcpListener,
cert_path: &Path,
key_path: &Path,
mtls_config: Option<&MtlsConfig>,
crl_set: Option<Arc<CrlSet>>,
) -> anyhow::Result<Self> {
rustls::crypto::ring::default_provider()
.install_default()
.ok();
let certs = load_certs(cert_path)?;
let key = load_key(key_path)?;
let mtls_default_role;
let tls_config = if let Some(mtls) = mtls_config {
mtls_default_role = mtls.default_role.clone();
let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
{
let Some(crl_set) = crl_set else {
return Err(anyhow::anyhow!(
"mTLS CRL verifier requested but CRL state was not initialized"
));
};
Arc::new(DynamicClientCertVerifier::new(crl_set))
} else {
let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
if mtls.required {
rustls::server::WebPkiClientVerifier::builder(root_store)
.build()
.map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
} else {
rustls::server::WebPkiClientVerifier::builder(root_store)
.allow_unauthenticated()
.build()
.map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
}
};
tracing::info!(
ca = %mtls.ca_cert_path.display(),
required = mtls.required,
crl_enabled = mtls.crl_enabled,
"mTLS client auth configured"
);
rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS12,
&rustls::version::TLS13,
])
.with_client_cert_verifier(verifier)
.with_single_cert(certs, key)?
} else {
mtls_default_role = "viewer".to_owned();
rustls::ServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS12,
&rustls::version::TLS13,
])
.with_no_client_auth()
.with_single_cert(certs, key)?
};
let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
tracing::info!(
"TLS enabled (cert: {}, key: {})",
cert_path.display(),
key_path.display()
);
Ok(Self {
inner,
acceptor,
mtls_default_role,
})
}
fn extract_handshake_identity(
tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
default_role: &str,
addr: SocketAddr,
) -> Option<AuthIdentity> {
let (_, server_conn) = tls_stream.get_ref();
let cert_der = server_conn.peer_certificates()?.first()?;
let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
Some(id)
}
}
pub(crate) struct AuthenticatedTlsStream {
inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
identity: Option<AuthIdentity>,
}
impl AuthenticatedTlsStream {
#[must_use]
pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
self.identity.as_ref()
}
}
impl std::fmt::Debug for AuthenticatedTlsStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthenticatedTlsStream")
.field("identity", &self.identity.as_ref().map(|id| &id.name))
.finish_non_exhaustive()
}
}
impl tokio::io::AsyncRead for AuthenticatedTlsStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> std::task::Poll<std::io::Result<usize>> {
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}
impl axum::serve::Listener for TlsListener {
type Io = AuthenticatedTlsStream;
type Addr = SocketAddr;
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
loop {
let (stream, addr) = match self.inner.accept().await {
Ok(pair) => pair,
Err(e) => {
tracing::debug!("TCP accept error: {e}");
continue;
}
};
let tls_stream = match self.acceptor.accept(stream).await {
Ok(s) => s,
Err(e) => {
tracing::debug!("TLS handshake failed from {addr}: {e}");
continue;
}
};
let identity =
Self::extract_handshake_identity(&tls_stream, &self.mtls_default_role, addr);
let wrapped = AuthenticatedTlsStream {
inner: tls_stream,
identity,
};
return (wrapped, addr);
}
}
fn local_addr(&self) -> std::io::Result<Self::Addr> {
self.inner.local_addr()
}
}
fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
use rustls::pki_types::pem::PemObject;
let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
.map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
.collect::<Result<_, _>>()
.map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
anyhow::ensure!(
!certs.is_empty(),
"no certificates found in {}",
path.display()
);
Ok(certs)
}
fn load_client_auth_roots(
path: &Path,
) -> anyhow::Result<(
Vec<rustls::pki_types::CertificateDer<'static>>,
Arc<RootCertStore>,
)> {
let ca_certs = load_certs(path)?;
let mut root_store = RootCertStore::empty();
for cert in &ca_certs {
root_store
.add(cert.clone())
.map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
}
Ok((ca_certs, Arc::new(root_store)))
}
fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
use rustls::pki_types::pem::PemObject;
rustls::pki_types::PrivateKeyDer::from_pem_file(path)
.map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
}
#[allow(clippy::unused_async)]
async fn healthz() -> impl IntoResponse {
axum::Json(serde_json::json!({
"status": "ok",
}))
}
fn version_payload(name: &str, version: &str) -> serde_json::Value {
serde_json::json!({
"name": name,
"version": version,
"build_git_sha": option_env!("MCPX_BUILD_SHA").unwrap_or("unknown"),
"build_timestamp": option_env!("MCPX_BUILD_TIME").unwrap_or("unknown"),
"rust_version": option_env!("MCPX_RUSTC_VERSION").unwrap_or("unknown"),
"mcpx_version": env!("CARGO_PKG_VERSION"),
})
}
fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
let value = version_payload(name, version);
serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
}
async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
let status = check().await;
let ready = status
.get("ready")
.and_then(serde_json::Value::as_bool)
.unwrap_or(false);
let code = if ready {
axum::http::StatusCode::OK
} else {
axum::http::StatusCode::SERVICE_UNAVAILABLE
};
(code, axum::Json(status))
}
async fn shutdown_signal() {
let ctrl_c = tokio::signal::ctrl_c();
#[cfg(unix)]
{
match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
Ok(mut term) => {
tokio::select! {
_ = ctrl_c => {}
_ = term.recv() => {}
}
}
Err(e) => {
tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
ctrl_c.await.ok();
}
}
}
#[cfg(not(unix))]
{
ctrl_c.await.ok();
}
}
#[cfg(feature = "metrics")]
async fn metrics_middleware(
metrics: Arc<crate::metrics::McpMetrics>,
req: Request<Body>,
next: Next,
) -> axum::response::Response {
let method = req.method().to_string();
let path = req.uri().path().to_owned();
let start = std::time::Instant::now();
let response = next.run(req).await;
let status = response.status().as_u16().to_string();
let duration = start.elapsed().as_secs_f64();
metrics
.http_requests_total
.with_label_values(&[&method, &path, &status])
.inc();
metrics
.http_request_duration_seconds
.with_label_values(&[&method, &path])
.observe(duration);
response
}
async fn security_headers_middleware(
is_tls: bool,
req: Request<Body>,
next: Next,
) -> axum::response::Response {
use axum::http::{HeaderName, HeaderValue, header};
let mut resp = next.run(req).await;
let headers = resp.headers_mut();
headers.remove(header::SERVER);
headers.remove(HeaderName::from_static("x-powered-by"));
headers.insert(
header::X_CONTENT_TYPE_OPTIONS,
HeaderValue::from_static("nosniff"),
);
headers.insert(header::X_FRAME_OPTIONS, HeaderValue::from_static("deny"));
headers.insert(
header::CACHE_CONTROL,
HeaderValue::from_static("no-store, max-age=0"),
);
headers.insert(
header::REFERRER_POLICY,
HeaderValue::from_static("no-referrer"),
);
headers.insert(
HeaderName::from_static("cross-origin-opener-policy"),
HeaderValue::from_static("same-origin"),
);
headers.insert(
HeaderName::from_static("cross-origin-resource-policy"),
HeaderValue::from_static("same-origin"),
);
headers.insert(
HeaderName::from_static("cross-origin-embedder-policy"),
HeaderValue::from_static("require-corp"),
);
headers.insert(
HeaderName::from_static("permissions-policy"),
HeaderValue::from_static("accelerometer=(), camera=(), geolocation=(), microphone=()"),
);
headers.insert(
HeaderName::from_static("x-permitted-cross-domain-policies"),
HeaderValue::from_static("none"),
);
headers.insert(
HeaderName::from_static("content-security-policy"),
HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'"),
);
headers.insert(
HeaderName::from_static("x-dns-prefetch-control"),
HeaderValue::from_static("off"),
);
if is_tls {
headers.insert(
header::STRICT_TRANSPORT_SECURITY,
HeaderValue::from_static("max-age=63072000; includeSubDomains"),
);
}
resp
}
async fn origin_check_middleware(
allowed: Arc<[String]>,
log_request_headers: bool,
req: Request<Body>,
next: Next,
) -> axum::response::Response {
let method = req.method().clone();
let path = req.uri().path().to_owned();
log_incoming_request(&method, &path, req.headers(), log_request_headers);
if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
let origin_str = origin.to_str().unwrap_or("");
if !allowed.iter().any(|a| a == origin_str) {
tracing::warn!(
origin = origin_str,
%method,
%path,
allowed = ?&*allowed,
"rejected request: Origin not allowed"
);
return (
axum::http::StatusCode::FORBIDDEN,
"Forbidden: Origin not allowed",
)
.into_response();
}
}
next.run(req).await
}
fn log_incoming_request(
method: &axum::http::Method,
path: &str,
headers: &axum::http::HeaderMap,
log_request_headers: bool,
) {
if log_request_headers {
tracing::debug!(
%method,
%path,
headers = %format_request_headers_for_log(headers),
"incoming request"
);
} else {
tracing::debug!(%method, %path, "incoming request");
}
}
fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
headers
.iter()
.map(|(k, v)| {
let name = k.as_str();
if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
format!("{name}: [REDACTED]")
} else {
format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
}
})
.collect::<Vec<_>>()
.join(", ")
}
#[allow(clippy::cognitive_complexity)]
pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
where
H: ServerHandler + 'static,
{
use rmcp::ServiceExt as _;
tracing::info!("stdio transport: serving on stdin/stdout");
tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
let transport = rmcp::transport::io::stdio();
let service = handler
.serve(transport)
.await
.map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
if let Err(e) = service.waiting().await {
tracing::warn!(error = %e, "stdio session ended with error");
}
tracing::info!("stdio session ended");
Ok(())
}
#[cfg(test)]
mod tests {
#![allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::indexing_slicing,
clippy::unwrap_in_result,
clippy::print_stdout,
clippy::print_stderr,
deprecated,
reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
)]
use std::sync::Arc;
use axum::{
body::Body,
http::{Request, StatusCode, header},
response::IntoResponse,
};
use http_body_util::BodyExt;
use tower::ServiceExt as _;
use super::*;
#[test]
fn server_config_new_defaults() {
let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
assert_eq!(cfg.name, "test-server");
assert_eq!(cfg.version, "1.0.0");
assert!(cfg.tls_cert_path.is_none());
assert!(cfg.tls_key_path.is_none());
assert!(cfg.auth.is_none());
assert!(cfg.rbac.is_none());
assert!(cfg.allowed_origins.is_empty());
assert!(cfg.tool_rate_limit.is_none());
assert!(cfg.readiness_check.is_none());
assert_eq!(cfg.max_request_body, 1024 * 1024);
assert_eq!(cfg.request_timeout, Duration::from_mins(2));
assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
assert!(!cfg.log_request_headers);
}
#[test]
fn validate_consumes_and_proves() {
let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
let validated = cfg.validate().expect("valid config");
assert_eq!(validated.name, "test-server");
let raw = validated.into_inner();
assert_eq!(raw.name, "test-server");
let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
bad.max_request_body = 0;
assert!(bad.validate().is_err(), "zero body cap must fail validate");
}
#[test]
fn derive_allowed_hosts_includes_public_host() {
let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
assert!(
hosts.iter().any(|h| h == "mcp.example.com"),
"public_url host must be allowed"
);
}
#[test]
fn derive_allowed_hosts_includes_bind_authority() {
let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
assert!(
hosts.iter().any(|h| h == "127.0.0.1"),
"bind host must be allowed"
);
assert!(
hosts.iter().any(|h| h == "127.0.0.1:8080"),
"bind authority must be allowed"
);
}
#[tokio::test]
async fn healthz_returns_ok_json() {
let resp = healthz().await.into_response();
assert_eq!(resp.status(), StatusCode::OK);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["status"], "ok");
assert!(
json.get("name").is_none(),
"healthz must not expose server name"
);
assert!(
json.get("version").is_none(),
"healthz must not expose version"
);
}
#[tokio::test]
async fn readyz_returns_ok_when_ready() {
let check: ReadinessCheck =
Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
let resp = readyz(check).await.into_response();
assert_eq!(resp.status(), StatusCode::OK);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["ready"], true);
assert!(
json.get("name").is_none(),
"readyz must not expose server name"
);
assert!(
json.get("version").is_none(),
"readyz must not expose version"
);
assert_eq!(json["db"], "connected");
}
#[tokio::test]
async fn readyz_returns_503_when_not_ready() {
let check: ReadinessCheck =
Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
let resp = readyz(check).await.into_response();
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
}
#[tokio::test]
async fn readyz_returns_503_when_ready_missing() {
let check: ReadinessCheck =
Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
let resp = readyz(check).await.into_response();
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
}
fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
let allowed: Arc<[String]> = Arc::from(origins);
axum::Router::new()
.route("/test", axum::routing::get(|| async { "ok" }))
.layer(axum::middleware::from_fn(move |req, next| {
let a = Arc::clone(&allowed);
origin_check_middleware(a, log_request_headers, req, next)
}))
}
#[tokio::test]
async fn origin_allowed_passes() {
let app = origin_router(vec!["http://localhost:3000".into()], false);
let req = Request::builder()
.uri("/test")
.header(header::ORIGIN, "http://localhost:3000")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn origin_rejected_returns_403() {
let app = origin_router(vec!["http://localhost:3000".into()], false);
let req = Request::builder()
.uri("/test")
.header(header::ORIGIN, "http://evil.com")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn no_origin_header_passes() {
let app = origin_router(vec!["http://localhost:3000".into()], false);
let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn empty_allowlist_rejects_any_origin() {
let app = origin_router(vec![], false);
let req = Request::builder()
.uri("/test")
.header(header::ORIGIN, "http://anything.com")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn empty_allowlist_passes_without_origin() {
let app = origin_router(vec![], false);
let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[test]
fn format_request_headers_redacts_sensitive_values() {
let mut headers = axum::http::HeaderMap::new();
headers.insert("authorization", "Bearer secret-token".parse().unwrap());
headers.insert("cookie", "sid=abc".parse().unwrap());
headers.insert("x-request-id", "req-123".parse().unwrap());
let out = format_request_headers_for_log(&headers);
assert!(out.contains("authorization: [REDACTED]"));
assert!(out.contains("cookie: [REDACTED]"));
assert!(out.contains("x-request-id: req-123"));
assert!(!out.contains("secret-token"));
}
fn security_router(is_tls: bool) -> axum::Router {
axum::Router::new()
.route("/test", axum::routing::get(|| async { "ok" }))
.layer(axum::middleware::from_fn(move |req, next| {
security_headers_middleware(is_tls, req, next)
}))
}
#[tokio::test]
async fn security_headers_set_on_response() {
let app = security_router(false);
let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let h = resp.headers();
assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
assert_eq!(h.get("x-frame-options").unwrap(), "deny");
assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
assert_eq!(
h.get("cross-origin-resource-policy").unwrap(),
"same-origin"
);
assert_eq!(
h.get("cross-origin-embedder-policy").unwrap(),
"require-corp"
);
assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
assert!(
h.get("permissions-policy")
.unwrap()
.to_str()
.unwrap()
.contains("camera=()"),
"permissions-policy must restrict browser features"
);
assert_eq!(
h.get("content-security-policy").unwrap(),
"default-src 'none'; frame-ancestors 'none'"
);
assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
assert!(h.get("strict-transport-security").is_none());
}
#[tokio::test]
async fn hsts_set_when_tls_enabled() {
let app = security_router(true);
let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
let resp = app.oneshot(req).await.unwrap();
let hsts = resp.headers().get("strict-transport-security").unwrap();
assert!(
hsts.to_str().unwrap().contains("max-age=63072000"),
"HSTS must set 2-year max-age"
);
}
#[test]
fn version_payload_contains_expected_fields() {
let v = version_payload("my-server", "1.2.3");
assert_eq!(v["name"], "my-server");
assert_eq!(v["version"], "1.2.3");
assert!(v["build_git_sha"].is_string());
assert!(v["build_timestamp"].is_string());
assert!(v["rust_version"].is_string());
assert!(v["mcpx_version"].is_string());
}
#[tokio::test]
async fn concurrency_limit_layer_composes_and_serves() {
let app = axum::Router::new()
.route("/ok", axum::routing::get(|| async { "ok" }))
.layer(
tower::ServiceBuilder::new()
.layer(axum::error_handling::HandleErrorLayer::new(
|_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
))
.layer(tower::load_shed::LoadShedLayer::new())
.layer(tower::limit::ConcurrencyLimitLayer::new(4)),
);
let resp = app
.oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn compression_layer_gzip_encodes_response() {
use tower_http::compression::Predicate as _;
let big_body = "a".repeat(4096);
let app = axum::Router::new()
.route(
"/big",
axum::routing::get(move || {
let body = big_body.clone();
async move { body }
}),
)
.layer(
tower_http::compression::CompressionLayer::new()
.gzip(true)
.br(true)
.compress_when(
tower_http::compression::DefaultPredicate::new()
.and(tower_http::compression::predicate::SizeAbove::new(1024)),
),
);
let req = Request::builder()
.uri("/big")
.header(header::ACCEPT_ENCODING, "gzip")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::CONTENT_ENCODING).unwrap(),
"gzip"
);
}
}