1use std::{
2 future::Future,
3 net::SocketAddr,
4 path::{Path, PathBuf},
5 pin::Pin,
6 sync::Arc,
7 time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{body::Body, extract::Request, middleware::Next, response::IntoResponse};
12use rmcp::{
13 ServerHandler,
14 transport::streamable_http_server::{
15 StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16 },
17};
18use rustls::RootCertStore;
19use tokio::net::TcpListener;
20use tokio_util::sync::CancellationToken;
21
22use crate::{
23 auth::{
24 AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
25 build_rate_limiter, extract_mtls_identity,
26 },
27 error::McpxError,
28 mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
29 rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
30};
31
32#[allow(
36 clippy::needless_pass_by_value,
37 reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
38)]
39fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
40 McpxError::Startup(format!("{e:#}"))
41}
42
43#[allow(
49 clippy::needless_pass_by_value,
50 reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
51)]
52fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
53 McpxError::Startup(format!("{op}: {e}"))
54}
55
56pub type ReadinessCheck =
61 Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
62
63#[allow(
65 missing_debug_implementations,
66 reason = "contains callback/trait objects that don't impl Debug"
67)]
68#[allow(
69 clippy::struct_excessive_bools,
70 reason = "server configuration naturally has many boolean feature flags"
71)]
72#[non_exhaustive]
73pub struct McpServerConfig {
74 #[deprecated(
76 since = "0.13.0",
77 note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in 1.0"
78 )]
79 pub bind_addr: String,
80 #[deprecated(
82 since = "0.13.0",
83 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in 1.0"
84 )]
85 pub name: String,
86 #[deprecated(
88 since = "0.13.0",
89 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in 1.0"
90 )]
91 pub version: String,
92 #[deprecated(
94 since = "0.13.0",
95 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in 1.0"
96 )]
97 pub tls_cert_path: Option<PathBuf>,
98 #[deprecated(
100 since = "0.13.0",
101 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in 1.0"
102 )]
103 pub tls_key_path: Option<PathBuf>,
104 #[deprecated(
107 since = "0.13.0",
108 note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in 1.0"
109 )]
110 pub auth: Option<AuthConfig>,
111 #[deprecated(
114 since = "0.13.0",
115 note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in 1.0"
116 )]
117 pub rbac: Option<Arc<RbacPolicy>>,
118 #[deprecated(
124 since = "0.13.0",
125 note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in 1.0"
126 )]
127 pub allowed_origins: Vec<String>,
128 #[deprecated(
131 since = "0.13.0",
132 note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in 1.0"
133 )]
134 pub tool_rate_limit: Option<u32>,
135 #[deprecated(
138 since = "0.13.0",
139 note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in 1.0"
140 )]
141 pub readiness_check: Option<ReadinessCheck>,
142 #[deprecated(
145 since = "0.13.0",
146 note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in 1.0"
147 )]
148 pub max_request_body: usize,
149 #[deprecated(
152 since = "0.13.0",
153 note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in 1.0"
154 )]
155 pub request_timeout: Duration,
156 #[deprecated(
159 since = "0.13.0",
160 note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in 1.0"
161 )]
162 pub shutdown_timeout: Duration,
163 #[deprecated(
166 since = "0.13.0",
167 note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in 1.0"
168 )]
169 pub session_idle_timeout: Duration,
170 #[deprecated(
173 since = "0.13.0",
174 note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in 1.0"
175 )]
176 pub sse_keep_alive: Duration,
177 #[deprecated(
181 since = "0.13.0",
182 note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in 1.0"
183 )]
184 pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
185 #[deprecated(
189 since = "0.13.0",
190 note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in 1.0"
191 )]
192 pub extra_router: Option<axum::Router>,
193 #[deprecated(
198 since = "0.13.0",
199 note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in 1.0"
200 )]
201 pub public_url: Option<String>,
202 #[deprecated(
205 since = "0.13.0",
206 note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in 1.0"
207 )]
208 pub log_request_headers: bool,
209 #[deprecated(
212 since = "0.13.0",
213 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in 1.0"
214 )]
215 pub compression_enabled: bool,
216 #[deprecated(
219 since = "0.13.0",
220 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in 1.0"
221 )]
222 pub compression_min_size: u16,
223 #[deprecated(
227 since = "0.13.0",
228 note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in 1.0"
229 )]
230 pub max_concurrent_requests: Option<usize>,
231 #[deprecated(
234 since = "0.13.0",
235 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in 1.0"
236 )]
237 pub admin_enabled: bool,
238 #[deprecated(
240 since = "0.13.0",
241 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in 1.0"
242 )]
243 pub admin_role: String,
244 #[cfg(feature = "metrics")]
247 #[deprecated(
248 since = "0.13.0",
249 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in 1.0"
250 )]
251 pub metrics_enabled: bool,
252 #[cfg(feature = "metrics")]
254 #[deprecated(
255 since = "0.13.0",
256 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in 1.0"
257 )]
258 pub metrics_bind: String,
259}
260
261#[allow(
319 missing_debug_implementations,
320 reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
321)]
322pub struct Validated<T>(T);
323
324impl<T> std::fmt::Debug for Validated<T> {
325 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 f.debug_struct("Validated").finish_non_exhaustive()
327 }
328}
329
330impl<T> Validated<T> {
331 #[must_use]
333 pub fn as_inner(&self) -> &T {
334 &self.0
335 }
336
337 #[must_use]
342 pub fn into_inner(self) -> T {
343 self.0
344 }
345}
346
347impl<T> std::ops::Deref for Validated<T> {
348 type Target = T;
349
350 fn deref(&self) -> &T {
351 &self.0
352 }
353}
354
355#[allow(
356 deprecated,
357 reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
358)]
359impl McpServerConfig {
360 #[must_use]
368 pub fn new(
369 bind_addr: impl Into<String>,
370 name: impl Into<String>,
371 version: impl Into<String>,
372 ) -> Self {
373 Self {
374 bind_addr: bind_addr.into(),
375 name: name.into(),
376 version: version.into(),
377 tls_cert_path: None,
378 tls_key_path: None,
379 auth: None,
380 rbac: None,
381 allowed_origins: Vec::new(),
382 tool_rate_limit: None,
383 readiness_check: None,
384 max_request_body: 1024 * 1024,
385 request_timeout: Duration::from_mins(2),
386 shutdown_timeout: Duration::from_secs(30),
387 session_idle_timeout: Duration::from_mins(20),
388 sse_keep_alive: Duration::from_secs(15),
389 on_reload_ready: None,
390 extra_router: None,
391 public_url: None,
392 log_request_headers: false,
393 compression_enabled: false,
394 compression_min_size: 1024,
395 max_concurrent_requests: None,
396 admin_enabled: false,
397 admin_role: "admin".to_owned(),
398 #[cfg(feature = "metrics")]
399 metrics_enabled: false,
400 #[cfg(feature = "metrics")]
401 metrics_bind: "127.0.0.1:9090".into(),
402 }
403 }
404
405 #[must_use]
415 pub fn with_auth(mut self, auth: AuthConfig) -> Self {
416 self.auth = Some(auth);
417 self
418 }
419
420 #[must_use]
424 pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
425 self.bind_addr = addr.into();
426 self
427 }
428
429 #[must_use]
432 pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
433 self.rbac = Some(rbac);
434 self
435 }
436
437 #[must_use]
441 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
442 self.tls_cert_path = Some(cert_path.into());
443 self.tls_key_path = Some(key_path.into());
444 self
445 }
446
447 #[must_use]
451 pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
452 self.public_url = Some(url.into());
453 self
454 }
455
456 #[must_use]
460 pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
461 where
462 I: IntoIterator<Item = S>,
463 S: Into<String>,
464 {
465 self.allowed_origins = origins.into_iter().map(Into::into).collect();
466 self
467 }
468
469 #[must_use]
473 pub fn with_extra_router(mut self, router: axum::Router) -> Self {
474 self.extra_router = Some(router);
475 self
476 }
477
478 #[must_use]
481 pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
482 self.readiness_check = Some(check);
483 self
484 }
485
486 #[must_use]
489 pub fn with_max_request_body(mut self, bytes: usize) -> Self {
490 self.max_request_body = bytes;
491 self
492 }
493
494 #[must_use]
496 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
497 self.request_timeout = timeout;
498 self
499 }
500
501 #[must_use]
503 pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
504 self.shutdown_timeout = timeout;
505 self
506 }
507
508 #[must_use]
510 pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
511 self.session_idle_timeout = timeout;
512 self
513 }
514
515 #[must_use]
517 pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
518 self.sse_keep_alive = interval;
519 self
520 }
521
522 #[must_use]
526 pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
527 self.max_concurrent_requests = Some(limit);
528 self
529 }
530
531 #[must_use]
534 pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
535 self.tool_rate_limit = Some(per_minute);
536 self
537 }
538
539 #[must_use]
543 pub fn with_reload_callback<F>(mut self, callback: F) -> Self
544 where
545 F: FnOnce(ReloadHandle) + Send + 'static,
546 {
547 self.on_reload_ready = Some(Box::new(callback));
548 self
549 }
550
551 #[must_use]
555 pub fn enable_compression(mut self, min_size: u16) -> Self {
556 self.compression_enabled = true;
557 self.compression_min_size = min_size;
558 self
559 }
560
561 #[must_use]
566 pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
567 self.admin_enabled = true;
568 self.admin_role = role.into();
569 self
570 }
571
572 #[must_use]
575 pub fn enable_request_header_logging(mut self) -> Self {
576 self.log_request_headers = true;
577 self
578 }
579
580 #[cfg(feature = "metrics")]
583 #[must_use]
584 pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
585 self.metrics_enabled = true;
586 self.metrics_bind = bind.into();
587 self
588 }
589
590 pub fn validate(self) -> Result<Validated<Self>, McpxError> {
623 self.check()?;
624 Ok(Validated(self))
625 }
626
627 fn check(&self) -> Result<(), McpxError> {
631 if self.admin_enabled {
635 let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
636 if !auth_enabled {
637 return Err(McpxError::Config(
638 "admin_enabled=true requires auth to be configured and enabled".into(),
639 ));
640 }
641 }
642
643 match (&self.tls_cert_path, &self.tls_key_path) {
645 (Some(_), None) => {
646 return Err(McpxError::Config(
647 "tls_cert_path is set but tls_key_path is missing".into(),
648 ));
649 }
650 (None, Some(_)) => {
651 return Err(McpxError::Config(
652 "tls_key_path is set but tls_cert_path is missing".into(),
653 ));
654 }
655 _ => {}
656 }
657
658 if self.bind_addr.parse::<SocketAddr>().is_err() {
660 return Err(McpxError::Config(format!(
661 "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
662 self.bind_addr
663 )));
664 }
665
666 if let Some(ref url) = self.public_url
668 && !(url.starts_with("http://") || url.starts_with("https://"))
669 {
670 return Err(McpxError::Config(format!(
671 "public_url {url:?} must start with http:// or https://"
672 )));
673 }
674
675 for origin in &self.allowed_origins {
677 if !(origin.starts_with("http://") || origin.starts_with("https://")) {
678 return Err(McpxError::Config(format!(
679 "allowed_origins entry {origin:?} must start with http:// or https://"
680 )));
681 }
682 }
683
684 if self.max_request_body == 0 {
686 return Err(McpxError::Config(
687 "max_request_body must be greater than zero".into(),
688 ));
689 }
690
691 #[cfg(feature = "oauth")]
693 if let Some(auth_cfg) = &self.auth
694 && let Some(oauth_cfg) = &auth_cfg.oauth
695 {
696 oauth_cfg.validate()?;
697 }
698
699 Ok(())
700 }
701}
702
703#[allow(
709 missing_debug_implementations,
710 reason = "contains Arc<AuthState> with non-Debug fields"
711)]
712pub struct ReloadHandle {
713 auth: Option<Arc<AuthState>>,
714 rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
715 crl_set: Option<Arc<CrlSet>>,
716}
717
718impl ReloadHandle {
719 pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
721 if let Some(ref auth) = self.auth {
722 auth.reload_keys(keys);
723 }
724 }
725
726 pub fn reload_rbac(&self, policy: RbacPolicy) {
728 if let Some(ref rbac) = self.rbac {
729 rbac.store(Arc::new(policy));
730 tracing::info!("RBAC policy reloaded");
731 }
732 }
733
734 pub async fn refresh_crls(&self) -> Result<(), McpxError> {
740 let Some(ref crl_set) = self.crl_set else {
741 return Err(McpxError::Config(
742 "CRL refresh requested but mTLS CRL support is not configured".into(),
743 ));
744 };
745
746 crl_set.force_refresh().await
747 }
748}
749
750#[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
767struct AppRunParams {
771 tls_paths: Option<(PathBuf, PathBuf)>,
773 mtls_config: Option<MtlsConfig>,
775 shutdown_timeout: Duration,
777 auth_state: Option<Arc<AuthState>>,
779 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
781 on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
783 ct: CancellationToken,
787 scheme: &'static str,
789 name: String,
791}
792
793#[allow(
803 clippy::cognitive_complexity,
804 reason = "router assembly is intrinsically sequential; splitting harms readability"
805)]
806#[allow(
807 deprecated,
808 reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
809)]
810fn build_app_router<H, F>(
811 mut config: McpServerConfig,
812 handler_factory: F,
813) -> anyhow::Result<(axum::Router, AppRunParams)>
814where
815 H: ServerHandler + 'static,
816 F: Fn() -> H + Send + Sync + Clone + 'static,
817{
818 let ct = CancellationToken::new();
819
820 let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
821 tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
822
823 let mcp_service = StreamableHttpService::new(
824 move || Ok(handler_factory()),
825 {
826 let mut mgr = LocalSessionManager::default();
827 mgr.session_config.keep_alive = Some(config.session_idle_timeout);
828 mgr.into()
829 },
830 StreamableHttpServerConfig::default()
831 .with_allowed_hosts(allowed_hosts)
832 .with_sse_keep_alive(Some(config.sse_keep_alive))
833 .with_cancellation_token(ct.child_token()),
834 );
835
836 let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
838
839 let auth_state: Option<Arc<AuthState>> = match config.auth {
843 Some(ref auth_config) if auth_config.enabled => {
844 let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
845 let pre_auth_limiter = auth_config
846 .rate_limit
847 .as_ref()
848 .map(crate::auth::build_pre_auth_limiter);
849
850 #[cfg(feature = "oauth")]
851 let jwks_cache = auth_config
852 .oauth
853 .as_ref()
854 .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
855 .transpose()
856 .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
857
858 Some(Arc::new(AuthState {
859 api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
860 rate_limiter,
861 pre_auth_limiter,
862 #[cfg(feature = "oauth")]
863 jwks_cache,
864 seen_identities: std::sync::Mutex::new(std::collections::HashSet::new()),
865 counters: crate::auth::AuthCounters::default(),
866 }))
867 }
868 _ => None,
869 };
870
871 let rbac_swap = Arc::new(ArcSwap::new(
874 config
875 .rbac
876 .clone()
877 .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
878 ));
879
880 if config.admin_enabled {
883 let Some(ref auth_state_ref) = auth_state else {
884 return Err(anyhow::anyhow!(
885 "admin_enabled=true requires auth to be configured and enabled"
886 ));
887 };
888 let admin_state = crate::admin::AdminState {
889 started_at: std::time::Instant::now(),
890 name: config.name.clone(),
891 version: config.version.clone(),
892 auth: Some(Arc::clone(auth_state_ref)),
893 rbac: Arc::clone(&rbac_swap),
894 };
895 let admin_cfg = crate::admin::AdminConfig {
896 role: config.admin_role.clone(),
897 };
898 mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
899 tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
900 }
901
902 {
935 let tool_limiter: Option<Arc<ToolRateLimiter>> =
936 config.tool_rate_limit.map(build_tool_rate_limiter);
937
938 if rbac_swap.load().is_enabled() {
939 tracing::info!("RBAC enforcement enabled on /mcp");
940 }
941 if let Some(limit) = config.tool_rate_limit {
942 tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
943 }
944
945 let rbac_for_mw = Arc::clone(&rbac_swap);
946 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
947 let p = rbac_for_mw.load_full();
948 let tl = tool_limiter.clone();
949 rbac_middleware(p, tl, req, next)
950 }));
951 }
952
953 if let Some(ref auth_config) = config.auth
955 && auth_config.enabled
956 {
957 let Some(ref state) = auth_state else {
958 return Err(anyhow::anyhow!("auth state missing despite enabled config"));
959 };
960
961 let methods: Vec<&str> = [
962 auth_config.mtls.is_some().then_some("mTLS"),
963 (!auth_config.api_keys.is_empty()).then_some("bearer"),
964 #[cfg(feature = "oauth")]
965 auth_config.oauth.is_some().then_some("oauth-jwt"),
966 ]
967 .into_iter()
968 .flatten()
969 .collect();
970
971 tracing::info!(
972 methods = %methods.join(", "),
973 api_keys = auth_config.api_keys.len(),
974 "auth enabled on /mcp"
975 );
976
977 let state_for_mw = Arc::clone(state);
978 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
979 let s = Arc::clone(&state_for_mw);
980 auth_middleware(s, req, next)
981 }));
982 }
983
984 mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
987 axum::http::StatusCode::REQUEST_TIMEOUT,
988 config.request_timeout,
989 ));
990
991 mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
995 config.max_request_body,
996 ));
997
998 let mut effective_origins = config.allowed_origins.clone();
1005 if effective_origins.is_empty()
1006 && let Some(ref url) = config.public_url
1007 {
1008 if let Some(scheme_end) = url.find("://") {
1011 let after_scheme = &url[scheme_end + 3..];
1012 let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1013 let origin = format!("{}{}", &url[..scheme_end + 3], &after_scheme[..host_end]);
1014 tracing::info!(
1015 %origin,
1016 "auto-derived allowed origin from public_url"
1017 );
1018 effective_origins.push(origin);
1019 }
1020 }
1021 let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1022 let cors_origins = Arc::clone(&allowed_origins);
1023 let log_request_headers = config.log_request_headers;
1024
1025 let readyz_route = if let Some(check) = config.readiness_check.take() {
1026 axum::routing::get(move || readyz(Arc::clone(&check)))
1027 } else {
1028 axum::routing::get(healthz)
1029 };
1030
1031 #[allow(unused_mut)] let mut router = axum::Router::new()
1033 .route("/healthz", axum::routing::get(healthz))
1034 .route("/readyz", readyz_route)
1035 .route(
1036 "/version",
1037 axum::routing::get({
1038 let payload_bytes: Arc<[u8]> =
1043 serialize_version_payload(&config.name, &config.version);
1044 move || {
1045 let p = Arc::clone(&payload_bytes);
1046 async move {
1047 (
1048 [(axum::http::header::CONTENT_TYPE, "application/json")],
1049 p.to_vec(),
1050 )
1051 }
1052 }
1053 }),
1054 )
1055 .merge(mcp_router);
1056
1057 if let Some(extra) = config.extra_router.take() {
1059 router = router.merge(extra);
1060 }
1061
1062 let server_url = if let Some(ref url) = config.public_url {
1069 url.trim_end_matches('/').to_owned()
1070 } else {
1071 let prm_scheme = if config.tls_cert_path.is_some() {
1072 "https"
1073 } else {
1074 "http"
1075 };
1076 format!("{prm_scheme}://{}", config.bind_addr)
1077 };
1078 let resource_url = format!("{server_url}/mcp");
1079
1080 #[cfg(feature = "oauth")]
1081 let prm_metadata = if let Some(ref auth_config) = config.auth
1082 && let Some(ref oauth_config) = auth_config.oauth
1083 {
1084 crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1085 } else {
1086 serde_json::json!({ "resource": resource_url })
1087 };
1088 #[cfg(not(feature = "oauth"))]
1089 let prm_metadata = serde_json::json!({ "resource": resource_url });
1090
1091 router = router.route(
1092 "/.well-known/oauth-protected-resource",
1093 axum::routing::get(move || {
1094 let m = prm_metadata.clone();
1095 async move { axum::Json(m) }
1096 }),
1097 );
1098
1099 #[cfg(feature = "oauth")]
1104 if let Some(ref auth_config) = config.auth
1105 && let Some(ref oauth_config) = auth_config.oauth
1106 && oauth_config.proxy.is_some()
1107 {
1108 router = install_oauth_proxy_routes(router, &server_url, oauth_config)?;
1109 }
1110
1111 let is_tls = config.tls_cert_path.is_some();
1114 router = router.layer(axum::middleware::from_fn(move |req, next| {
1115 security_headers_middleware(is_tls, req, next)
1116 }));
1117
1118 if !cors_origins.is_empty() {
1122 let cors = tower_http::cors::CorsLayer::new()
1123 .allow_origin(
1124 cors_origins
1125 .iter()
1126 .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1127 .collect::<Vec<_>>(),
1128 )
1129 .allow_methods([
1130 axum::http::Method::GET,
1131 axum::http::Method::POST,
1132 axum::http::Method::OPTIONS,
1133 ])
1134 .allow_headers([
1135 axum::http::header::CONTENT_TYPE,
1136 axum::http::header::AUTHORIZATION,
1137 ]);
1138 router = router.layer(cors);
1139 }
1140
1141 if config.compression_enabled {
1145 use tower_http::compression::Predicate as _;
1146 let predicate = tower_http::compression::DefaultPredicate::new().and(
1147 tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1148 );
1149 router = router.layer(
1150 tower_http::compression::CompressionLayer::new()
1151 .gzip(true)
1152 .br(true)
1153 .compress_when(predicate),
1154 );
1155 tracing::info!(
1156 min_size = config.compression_min_size,
1157 "response compression enabled (gzip, br)"
1158 );
1159 }
1160
1161 if let Some(max) = config.max_concurrent_requests {
1164 let overload_handler = tower::ServiceBuilder::new()
1165 .layer(axum::error_handling::HandleErrorLayer::new(
1166 |_err: tower::BoxError| async {
1167 (
1168 axum::http::StatusCode::SERVICE_UNAVAILABLE,
1169 axum::Json(serde_json::json!({
1170 "error": "overloaded",
1171 "error_description": "server is at capacity, retry later"
1172 })),
1173 )
1174 },
1175 ))
1176 .layer(tower::load_shed::LoadShedLayer::new())
1177 .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1178 router = router.layer(overload_handler);
1179 tracing::info!(max, "global concurrency limit enabled");
1180 }
1181
1182 router = router.fallback(|| async {
1186 (
1187 axum::http::StatusCode::NOT_FOUND,
1188 axum::Json(serde_json::json!({
1189 "error": "not_found",
1190 "error_description": "The requested endpoint does not exist"
1191 })),
1192 )
1193 });
1194
1195 #[cfg(feature = "metrics")]
1197 if config.metrics_enabled {
1198 let metrics = Arc::new(
1199 crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1200 );
1201 let m = Arc::clone(&metrics);
1202 router = router.layer(axum::middleware::from_fn(
1203 move |req: Request<Body>, next: Next| {
1204 let m = Arc::clone(&m);
1205 metrics_middleware(m, req, next)
1206 },
1207 ));
1208 let metrics_bind = config.metrics_bind.clone();
1209 tokio::spawn(async move {
1210 if let Err(e) = crate::metrics::serve_metrics(metrics_bind, metrics).await {
1211 tracing::error!("metrics listener failed: {e}");
1212 }
1213 });
1214 }
1215
1216 router = router.layer(axum::middleware::from_fn(move |req, next| {
1227 let origins = Arc::clone(&allowed_origins);
1228 origin_check_middleware(origins, log_request_headers, req, next)
1229 }));
1230
1231 let scheme = if config.tls_cert_path.is_some() {
1232 "https"
1233 } else {
1234 "http"
1235 };
1236
1237 let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1238 (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1239 _ => None,
1240 };
1241 let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1242
1243 Ok((
1244 router,
1245 AppRunParams {
1246 tls_paths,
1247 mtls_config,
1248 shutdown_timeout: config.shutdown_timeout,
1249 auth_state,
1250 rbac_swap,
1251 on_reload_ready: config.on_reload_ready.take(),
1252 ct,
1253 scheme,
1254 name: config.name.clone(),
1255 },
1256 ))
1257}
1258
1259pub async fn serve<H, F>(
1276 config: Validated<McpServerConfig>,
1277 handler_factory: F,
1278) -> Result<(), McpxError>
1279where
1280 H: ServerHandler + 'static,
1281 F: Fn() -> H + Send + Sync + Clone + 'static,
1282{
1283 let config = config.into_inner();
1284 #[allow(
1285 deprecated,
1286 reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1287 )]
1288 let bind_addr = config.bind_addr.clone();
1289 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1290
1291 let listener = TcpListener::bind(&bind_addr)
1292 .await
1293 .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1294 log_listening(¶ms.name, params.scheme, &bind_addr);
1295
1296 run_server(
1297 router,
1298 listener,
1299 params.tls_paths,
1300 params.mtls_config,
1301 params.shutdown_timeout,
1302 params.auth_state,
1303 params.rbac_swap,
1304 params.on_reload_ready,
1305 params.ct,
1306 )
1307 .await
1308 .map_err(anyhow_to_startup)
1309}
1310
1311pub async fn serve_with_listener<H, F>(
1341 listener: TcpListener,
1342 config: Validated<McpServerConfig>,
1343 handler_factory: F,
1344 ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1345 shutdown: Option<CancellationToken>,
1346) -> Result<(), McpxError>
1347where
1348 H: ServerHandler + 'static,
1349 F: Fn() -> H + Send + Sync + Clone + 'static,
1350{
1351 let config = config.into_inner();
1352 let local_addr = listener
1353 .local_addr()
1354 .map_err(|e| io_to_startup("listener.local_addr", e))?;
1355 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1356
1357 log_listening(¶ms.name, params.scheme, &local_addr.to_string());
1358
1359 if let Some(external) = shutdown {
1363 let internal = params.ct.clone();
1364 tokio::spawn(async move {
1365 external.cancelled().await;
1366 internal.cancel();
1367 });
1368 }
1369
1370 if let Some(tx) = ready_tx {
1374 let _ = tx.send(local_addr);
1376 }
1377
1378 run_server(
1379 router,
1380 listener,
1381 params.tls_paths,
1382 params.mtls_config,
1383 params.shutdown_timeout,
1384 params.auth_state,
1385 params.rbac_swap,
1386 params.on_reload_ready,
1387 params.ct,
1388 )
1389 .await
1390 .map_err(anyhow_to_startup)
1391}
1392
1393#[allow(
1396 clippy::cognitive_complexity,
1397 reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1398)]
1399fn log_listening(name: &str, scheme: &str, addr: &str) {
1400 tracing::info!("{name} listening on {addr}");
1401 tracing::info!(" MCP endpoint: {scheme}://{addr}/mcp");
1402 tracing::info!(" Health check: {scheme}://{addr}/healthz");
1403 tracing::info!(" Readiness: {scheme}://{addr}/readyz");
1404}
1405
1406#[allow(
1429 clippy::too_many_arguments,
1430 clippy::cognitive_complexity,
1431 reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1432)]
1433async fn run_server(
1434 router: axum::Router,
1435 listener: TcpListener,
1436 tls_paths: Option<(PathBuf, PathBuf)>,
1437 mtls_config: Option<MtlsConfig>,
1438 shutdown_timeout: Duration,
1439 auth_state: Option<Arc<AuthState>>,
1440 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1441 mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1442 ct: CancellationToken,
1443) -> anyhow::Result<()> {
1444 let shutdown_trigger = CancellationToken::new();
1448 {
1449 let trigger = shutdown_trigger.clone();
1450 let parent = ct.clone();
1451 tokio::spawn(async move {
1452 tokio::select! {
1453 () = shutdown_signal() => {}
1454 () = parent.cancelled() => {}
1455 }
1456 trigger.cancel();
1457 });
1458 }
1459
1460 let graceful = {
1461 let trigger = shutdown_trigger.clone();
1462 let ct = ct.clone();
1463 async move {
1464 trigger.cancelled().await;
1465 tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1466 ct.cancel();
1467 }
1468 };
1469
1470 let force_exit_timer = {
1471 let trigger = shutdown_trigger.clone();
1472 async move {
1473 trigger.cancelled().await;
1474 tokio::time::sleep(shutdown_timeout).await;
1475 }
1476 };
1477
1478 if let Some((cert_path, key_path)) = tls_paths {
1479 let crl_set = if let Some(mtls) = mtls_config.as_ref()
1480 && mtls.crl_enabled
1481 {
1482 let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1483 let (crl_set, discover_rx) =
1484 mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1485 .await
1486 .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1487 tokio::spawn(mtls_revocation::run_crl_refresher(
1488 Arc::clone(&crl_set),
1489 discover_rx,
1490 ct.clone(),
1491 ));
1492 Some(crl_set)
1493 } else {
1494 None
1495 };
1496
1497 if let Some(cb) = on_reload_ready.take() {
1498 cb(ReloadHandle {
1499 auth: auth_state.clone(),
1500 rbac: Some(Arc::clone(&rbac_swap)),
1501 crl_set: crl_set.clone(),
1502 });
1503 }
1504
1505 let tls_listener = TlsListener::new(
1506 listener,
1507 &cert_path,
1508 &key_path,
1509 mtls_config.as_ref(),
1510 crl_set,
1511 )?;
1512 let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1513 tokio::select! {
1514 result = axum::serve(tls_listener, make_svc)
1515 .with_graceful_shutdown(graceful) => { result?; }
1516 () = force_exit_timer => {
1517 tracing::warn!("shutdown timeout exceeded, forcing exit");
1518 }
1519 }
1520 } else {
1521 if let Some(cb) = on_reload_ready.take() {
1522 cb(ReloadHandle {
1523 auth: auth_state,
1524 rbac: Some(rbac_swap),
1525 crl_set: None,
1526 });
1527 }
1528
1529 let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1530 tokio::select! {
1531 result = axum::serve(listener, make_svc)
1532 .with_graceful_shutdown(graceful) => { result?; }
1533 () = force_exit_timer => {
1534 tracing::warn!("shutdown timeout exceeded, forcing exit");
1535 }
1536 }
1537 }
1538
1539 Ok(())
1540}
1541
1542#[cfg(feature = "oauth")]
1551fn install_oauth_proxy_routes(
1552 router: axum::Router,
1553 server_url: &str,
1554 oauth_config: &crate::oauth::OAuthConfig,
1555) -> Result<axum::Router, McpxError> {
1556 let Some(ref proxy) = oauth_config.proxy else {
1557 return Ok(router);
1558 };
1559
1560 let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1563
1564 let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1565 let router = router.route(
1566 "/.well-known/oauth-authorization-server",
1567 axum::routing::get(move || {
1568 let m = asm.clone();
1569 async move { axum::Json(m) }
1570 }),
1571 );
1572
1573 let proxy_authorize = proxy.clone();
1574 let router = router.route(
1575 "/authorize",
1576 axum::routing::get(
1577 move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1578 let p = proxy_authorize.clone();
1579 async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1580 },
1581 ),
1582 );
1583
1584 let proxy_token = proxy.clone();
1585 let token_http = http.clone();
1586 let router = router.route(
1587 "/token",
1588 axum::routing::post(move |body: String| {
1589 let p = proxy_token.clone();
1590 let h = token_http.clone();
1591 async move { crate::oauth::handle_token(&h, &p, &body).await }
1592 }),
1593 );
1594
1595 let proxy_register = proxy.clone();
1596 let router = router.route(
1597 "/register",
1598 axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1599 let p = proxy_register;
1600 async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1601 }),
1602 );
1603
1604 let router = if proxy.expose_admin_endpoints && proxy.introspection_url.is_some() {
1605 let proxy_introspect = proxy.clone();
1606 let introspect_http = http.clone();
1607 router.route(
1608 "/introspect",
1609 axum::routing::post(move |body: String| {
1610 let p = proxy_introspect.clone();
1611 let h = introspect_http.clone();
1612 async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1613 }),
1614 )
1615 } else {
1616 router
1617 };
1618
1619 let router = if proxy.expose_admin_endpoints && proxy.revocation_url.is_some() {
1620 let proxy_revoke = proxy.clone();
1621 let revoke_http = http;
1622 router.route(
1623 "/revoke",
1624 axum::routing::post(move |body: String| {
1625 let p = proxy_revoke.clone();
1626 let h = revoke_http.clone();
1627 async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1628 }),
1629 )
1630 } else {
1631 router
1632 };
1633
1634 tracing::info!(
1635 introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1636 revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1637 "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1638 );
1639 Ok(router)
1640}
1641
1642fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
1647 let mut hosts = vec![
1648 "localhost".to_owned(),
1649 "127.0.0.1".to_owned(),
1650 "::1".to_owned(),
1651 ];
1652
1653 if let Some(url) = public_url
1654 && let Ok(uri) = url.parse::<axum::http::Uri>()
1655 && let Some(authority) = uri.authority()
1656 {
1657 let host = authority.host().to_owned();
1658 if !hosts.iter().any(|h| h == &host) {
1659 hosts.push(host);
1660 }
1661
1662 let authority = authority.as_str().to_owned();
1663 if !hosts.iter().any(|h| h == &authority) {
1664 hosts.push(authority);
1665 }
1666 }
1667
1668 if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
1669 && let Some(authority) = uri.authority()
1670 {
1671 let host = authority.host().to_owned();
1672 if !hosts.iter().any(|h| h == &host) {
1673 hosts.push(host);
1674 }
1675
1676 let authority = authority.as_str().to_owned();
1677 if !hosts.iter().any(|h| h == &authority) {
1678 hosts.push(authority);
1679 }
1680 }
1681
1682 hosts
1683}
1684
1685impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
1698 for TlsConnInfo
1699{
1700 fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
1701 let addr = *target.remote_addr();
1702 let identity = target.io().identity().cloned();
1703 TlsConnInfo::new(addr, identity)
1704 }
1705}
1706
1707struct TlsListener {
1715 inner: TcpListener,
1716 acceptor: tokio_rustls::TlsAcceptor,
1717 mtls_default_role: String,
1718}
1719
1720impl TlsListener {
1721 fn new(
1722 inner: TcpListener,
1723 cert_path: &Path,
1724 key_path: &Path,
1725 mtls_config: Option<&MtlsConfig>,
1726 crl_set: Option<Arc<CrlSet>>,
1727 ) -> anyhow::Result<Self> {
1728 rustls::crypto::ring::default_provider()
1730 .install_default()
1731 .ok();
1732
1733 let certs = load_certs(cert_path)?;
1734 let key = load_key(key_path)?;
1735
1736 let mtls_default_role;
1737
1738 let tls_config = if let Some(mtls) = mtls_config {
1739 mtls_default_role = mtls.default_role.clone();
1740 let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
1741 {
1742 let Some(crl_set) = crl_set else {
1743 return Err(anyhow::anyhow!(
1744 "mTLS CRL verifier requested but CRL state was not initialized"
1745 ));
1746 };
1747 Arc::new(DynamicClientCertVerifier::new(crl_set))
1748 } else {
1749 let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
1750 if mtls.required {
1751 rustls::server::WebPkiClientVerifier::builder(root_store)
1752 .build()
1753 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1754 } else {
1755 rustls::server::WebPkiClientVerifier::builder(root_store)
1756 .allow_unauthenticated()
1757 .build()
1758 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1759 }
1760 };
1761
1762 tracing::info!(
1763 ca = %mtls.ca_cert_path.display(),
1764 required = mtls.required,
1765 crl_enabled = mtls.crl_enabled,
1766 "mTLS client auth configured"
1767 );
1768
1769 rustls::ServerConfig::builder_with_protocol_versions(&[
1770 &rustls::version::TLS12,
1771 &rustls::version::TLS13,
1772 ])
1773 .with_client_cert_verifier(verifier)
1774 .with_single_cert(certs, key)?
1775 } else {
1776 mtls_default_role = "viewer".to_owned();
1777 rustls::ServerConfig::builder_with_protocol_versions(&[
1778 &rustls::version::TLS12,
1779 &rustls::version::TLS13,
1780 ])
1781 .with_no_client_auth()
1782 .with_single_cert(certs, key)?
1783 };
1784
1785 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
1786 tracing::info!(
1787 "TLS enabled (cert: {}, key: {})",
1788 cert_path.display(),
1789 key_path.display()
1790 );
1791 Ok(Self {
1792 inner,
1793 acceptor,
1794 mtls_default_role,
1795 })
1796 }
1797
1798 fn extract_handshake_identity(
1802 tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1803 default_role: &str,
1804 addr: SocketAddr,
1805 ) -> Option<AuthIdentity> {
1806 let (_, server_conn) = tls_stream.get_ref();
1807 let cert_der = server_conn.peer_certificates()?.first()?;
1808 let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
1809 tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
1810 Some(id)
1811 }
1812}
1813
1814pub(crate) struct AuthenticatedTlsStream {
1826 inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1827 identity: Option<AuthIdentity>,
1828}
1829
1830impl AuthenticatedTlsStream {
1831 #[must_use]
1833 pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
1834 self.identity.as_ref()
1835 }
1836}
1837
1838impl std::fmt::Debug for AuthenticatedTlsStream {
1839 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1840 f.debug_struct("AuthenticatedTlsStream")
1841 .field("identity", &self.identity.as_ref().map(|id| &id.name))
1842 .finish_non_exhaustive()
1843 }
1844}
1845
1846impl tokio::io::AsyncRead for AuthenticatedTlsStream {
1847 fn poll_read(
1848 mut self: Pin<&mut Self>,
1849 cx: &mut std::task::Context<'_>,
1850 buf: &mut tokio::io::ReadBuf<'_>,
1851 ) -> std::task::Poll<std::io::Result<()>> {
1852 Pin::new(&mut self.inner).poll_read(cx, buf)
1853 }
1854}
1855
1856impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
1857 fn poll_write(
1858 mut self: Pin<&mut Self>,
1859 cx: &mut std::task::Context<'_>,
1860 buf: &[u8],
1861 ) -> std::task::Poll<std::io::Result<usize>> {
1862 Pin::new(&mut self.inner).poll_write(cx, buf)
1863 }
1864
1865 fn poll_flush(
1866 mut self: Pin<&mut Self>,
1867 cx: &mut std::task::Context<'_>,
1868 ) -> std::task::Poll<std::io::Result<()>> {
1869 Pin::new(&mut self.inner).poll_flush(cx)
1870 }
1871
1872 fn poll_shutdown(
1873 mut self: Pin<&mut Self>,
1874 cx: &mut std::task::Context<'_>,
1875 ) -> std::task::Poll<std::io::Result<()>> {
1876 Pin::new(&mut self.inner).poll_shutdown(cx)
1877 }
1878
1879 fn poll_write_vectored(
1880 mut self: Pin<&mut Self>,
1881 cx: &mut std::task::Context<'_>,
1882 bufs: &[std::io::IoSlice<'_>],
1883 ) -> std::task::Poll<std::io::Result<usize>> {
1884 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
1885 }
1886
1887 fn is_write_vectored(&self) -> bool {
1888 self.inner.is_write_vectored()
1889 }
1890}
1891
1892impl axum::serve::Listener for TlsListener {
1893 type Io = AuthenticatedTlsStream;
1894 type Addr = SocketAddr;
1895
1896 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
1897 loop {
1898 let (stream, addr) = match self.inner.accept().await {
1899 Ok(pair) => pair,
1900 Err(e) => {
1901 tracing::debug!("TCP accept error: {e}");
1902 continue;
1903 }
1904 };
1905 let tls_stream = match self.acceptor.accept(stream).await {
1906 Ok(s) => s,
1907 Err(e) => {
1908 tracing::debug!("TLS handshake failed from {addr}: {e}");
1909 continue;
1910 }
1911 };
1912 let identity =
1913 Self::extract_handshake_identity(&tls_stream, &self.mtls_default_role, addr);
1914 let wrapped = AuthenticatedTlsStream {
1915 inner: tls_stream,
1916 identity,
1917 };
1918 return (wrapped, addr);
1919 }
1920 }
1921
1922 fn local_addr(&self) -> std::io::Result<Self::Addr> {
1923 self.inner.local_addr()
1924 }
1925}
1926
1927fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
1928 use rustls::pki_types::pem::PemObject;
1929 let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
1930 .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
1931 .collect::<Result<_, _>>()
1932 .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
1933 anyhow::ensure!(
1934 !certs.is_empty(),
1935 "no certificates found in {}",
1936 path.display()
1937 );
1938 Ok(certs)
1939}
1940
1941fn load_client_auth_roots(
1942 path: &Path,
1943) -> anyhow::Result<(
1944 Vec<rustls::pki_types::CertificateDer<'static>>,
1945 Arc<RootCertStore>,
1946)> {
1947 let ca_certs = load_certs(path)?;
1948 let mut root_store = RootCertStore::empty();
1949 for cert in &ca_certs {
1950 root_store
1951 .add(cert.clone())
1952 .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
1953 }
1954
1955 Ok((ca_certs, Arc::new(root_store)))
1956}
1957
1958fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
1959 use rustls::pki_types::pem::PemObject;
1960 rustls::pki_types::PrivateKeyDer::from_pem_file(path)
1961 .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
1962}
1963
1964#[allow(clippy::unused_async)]
1965async fn healthz() -> impl IntoResponse {
1966 axum::Json(serde_json::json!({
1967 "status": "ok",
1968 }))
1969}
1970
1971fn version_payload(name: &str, version: &str) -> serde_json::Value {
1977 serde_json::json!({
1978 "name": name,
1979 "version": version,
1980 "build_git_sha": option_env!("MCPX_BUILD_SHA").unwrap_or("unknown"),
1981 "build_timestamp": option_env!("MCPX_BUILD_TIME").unwrap_or("unknown"),
1982 "rust_version": option_env!("MCPX_RUSTC_VERSION").unwrap_or("unknown"),
1983 "mcpx_version": env!("CARGO_PKG_VERSION"),
1984 })
1985}
1986
1987fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
1997 let value = version_payload(name, version);
1998 serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
1999}
2000
2001async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2002 let status = check().await;
2003 let ready = status
2004 .get("ready")
2005 .and_then(serde_json::Value::as_bool)
2006 .unwrap_or(false);
2007 let code = if ready {
2008 axum::http::StatusCode::OK
2009 } else {
2010 axum::http::StatusCode::SERVICE_UNAVAILABLE
2011 };
2012 (code, axum::Json(status))
2013}
2014
2015async fn shutdown_signal() {
2019 let ctrl_c = tokio::signal::ctrl_c();
2020
2021 #[cfg(unix)]
2022 {
2023 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2024 Ok(mut term) => {
2025 tokio::select! {
2026 _ = ctrl_c => {}
2027 _ = term.recv() => {}
2028 }
2029 }
2030 Err(e) => {
2031 tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2032 ctrl_c.await.ok();
2033 }
2034 }
2035 }
2036
2037 #[cfg(not(unix))]
2038 {
2039 ctrl_c.await.ok();
2040 }
2041}
2042
2043#[cfg(feature = "metrics")]
2049async fn metrics_middleware(
2050 metrics: Arc<crate::metrics::McpMetrics>,
2051 req: Request<Body>,
2052 next: Next,
2053) -> axum::response::Response {
2054 let method = req.method().to_string();
2055 let path = req.uri().path().to_owned();
2056 let start = std::time::Instant::now();
2057
2058 let response = next.run(req).await;
2059
2060 let status = response.status().as_u16().to_string();
2061 let duration = start.elapsed().as_secs_f64();
2062
2063 metrics
2064 .http_requests_total
2065 .with_label_values(&[&method, &path, &status])
2066 .inc();
2067 metrics
2068 .http_request_duration_seconds
2069 .with_label_values(&[&method, &path])
2070 .observe(duration);
2071
2072 response
2073}
2074
2075async fn security_headers_middleware(
2083 is_tls: bool,
2084 req: Request<Body>,
2085 next: Next,
2086) -> axum::response::Response {
2087 use axum::http::{HeaderName, HeaderValue, header};
2088
2089 let mut resp = next.run(req).await;
2090 let headers = resp.headers_mut();
2091
2092 headers.remove(header::SERVER);
2094 headers.remove(HeaderName::from_static("x-powered-by"));
2095
2096 headers.insert(
2097 header::X_CONTENT_TYPE_OPTIONS,
2098 HeaderValue::from_static("nosniff"),
2099 );
2100 headers.insert(header::X_FRAME_OPTIONS, HeaderValue::from_static("deny"));
2101 headers.insert(
2102 header::CACHE_CONTROL,
2103 HeaderValue::from_static("no-store, max-age=0"),
2104 );
2105 headers.insert(
2106 header::REFERRER_POLICY,
2107 HeaderValue::from_static("no-referrer"),
2108 );
2109 headers.insert(
2110 HeaderName::from_static("cross-origin-opener-policy"),
2111 HeaderValue::from_static("same-origin"),
2112 );
2113 headers.insert(
2114 HeaderName::from_static("cross-origin-resource-policy"),
2115 HeaderValue::from_static("same-origin"),
2116 );
2117 headers.insert(
2118 HeaderName::from_static("cross-origin-embedder-policy"),
2119 HeaderValue::from_static("require-corp"),
2120 );
2121 headers.insert(
2122 HeaderName::from_static("permissions-policy"),
2123 HeaderValue::from_static("accelerometer=(), camera=(), geolocation=(), microphone=()"),
2124 );
2125 headers.insert(
2126 HeaderName::from_static("x-permitted-cross-domain-policies"),
2127 HeaderValue::from_static("none"),
2128 );
2129 headers.insert(
2130 HeaderName::from_static("content-security-policy"),
2131 HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'"),
2132 );
2133 headers.insert(
2134 HeaderName::from_static("x-dns-prefetch-control"),
2135 HeaderValue::from_static("off"),
2136 );
2137
2138 if is_tls {
2139 headers.insert(
2140 header::STRICT_TRANSPORT_SECURITY,
2141 HeaderValue::from_static("max-age=63072000; includeSubDomains"),
2142 );
2143 }
2144
2145 resp
2146}
2147
2148async fn origin_check_middleware(
2152 allowed: Arc<[String]>,
2153 log_request_headers: bool,
2154 req: Request<Body>,
2155 next: Next,
2156) -> axum::response::Response {
2157 let method = req.method().clone();
2158 let path = req.uri().path().to_owned();
2159
2160 log_incoming_request(&method, &path, req.headers(), log_request_headers);
2161
2162 if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2163 let origin_str = origin.to_str().unwrap_or("");
2164 if !allowed.iter().any(|a| a == origin_str) {
2165 tracing::warn!(
2166 origin = origin_str,
2167 %method,
2168 %path,
2169 allowed = ?&*allowed,
2170 "rejected request: Origin not allowed"
2171 );
2172 return (
2173 axum::http::StatusCode::FORBIDDEN,
2174 "Forbidden: Origin not allowed",
2175 )
2176 .into_response();
2177 }
2178 }
2179 next.run(req).await
2180}
2181
2182fn log_incoming_request(
2185 method: &axum::http::Method,
2186 path: &str,
2187 headers: &axum::http::HeaderMap,
2188 log_request_headers: bool,
2189) {
2190 if log_request_headers {
2191 tracing::debug!(
2192 %method,
2193 %path,
2194 headers = %format_request_headers_for_log(headers),
2195 "incoming request"
2196 );
2197 } else {
2198 tracing::debug!(%method, %path, "incoming request");
2199 }
2200}
2201
2202fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2203 headers
2204 .iter()
2205 .map(|(k, v)| {
2206 let name = k.as_str();
2207 if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2208 format!("{name}: [REDACTED]")
2209 } else {
2210 format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2211 }
2212 })
2213 .collect::<Vec<_>>()
2214 .join(", ")
2215}
2216
2217#[allow(clippy::cognitive_complexity)]
2241pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2242where
2243 H: ServerHandler + 'static,
2244{
2245 use rmcp::ServiceExt as _;
2246
2247 tracing::info!("stdio transport: serving on stdin/stdout");
2248 tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2249
2250 let transport = rmcp::transport::io::stdio();
2251
2252 let service = handler
2253 .serve(transport)
2254 .await
2255 .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2256
2257 if let Err(e) = service.waiting().await {
2258 tracing::warn!(error = %e, "stdio session ended with error");
2259 }
2260 tracing::info!("stdio session ended");
2261 Ok(())
2262}
2263
2264#[cfg(test)]
2265mod tests {
2266 #![allow(
2267 clippy::unwrap_used,
2268 clippy::expect_used,
2269 clippy::panic,
2270 clippy::indexing_slicing,
2271 clippy::unwrap_in_result,
2272 clippy::print_stdout,
2273 clippy::print_stderr,
2274 deprecated,
2275 reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2276 )]
2277 use std::sync::Arc;
2278
2279 use axum::{
2280 body::Body,
2281 http::{Request, StatusCode, header},
2282 response::IntoResponse,
2283 };
2284 use http_body_util::BodyExt;
2285 use tower::ServiceExt as _;
2286
2287 use super::*;
2288
2289 #[test]
2292 fn server_config_new_defaults() {
2293 let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2294 assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2295 assert_eq!(cfg.name, "test-server");
2296 assert_eq!(cfg.version, "1.0.0");
2297 assert!(cfg.tls_cert_path.is_none());
2298 assert!(cfg.tls_key_path.is_none());
2299 assert!(cfg.auth.is_none());
2300 assert!(cfg.rbac.is_none());
2301 assert!(cfg.allowed_origins.is_empty());
2302 assert!(cfg.tool_rate_limit.is_none());
2303 assert!(cfg.readiness_check.is_none());
2304 assert_eq!(cfg.max_request_body, 1024 * 1024);
2305 assert_eq!(cfg.request_timeout, Duration::from_mins(2));
2306 assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
2307 assert!(!cfg.log_request_headers);
2308 }
2309
2310 #[test]
2311 fn validate_consumes_and_proves() {
2312 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2314 let validated = cfg.validate().expect("valid config");
2315 assert_eq!(validated.name, "test-server");
2317 let raw = validated.into_inner();
2319 assert_eq!(raw.name, "test-server");
2320
2321 let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2323 bad.max_request_body = 0;
2324 assert!(bad.validate().is_err(), "zero body cap must fail validate");
2325 }
2326
2327 #[test]
2328 fn derive_allowed_hosts_includes_public_host() {
2329 let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
2330 assert!(
2331 hosts.iter().any(|h| h == "mcp.example.com"),
2332 "public_url host must be allowed"
2333 );
2334 }
2335
2336 #[test]
2337 fn derive_allowed_hosts_includes_bind_authority() {
2338 let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
2339 assert!(
2340 hosts.iter().any(|h| h == "127.0.0.1"),
2341 "bind host must be allowed"
2342 );
2343 assert!(
2344 hosts.iter().any(|h| h == "127.0.0.1:8080"),
2345 "bind authority must be allowed"
2346 );
2347 }
2348
2349 #[tokio::test]
2352 async fn healthz_returns_ok_json() {
2353 let resp = healthz().await.into_response();
2354 assert_eq!(resp.status(), StatusCode::OK);
2355 let body = resp.into_body().collect().await.unwrap().to_bytes();
2356 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2357 assert_eq!(json["status"], "ok");
2358 assert!(
2359 json.get("name").is_none(),
2360 "healthz must not expose server name"
2361 );
2362 assert!(
2363 json.get("version").is_none(),
2364 "healthz must not expose version"
2365 );
2366 }
2367
2368 #[tokio::test]
2371 async fn readyz_returns_ok_when_ready() {
2372 let check: ReadinessCheck =
2373 Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
2374 let resp = readyz(check).await.into_response();
2375 assert_eq!(resp.status(), StatusCode::OK);
2376 let body = resp.into_body().collect().await.unwrap().to_bytes();
2377 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2378 assert_eq!(json["ready"], true);
2379 assert!(
2380 json.get("name").is_none(),
2381 "readyz must not expose server name"
2382 );
2383 assert!(
2384 json.get("version").is_none(),
2385 "readyz must not expose version"
2386 );
2387 assert_eq!(json["db"], "connected");
2388 }
2389
2390 #[tokio::test]
2391 async fn readyz_returns_503_when_not_ready() {
2392 let check: ReadinessCheck =
2393 Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
2394 let resp = readyz(check).await.into_response();
2395 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2396 }
2397
2398 #[tokio::test]
2399 async fn readyz_returns_503_when_ready_missing() {
2400 let check: ReadinessCheck =
2401 Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
2402 let resp = readyz(check).await.into_response();
2403 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2405 }
2406
2407 fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
2411 let allowed: Arc<[String]> = Arc::from(origins);
2412 axum::Router::new()
2413 .route("/test", axum::routing::get(|| async { "ok" }))
2414 .layer(axum::middleware::from_fn(move |req, next| {
2415 let a = Arc::clone(&allowed);
2416 origin_check_middleware(a, log_request_headers, req, next)
2417 }))
2418 }
2419
2420 #[tokio::test]
2421 async fn origin_allowed_passes() {
2422 let app = origin_router(vec!["http://localhost:3000".into()], false);
2423 let req = Request::builder()
2424 .uri("/test")
2425 .header(header::ORIGIN, "http://localhost:3000")
2426 .body(Body::empty())
2427 .unwrap();
2428 let resp = app.oneshot(req).await.unwrap();
2429 assert_eq!(resp.status(), StatusCode::OK);
2430 }
2431
2432 #[tokio::test]
2433 async fn origin_rejected_returns_403() {
2434 let app = origin_router(vec!["http://localhost:3000".into()], false);
2435 let req = Request::builder()
2436 .uri("/test")
2437 .header(header::ORIGIN, "http://evil.com")
2438 .body(Body::empty())
2439 .unwrap();
2440 let resp = app.oneshot(req).await.unwrap();
2441 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2442 }
2443
2444 #[tokio::test]
2445 async fn no_origin_header_passes() {
2446 let app = origin_router(vec!["http://localhost:3000".into()], false);
2447 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2448 let resp = app.oneshot(req).await.unwrap();
2449 assert_eq!(resp.status(), StatusCode::OK);
2450 }
2451
2452 #[tokio::test]
2453 async fn empty_allowlist_rejects_any_origin() {
2454 let app = origin_router(vec![], false);
2455 let req = Request::builder()
2456 .uri("/test")
2457 .header(header::ORIGIN, "http://anything.com")
2458 .body(Body::empty())
2459 .unwrap();
2460 let resp = app.oneshot(req).await.unwrap();
2461 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2462 }
2463
2464 #[tokio::test]
2465 async fn empty_allowlist_passes_without_origin() {
2466 let app = origin_router(vec![], false);
2467 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2468 let resp = app.oneshot(req).await.unwrap();
2469 assert_eq!(resp.status(), StatusCode::OK);
2470 }
2471
2472 #[test]
2473 fn format_request_headers_redacts_sensitive_values() {
2474 let mut headers = axum::http::HeaderMap::new();
2475 headers.insert("authorization", "Bearer secret-token".parse().unwrap());
2476 headers.insert("cookie", "sid=abc".parse().unwrap());
2477 headers.insert("x-request-id", "req-123".parse().unwrap());
2478
2479 let out = format_request_headers_for_log(&headers);
2480 assert!(out.contains("authorization: [REDACTED]"));
2481 assert!(out.contains("cookie: [REDACTED]"));
2482 assert!(out.contains("x-request-id: req-123"));
2483 assert!(!out.contains("secret-token"));
2484 }
2485
2486 fn security_router(is_tls: bool) -> axum::Router {
2489 axum::Router::new()
2490 .route("/test", axum::routing::get(|| async { "ok" }))
2491 .layer(axum::middleware::from_fn(move |req, next| {
2492 security_headers_middleware(is_tls, req, next)
2493 }))
2494 }
2495
2496 #[tokio::test]
2497 async fn security_headers_set_on_response() {
2498 let app = security_router(false);
2499 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2500 let resp = app.oneshot(req).await.unwrap();
2501 assert_eq!(resp.status(), StatusCode::OK);
2502
2503 let h = resp.headers();
2504 assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
2505 assert_eq!(h.get("x-frame-options").unwrap(), "deny");
2506 assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
2507 assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
2508 assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
2509 assert_eq!(
2510 h.get("cross-origin-resource-policy").unwrap(),
2511 "same-origin"
2512 );
2513 assert_eq!(
2514 h.get("cross-origin-embedder-policy").unwrap(),
2515 "require-corp"
2516 );
2517 assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
2518 assert!(
2519 h.get("permissions-policy")
2520 .unwrap()
2521 .to_str()
2522 .unwrap()
2523 .contains("camera=()"),
2524 "permissions-policy must restrict browser features"
2525 );
2526 assert_eq!(
2527 h.get("content-security-policy").unwrap(),
2528 "default-src 'none'; frame-ancestors 'none'"
2529 );
2530 assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
2531 assert!(h.get("strict-transport-security").is_none());
2533 }
2534
2535 #[tokio::test]
2536 async fn hsts_set_when_tls_enabled() {
2537 let app = security_router(true);
2538 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2539 let resp = app.oneshot(req).await.unwrap();
2540
2541 let hsts = resp.headers().get("strict-transport-security").unwrap();
2542 assert!(
2543 hsts.to_str().unwrap().contains("max-age=63072000"),
2544 "HSTS must set 2-year max-age"
2545 );
2546 }
2547
2548 #[test]
2551 fn version_payload_contains_expected_fields() {
2552 let v = version_payload("my-server", "1.2.3");
2553 assert_eq!(v["name"], "my-server");
2554 assert_eq!(v["version"], "1.2.3");
2555 assert!(v["build_git_sha"].is_string());
2556 assert!(v["build_timestamp"].is_string());
2557 assert!(v["rust_version"].is_string());
2558 assert!(v["mcpx_version"].is_string());
2559 }
2560
2561 #[tokio::test]
2564 async fn concurrency_limit_layer_composes_and_serves() {
2565 let app = axum::Router::new()
2569 .route("/ok", axum::routing::get(|| async { "ok" }))
2570 .layer(
2571 tower::ServiceBuilder::new()
2572 .layer(axum::error_handling::HandleErrorLayer::new(
2573 |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
2574 ))
2575 .layer(tower::load_shed::LoadShedLayer::new())
2576 .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
2577 );
2578 let resp = app
2579 .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
2580 .await
2581 .unwrap();
2582 assert_eq!(resp.status(), StatusCode::OK);
2583 }
2584
2585 #[tokio::test]
2588 async fn compression_layer_gzip_encodes_response() {
2589 use tower_http::compression::Predicate as _;
2590
2591 let big_body = "a".repeat(4096);
2592 let app = axum::Router::new()
2593 .route(
2594 "/big",
2595 axum::routing::get(move || {
2596 let body = big_body.clone();
2597 async move { body }
2598 }),
2599 )
2600 .layer(
2601 tower_http::compression::CompressionLayer::new()
2602 .gzip(true)
2603 .br(true)
2604 .compress_when(
2605 tower_http::compression::DefaultPredicate::new()
2606 .and(tower_http::compression::predicate::SizeAbove::new(1024)),
2607 ),
2608 );
2609
2610 let req = Request::builder()
2611 .uri("/big")
2612 .header(header::ACCEPT_ENCODING, "gzip")
2613 .body(Body::empty())
2614 .unwrap();
2615 let resp = app.oneshot(req).await.unwrap();
2616 assert_eq!(resp.status(), StatusCode::OK);
2617 assert_eq!(
2618 resp.headers().get(header::CONTENT_ENCODING).unwrap(),
2619 "gzip"
2620 );
2621 }
2622}