1use std::{
2 future::Future,
3 net::SocketAddr,
4 path::{Path, PathBuf},
5 pin::Pin,
6 sync::Arc,
7 time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{body::Body, extract::Request, middleware::Next, response::IntoResponse};
12use rmcp::{
13 ServerHandler,
14 transport::streamable_http_server::{
15 StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16 },
17};
18use rustls::RootCertStore;
19use tokio::net::TcpListener;
20use tokio_util::sync::CancellationToken;
21
22use crate::{
23 auth::{
24 AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
25 build_rate_limiter, extract_mtls_identity,
26 },
27 error::McpxError,
28 mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
29 rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
30};
31
32#[allow(
36 clippy::needless_pass_by_value,
37 reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
38)]
39fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
40 McpxError::Startup(format!("{e:#}"))
41}
42
43#[allow(
49 clippy::needless_pass_by_value,
50 reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
51)]
52fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
53 McpxError::Startup(format!("{op}: {e}"))
54}
55
56pub type ReadinessCheck =
61 Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
62
63#[derive(Debug, Clone, Default)]
84#[non_exhaustive]
85pub struct SecurityHeadersConfig {
86 pub x_content_type_options: Option<String>,
88 pub x_frame_options: Option<String>,
90 pub cache_control: Option<String>,
92 pub referrer_policy: Option<String>,
94 pub cross_origin_opener_policy: Option<String>,
96 pub cross_origin_resource_policy: Option<String>,
98 pub cross_origin_embedder_policy: Option<String>,
100 pub permissions_policy: Option<String>,
103 pub x_permitted_cross_domain_policies: Option<String>,
105 pub content_security_policy: Option<String>,
108 pub x_dns_prefetch_control: Option<String>,
110 pub strict_transport_security: Option<String>,
115}
116
117#[allow(
119 missing_debug_implementations,
120 reason = "contains callback/trait objects that don't impl Debug"
121)]
122#[allow(
123 clippy::struct_excessive_bools,
124 reason = "server configuration naturally has many boolean feature flags"
125)]
126#[non_exhaustive]
127pub struct McpServerConfig {
128 #[deprecated(
130 since = "0.13.0",
131 note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
132 )]
133 pub bind_addr: String,
134 #[deprecated(
136 since = "0.13.0",
137 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
138 )]
139 pub name: String,
140 #[deprecated(
142 since = "0.13.0",
143 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
144 )]
145 pub version: String,
146 #[deprecated(
148 since = "0.13.0",
149 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
150 )]
151 pub tls_cert_path: Option<PathBuf>,
152 #[deprecated(
154 since = "0.13.0",
155 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
156 )]
157 pub tls_key_path: Option<PathBuf>,
158 #[deprecated(
161 since = "0.13.0",
162 note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
163 )]
164 pub auth: Option<AuthConfig>,
165 #[deprecated(
168 since = "0.13.0",
169 note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
170 )]
171 pub rbac: Option<Arc<RbacPolicy>>,
172 #[deprecated(
178 since = "0.13.0",
179 note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
180 )]
181 pub allowed_origins: Vec<String>,
182 #[deprecated(
185 since = "0.13.0",
186 note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
187 )]
188 pub tool_rate_limit: Option<u32>,
189 #[deprecated(
192 since = "0.13.0",
193 note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
194 )]
195 pub readiness_check: Option<ReadinessCheck>,
196 #[deprecated(
199 since = "0.13.0",
200 note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
201 )]
202 pub max_request_body: usize,
203 #[deprecated(
206 since = "0.13.0",
207 note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
208 )]
209 pub request_timeout: Duration,
210 #[deprecated(
213 since = "0.13.0",
214 note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
215 )]
216 pub shutdown_timeout: Duration,
217 #[deprecated(
220 since = "0.13.0",
221 note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
222 )]
223 pub session_idle_timeout: Duration,
224 #[deprecated(
227 since = "0.13.0",
228 note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
229 )]
230 pub sse_keep_alive: Duration,
231 #[deprecated(
235 since = "0.13.0",
236 note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
237 )]
238 pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
239 #[deprecated(
243 since = "0.13.0",
244 note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
245 )]
246 pub extra_router: Option<axum::Router>,
247 #[deprecated(
252 since = "0.13.0",
253 note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
254 )]
255 pub public_url: Option<String>,
256 #[deprecated(
259 since = "0.13.0",
260 note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
261 )]
262 pub log_request_headers: bool,
263 #[deprecated(
266 since = "0.13.0",
267 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
268 )]
269 pub compression_enabled: bool,
270 #[deprecated(
273 since = "0.13.0",
274 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
275 )]
276 pub compression_min_size: u16,
277 #[deprecated(
281 since = "0.13.0",
282 note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
283 )]
284 pub max_concurrent_requests: Option<usize>,
285 #[deprecated(
288 since = "0.13.0",
289 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
290 )]
291 pub admin_enabled: bool,
292 #[deprecated(
294 since = "0.13.0",
295 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
296 )]
297 pub admin_role: String,
298 #[cfg(feature = "metrics")]
301 #[deprecated(
302 since = "0.13.0",
303 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
304 )]
305 pub metrics_enabled: bool,
306 #[cfg(feature = "metrics")]
308 #[deprecated(
309 since = "0.13.0",
310 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
311 )]
312 pub metrics_bind: String,
313 #[deprecated(
317 since = "1.5.0",
318 note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
319 )]
320 pub security_headers: SecurityHeadersConfig,
321}
322
323#[allow(
381 missing_debug_implementations,
382 reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
383)]
384pub struct Validated<T>(T);
385
386impl<T> std::fmt::Debug for Validated<T> {
387 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388 f.debug_struct("Validated").finish_non_exhaustive()
389 }
390}
391
392impl<T> Validated<T> {
393 #[must_use]
395 pub fn as_inner(&self) -> &T {
396 &self.0
397 }
398
399 #[must_use]
404 pub fn into_inner(self) -> T {
405 self.0
406 }
407}
408
409#[allow(
410 deprecated,
411 reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
412)]
413impl McpServerConfig {
414 #[must_use]
422 pub fn new(
423 bind_addr: impl Into<String>,
424 name: impl Into<String>,
425 version: impl Into<String>,
426 ) -> Self {
427 Self {
428 bind_addr: bind_addr.into(),
429 name: name.into(),
430 version: version.into(),
431 tls_cert_path: None,
432 tls_key_path: None,
433 auth: None,
434 rbac: None,
435 allowed_origins: Vec::new(),
436 tool_rate_limit: None,
437 readiness_check: None,
438 max_request_body: 1024 * 1024,
439 request_timeout: Duration::from_mins(2),
440 shutdown_timeout: Duration::from_secs(30),
441 session_idle_timeout: Duration::from_mins(20),
442 sse_keep_alive: Duration::from_secs(15),
443 on_reload_ready: None,
444 extra_router: None,
445 public_url: None,
446 log_request_headers: false,
447 compression_enabled: false,
448 compression_min_size: 1024,
449 max_concurrent_requests: None,
450 admin_enabled: false,
451 admin_role: "admin".to_owned(),
452 #[cfg(feature = "metrics")]
453 metrics_enabled: false,
454 #[cfg(feature = "metrics")]
455 metrics_bind: "127.0.0.1:9090".into(),
456 security_headers: SecurityHeadersConfig::default(),
457 }
458 }
459
460 #[must_use]
470 pub fn with_auth(mut self, auth: AuthConfig) -> Self {
471 self.auth = Some(auth);
472 self
473 }
474
475 #[must_use]
480 pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
481 self.security_headers = headers;
482 self
483 }
484
485 #[must_use]
489 pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
490 self.bind_addr = addr.into();
491 self
492 }
493
494 #[must_use]
497 pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
498 self.rbac = Some(rbac);
499 self
500 }
501
502 #[must_use]
506 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
507 self.tls_cert_path = Some(cert_path.into());
508 self.tls_key_path = Some(key_path.into());
509 self
510 }
511
512 #[must_use]
516 pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
517 self.public_url = Some(url.into());
518 self
519 }
520
521 #[must_use]
525 pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
526 where
527 I: IntoIterator<Item = S>,
528 S: Into<String>,
529 {
530 self.allowed_origins = origins.into_iter().map(Into::into).collect();
531 self
532 }
533
534 #[must_use]
538 pub fn with_extra_router(mut self, router: axum::Router) -> Self {
539 self.extra_router = Some(router);
540 self
541 }
542
543 #[must_use]
546 pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
547 self.readiness_check = Some(check);
548 self
549 }
550
551 #[must_use]
554 pub fn with_max_request_body(mut self, bytes: usize) -> Self {
555 self.max_request_body = bytes;
556 self
557 }
558
559 #[must_use]
561 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
562 self.request_timeout = timeout;
563 self
564 }
565
566 #[must_use]
568 pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
569 self.shutdown_timeout = timeout;
570 self
571 }
572
573 #[must_use]
575 pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
576 self.session_idle_timeout = timeout;
577 self
578 }
579
580 #[must_use]
582 pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
583 self.sse_keep_alive = interval;
584 self
585 }
586
587 #[must_use]
591 pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
592 self.max_concurrent_requests = Some(limit);
593 self
594 }
595
596 #[must_use]
599 pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
600 self.tool_rate_limit = Some(per_minute);
601 self
602 }
603
604 #[must_use]
608 pub fn with_reload_callback<F>(mut self, callback: F) -> Self
609 where
610 F: FnOnce(ReloadHandle) + Send + 'static,
611 {
612 self.on_reload_ready = Some(Box::new(callback));
613 self
614 }
615
616 #[must_use]
620 pub fn enable_compression(mut self, min_size: u16) -> Self {
621 self.compression_enabled = true;
622 self.compression_min_size = min_size;
623 self
624 }
625
626 #[must_use]
631 pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
632 self.admin_enabled = true;
633 self.admin_role = role.into();
634 self
635 }
636
637 #[must_use]
640 pub fn enable_request_header_logging(mut self) -> Self {
641 self.log_request_headers = true;
642 self
643 }
644
645 #[cfg(feature = "metrics")]
648 #[must_use]
649 pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
650 self.metrics_enabled = true;
651 self.metrics_bind = bind.into();
652 self
653 }
654
655 pub fn validate(self) -> Result<Validated<Self>, McpxError> {
688 self.check()?;
689 Ok(Validated(self))
690 }
691
692 fn check(&self) -> Result<(), McpxError> {
696 if self.admin_enabled {
700 let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
701 if !auth_enabled {
702 return Err(McpxError::Config(
703 "admin_enabled=true requires auth to be configured and enabled".into(),
704 ));
705 }
706 }
707
708 match (&self.tls_cert_path, &self.tls_key_path) {
710 (Some(_), None) => {
711 return Err(McpxError::Config(
712 "tls_cert_path is set but tls_key_path is missing".into(),
713 ));
714 }
715 (None, Some(_)) => {
716 return Err(McpxError::Config(
717 "tls_key_path is set but tls_cert_path is missing".into(),
718 ));
719 }
720 _ => {}
721 }
722
723 if self.bind_addr.parse::<SocketAddr>().is_err() {
725 return Err(McpxError::Config(format!(
726 "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
727 self.bind_addr
728 )));
729 }
730
731 if let Some(ref url) = self.public_url
733 && !(url.starts_with("http://") || url.starts_with("https://"))
734 {
735 return Err(McpxError::Config(format!(
736 "public_url {url:?} must start with http:// or https://"
737 )));
738 }
739
740 for origin in &self.allowed_origins {
742 if !(origin.starts_with("http://") || origin.starts_with("https://")) {
743 return Err(McpxError::Config(format!(
744 "allowed_origins entry {origin:?} must start with http:// or https://"
745 )));
746 }
747 }
748
749 if self.max_request_body == 0 {
751 return Err(McpxError::Config(
752 "max_request_body must be greater than zero".into(),
753 ));
754 }
755
756 #[cfg(feature = "oauth")]
758 if let Some(auth_cfg) = &self.auth
759 && let Some(oauth_cfg) = &auth_cfg.oauth
760 {
761 oauth_cfg.validate()?;
762 }
763
764 validate_security_headers(&self.security_headers)?;
767
768 if let Some(0) = self.max_concurrent_requests {
772 return Err(McpxError::Config(
773 "max_concurrent_requests must be greater than zero when set".into(),
774 ));
775 }
776
777 if let Some(auth_cfg) = &self.auth
781 && let Some(rl) = &auth_cfg.rate_limit
782 && rl.max_tracked_keys == 0
783 {
784 return Err(McpxError::Config(
785 "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
786 ));
787 }
788
789 Ok(())
790 }
791}
792
793#[allow(
799 missing_debug_implementations,
800 reason = "contains Arc<AuthState> with non-Debug fields"
801)]
802pub struct ReloadHandle {
803 auth: Option<Arc<AuthState>>,
804 rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
805 crl_set: Option<Arc<CrlSet>>,
806}
807
808impl ReloadHandle {
809 pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
811 if let Some(ref auth) = self.auth {
812 auth.reload_keys(keys);
813 }
814 }
815
816 pub fn reload_rbac(&self, policy: RbacPolicy) {
818 if let Some(ref rbac) = self.rbac {
819 rbac.store(Arc::new(policy));
820 tracing::info!("RBAC policy reloaded");
821 }
822 }
823
824 pub async fn refresh_crls(&self) -> Result<(), McpxError> {
830 let Some(ref crl_set) = self.crl_set else {
831 return Err(McpxError::Config(
832 "CRL refresh requested but mTLS CRL support is not configured".into(),
833 ));
834 };
835
836 crl_set.force_refresh().await
837 }
838}
839
840#[allow(
857 clippy::too_many_lines,
858 clippy::cognitive_complexity,
859 reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
860)]
861struct AppRunParams {
865 tls_paths: Option<(PathBuf, PathBuf)>,
867 mtls_config: Option<MtlsConfig>,
869 shutdown_timeout: Duration,
871 auth_state: Option<Arc<AuthState>>,
873 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
875 on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
877 ct: CancellationToken,
881 scheme: &'static str,
883 name: String,
885}
886
887#[allow(
897 clippy::cognitive_complexity,
898 reason = "router assembly is intrinsically sequential; splitting harms readability"
899)]
900#[allow(
901 deprecated,
902 reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
903)]
904fn build_app_router<H, F>(
905 mut config: McpServerConfig,
906 handler_factory: F,
907) -> anyhow::Result<(axum::Router, AppRunParams)>
908where
909 H: ServerHandler + 'static,
910 F: Fn() -> H + Send + Sync + Clone + 'static,
911{
912 let ct = CancellationToken::new();
913
914 let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
915 tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
916
917 let mcp_service = StreamableHttpService::new(
918 move || Ok(handler_factory()),
919 {
920 let mut mgr = LocalSessionManager::default();
921 mgr.session_config.keep_alive = Some(config.session_idle_timeout);
922 mgr.into()
923 },
924 StreamableHttpServerConfig::default()
925 .with_allowed_hosts(allowed_hosts)
926 .with_sse_keep_alive(Some(config.sse_keep_alive))
927 .with_cancellation_token(ct.child_token()),
928 );
929
930 let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
932
933 let auth_state: Option<Arc<AuthState>> = match config.auth {
937 Some(ref auth_config) if auth_config.enabled => {
938 let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
939 let pre_auth_limiter = auth_config
940 .rate_limit
941 .as_ref()
942 .map(crate::auth::build_pre_auth_limiter);
943
944 #[cfg(feature = "oauth")]
945 let jwks_cache = auth_config
946 .oauth
947 .as_ref()
948 .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
949 .transpose()
950 .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
951
952 Some(Arc::new(AuthState {
953 api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
954 rate_limiter,
955 pre_auth_limiter,
956 #[cfg(feature = "oauth")]
957 jwks_cache,
958 seen_identities: crate::auth::SeenIdentitySet::new(),
959 counters: crate::auth::AuthCounters::default(),
960 }))
961 }
962 _ => None,
963 };
964
965 let rbac_swap = Arc::new(ArcSwap::new(
968 config
969 .rbac
970 .clone()
971 .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
972 ));
973
974 if config.admin_enabled {
977 let Some(ref auth_state_ref) = auth_state else {
978 return Err(anyhow::anyhow!(
979 "admin_enabled=true requires auth to be configured and enabled"
980 ));
981 };
982 let admin_state = crate::admin::AdminState {
983 started_at: std::time::Instant::now(),
984 name: config.name.clone(),
985 version: config.version.clone(),
986 auth: Some(Arc::clone(auth_state_ref)),
987 rbac: Arc::clone(&rbac_swap),
988 };
989 let admin_cfg = crate::admin::AdminConfig {
990 role: config.admin_role.clone(),
991 };
992 mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
993 tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
994 }
995
996 {
1029 let tool_limiter: Option<Arc<ToolRateLimiter>> =
1030 config.tool_rate_limit.map(build_tool_rate_limiter);
1031
1032 if rbac_swap.load().is_enabled() {
1033 tracing::info!("RBAC enforcement enabled on /mcp");
1034 }
1035 if let Some(limit) = config.tool_rate_limit {
1036 tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1037 }
1038
1039 let rbac_for_mw = Arc::clone(&rbac_swap);
1040 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1041 let p = rbac_for_mw.load_full();
1042 let tl = tool_limiter.clone();
1043 rbac_middleware(p, tl, req, next)
1044 }));
1045 }
1046
1047 if let Some(ref auth_config) = config.auth
1049 && auth_config.enabled
1050 {
1051 let Some(ref state) = auth_state else {
1052 return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1053 };
1054
1055 let methods: Vec<&str> = [
1056 auth_config.mtls.is_some().then_some("mTLS"),
1057 (!auth_config.api_keys.is_empty()).then_some("bearer"),
1058 #[cfg(feature = "oauth")]
1059 auth_config.oauth.is_some().then_some("oauth-jwt"),
1060 ]
1061 .into_iter()
1062 .flatten()
1063 .collect();
1064
1065 tracing::info!(
1066 methods = %methods.join(", "),
1067 api_keys = auth_config.api_keys.len(),
1068 "auth enabled on /mcp"
1069 );
1070
1071 let state_for_mw = Arc::clone(state);
1072 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1073 let s = Arc::clone(&state_for_mw);
1074 auth_middleware(s, req, next)
1075 }));
1076 }
1077
1078 mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1081 axum::http::StatusCode::REQUEST_TIMEOUT,
1082 config.request_timeout,
1083 ));
1084
1085 mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1089 config.max_request_body,
1090 ));
1091
1092 let mut effective_origins = config.allowed_origins.clone();
1099 if effective_origins.is_empty()
1100 && let Some(ref url) = config.public_url
1101 {
1102 if let Some(scheme_end) = url.find("://") {
1105 let after_scheme = &url[scheme_end + 3..];
1106 let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1107 let origin = format!("{}{}", &url[..scheme_end + 3], &after_scheme[..host_end]);
1108 tracing::info!(
1109 %origin,
1110 "auto-derived allowed origin from public_url"
1111 );
1112 effective_origins.push(origin);
1113 }
1114 }
1115 let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1116 let cors_origins = Arc::clone(&allowed_origins);
1117 let log_request_headers = config.log_request_headers;
1118
1119 let readyz_route = if let Some(check) = config.readiness_check.take() {
1120 axum::routing::get(move || readyz(Arc::clone(&check)))
1121 } else {
1122 axum::routing::get(healthz)
1123 };
1124
1125 #[allow(unused_mut)] let mut router = axum::Router::new()
1127 .route("/healthz", axum::routing::get(healthz))
1128 .route("/readyz", readyz_route)
1129 .route(
1130 "/version",
1131 axum::routing::get({
1132 let payload_bytes: Arc<[u8]> =
1137 serialize_version_payload(&config.name, &config.version);
1138 move || {
1139 let p = Arc::clone(&payload_bytes);
1140 async move {
1141 (
1142 [(axum::http::header::CONTENT_TYPE, "application/json")],
1143 p.to_vec(),
1144 )
1145 }
1146 }
1147 }),
1148 )
1149 .merge(mcp_router);
1150
1151 if let Some(extra) = config.extra_router.take() {
1153 router = router.merge(extra);
1154 }
1155
1156 let server_url = if let Some(ref url) = config.public_url {
1163 url.trim_end_matches('/').to_owned()
1164 } else {
1165 let prm_scheme = if config.tls_cert_path.is_some() {
1166 "https"
1167 } else {
1168 "http"
1169 };
1170 format!("{prm_scheme}://{}", config.bind_addr)
1171 };
1172 let resource_url = format!("{server_url}/mcp");
1173
1174 #[cfg(feature = "oauth")]
1175 let prm_metadata = if let Some(ref auth_config) = config.auth
1176 && let Some(ref oauth_config) = auth_config.oauth
1177 {
1178 crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1179 } else {
1180 serde_json::json!({ "resource": resource_url })
1181 };
1182 #[cfg(not(feature = "oauth"))]
1183 let prm_metadata = serde_json::json!({ "resource": resource_url });
1184
1185 router = router.route(
1186 "/.well-known/oauth-protected-resource",
1187 axum::routing::get(move || {
1188 let m = prm_metadata.clone();
1189 async move { axum::Json(m) }
1190 }),
1191 );
1192
1193 #[cfg(feature = "oauth")]
1198 if let Some(ref auth_config) = config.auth
1199 && let Some(ref oauth_config) = auth_config.oauth
1200 && oauth_config.proxy.is_some()
1201 {
1202 router =
1203 install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1204 }
1205
1206 let is_tls = config.tls_cert_path.is_some();
1209 let security_headers_cfg = Arc::new(config.security_headers.clone());
1210 router = router.layer(axum::middleware::from_fn(move |req, next| {
1211 let cfg = Arc::clone(&security_headers_cfg);
1212 security_headers_middleware(is_tls, cfg, req, next)
1213 }));
1214
1215 if !cors_origins.is_empty() {
1219 let cors = tower_http::cors::CorsLayer::new()
1220 .allow_origin(
1221 cors_origins
1222 .iter()
1223 .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1224 .collect::<Vec<_>>(),
1225 )
1226 .allow_methods([
1227 axum::http::Method::GET,
1228 axum::http::Method::POST,
1229 axum::http::Method::OPTIONS,
1230 ])
1231 .allow_headers([
1232 axum::http::header::CONTENT_TYPE,
1233 axum::http::header::AUTHORIZATION,
1234 ]);
1235 router = router.layer(cors);
1236 }
1237
1238 if config.compression_enabled {
1242 use tower_http::compression::Predicate as _;
1243 let predicate = tower_http::compression::DefaultPredicate::new().and(
1244 tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1245 );
1246 router = router.layer(
1247 tower_http::compression::CompressionLayer::new()
1248 .gzip(true)
1249 .br(true)
1250 .compress_when(predicate),
1251 );
1252 tracing::info!(
1253 min_size = config.compression_min_size,
1254 "response compression enabled (gzip, br)"
1255 );
1256 }
1257
1258 if let Some(max) = config.max_concurrent_requests {
1261 let overload_handler = tower::ServiceBuilder::new()
1262 .layer(axum::error_handling::HandleErrorLayer::new(
1263 |_err: tower::BoxError| async {
1264 (
1265 axum::http::StatusCode::SERVICE_UNAVAILABLE,
1266 axum::Json(serde_json::json!({
1267 "error": "overloaded",
1268 "error_description": "server is at capacity, retry later"
1269 })),
1270 )
1271 },
1272 ))
1273 .layer(tower::load_shed::LoadShedLayer::new())
1274 .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1275 router = router.layer(overload_handler);
1276 tracing::info!(max, "global concurrency limit enabled");
1277 }
1278
1279 router = router.fallback(|| async {
1283 (
1284 axum::http::StatusCode::NOT_FOUND,
1285 axum::Json(serde_json::json!({
1286 "error": "not_found",
1287 "error_description": "The requested endpoint does not exist"
1288 })),
1289 )
1290 });
1291
1292 #[cfg(feature = "metrics")]
1294 if config.metrics_enabled {
1295 let metrics = Arc::new(
1296 crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1297 );
1298 let m = Arc::clone(&metrics);
1299 router = router.layer(axum::middleware::from_fn(
1300 move |req: Request<Body>, next: Next| {
1301 let m = Arc::clone(&m);
1302 metrics_middleware(m, req, next)
1303 },
1304 ));
1305 let metrics_bind = config.metrics_bind.clone();
1306 let metrics_shutdown = ct.clone();
1307 tokio::spawn(async move {
1308 if let Err(e) =
1309 crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1310 {
1311 tracing::error!("metrics listener failed: {e}");
1312 }
1313 });
1314 }
1315
1316 router = router.layer(axum::middleware::from_fn(move |req, next| {
1327 let origins = Arc::clone(&allowed_origins);
1328 origin_check_middleware(origins, log_request_headers, req, next)
1329 }));
1330
1331 let scheme = if config.tls_cert_path.is_some() {
1332 "https"
1333 } else {
1334 "http"
1335 };
1336
1337 let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1338 (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1339 _ => None,
1340 };
1341 let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1342
1343 Ok((
1344 router,
1345 AppRunParams {
1346 tls_paths,
1347 mtls_config,
1348 shutdown_timeout: config.shutdown_timeout,
1349 auth_state,
1350 rbac_swap,
1351 on_reload_ready: config.on_reload_ready.take(),
1352 ct,
1353 scheme,
1354 name: config.name.clone(),
1355 },
1356 ))
1357}
1358
1359pub async fn serve<H, F>(
1376 config: Validated<McpServerConfig>,
1377 handler_factory: F,
1378) -> Result<(), McpxError>
1379where
1380 H: ServerHandler + 'static,
1381 F: Fn() -> H + Send + Sync + Clone + 'static,
1382{
1383 let config = config.into_inner();
1384 #[allow(
1385 deprecated,
1386 reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1387 )]
1388 let bind_addr = config.bind_addr.clone();
1389 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1390
1391 let listener = TcpListener::bind(&bind_addr)
1392 .await
1393 .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1394 log_listening(¶ms.name, params.scheme, &bind_addr);
1395
1396 run_server(
1397 router,
1398 listener,
1399 params.tls_paths,
1400 params.mtls_config,
1401 params.shutdown_timeout,
1402 params.auth_state,
1403 params.rbac_swap,
1404 params.on_reload_ready,
1405 params.ct,
1406 )
1407 .await
1408 .map_err(anyhow_to_startup)
1409}
1410
1411pub async fn serve_with_listener<H, F>(
1441 listener: TcpListener,
1442 config: Validated<McpServerConfig>,
1443 handler_factory: F,
1444 ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1445 shutdown: Option<CancellationToken>,
1446) -> Result<(), McpxError>
1447where
1448 H: ServerHandler + 'static,
1449 F: Fn() -> H + Send + Sync + Clone + 'static,
1450{
1451 let config = config.into_inner();
1452 let local_addr = listener
1453 .local_addr()
1454 .map_err(|e| io_to_startup("listener.local_addr", e))?;
1455 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1456
1457 log_listening(¶ms.name, params.scheme, &local_addr.to_string());
1458
1459 if let Some(external) = shutdown {
1463 let internal = params.ct.clone();
1464 tokio::spawn(async move {
1465 external.cancelled().await;
1466 internal.cancel();
1467 });
1468 }
1469
1470 if let Some(tx) = ready_tx {
1474 let _ = tx.send(local_addr);
1476 }
1477
1478 run_server(
1479 router,
1480 listener,
1481 params.tls_paths,
1482 params.mtls_config,
1483 params.shutdown_timeout,
1484 params.auth_state,
1485 params.rbac_swap,
1486 params.on_reload_ready,
1487 params.ct,
1488 )
1489 .await
1490 .map_err(anyhow_to_startup)
1491}
1492
1493#[allow(
1496 clippy::cognitive_complexity,
1497 reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1498)]
1499fn log_listening(name: &str, scheme: &str, addr: &str) {
1500 tracing::info!("{name} listening on {addr}");
1501 tracing::info!(" MCP endpoint: {scheme}://{addr}/mcp");
1502 tracing::info!(" Health check: {scheme}://{addr}/healthz");
1503 tracing::info!(" Readiness: {scheme}://{addr}/readyz");
1504}
1505
1506#[allow(
1529 clippy::too_many_arguments,
1530 clippy::cognitive_complexity,
1531 reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1532)]
1533async fn run_server(
1534 router: axum::Router,
1535 listener: TcpListener,
1536 tls_paths: Option<(PathBuf, PathBuf)>,
1537 mtls_config: Option<MtlsConfig>,
1538 shutdown_timeout: Duration,
1539 auth_state: Option<Arc<AuthState>>,
1540 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1541 mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1542 ct: CancellationToken,
1543) -> anyhow::Result<()> {
1544 let shutdown_trigger = CancellationToken::new();
1548 {
1549 let trigger = shutdown_trigger.clone();
1550 let parent = ct.clone();
1551 tokio::spawn(async move {
1552 tokio::select! {
1553 () = shutdown_signal() => {}
1554 () = parent.cancelled() => {}
1555 }
1556 trigger.cancel();
1557 });
1558 }
1559
1560 let graceful = {
1561 let trigger = shutdown_trigger.clone();
1562 let ct = ct.clone();
1563 async move {
1564 trigger.cancelled().await;
1565 tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1566 ct.cancel();
1567 }
1568 };
1569
1570 let force_exit_timer = {
1571 let trigger = shutdown_trigger.clone();
1572 async move {
1573 trigger.cancelled().await;
1574 tokio::time::sleep(shutdown_timeout).await;
1575 }
1576 };
1577
1578 if let Some((cert_path, key_path)) = tls_paths {
1579 let crl_set = if let Some(mtls) = mtls_config.as_ref()
1580 && mtls.crl_enabled
1581 {
1582 let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1583 let (crl_set, discover_rx) =
1584 mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1585 .await
1586 .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1587 tokio::spawn(mtls_revocation::run_crl_refresher(
1588 Arc::clone(&crl_set),
1589 discover_rx,
1590 ct.clone(),
1591 ));
1592 Some(crl_set)
1593 } else {
1594 None
1595 };
1596
1597 if let Some(cb) = on_reload_ready.take() {
1598 cb(ReloadHandle {
1599 auth: auth_state.clone(),
1600 rbac: Some(Arc::clone(&rbac_swap)),
1601 crl_set: crl_set.clone(),
1602 });
1603 }
1604
1605 let tls_listener = TlsListener::new(
1606 listener,
1607 &cert_path,
1608 &key_path,
1609 mtls_config.as_ref(),
1610 crl_set,
1611 )?;
1612 let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1613 tokio::select! {
1614 result = axum::serve(tls_listener, make_svc)
1615 .with_graceful_shutdown(graceful) => { result?; }
1616 () = force_exit_timer => {
1617 tracing::warn!("shutdown timeout exceeded, forcing exit");
1618 }
1619 }
1620 } else {
1621 if let Some(cb) = on_reload_ready.take() {
1622 cb(ReloadHandle {
1623 auth: auth_state,
1624 rbac: Some(rbac_swap),
1625 crl_set: None,
1626 });
1627 }
1628
1629 let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1630 tokio::select! {
1631 result = axum::serve(listener, make_svc)
1632 .with_graceful_shutdown(graceful) => { result?; }
1633 () = force_exit_timer => {
1634 tracing::warn!("shutdown timeout exceeded, forcing exit");
1635 }
1636 }
1637 }
1638
1639 Ok(())
1640}
1641
1642#[cfg(feature = "oauth")]
1651fn install_oauth_proxy_routes(
1652 router: axum::Router,
1653 server_url: &str,
1654 oauth_config: &crate::oauth::OAuthConfig,
1655 auth_state: Option<&Arc<AuthState>>,
1656) -> Result<axum::Router, McpxError> {
1657 let Some(ref proxy) = oauth_config.proxy else {
1658 return Ok(router);
1659 };
1660
1661 let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1664
1665 let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1666 let router = router.route(
1667 "/.well-known/oauth-authorization-server",
1668 axum::routing::get(move || {
1669 let m = asm.clone();
1670 async move { axum::Json(m) }
1671 }),
1672 );
1673
1674 let proxy_authorize = proxy.clone();
1675 let router = router.route(
1676 "/authorize",
1677 axum::routing::get(
1678 move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1679 let p = proxy_authorize.clone();
1680 async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1681 },
1682 ),
1683 );
1684
1685 let proxy_token = proxy.clone();
1686 let token_http = http.clone();
1687 let router = router.route(
1688 "/token",
1689 axum::routing::post(move |body: String| {
1690 let p = proxy_token.clone();
1691 let h = token_http.clone();
1692 async move { crate::oauth::handle_token(&h, &p, &body).await }
1693 })
1694 .layer(axum::middleware::from_fn(
1695 oauth_token_cache_headers_middleware,
1696 )),
1697 );
1698
1699 let proxy_register = proxy.clone();
1700 let router = router.route(
1701 "/register",
1702 axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1703 let p = proxy_register;
1704 async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1705 })
1706 .layer(axum::middleware::from_fn(
1707 oauth_token_cache_headers_middleware,
1708 )),
1709 );
1710
1711 let admin_routes_enabled = proxy.expose_admin_endpoints
1712 && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1713 if proxy.expose_admin_endpoints
1714 && !proxy.require_auth_on_admin_endpoints
1715 && proxy.allow_unauthenticated_admin_endpoints
1716 {
1717 tracing::warn!(
1721 "OAuth introspect/revoke endpoints are unauthenticated by explicit \
1722 allow_unauthenticated_admin_endpoints opt-out; ensure an \
1723 authenticated reverse proxy fronts these routes"
1724 );
1725 }
1726
1727 let admin_router = if admin_routes_enabled {
1728 build_oauth_admin_router(proxy, http, auth_state)?
1729 } else {
1730 axum::Router::new()
1731 };
1732
1733 let router = router.merge(admin_router);
1734
1735 tracing::info!(
1736 introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1737 revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1738 "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1739 );
1740 Ok(router)
1741}
1742
1743#[cfg(feature = "oauth")]
1749fn build_oauth_admin_router(
1750 proxy: &crate::oauth::OAuthProxyConfig,
1751 http: crate::oauth::OauthHttpClient,
1752 auth_state: Option<&Arc<AuthState>>,
1753) -> Result<axum::Router, McpxError> {
1754 let mut admin_router = axum::Router::new();
1755 if proxy.introspection_url.is_some() {
1756 let proxy_introspect = proxy.clone();
1757 let introspect_http = http.clone();
1758 admin_router = admin_router.route(
1759 "/introspect",
1760 axum::routing::post(move |body: String| {
1761 let p = proxy_introspect.clone();
1762 let h = introspect_http.clone();
1763 async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1764 }),
1765 );
1766 }
1767 if proxy.revocation_url.is_some() {
1768 let proxy_revoke = proxy.clone();
1769 let revoke_http = http;
1770 admin_router = admin_router.route(
1771 "/revoke",
1772 axum::routing::post(move |body: String| {
1773 let p = proxy_revoke.clone();
1774 let h = revoke_http.clone();
1775 async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1776 }),
1777 );
1778 }
1779
1780 let admin_router = admin_router.layer(axum::middleware::from_fn(
1781 oauth_token_cache_headers_middleware,
1782 ));
1783
1784 if proxy.require_auth_on_admin_endpoints {
1785 let Some(state) = auth_state else {
1786 return Err(McpxError::Startup(
1787 "oauth proxy admin endpoints require auth state".into(),
1788 ));
1789 };
1790 let state_for_mw = Arc::clone(state);
1791 Ok(
1792 admin_router.layer(axum::middleware::from_fn(move |req, next| {
1793 let s = Arc::clone(&state_for_mw);
1794 auth_middleware(s, req, next)
1795 })),
1796 )
1797 } else {
1798 Ok(admin_router)
1799 }
1800}
1801
1802fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
1807 let mut hosts = vec![
1808 "localhost".to_owned(),
1809 "127.0.0.1".to_owned(),
1810 "::1".to_owned(),
1811 ];
1812
1813 if let Some(url) = public_url
1814 && let Ok(uri) = url.parse::<axum::http::Uri>()
1815 && let Some(authority) = uri.authority()
1816 {
1817 let host = authority.host().to_owned();
1818 if !hosts.iter().any(|h| h == &host) {
1819 hosts.push(host);
1820 }
1821
1822 let authority = authority.as_str().to_owned();
1823 if !hosts.iter().any(|h| h == &authority) {
1824 hosts.push(authority);
1825 }
1826 }
1827
1828 if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
1829 && let Some(authority) = uri.authority()
1830 {
1831 let host = authority.host().to_owned();
1832 if !hosts.iter().any(|h| h == &host) {
1833 hosts.push(host);
1834 }
1835
1836 let authority = authority.as_str().to_owned();
1837 if !hosts.iter().any(|h| h == &authority) {
1838 hosts.push(authority);
1839 }
1840 }
1841
1842 hosts
1843}
1844
1845impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
1858 for TlsConnInfo
1859{
1860 fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
1861 let addr = *target.remote_addr();
1862 let identity = target.io().identity().cloned();
1863 TlsConnInfo::new(addr, identity)
1864 }
1865}
1866
1867struct TlsListener {
1875 inner: TcpListener,
1876 acceptor: tokio_rustls::TlsAcceptor,
1877 mtls_default_role: String,
1878}
1879
1880impl TlsListener {
1881 fn new(
1882 inner: TcpListener,
1883 cert_path: &Path,
1884 key_path: &Path,
1885 mtls_config: Option<&MtlsConfig>,
1886 crl_set: Option<Arc<CrlSet>>,
1887 ) -> anyhow::Result<Self> {
1888 rustls::crypto::ring::default_provider()
1890 .install_default()
1891 .ok();
1892
1893 let certs = load_certs(cert_path)?;
1894 let key = load_key(key_path)?;
1895
1896 let mtls_default_role;
1897
1898 let tls_config = if let Some(mtls) = mtls_config {
1899 mtls_default_role = mtls.default_role.clone();
1900 let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
1901 {
1902 let Some(crl_set) = crl_set else {
1903 return Err(anyhow::anyhow!(
1904 "mTLS CRL verifier requested but CRL state was not initialized"
1905 ));
1906 };
1907 Arc::new(DynamicClientCertVerifier::new(crl_set))
1908 } else {
1909 let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
1910 if mtls.required {
1911 rustls::server::WebPkiClientVerifier::builder(root_store)
1912 .build()
1913 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1914 } else {
1915 rustls::server::WebPkiClientVerifier::builder(root_store)
1916 .allow_unauthenticated()
1917 .build()
1918 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1919 }
1920 };
1921
1922 tracing::info!(
1923 ca = %mtls.ca_cert_path.display(),
1924 required = mtls.required,
1925 crl_enabled = mtls.crl_enabled,
1926 "mTLS client auth configured"
1927 );
1928
1929 rustls::ServerConfig::builder_with_protocol_versions(&[
1930 &rustls::version::TLS12,
1931 &rustls::version::TLS13,
1932 ])
1933 .with_client_cert_verifier(verifier)
1934 .with_single_cert(certs, key)?
1935 } else {
1936 mtls_default_role = "viewer".to_owned();
1937 rustls::ServerConfig::builder_with_protocol_versions(&[
1938 &rustls::version::TLS12,
1939 &rustls::version::TLS13,
1940 ])
1941 .with_no_client_auth()
1942 .with_single_cert(certs, key)?
1943 };
1944
1945 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
1946 tracing::info!(
1947 "TLS enabled (cert: {}, key: {})",
1948 cert_path.display(),
1949 key_path.display()
1950 );
1951 Ok(Self {
1952 inner,
1953 acceptor,
1954 mtls_default_role,
1955 })
1956 }
1957
1958 fn extract_handshake_identity(
1962 tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1963 default_role: &str,
1964 addr: SocketAddr,
1965 ) -> Option<AuthIdentity> {
1966 let (_, server_conn) = tls_stream.get_ref();
1967 let cert_der = server_conn.peer_certificates()?.first()?;
1968 let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
1969 tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
1970 Some(id)
1971 }
1972}
1973
1974pub(crate) struct AuthenticatedTlsStream {
1986 inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1987 identity: Option<AuthIdentity>,
1988}
1989
1990impl AuthenticatedTlsStream {
1991 #[must_use]
1993 pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
1994 self.identity.as_ref()
1995 }
1996}
1997
1998impl std::fmt::Debug for AuthenticatedTlsStream {
1999 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2000 f.debug_struct("AuthenticatedTlsStream")
2001 .field("identity", &self.identity.as_ref().map(|id| &id.name))
2002 .finish_non_exhaustive()
2003 }
2004}
2005
2006impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2007 fn poll_read(
2008 mut self: Pin<&mut Self>,
2009 cx: &mut std::task::Context<'_>,
2010 buf: &mut tokio::io::ReadBuf<'_>,
2011 ) -> std::task::Poll<std::io::Result<()>> {
2012 Pin::new(&mut self.inner).poll_read(cx, buf)
2013 }
2014}
2015
2016impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2017 fn poll_write(
2018 mut self: Pin<&mut Self>,
2019 cx: &mut std::task::Context<'_>,
2020 buf: &[u8],
2021 ) -> std::task::Poll<std::io::Result<usize>> {
2022 Pin::new(&mut self.inner).poll_write(cx, buf)
2023 }
2024
2025 fn poll_flush(
2026 mut self: Pin<&mut Self>,
2027 cx: &mut std::task::Context<'_>,
2028 ) -> std::task::Poll<std::io::Result<()>> {
2029 Pin::new(&mut self.inner).poll_flush(cx)
2030 }
2031
2032 fn poll_shutdown(
2033 mut self: Pin<&mut Self>,
2034 cx: &mut std::task::Context<'_>,
2035 ) -> std::task::Poll<std::io::Result<()>> {
2036 Pin::new(&mut self.inner).poll_shutdown(cx)
2037 }
2038
2039 fn poll_write_vectored(
2040 mut self: Pin<&mut Self>,
2041 cx: &mut std::task::Context<'_>,
2042 bufs: &[std::io::IoSlice<'_>],
2043 ) -> std::task::Poll<std::io::Result<usize>> {
2044 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2045 }
2046
2047 fn is_write_vectored(&self) -> bool {
2048 self.inner.is_write_vectored()
2049 }
2050}
2051
2052impl axum::serve::Listener for TlsListener {
2053 type Io = AuthenticatedTlsStream;
2054 type Addr = SocketAddr;
2055
2056 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2057 loop {
2058 let (stream, addr) = match self.inner.accept().await {
2059 Ok(pair) => pair,
2060 Err(e) => {
2061 tracing::debug!("TCP accept error: {e}");
2062 continue;
2063 }
2064 };
2065 let tls_stream = match self.acceptor.accept(stream).await {
2066 Ok(s) => s,
2067 Err(e) => {
2068 tracing::debug!("TLS handshake failed from {addr}: {e}");
2069 continue;
2070 }
2071 };
2072 let identity =
2073 Self::extract_handshake_identity(&tls_stream, &self.mtls_default_role, addr);
2074 let wrapped = AuthenticatedTlsStream {
2075 inner: tls_stream,
2076 identity,
2077 };
2078 return (wrapped, addr);
2079 }
2080 }
2081
2082 fn local_addr(&self) -> std::io::Result<Self::Addr> {
2083 self.inner.local_addr()
2084 }
2085}
2086
2087fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2088 use rustls::pki_types::pem::PemObject;
2089 let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2090 .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2091 .collect::<Result<_, _>>()
2092 .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2093 anyhow::ensure!(
2094 !certs.is_empty(),
2095 "no certificates found in {}",
2096 path.display()
2097 );
2098 Ok(certs)
2099}
2100
2101fn load_client_auth_roots(
2102 path: &Path,
2103) -> anyhow::Result<(
2104 Vec<rustls::pki_types::CertificateDer<'static>>,
2105 Arc<RootCertStore>,
2106)> {
2107 let ca_certs = load_certs(path)?;
2108 let mut root_store = RootCertStore::empty();
2109 for cert in &ca_certs {
2110 root_store
2111 .add(cert.clone())
2112 .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2113 }
2114
2115 Ok((ca_certs, Arc::new(root_store)))
2116}
2117
2118fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2119 use rustls::pki_types::pem::PemObject;
2120 rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2121 .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2122}
2123
2124#[allow(
2125 clippy::unused_async,
2126 reason = "axum route handler signature requires `async fn` even when the body is synchronous"
2127)]
2128async fn healthz() -> impl IntoResponse {
2129 axum::Json(serde_json::json!({
2130 "status": "ok",
2131 }))
2132}
2133
2134fn version_payload(name: &str, version: &str) -> serde_json::Value {
2141 serde_json::json!({
2142 "name": name,
2143 "version": version,
2144 "build_git_sha": option_env!("RMCP_SERVER_KIT_BUILD_SHA").unwrap_or("unknown"),
2145 "build_timestamp": option_env!("RMCP_SERVER_KIT_BUILD_TIME").unwrap_or("unknown"),
2146 "rust_version": option_env!("RMCP_SERVER_KIT_RUSTC_VERSION").unwrap_or("unknown"),
2147 "mcpx_version": env!("CARGO_PKG_VERSION"),
2148 })
2149}
2150
2151fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2161 let value = version_payload(name, version);
2162 serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2163}
2164
2165async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2166 let status = check().await;
2167 let ready = status
2168 .get("ready")
2169 .and_then(serde_json::Value::as_bool)
2170 .unwrap_or(false);
2171 let code = if ready {
2172 axum::http::StatusCode::OK
2173 } else {
2174 axum::http::StatusCode::SERVICE_UNAVAILABLE
2175 };
2176 (code, axum::Json(status))
2177}
2178
2179async fn shutdown_signal() {
2183 let ctrl_c = tokio::signal::ctrl_c();
2184
2185 #[cfg(unix)]
2186 {
2187 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2188 Ok(mut term) => {
2189 tokio::select! {
2190 _ = ctrl_c => {}
2191 _ = term.recv() => {}
2192 }
2193 }
2194 Err(e) => {
2195 tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2196 ctrl_c.await.ok();
2197 }
2198 }
2199 }
2200
2201 #[cfg(not(unix))]
2202 {
2203 ctrl_c.await.ok();
2204 }
2205}
2206
2207#[cfg(feature = "metrics")]
2213async fn metrics_middleware(
2214 metrics: Arc<crate::metrics::McpMetrics>,
2215 req: Request<Body>,
2216 next: Next,
2217) -> axum::response::Response {
2218 let method = req.method().to_string();
2219 let path = req.uri().path().to_owned();
2220 let start = std::time::Instant::now();
2221
2222 let response = next.run(req).await;
2223
2224 let status = response.status().as_u16().to_string();
2225 let duration = start.elapsed().as_secs_f64();
2226
2227 metrics
2228 .http_requests_total
2229 .with_label_values(&[&method, &path, &status])
2230 .inc();
2231 metrics
2232 .http_request_duration_seconds
2233 .with_label_values(&[&method, &path])
2234 .observe(duration);
2235
2236 response
2237}
2238
2239async fn security_headers_middleware(
2251 is_tls: bool,
2252 cfg: Arc<SecurityHeadersConfig>,
2253 req: Request<Body>,
2254 next: Next,
2255) -> axum::response::Response {
2256 use axum::http::{HeaderName, header};
2257
2258 let mut resp = next.run(req).await;
2259 let headers = resp.headers_mut();
2260
2261 headers.remove(header::SERVER);
2263 headers.remove(HeaderName::from_static("x-powered-by"));
2264
2265 apply_security_header(
2266 headers,
2267 header::X_CONTENT_TYPE_OPTIONS,
2268 cfg.x_content_type_options.as_deref(),
2269 "nosniff",
2270 );
2271 apply_security_header(
2272 headers,
2273 header::X_FRAME_OPTIONS,
2274 cfg.x_frame_options.as_deref(),
2275 "deny",
2276 );
2277 apply_security_header(
2278 headers,
2279 header::CACHE_CONTROL,
2280 cfg.cache_control.as_deref(),
2281 "no-store, max-age=0",
2282 );
2283 apply_security_header(
2284 headers,
2285 header::REFERRER_POLICY,
2286 cfg.referrer_policy.as_deref(),
2287 "no-referrer",
2288 );
2289 apply_security_header(
2290 headers,
2291 HeaderName::from_static("cross-origin-opener-policy"),
2292 cfg.cross_origin_opener_policy.as_deref(),
2293 "same-origin",
2294 );
2295 apply_security_header(
2296 headers,
2297 HeaderName::from_static("cross-origin-resource-policy"),
2298 cfg.cross_origin_resource_policy.as_deref(),
2299 "same-origin",
2300 );
2301 apply_security_header(
2302 headers,
2303 HeaderName::from_static("cross-origin-embedder-policy"),
2304 cfg.cross_origin_embedder_policy.as_deref(),
2305 "require-corp",
2306 );
2307 apply_security_header(
2308 headers,
2309 HeaderName::from_static("permissions-policy"),
2310 cfg.permissions_policy.as_deref(),
2311 "accelerometer=(), camera=(), geolocation=(), microphone=()",
2312 );
2313 apply_security_header(
2314 headers,
2315 HeaderName::from_static("x-permitted-cross-domain-policies"),
2316 cfg.x_permitted_cross_domain_policies.as_deref(),
2317 "none",
2318 );
2319 apply_security_header(
2320 headers,
2321 HeaderName::from_static("content-security-policy"),
2322 cfg.content_security_policy.as_deref(),
2323 "default-src 'none'; frame-ancestors 'none'",
2324 );
2325 apply_security_header(
2326 headers,
2327 HeaderName::from_static("x-dns-prefetch-control"),
2328 cfg.x_dns_prefetch_control.as_deref(),
2329 "off",
2330 );
2331
2332 if is_tls {
2333 apply_security_header(
2334 headers,
2335 header::STRICT_TRANSPORT_SECURITY,
2336 cfg.strict_transport_security.as_deref(),
2337 "max-age=63072000; includeSubDomains",
2338 );
2339 }
2340
2341 resp
2342}
2343
2344fn apply_security_header(
2355 headers: &mut axum::http::HeaderMap,
2356 name: axum::http::HeaderName,
2357 override_value: Option<&str>,
2358 default: &'static str,
2359) {
2360 use axum::http::HeaderValue;
2361
2362 match override_value {
2363 None => {
2364 headers.insert(name, HeaderValue::from_static(default));
2365 }
2366 Some("") => {
2367 }
2369 Some(v) => match HeaderValue::from_str(v) {
2370 Ok(hv) => {
2371 headers.insert(name, hv);
2372 }
2373 Err(err) => {
2374 tracing::error!(
2375 header = %name,
2376 error = %err,
2377 "invalid security header override reached middleware; using default"
2378 );
2379 headers.insert(name, HeaderValue::from_static(default));
2380 }
2381 },
2382 }
2383}
2384
2385fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2396 use axum::http::HeaderValue;
2397
2398 let fields: &[(&str, Option<&str>)] = &[
2399 (
2400 "x_content_type_options",
2401 cfg.x_content_type_options.as_deref(),
2402 ),
2403 ("x_frame_options", cfg.x_frame_options.as_deref()),
2404 ("cache_control", cfg.cache_control.as_deref()),
2405 ("referrer_policy", cfg.referrer_policy.as_deref()),
2406 (
2407 "cross_origin_opener_policy",
2408 cfg.cross_origin_opener_policy.as_deref(),
2409 ),
2410 (
2411 "cross_origin_resource_policy",
2412 cfg.cross_origin_resource_policy.as_deref(),
2413 ),
2414 (
2415 "cross_origin_embedder_policy",
2416 cfg.cross_origin_embedder_policy.as_deref(),
2417 ),
2418 ("permissions_policy", cfg.permissions_policy.as_deref()),
2419 (
2420 "x_permitted_cross_domain_policies",
2421 cfg.x_permitted_cross_domain_policies.as_deref(),
2422 ),
2423 (
2424 "content_security_policy",
2425 cfg.content_security_policy.as_deref(),
2426 ),
2427 (
2428 "x_dns_prefetch_control",
2429 cfg.x_dns_prefetch_control.as_deref(),
2430 ),
2431 (
2432 "strict_transport_security",
2433 cfg.strict_transport_security.as_deref(),
2434 ),
2435 ];
2436
2437 for (field, value) in fields {
2438 let Some(v) = value else { continue };
2439 if v.is_empty() {
2440 continue;
2441 }
2442 if let Err(err) = HeaderValue::from_str(v) {
2443 return Err(McpxError::Config(format!(
2444 "invalid security_headers.{field}: {err}"
2445 )));
2446 }
2447 }
2448
2449 if let Some(v) = cfg.strict_transport_security.as_deref()
2450 && !v.is_empty()
2451 && v.to_ascii_lowercase().contains("preload")
2452 {
2453 return Err(McpxError::Config(format!(
2454 "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2455 HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2456 )));
2457 }
2458
2459 Ok(())
2460}
2461
2462#[cfg(feature = "oauth")]
2477async fn oauth_token_cache_headers_middleware(
2478 req: Request<Body>,
2479 next: Next,
2480) -> axum::response::Response {
2481 use axum::http::{HeaderValue, header};
2482
2483 let mut resp = next.run(req).await;
2484 let headers = resp.headers_mut();
2485 headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2486 headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2487 resp
2488}
2489
2490async fn origin_check_middleware(
2494 allowed: Arc<[String]>,
2495 log_request_headers: bool,
2496 req: Request<Body>,
2497 next: Next,
2498) -> axum::response::Response {
2499 let method = req.method().clone();
2500 let path = req.uri().path().to_owned();
2501
2502 log_incoming_request(&method, &path, req.headers(), log_request_headers);
2503
2504 if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2505 let origin_str = origin.to_str().unwrap_or("");
2506 if !allowed.iter().any(|a| a == origin_str) {
2507 tracing::warn!(
2508 origin = origin_str,
2509 %method,
2510 %path,
2511 allowed = ?&*allowed,
2512 "rejected request: Origin not allowed"
2513 );
2514 return (
2515 axum::http::StatusCode::FORBIDDEN,
2516 "Forbidden: Origin not allowed",
2517 )
2518 .into_response();
2519 }
2520 }
2521 next.run(req).await
2522}
2523
2524fn log_incoming_request(
2527 method: &axum::http::Method,
2528 path: &str,
2529 headers: &axum::http::HeaderMap,
2530 log_request_headers: bool,
2531) {
2532 if log_request_headers {
2533 tracing::debug!(
2534 %method,
2535 %path,
2536 headers = %format_request_headers_for_log(headers),
2537 "incoming request"
2538 );
2539 } else {
2540 tracing::debug!(%method, %path, "incoming request");
2541 }
2542}
2543
2544fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2545 headers
2546 .iter()
2547 .map(|(k, v)| {
2548 let name = k.as_str();
2549 if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2550 format!("{name}: [REDACTED]")
2551 } else {
2552 format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2553 }
2554 })
2555 .collect::<Vec<_>>()
2556 .join(", ")
2557}
2558
2559#[allow(
2583 clippy::cognitive_complexity,
2584 reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
2585)]
2586pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2587where
2588 H: ServerHandler + 'static,
2589{
2590 use rmcp::ServiceExt as _;
2591
2592 tracing::info!("stdio transport: serving on stdin/stdout");
2593 tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2594
2595 let transport = rmcp::transport::io::stdio();
2596
2597 let service = handler
2598 .serve(transport)
2599 .await
2600 .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2601
2602 if let Err(e) = service.waiting().await {
2603 tracing::warn!(error = %e, "stdio session ended with error");
2604 }
2605 tracing::info!("stdio session ended");
2606 Ok(())
2607}
2608
2609#[cfg(test)]
2610mod tests {
2611 #![allow(
2612 clippy::unwrap_used,
2613 clippy::expect_used,
2614 clippy::panic,
2615 clippy::indexing_slicing,
2616 clippy::unwrap_in_result,
2617 clippy::print_stdout,
2618 clippy::print_stderr,
2619 deprecated,
2620 reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2621 )]
2622 use std::{sync::Arc, time::Duration};
2623
2624 use axum::{
2625 body::Body,
2626 http::{Request, StatusCode, header},
2627 response::IntoResponse,
2628 };
2629 use http_body_util::BodyExt;
2630 use tower::ServiceExt as _;
2631
2632 use super::*;
2633
2634 #[test]
2637 fn server_config_new_defaults() {
2638 let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2639 assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2640 assert_eq!(cfg.name, "test-server");
2641 assert_eq!(cfg.version, "1.0.0");
2642 assert!(cfg.tls_cert_path.is_none());
2643 assert!(cfg.tls_key_path.is_none());
2644 assert!(cfg.auth.is_none());
2645 assert!(cfg.rbac.is_none());
2646 assert!(cfg.allowed_origins.is_empty());
2647 assert!(cfg.tool_rate_limit.is_none());
2648 assert!(cfg.readiness_check.is_none());
2649 assert_eq!(cfg.max_request_body, 1024 * 1024);
2650 assert_eq!(cfg.request_timeout, Duration::from_mins(2));
2651 assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
2652 assert!(!cfg.log_request_headers);
2653 }
2654
2655 #[test]
2656 fn validate_consumes_and_proves() {
2657 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2659 let validated = cfg.validate().expect("valid config");
2660 assert_eq!(validated.as_inner().name, "test-server");
2662 let raw = validated.into_inner();
2664 assert_eq!(raw.name, "test-server");
2665
2666 let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2668 bad.max_request_body = 0;
2669 assert!(bad.validate().is_err(), "zero body cap must fail validate");
2670 }
2671
2672 #[test]
2673 fn validate_rejects_zero_max_concurrent_requests() {
2674 let cfg =
2675 McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
2676 let err = cfg.validate().expect_err("zero concurrency cap must fail");
2677 assert!(
2678 format!("{err}").contains("max_concurrent_requests"),
2679 "error should mention max_concurrent_requests, got: {err}"
2680 );
2681 }
2682
2683 #[test]
2684 fn validate_rejects_zero_max_tracked_keys() {
2685 let rl = crate::auth::RateLimitConfig {
2688 max_attempts_per_minute: 30,
2689 pre_auth_max_per_minute: None,
2690 max_tracked_keys: 0,
2691 idle_eviction: Duration::from_secs(15 * 60),
2692 };
2693 let auth_cfg = AuthConfig {
2694 enabled: true,
2695 api_keys: Vec::new(),
2696 mtls: None,
2697 rate_limit: Some(rl),
2698 #[cfg(feature = "oauth")]
2699 oauth: None,
2700 };
2701 let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
2702 let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
2703 assert!(
2704 format!("{err}").contains("max_tracked_keys"),
2705 "error should mention max_tracked_keys, got: {err}"
2706 );
2707 }
2708
2709 #[test]
2710 fn derive_allowed_hosts_includes_public_host() {
2711 let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
2712 assert!(
2713 hosts.iter().any(|h| h == "mcp.example.com"),
2714 "public_url host must be allowed"
2715 );
2716 }
2717
2718 #[test]
2719 fn derive_allowed_hosts_includes_bind_authority() {
2720 let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
2721 assert!(
2722 hosts.iter().any(|h| h == "127.0.0.1"),
2723 "bind host must be allowed"
2724 );
2725 assert!(
2726 hosts.iter().any(|h| h == "127.0.0.1:8080"),
2727 "bind authority must be allowed"
2728 );
2729 }
2730
2731 #[tokio::test]
2734 async fn healthz_returns_ok_json() {
2735 let resp = healthz().await.into_response();
2736 assert_eq!(resp.status(), StatusCode::OK);
2737 let body = resp.into_body().collect().await.unwrap().to_bytes();
2738 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2739 assert_eq!(json["status"], "ok");
2740 assert!(
2741 json.get("name").is_none(),
2742 "healthz must not expose server name"
2743 );
2744 assert!(
2745 json.get("version").is_none(),
2746 "healthz must not expose version"
2747 );
2748 }
2749
2750 #[tokio::test]
2753 async fn readyz_returns_ok_when_ready() {
2754 let check: ReadinessCheck =
2755 Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
2756 let resp = readyz(check).await.into_response();
2757 assert_eq!(resp.status(), StatusCode::OK);
2758 let body = resp.into_body().collect().await.unwrap().to_bytes();
2759 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2760 assert_eq!(json["ready"], true);
2761 assert!(
2762 json.get("name").is_none(),
2763 "readyz must not expose server name"
2764 );
2765 assert!(
2766 json.get("version").is_none(),
2767 "readyz must not expose version"
2768 );
2769 assert_eq!(json["db"], "connected");
2770 }
2771
2772 #[tokio::test]
2773 async fn readyz_returns_503_when_not_ready() {
2774 let check: ReadinessCheck =
2775 Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
2776 let resp = readyz(check).await.into_response();
2777 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2778 }
2779
2780 #[tokio::test]
2781 async fn readyz_returns_503_when_ready_missing() {
2782 let check: ReadinessCheck =
2783 Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
2784 let resp = readyz(check).await.into_response();
2785 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2787 }
2788
2789 fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
2793 let allowed: Arc<[String]> = Arc::from(origins);
2794 axum::Router::new()
2795 .route("/test", axum::routing::get(|| async { "ok" }))
2796 .layer(axum::middleware::from_fn(move |req, next| {
2797 let a = Arc::clone(&allowed);
2798 origin_check_middleware(a, log_request_headers, req, next)
2799 }))
2800 }
2801
2802 #[tokio::test]
2803 async fn origin_allowed_passes() {
2804 let app = origin_router(vec!["http://localhost:3000".into()], false);
2805 let req = Request::builder()
2806 .uri("/test")
2807 .header(header::ORIGIN, "http://localhost:3000")
2808 .body(Body::empty())
2809 .unwrap();
2810 let resp = app.oneshot(req).await.unwrap();
2811 assert_eq!(resp.status(), StatusCode::OK);
2812 }
2813
2814 #[tokio::test]
2815 async fn origin_rejected_returns_403() {
2816 let app = origin_router(vec!["http://localhost:3000".into()], false);
2817 let req = Request::builder()
2818 .uri("/test")
2819 .header(header::ORIGIN, "http://evil.com")
2820 .body(Body::empty())
2821 .unwrap();
2822 let resp = app.oneshot(req).await.unwrap();
2823 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2824 }
2825
2826 #[tokio::test]
2827 async fn no_origin_header_passes() {
2828 let app = origin_router(vec!["http://localhost:3000".into()], false);
2829 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2830 let resp = app.oneshot(req).await.unwrap();
2831 assert_eq!(resp.status(), StatusCode::OK);
2832 }
2833
2834 #[tokio::test]
2835 async fn empty_allowlist_rejects_any_origin() {
2836 let app = origin_router(vec![], false);
2837 let req = Request::builder()
2838 .uri("/test")
2839 .header(header::ORIGIN, "http://anything.com")
2840 .body(Body::empty())
2841 .unwrap();
2842 let resp = app.oneshot(req).await.unwrap();
2843 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2844 }
2845
2846 #[tokio::test]
2847 async fn empty_allowlist_passes_without_origin() {
2848 let app = origin_router(vec![], false);
2849 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2850 let resp = app.oneshot(req).await.unwrap();
2851 assert_eq!(resp.status(), StatusCode::OK);
2852 }
2853
2854 #[test]
2855 fn format_request_headers_redacts_sensitive_values() {
2856 let mut headers = axum::http::HeaderMap::new();
2857 headers.insert("authorization", "Bearer secret-token".parse().unwrap());
2858 headers.insert("cookie", "sid=abc".parse().unwrap());
2859 headers.insert("x-request-id", "req-123".parse().unwrap());
2860
2861 let out = format_request_headers_for_log(&headers);
2862 assert!(out.contains("authorization: [REDACTED]"));
2863 assert!(out.contains("cookie: [REDACTED]"));
2864 assert!(out.contains("x-request-id: req-123"));
2865 assert!(!out.contains("secret-token"));
2866 }
2867
2868 fn security_router(is_tls: bool) -> axum::Router {
2871 security_router_with(is_tls, SecurityHeadersConfig::default())
2872 }
2873
2874 fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
2875 let cfg = Arc::new(cfg);
2876 axum::Router::new()
2877 .route("/test", axum::routing::get(|| async { "ok" }))
2878 .layer(axum::middleware::from_fn(move |req, next| {
2879 let c = Arc::clone(&cfg);
2880 security_headers_middleware(is_tls, c, req, next)
2881 }))
2882 }
2883
2884 #[tokio::test]
2885 async fn security_headers_set_on_response() {
2886 let app = security_router(false);
2887 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2888 let resp = app.oneshot(req).await.unwrap();
2889 assert_eq!(resp.status(), StatusCode::OK);
2890
2891 let h = resp.headers();
2892 assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
2893 assert_eq!(h.get("x-frame-options").unwrap(), "deny");
2894 assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
2895 assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
2896 assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
2897 assert_eq!(
2898 h.get("cross-origin-resource-policy").unwrap(),
2899 "same-origin"
2900 );
2901 assert_eq!(
2902 h.get("cross-origin-embedder-policy").unwrap(),
2903 "require-corp"
2904 );
2905 assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
2906 assert!(
2907 h.get("permissions-policy")
2908 .unwrap()
2909 .to_str()
2910 .unwrap()
2911 .contains("camera=()"),
2912 "permissions-policy must restrict browser features"
2913 );
2914 assert_eq!(
2915 h.get("content-security-policy").unwrap(),
2916 "default-src 'none'; frame-ancestors 'none'"
2917 );
2918 assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
2919 assert!(h.get("strict-transport-security").is_none());
2921 }
2922
2923 #[tokio::test]
2924 async fn hsts_set_when_tls_enabled() {
2925 let app = security_router(true);
2926 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2927 let resp = app.oneshot(req).await.unwrap();
2928
2929 let hsts = resp.headers().get("strict-transport-security").unwrap();
2930 assert!(
2931 hsts.to_str().unwrap().contains("max-age=63072000"),
2932 "HSTS must set 2-year max-age"
2933 );
2934 }
2935
2936 fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
2942 let cfg =
2943 McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
2944 cfg.check()
2945 }
2946
2947 #[test]
2948 fn security_headers_config_default_validates() {
2949 check_with_security_headers(SecurityHeadersConfig::default())
2950 .expect("default SecurityHeadersConfig must validate");
2951 }
2952
2953 #[test]
2954 fn security_headers_config_validate_accepts_empty_string() {
2955 let h = SecurityHeadersConfig {
2957 x_content_type_options: Some(String::new()),
2958 x_frame_options: Some(String::new()),
2959 cache_control: Some(String::new()),
2960 referrer_policy: Some(String::new()),
2961 cross_origin_opener_policy: Some(String::new()),
2962 cross_origin_resource_policy: Some(String::new()),
2963 cross_origin_embedder_policy: Some(String::new()),
2964 permissions_policy: Some(String::new()),
2965 x_permitted_cross_domain_policies: Some(String::new()),
2966 content_security_policy: Some(String::new()),
2967 x_dns_prefetch_control: Some(String::new()),
2968 strict_transport_security: Some(String::new()),
2969 };
2970 check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
2971 }
2972
2973 #[test]
2974 fn security_headers_config_validate_rejects_bad_value() {
2975 let h = SecurityHeadersConfig {
2977 referrer_policy: Some("\u{0007}".into()),
2978 ..SecurityHeadersConfig::default()
2979 };
2980 let err = check_with_security_headers(h)
2981 .expect_err("control char in referrer_policy must reject");
2982 let msg = err.to_string();
2983 assert!(
2984 msg.contains("referrer_policy"),
2985 "error must name the offending field, got: {msg}"
2986 );
2987 }
2988
2989 #[test]
2990 fn security_headers_config_validate_rejects_hsts_preload() {
2991 let h = SecurityHeadersConfig {
2992 strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
2993 ..SecurityHeadersConfig::default()
2994 };
2995 let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
2996 let msg = err.to_string();
2997 assert!(
2998 msg.contains("strict_transport_security"),
2999 "error must name the field, got: {msg}"
3000 );
3001 assert!(
3002 msg.to_lowercase().contains("preload"),
3003 "error must mention `preload`, got: {msg}"
3004 );
3005 }
3006
3007 #[test]
3008 fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
3009 let h = SecurityHeadersConfig {
3011 strict_transport_security: Some("max-age=600; PRELOAD".into()),
3012 ..SecurityHeadersConfig::default()
3013 };
3014 check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
3015 }
3016
3017 #[tokio::test]
3018 async fn security_headers_override_honored() {
3019 let h = SecurityHeadersConfig {
3021 x_frame_options: Some("SAMEORIGIN".into()),
3022 ..SecurityHeadersConfig::default()
3023 };
3024 let app = security_router_with(false, h);
3025 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3026 let resp = app.oneshot(req).await.unwrap();
3027 assert_eq!(resp.status(), StatusCode::OK);
3028
3029 let xfo = resp.headers().get("x-frame-options").unwrap();
3030 assert_eq!(xfo, "SAMEORIGIN");
3031 }
3032
3033 #[tokio::test]
3034 async fn security_headers_empty_string_omits() {
3035 let h = SecurityHeadersConfig {
3037 referrer_policy: Some(String::new()),
3038 ..SecurityHeadersConfig::default()
3039 };
3040 let app = security_router_with(false, h);
3041 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3042 let resp = app.oneshot(req).await.unwrap();
3043 assert_eq!(resp.status(), StatusCode::OK);
3044
3045 assert!(
3046 resp.headers().get("referrer-policy").is_none(),
3047 "Some(\"\") must omit the header"
3048 );
3049 assert_eq!(
3051 resp.headers().get("x-content-type-options").unwrap(),
3052 "nosniff"
3053 );
3054 }
3055
3056 #[tokio::test]
3057 async fn security_headers_hsts_only_when_tls() {
3058 let h = SecurityHeadersConfig {
3060 strict_transport_security: Some("max-age=600".into()),
3061 ..SecurityHeadersConfig::default()
3062 };
3063 let app = security_router_with(false, h);
3064 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3065 let resp = app.oneshot(req).await.unwrap();
3066 assert!(
3067 resp.headers().get("strict-transport-security").is_none(),
3068 "HSTS must remain absent on plaintext deployments even with override"
3069 );
3070 }
3071
3072 #[cfg(feature = "oauth")]
3075 #[tokio::test]
3076 async fn oauth_token_cache_headers_set_pragma_and_vary() {
3077 let app = axum::Router::new()
3078 .route("/token", axum::routing::post(|| async { "{}" }))
3079 .layer(axum::middleware::from_fn(
3080 oauth_token_cache_headers_middleware,
3081 ));
3082 let req = Request::builder()
3083 .method("POST")
3084 .uri("/token")
3085 .body(Body::from("{}"))
3086 .unwrap();
3087 let resp = app.oneshot(req).await.unwrap();
3088 assert_eq!(resp.status(), StatusCode::OK);
3089
3090 let h = resp.headers();
3091 assert_eq!(
3092 h.get("pragma").unwrap(),
3093 "no-cache",
3094 "RFC 6749 §5.1: token responses must set Pragma: no-cache"
3095 );
3096 let vary_values: Vec<String> = h
3097 .get_all("vary")
3098 .iter()
3099 .filter_map(|v| v.to_str().ok().map(str::to_owned))
3100 .collect();
3101 assert!(
3102 vary_values
3103 .iter()
3104 .any(|v| v.eq_ignore_ascii_case("Authorization")),
3105 "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
3106 );
3107 }
3108
3109 #[cfg(feature = "oauth")]
3110 #[tokio::test]
3111 async fn oauth_token_cache_headers_preserve_existing_vary() {
3112 let app = axum::Router::new()
3115 .route(
3116 "/token",
3117 axum::routing::post(|| async {
3118 axum::response::Response::builder()
3119 .header("vary", "Accept-Encoding")
3120 .body(axum::body::Body::from("{}"))
3121 .unwrap()
3122 }),
3123 )
3124 .layer(axum::middleware::from_fn(
3125 oauth_token_cache_headers_middleware,
3126 ));
3127 let req = Request::builder()
3128 .method("POST")
3129 .uri("/token")
3130 .body(Body::empty())
3131 .unwrap();
3132 let resp = app.oneshot(req).await.unwrap();
3133
3134 let vary: Vec<String> = resp
3135 .headers()
3136 .get_all("vary")
3137 .iter()
3138 .filter_map(|v| v.to_str().ok().map(str::to_owned))
3139 .collect();
3140 assert!(
3141 vary.iter().any(|v| v.contains("Accept-Encoding")),
3142 "must preserve pre-existing Vary value, got {vary:?}"
3143 );
3144 assert!(
3145 vary.iter().any(|v| v.contains("Authorization")),
3146 "must append Authorization to Vary, got {vary:?}"
3147 );
3148 }
3149
3150 #[test]
3153 fn version_payload_contains_expected_fields() {
3154 let v = version_payload("my-server", "1.2.3");
3155 assert_eq!(v["name"], "my-server");
3156 assert_eq!(v["version"], "1.2.3");
3157 assert!(v["build_git_sha"].is_string());
3158 assert!(v["build_timestamp"].is_string());
3159 assert!(v["rust_version"].is_string());
3160 assert!(v["mcpx_version"].is_string());
3161 }
3162
3163 #[tokio::test]
3166 async fn concurrency_limit_layer_composes_and_serves() {
3167 let app = axum::Router::new()
3171 .route("/ok", axum::routing::get(|| async { "ok" }))
3172 .layer(
3173 tower::ServiceBuilder::new()
3174 .layer(axum::error_handling::HandleErrorLayer::new(
3175 |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
3176 ))
3177 .layer(tower::load_shed::LoadShedLayer::new())
3178 .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
3179 );
3180 let resp = app
3181 .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
3182 .await
3183 .unwrap();
3184 assert_eq!(resp.status(), StatusCode::OK);
3185 }
3186
3187 #[tokio::test]
3190 async fn compression_layer_gzip_encodes_response() {
3191 use tower_http::compression::Predicate as _;
3192
3193 let big_body = "a".repeat(4096);
3194 let app = axum::Router::new()
3195 .route(
3196 "/big",
3197 axum::routing::get(move || {
3198 let body = big_body.clone();
3199 async move { body }
3200 }),
3201 )
3202 .layer(
3203 tower_http::compression::CompressionLayer::new()
3204 .gzip(true)
3205 .br(true)
3206 .compress_when(
3207 tower_http::compression::DefaultPredicate::new()
3208 .and(tower_http::compression::predicate::SizeAbove::new(1024)),
3209 ),
3210 );
3211
3212 let req = Request::builder()
3213 .uri("/big")
3214 .header(header::ACCEPT_ENCODING, "gzip")
3215 .body(Body::empty())
3216 .unwrap();
3217 let resp = app.oneshot(req).await.unwrap();
3218 assert_eq!(resp.status(), StatusCode::OK);
3219 assert_eq!(
3220 resp.headers().get(header::CONTENT_ENCODING).unwrap(),
3221 "gzip"
3222 );
3223 }
3224}