1use std::{
2 future::Future,
3 net::SocketAddr,
4 path::{Path, PathBuf},
5 pin::Pin,
6 sync::Arc,
7 time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{body::Body, extract::Request, middleware::Next, response::IntoResponse};
12use rmcp::{
13 ServerHandler,
14 transport::streamable_http_server::{
15 StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16 },
17};
18use rustls::RootCertStore;
19use tokio::net::TcpListener;
20use tokio_util::sync::CancellationToken;
21
22use crate::{
23 auth::{
24 AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
25 build_rate_limiter, extract_mtls_identity,
26 },
27 error::McpxError,
28 mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
29 rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
30};
31
32#[allow(
36 clippy::needless_pass_by_value,
37 reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
38)]
39fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
40 McpxError::Startup(format!("{e:#}"))
41}
42
43#[allow(
49 clippy::needless_pass_by_value,
50 reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
51)]
52fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
53 McpxError::Startup(format!("{op}: {e}"))
54}
55
56pub type ReadinessCheck =
61 Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
62
63#[derive(Debug, Clone, Default)]
84#[non_exhaustive]
85pub struct SecurityHeadersConfig {
86 pub x_content_type_options: Option<String>,
88 pub x_frame_options: Option<String>,
90 pub cache_control: Option<String>,
92 pub referrer_policy: Option<String>,
94 pub cross_origin_opener_policy: Option<String>,
96 pub cross_origin_resource_policy: Option<String>,
98 pub cross_origin_embedder_policy: Option<String>,
100 pub permissions_policy: Option<String>,
103 pub x_permitted_cross_domain_policies: Option<String>,
105 pub content_security_policy: Option<String>,
108 pub x_dns_prefetch_control: Option<String>,
110 pub strict_transport_security: Option<String>,
115}
116
117#[allow(
119 missing_debug_implementations,
120 reason = "contains callback/trait objects that don't impl Debug"
121)]
122#[allow(
123 clippy::struct_excessive_bools,
124 reason = "server configuration naturally has many boolean feature flags"
125)]
126#[non_exhaustive]
127pub struct McpServerConfig {
128 #[deprecated(
130 since = "0.13.0",
131 note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
132 )]
133 pub bind_addr: String,
134 #[deprecated(
136 since = "0.13.0",
137 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
138 )]
139 pub name: String,
140 #[deprecated(
142 since = "0.13.0",
143 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
144 )]
145 pub version: String,
146 #[deprecated(
148 since = "0.13.0",
149 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
150 )]
151 pub tls_cert_path: Option<PathBuf>,
152 #[deprecated(
154 since = "0.13.0",
155 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
156 )]
157 pub tls_key_path: Option<PathBuf>,
158 #[deprecated(
161 since = "0.13.0",
162 note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
163 )]
164 pub auth: Option<AuthConfig>,
165 #[deprecated(
168 since = "0.13.0",
169 note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
170 )]
171 pub rbac: Option<Arc<RbacPolicy>>,
172 #[deprecated(
178 since = "0.13.0",
179 note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
180 )]
181 pub allowed_origins: Vec<String>,
182 #[deprecated(
185 since = "0.13.0",
186 note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
187 )]
188 pub tool_rate_limit: Option<u32>,
189 #[deprecated(
192 since = "0.13.0",
193 note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
194 )]
195 pub readiness_check: Option<ReadinessCheck>,
196 #[deprecated(
199 since = "0.13.0",
200 note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
201 )]
202 pub max_request_body: usize,
203 #[deprecated(
206 since = "0.13.0",
207 note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
208 )]
209 pub request_timeout: Duration,
210 #[deprecated(
213 since = "0.13.0",
214 note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
215 )]
216 pub shutdown_timeout: Duration,
217 #[deprecated(
220 since = "0.13.0",
221 note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
222 )]
223 pub session_idle_timeout: Duration,
224 #[deprecated(
227 since = "0.13.0",
228 note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
229 )]
230 pub sse_keep_alive: Duration,
231 #[deprecated(
235 since = "0.13.0",
236 note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
237 )]
238 pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
239 #[deprecated(
243 since = "0.13.0",
244 note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
245 )]
246 pub extra_router: Option<axum::Router>,
247 #[deprecated(
252 since = "0.13.0",
253 note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
254 )]
255 pub public_url: Option<String>,
256 #[deprecated(
259 since = "0.13.0",
260 note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
261 )]
262 pub log_request_headers: bool,
263 #[deprecated(
266 since = "0.13.0",
267 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
268 )]
269 pub compression_enabled: bool,
270 #[deprecated(
273 since = "0.13.0",
274 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
275 )]
276 pub compression_min_size: u16,
277 #[deprecated(
281 since = "0.13.0",
282 note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
283 )]
284 pub max_concurrent_requests: Option<usize>,
285 #[deprecated(
288 since = "0.13.0",
289 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
290 )]
291 pub admin_enabled: bool,
292 #[deprecated(
294 since = "0.13.0",
295 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
296 )]
297 pub admin_role: String,
298 #[cfg(feature = "metrics")]
301 #[deprecated(
302 since = "0.13.0",
303 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
304 )]
305 pub metrics_enabled: bool,
306 #[cfg(feature = "metrics")]
308 #[deprecated(
309 since = "0.13.0",
310 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
311 )]
312 pub metrics_bind: String,
313 #[deprecated(
317 since = "1.5.0",
318 note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
319 )]
320 pub security_headers: SecurityHeadersConfig,
321}
322
323#[allow(
381 missing_debug_implementations,
382 reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
383)]
384pub struct Validated<T>(T);
385
386impl<T> std::fmt::Debug for Validated<T> {
387 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388 f.debug_struct("Validated").finish_non_exhaustive()
389 }
390}
391
392impl<T> Validated<T> {
393 #[must_use]
395 pub fn as_inner(&self) -> &T {
396 &self.0
397 }
398
399 #[must_use]
404 pub fn into_inner(self) -> T {
405 self.0
406 }
407}
408
409#[allow(
410 deprecated,
411 reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
412)]
413impl McpServerConfig {
414 #[must_use]
422 pub fn new(
423 bind_addr: impl Into<String>,
424 name: impl Into<String>,
425 version: impl Into<String>,
426 ) -> Self {
427 Self {
428 bind_addr: bind_addr.into(),
429 name: name.into(),
430 version: version.into(),
431 tls_cert_path: None,
432 tls_key_path: None,
433 auth: None,
434 rbac: None,
435 allowed_origins: Vec::new(),
436 tool_rate_limit: None,
437 readiness_check: None,
438 max_request_body: 1024 * 1024,
439 request_timeout: Duration::from_mins(2),
440 shutdown_timeout: Duration::from_secs(30),
441 session_idle_timeout: Duration::from_mins(20),
442 sse_keep_alive: Duration::from_secs(15),
443 on_reload_ready: None,
444 extra_router: None,
445 public_url: None,
446 log_request_headers: false,
447 compression_enabled: false,
448 compression_min_size: 1024,
449 max_concurrent_requests: None,
450 admin_enabled: false,
451 admin_role: "admin".to_owned(),
452 #[cfg(feature = "metrics")]
453 metrics_enabled: false,
454 #[cfg(feature = "metrics")]
455 metrics_bind: "127.0.0.1:9090".into(),
456 security_headers: SecurityHeadersConfig::default(),
457 }
458 }
459
460 #[must_use]
470 pub fn with_auth(mut self, auth: AuthConfig) -> Self {
471 self.auth = Some(auth);
472 self
473 }
474
475 #[must_use]
480 pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
481 self.security_headers = headers;
482 self
483 }
484
485 #[must_use]
489 pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
490 self.bind_addr = addr.into();
491 self
492 }
493
494 #[must_use]
497 pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
498 self.rbac = Some(rbac);
499 self
500 }
501
502 #[must_use]
506 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
507 self.tls_cert_path = Some(cert_path.into());
508 self.tls_key_path = Some(key_path.into());
509 self
510 }
511
512 #[must_use]
516 pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
517 self.public_url = Some(url.into());
518 self
519 }
520
521 #[must_use]
525 pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
526 where
527 I: IntoIterator<Item = S>,
528 S: Into<String>,
529 {
530 self.allowed_origins = origins.into_iter().map(Into::into).collect();
531 self
532 }
533
534 #[must_use]
538 pub fn with_extra_router(mut self, router: axum::Router) -> Self {
539 self.extra_router = Some(router);
540 self
541 }
542
543 #[must_use]
546 pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
547 self.readiness_check = Some(check);
548 self
549 }
550
551 #[must_use]
554 pub fn with_max_request_body(mut self, bytes: usize) -> Self {
555 self.max_request_body = bytes;
556 self
557 }
558
559 #[must_use]
561 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
562 self.request_timeout = timeout;
563 self
564 }
565
566 #[must_use]
568 pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
569 self.shutdown_timeout = timeout;
570 self
571 }
572
573 #[must_use]
575 pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
576 self.session_idle_timeout = timeout;
577 self
578 }
579
580 #[must_use]
582 pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
583 self.sse_keep_alive = interval;
584 self
585 }
586
587 #[must_use]
591 pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
592 self.max_concurrent_requests = Some(limit);
593 self
594 }
595
596 #[must_use]
599 pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
600 self.tool_rate_limit = Some(per_minute);
601 self
602 }
603
604 #[must_use]
608 pub fn with_reload_callback<F>(mut self, callback: F) -> Self
609 where
610 F: FnOnce(ReloadHandle) + Send + 'static,
611 {
612 self.on_reload_ready = Some(Box::new(callback));
613 self
614 }
615
616 #[must_use]
620 pub fn enable_compression(mut self, min_size: u16) -> Self {
621 self.compression_enabled = true;
622 self.compression_min_size = min_size;
623 self
624 }
625
626 #[must_use]
631 pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
632 self.admin_enabled = true;
633 self.admin_role = role.into();
634 self
635 }
636
637 #[must_use]
640 pub fn enable_request_header_logging(mut self) -> Self {
641 self.log_request_headers = true;
642 self
643 }
644
645 #[cfg(feature = "metrics")]
648 #[must_use]
649 pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
650 self.metrics_enabled = true;
651 self.metrics_bind = bind.into();
652 self
653 }
654
655 pub fn validate(self) -> Result<Validated<Self>, McpxError> {
688 self.check()?;
689 Ok(Validated(self))
690 }
691
692 fn check(&self) -> Result<(), McpxError> {
696 if self.admin_enabled {
700 let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
701 if !auth_enabled {
702 return Err(McpxError::Config(
703 "admin_enabled=true requires auth to be configured and enabled".into(),
704 ));
705 }
706 }
707
708 match (&self.tls_cert_path, &self.tls_key_path) {
710 (Some(_), None) => {
711 return Err(McpxError::Config(
712 "tls_cert_path is set but tls_key_path is missing".into(),
713 ));
714 }
715 (None, Some(_)) => {
716 return Err(McpxError::Config(
717 "tls_key_path is set but tls_cert_path is missing".into(),
718 ));
719 }
720 _ => {}
721 }
722
723 if self.bind_addr.parse::<SocketAddr>().is_err() {
725 return Err(McpxError::Config(format!(
726 "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
727 self.bind_addr
728 )));
729 }
730
731 if let Some(ref url) = self.public_url
733 && !(url.starts_with("http://") || url.starts_with("https://"))
734 {
735 return Err(McpxError::Config(format!(
736 "public_url {url:?} must start with http:// or https://"
737 )));
738 }
739
740 for origin in &self.allowed_origins {
742 if !(origin.starts_with("http://") || origin.starts_with("https://")) {
743 return Err(McpxError::Config(format!(
744 "allowed_origins entry {origin:?} must start with http:// or https://"
745 )));
746 }
747 }
748
749 if self.max_request_body == 0 {
751 return Err(McpxError::Config(
752 "max_request_body must be greater than zero".into(),
753 ));
754 }
755
756 #[cfg(feature = "oauth")]
758 if let Some(auth_cfg) = &self.auth
759 && let Some(oauth_cfg) = &auth_cfg.oauth
760 {
761 oauth_cfg.validate()?;
762 }
763
764 validate_security_headers(&self.security_headers)?;
767
768 if let Some(0) = self.max_concurrent_requests {
772 return Err(McpxError::Config(
773 "max_concurrent_requests must be greater than zero when set".into(),
774 ));
775 }
776
777 if let Some(auth_cfg) = &self.auth
781 && let Some(rl) = &auth_cfg.rate_limit
782 && rl.max_tracked_keys == 0
783 {
784 return Err(McpxError::Config(
785 "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
786 ));
787 }
788
789 Ok(())
790 }
791}
792
793#[allow(
799 missing_debug_implementations,
800 reason = "contains Arc<AuthState> with non-Debug fields"
801)]
802pub struct ReloadHandle {
803 auth: Option<Arc<AuthState>>,
804 rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
805 crl_set: Option<Arc<CrlSet>>,
806}
807
808impl ReloadHandle {
809 pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
811 if let Some(ref auth) = self.auth {
812 auth.reload_keys(keys);
813 }
814 }
815
816 pub fn reload_rbac(&self, policy: RbacPolicy) {
818 if let Some(ref rbac) = self.rbac {
819 rbac.store(Arc::new(policy));
820 tracing::info!("RBAC policy reloaded");
821 }
822 }
823
824 pub async fn refresh_crls(&self) -> Result<(), McpxError> {
830 let Some(ref crl_set) = self.crl_set else {
831 return Err(McpxError::Config(
832 "CRL refresh requested but mTLS CRL support is not configured".into(),
833 ));
834 };
835
836 crl_set.force_refresh().await
837 }
838}
839
840#[allow(
857 clippy::too_many_lines,
858 clippy::cognitive_complexity,
859 reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
860)]
861struct AppRunParams {
865 tls_paths: Option<(PathBuf, PathBuf)>,
867 mtls_config: Option<MtlsConfig>,
869 shutdown_timeout: Duration,
871 auth_state: Option<Arc<AuthState>>,
873 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
875 on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
877 ct: CancellationToken,
881 scheme: &'static str,
883 name: String,
885}
886
887#[allow(
897 clippy::cognitive_complexity,
898 reason = "router assembly is intrinsically sequential; splitting harms readability"
899)]
900#[allow(
901 deprecated,
902 reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
903)]
904fn build_app_router<H, F>(
905 mut config: McpServerConfig,
906 handler_factory: F,
907) -> anyhow::Result<(axum::Router, AppRunParams)>
908where
909 H: ServerHandler + 'static,
910 F: Fn() -> H + Send + Sync + Clone + 'static,
911{
912 let ct = CancellationToken::new();
913
914 let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
915 tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
916
917 let mcp_service = StreamableHttpService::new(
918 move || Ok(handler_factory()),
919 {
920 let mut mgr = LocalSessionManager::default();
921 mgr.session_config.keep_alive = Some(config.session_idle_timeout);
922 mgr.into()
923 },
924 StreamableHttpServerConfig::default()
925 .with_allowed_hosts(allowed_hosts)
926 .with_sse_keep_alive(Some(config.sse_keep_alive))
927 .with_cancellation_token(ct.child_token()),
928 );
929
930 let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
932
933 let auth_state: Option<Arc<AuthState>> = match config.auth {
937 Some(ref auth_config) if auth_config.enabled => {
938 let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
939 let pre_auth_limiter = auth_config
940 .rate_limit
941 .as_ref()
942 .map(crate::auth::build_pre_auth_limiter);
943
944 #[cfg(feature = "oauth")]
945 let jwks_cache = auth_config
946 .oauth
947 .as_ref()
948 .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
949 .transpose()
950 .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
951
952 Some(Arc::new(AuthState {
953 api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
954 rate_limiter,
955 pre_auth_limiter,
956 #[cfg(feature = "oauth")]
957 jwks_cache,
958 seen_identities: std::sync::Mutex::new(std::collections::HashSet::new()),
959 counters: crate::auth::AuthCounters::default(),
960 }))
961 }
962 _ => None,
963 };
964
965 let rbac_swap = Arc::new(ArcSwap::new(
968 config
969 .rbac
970 .clone()
971 .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
972 ));
973
974 if config.admin_enabled {
977 let Some(ref auth_state_ref) = auth_state else {
978 return Err(anyhow::anyhow!(
979 "admin_enabled=true requires auth to be configured and enabled"
980 ));
981 };
982 let admin_state = crate::admin::AdminState {
983 started_at: std::time::Instant::now(),
984 name: config.name.clone(),
985 version: config.version.clone(),
986 auth: Some(Arc::clone(auth_state_ref)),
987 rbac: Arc::clone(&rbac_swap),
988 };
989 let admin_cfg = crate::admin::AdminConfig {
990 role: config.admin_role.clone(),
991 };
992 mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
993 tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
994 }
995
996 {
1029 let tool_limiter: Option<Arc<ToolRateLimiter>> =
1030 config.tool_rate_limit.map(build_tool_rate_limiter);
1031
1032 if rbac_swap.load().is_enabled() {
1033 tracing::info!("RBAC enforcement enabled on /mcp");
1034 }
1035 if let Some(limit) = config.tool_rate_limit {
1036 tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1037 }
1038
1039 let rbac_for_mw = Arc::clone(&rbac_swap);
1040 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1041 let p = rbac_for_mw.load_full();
1042 let tl = tool_limiter.clone();
1043 rbac_middleware(p, tl, req, next)
1044 }));
1045 }
1046
1047 if let Some(ref auth_config) = config.auth
1049 && auth_config.enabled
1050 {
1051 let Some(ref state) = auth_state else {
1052 return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1053 };
1054
1055 let methods: Vec<&str> = [
1056 auth_config.mtls.is_some().then_some("mTLS"),
1057 (!auth_config.api_keys.is_empty()).then_some("bearer"),
1058 #[cfg(feature = "oauth")]
1059 auth_config.oauth.is_some().then_some("oauth-jwt"),
1060 ]
1061 .into_iter()
1062 .flatten()
1063 .collect();
1064
1065 tracing::info!(
1066 methods = %methods.join(", "),
1067 api_keys = auth_config.api_keys.len(),
1068 "auth enabled on /mcp"
1069 );
1070
1071 let state_for_mw = Arc::clone(state);
1072 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1073 let s = Arc::clone(&state_for_mw);
1074 auth_middleware(s, req, next)
1075 }));
1076 }
1077
1078 mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1081 axum::http::StatusCode::REQUEST_TIMEOUT,
1082 config.request_timeout,
1083 ));
1084
1085 mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1089 config.max_request_body,
1090 ));
1091
1092 let mut effective_origins = config.allowed_origins.clone();
1099 if effective_origins.is_empty()
1100 && let Some(ref url) = config.public_url
1101 {
1102 if let Some(scheme_end) = url.find("://") {
1105 let after_scheme = &url[scheme_end + 3..];
1106 let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1107 let origin = format!("{}{}", &url[..scheme_end + 3], &after_scheme[..host_end]);
1108 tracing::info!(
1109 %origin,
1110 "auto-derived allowed origin from public_url"
1111 );
1112 effective_origins.push(origin);
1113 }
1114 }
1115 let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1116 let cors_origins = Arc::clone(&allowed_origins);
1117 let log_request_headers = config.log_request_headers;
1118
1119 let readyz_route = if let Some(check) = config.readiness_check.take() {
1120 axum::routing::get(move || readyz(Arc::clone(&check)))
1121 } else {
1122 axum::routing::get(healthz)
1123 };
1124
1125 #[allow(unused_mut)] let mut router = axum::Router::new()
1127 .route("/healthz", axum::routing::get(healthz))
1128 .route("/readyz", readyz_route)
1129 .route(
1130 "/version",
1131 axum::routing::get({
1132 let payload_bytes: Arc<[u8]> =
1137 serialize_version_payload(&config.name, &config.version);
1138 move || {
1139 let p = Arc::clone(&payload_bytes);
1140 async move {
1141 (
1142 [(axum::http::header::CONTENT_TYPE, "application/json")],
1143 p.to_vec(),
1144 )
1145 }
1146 }
1147 }),
1148 )
1149 .merge(mcp_router);
1150
1151 if let Some(extra) = config.extra_router.take() {
1153 router = router.merge(extra);
1154 }
1155
1156 let server_url = if let Some(ref url) = config.public_url {
1163 url.trim_end_matches('/').to_owned()
1164 } else {
1165 let prm_scheme = if config.tls_cert_path.is_some() {
1166 "https"
1167 } else {
1168 "http"
1169 };
1170 format!("{prm_scheme}://{}", config.bind_addr)
1171 };
1172 let resource_url = format!("{server_url}/mcp");
1173
1174 #[cfg(feature = "oauth")]
1175 let prm_metadata = if let Some(ref auth_config) = config.auth
1176 && let Some(ref oauth_config) = auth_config.oauth
1177 {
1178 crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1179 } else {
1180 serde_json::json!({ "resource": resource_url })
1181 };
1182 #[cfg(not(feature = "oauth"))]
1183 let prm_metadata = serde_json::json!({ "resource": resource_url });
1184
1185 router = router.route(
1186 "/.well-known/oauth-protected-resource",
1187 axum::routing::get(move || {
1188 let m = prm_metadata.clone();
1189 async move { axum::Json(m) }
1190 }),
1191 );
1192
1193 #[cfg(feature = "oauth")]
1198 if let Some(ref auth_config) = config.auth
1199 && let Some(ref oauth_config) = auth_config.oauth
1200 && oauth_config.proxy.is_some()
1201 {
1202 router =
1203 install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1204 }
1205
1206 let is_tls = config.tls_cert_path.is_some();
1209 let security_headers_cfg = Arc::new(config.security_headers.clone());
1210 router = router.layer(axum::middleware::from_fn(move |req, next| {
1211 let cfg = Arc::clone(&security_headers_cfg);
1212 security_headers_middleware(is_tls, cfg, req, next)
1213 }));
1214
1215 if !cors_origins.is_empty() {
1219 let cors = tower_http::cors::CorsLayer::new()
1220 .allow_origin(
1221 cors_origins
1222 .iter()
1223 .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1224 .collect::<Vec<_>>(),
1225 )
1226 .allow_methods([
1227 axum::http::Method::GET,
1228 axum::http::Method::POST,
1229 axum::http::Method::OPTIONS,
1230 ])
1231 .allow_headers([
1232 axum::http::header::CONTENT_TYPE,
1233 axum::http::header::AUTHORIZATION,
1234 ]);
1235 router = router.layer(cors);
1236 }
1237
1238 if config.compression_enabled {
1242 use tower_http::compression::Predicate as _;
1243 let predicate = tower_http::compression::DefaultPredicate::new().and(
1244 tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1245 );
1246 router = router.layer(
1247 tower_http::compression::CompressionLayer::new()
1248 .gzip(true)
1249 .br(true)
1250 .compress_when(predicate),
1251 );
1252 tracing::info!(
1253 min_size = config.compression_min_size,
1254 "response compression enabled (gzip, br)"
1255 );
1256 }
1257
1258 if let Some(max) = config.max_concurrent_requests {
1261 let overload_handler = tower::ServiceBuilder::new()
1262 .layer(axum::error_handling::HandleErrorLayer::new(
1263 |_err: tower::BoxError| async {
1264 (
1265 axum::http::StatusCode::SERVICE_UNAVAILABLE,
1266 axum::Json(serde_json::json!({
1267 "error": "overloaded",
1268 "error_description": "server is at capacity, retry later"
1269 })),
1270 )
1271 },
1272 ))
1273 .layer(tower::load_shed::LoadShedLayer::new())
1274 .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1275 router = router.layer(overload_handler);
1276 tracing::info!(max, "global concurrency limit enabled");
1277 }
1278
1279 router = router.fallback(|| async {
1283 (
1284 axum::http::StatusCode::NOT_FOUND,
1285 axum::Json(serde_json::json!({
1286 "error": "not_found",
1287 "error_description": "The requested endpoint does not exist"
1288 })),
1289 )
1290 });
1291
1292 #[cfg(feature = "metrics")]
1294 if config.metrics_enabled {
1295 let metrics = Arc::new(
1296 crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1297 );
1298 let m = Arc::clone(&metrics);
1299 router = router.layer(axum::middleware::from_fn(
1300 move |req: Request<Body>, next: Next| {
1301 let m = Arc::clone(&m);
1302 metrics_middleware(m, req, next)
1303 },
1304 ));
1305 let metrics_bind = config.metrics_bind.clone();
1306 let metrics_shutdown = ct.clone();
1307 tokio::spawn(async move {
1308 if let Err(e) =
1309 crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1310 {
1311 tracing::error!("metrics listener failed: {e}");
1312 }
1313 });
1314 }
1315
1316 router = router.layer(axum::middleware::from_fn(move |req, next| {
1327 let origins = Arc::clone(&allowed_origins);
1328 origin_check_middleware(origins, log_request_headers, req, next)
1329 }));
1330
1331 let scheme = if config.tls_cert_path.is_some() {
1332 "https"
1333 } else {
1334 "http"
1335 };
1336
1337 let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1338 (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1339 _ => None,
1340 };
1341 let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1342
1343 Ok((
1344 router,
1345 AppRunParams {
1346 tls_paths,
1347 mtls_config,
1348 shutdown_timeout: config.shutdown_timeout,
1349 auth_state,
1350 rbac_swap,
1351 on_reload_ready: config.on_reload_ready.take(),
1352 ct,
1353 scheme,
1354 name: config.name.clone(),
1355 },
1356 ))
1357}
1358
1359pub async fn serve<H, F>(
1376 config: Validated<McpServerConfig>,
1377 handler_factory: F,
1378) -> Result<(), McpxError>
1379where
1380 H: ServerHandler + 'static,
1381 F: Fn() -> H + Send + Sync + Clone + 'static,
1382{
1383 let config = config.into_inner();
1384 #[allow(
1385 deprecated,
1386 reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1387 )]
1388 let bind_addr = config.bind_addr.clone();
1389 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1390
1391 let listener = TcpListener::bind(&bind_addr)
1392 .await
1393 .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1394 log_listening(¶ms.name, params.scheme, &bind_addr);
1395
1396 run_server(
1397 router,
1398 listener,
1399 params.tls_paths,
1400 params.mtls_config,
1401 params.shutdown_timeout,
1402 params.auth_state,
1403 params.rbac_swap,
1404 params.on_reload_ready,
1405 params.ct,
1406 )
1407 .await
1408 .map_err(anyhow_to_startup)
1409}
1410
1411pub async fn serve_with_listener<H, F>(
1441 listener: TcpListener,
1442 config: Validated<McpServerConfig>,
1443 handler_factory: F,
1444 ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1445 shutdown: Option<CancellationToken>,
1446) -> Result<(), McpxError>
1447where
1448 H: ServerHandler + 'static,
1449 F: Fn() -> H + Send + Sync + Clone + 'static,
1450{
1451 let config = config.into_inner();
1452 let local_addr = listener
1453 .local_addr()
1454 .map_err(|e| io_to_startup("listener.local_addr", e))?;
1455 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1456
1457 log_listening(¶ms.name, params.scheme, &local_addr.to_string());
1458
1459 if let Some(external) = shutdown {
1463 let internal = params.ct.clone();
1464 tokio::spawn(async move {
1465 external.cancelled().await;
1466 internal.cancel();
1467 });
1468 }
1469
1470 if let Some(tx) = ready_tx {
1474 let _ = tx.send(local_addr);
1476 }
1477
1478 run_server(
1479 router,
1480 listener,
1481 params.tls_paths,
1482 params.mtls_config,
1483 params.shutdown_timeout,
1484 params.auth_state,
1485 params.rbac_swap,
1486 params.on_reload_ready,
1487 params.ct,
1488 )
1489 .await
1490 .map_err(anyhow_to_startup)
1491}
1492
1493#[allow(
1496 clippy::cognitive_complexity,
1497 reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1498)]
1499fn log_listening(name: &str, scheme: &str, addr: &str) {
1500 tracing::info!("{name} listening on {addr}");
1501 tracing::info!(" MCP endpoint: {scheme}://{addr}/mcp");
1502 tracing::info!(" Health check: {scheme}://{addr}/healthz");
1503 tracing::info!(" Readiness: {scheme}://{addr}/readyz");
1504}
1505
1506#[allow(
1529 clippy::too_many_arguments,
1530 clippy::cognitive_complexity,
1531 reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1532)]
1533async fn run_server(
1534 router: axum::Router,
1535 listener: TcpListener,
1536 tls_paths: Option<(PathBuf, PathBuf)>,
1537 mtls_config: Option<MtlsConfig>,
1538 shutdown_timeout: Duration,
1539 auth_state: Option<Arc<AuthState>>,
1540 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1541 mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1542 ct: CancellationToken,
1543) -> anyhow::Result<()> {
1544 let shutdown_trigger = CancellationToken::new();
1548 {
1549 let trigger = shutdown_trigger.clone();
1550 let parent = ct.clone();
1551 tokio::spawn(async move {
1552 tokio::select! {
1553 () = shutdown_signal() => {}
1554 () = parent.cancelled() => {}
1555 }
1556 trigger.cancel();
1557 });
1558 }
1559
1560 let graceful = {
1561 let trigger = shutdown_trigger.clone();
1562 let ct = ct.clone();
1563 async move {
1564 trigger.cancelled().await;
1565 tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1566 ct.cancel();
1567 }
1568 };
1569
1570 let force_exit_timer = {
1571 let trigger = shutdown_trigger.clone();
1572 async move {
1573 trigger.cancelled().await;
1574 tokio::time::sleep(shutdown_timeout).await;
1575 }
1576 };
1577
1578 if let Some((cert_path, key_path)) = tls_paths {
1579 let crl_set = if let Some(mtls) = mtls_config.as_ref()
1580 && mtls.crl_enabled
1581 {
1582 let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1583 let (crl_set, discover_rx) =
1584 mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1585 .await
1586 .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1587 tokio::spawn(mtls_revocation::run_crl_refresher(
1588 Arc::clone(&crl_set),
1589 discover_rx,
1590 ct.clone(),
1591 ));
1592 Some(crl_set)
1593 } else {
1594 None
1595 };
1596
1597 if let Some(cb) = on_reload_ready.take() {
1598 cb(ReloadHandle {
1599 auth: auth_state.clone(),
1600 rbac: Some(Arc::clone(&rbac_swap)),
1601 crl_set: crl_set.clone(),
1602 });
1603 }
1604
1605 let tls_listener = TlsListener::new(
1606 listener,
1607 &cert_path,
1608 &key_path,
1609 mtls_config.as_ref(),
1610 crl_set,
1611 )?;
1612 let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1613 tokio::select! {
1614 result = axum::serve(tls_listener, make_svc)
1615 .with_graceful_shutdown(graceful) => { result?; }
1616 () = force_exit_timer => {
1617 tracing::warn!("shutdown timeout exceeded, forcing exit");
1618 }
1619 }
1620 } else {
1621 if let Some(cb) = on_reload_ready.take() {
1622 cb(ReloadHandle {
1623 auth: auth_state,
1624 rbac: Some(rbac_swap),
1625 crl_set: None,
1626 });
1627 }
1628
1629 let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1630 tokio::select! {
1631 result = axum::serve(listener, make_svc)
1632 .with_graceful_shutdown(graceful) => { result?; }
1633 () = force_exit_timer => {
1634 tracing::warn!("shutdown timeout exceeded, forcing exit");
1635 }
1636 }
1637 }
1638
1639 Ok(())
1640}
1641
1642#[cfg(feature = "oauth")]
1651fn install_oauth_proxy_routes(
1652 router: axum::Router,
1653 server_url: &str,
1654 oauth_config: &crate::oauth::OAuthConfig,
1655 auth_state: Option<&Arc<AuthState>>,
1656) -> Result<axum::Router, McpxError> {
1657 let Some(ref proxy) = oauth_config.proxy else {
1658 return Ok(router);
1659 };
1660
1661 let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1664
1665 let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1666 let router = router.route(
1667 "/.well-known/oauth-authorization-server",
1668 axum::routing::get(move || {
1669 let m = asm.clone();
1670 async move { axum::Json(m) }
1671 }),
1672 );
1673
1674 let proxy_authorize = proxy.clone();
1675 let router = router.route(
1676 "/authorize",
1677 axum::routing::get(
1678 move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1679 let p = proxy_authorize.clone();
1680 async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1681 },
1682 ),
1683 );
1684
1685 let proxy_token = proxy.clone();
1686 let token_http = http.clone();
1687 let router = router.route(
1688 "/token",
1689 axum::routing::post(move |body: String| {
1690 let p = proxy_token.clone();
1691 let h = token_http.clone();
1692 async move { crate::oauth::handle_token(&h, &p, &body).await }
1693 })
1694 .layer(axum::middleware::from_fn(
1695 oauth_token_cache_headers_middleware,
1696 )),
1697 );
1698
1699 let proxy_register = proxy.clone();
1700 let router = router.route(
1701 "/register",
1702 axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1703 let p = proxy_register;
1704 async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1705 })
1706 .layer(axum::middleware::from_fn(
1707 oauth_token_cache_headers_middleware,
1708 )),
1709 );
1710
1711 let admin_routes_enabled = proxy.expose_admin_endpoints
1712 && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1713 if proxy.expose_admin_endpoints
1714 && !proxy.require_auth_on_admin_endpoints
1715 && proxy.allow_unauthenticated_admin_endpoints
1716 {
1717 tracing::warn!(
1721 "OAuth introspect/revoke endpoints are unauthenticated by explicit \
1722 allow_unauthenticated_admin_endpoints opt-out; ensure an \
1723 authenticated reverse proxy fronts these routes"
1724 );
1725 }
1726
1727 let admin_router = if admin_routes_enabled {
1728 build_oauth_admin_router(proxy, http, auth_state)?
1729 } else {
1730 axum::Router::new()
1731 };
1732
1733 let router = router.merge(admin_router);
1734
1735 tracing::info!(
1736 introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1737 revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1738 "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1739 );
1740 Ok(router)
1741}
1742
1743#[cfg(feature = "oauth")]
1749fn build_oauth_admin_router(
1750 proxy: &crate::oauth::OAuthProxyConfig,
1751 http: crate::oauth::OauthHttpClient,
1752 auth_state: Option<&Arc<AuthState>>,
1753) -> Result<axum::Router, McpxError> {
1754 let mut admin_router = axum::Router::new();
1755 if proxy.introspection_url.is_some() {
1756 let proxy_introspect = proxy.clone();
1757 let introspect_http = http.clone();
1758 admin_router = admin_router.route(
1759 "/introspect",
1760 axum::routing::post(move |body: String| {
1761 let p = proxy_introspect.clone();
1762 let h = introspect_http.clone();
1763 async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1764 }),
1765 );
1766 }
1767 if proxy.revocation_url.is_some() {
1768 let proxy_revoke = proxy.clone();
1769 let revoke_http = http;
1770 admin_router = admin_router.route(
1771 "/revoke",
1772 axum::routing::post(move |body: String| {
1773 let p = proxy_revoke.clone();
1774 let h = revoke_http.clone();
1775 async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1776 }),
1777 );
1778 }
1779
1780 let admin_router = admin_router.layer(axum::middleware::from_fn(
1781 oauth_token_cache_headers_middleware,
1782 ));
1783
1784 if proxy.require_auth_on_admin_endpoints {
1785 let Some(state) = auth_state else {
1786 return Err(McpxError::Startup(
1787 "oauth proxy admin endpoints require auth state".into(),
1788 ));
1789 };
1790 let state_for_mw = Arc::clone(state);
1791 Ok(
1792 admin_router.layer(axum::middleware::from_fn(move |req, next| {
1793 let s = Arc::clone(&state_for_mw);
1794 auth_middleware(s, req, next)
1795 })),
1796 )
1797 } else {
1798 Ok(admin_router)
1799 }
1800}
1801
1802fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
1807 let mut hosts = vec![
1808 "localhost".to_owned(),
1809 "127.0.0.1".to_owned(),
1810 "::1".to_owned(),
1811 ];
1812
1813 if let Some(url) = public_url
1814 && let Ok(uri) = url.parse::<axum::http::Uri>()
1815 && let Some(authority) = uri.authority()
1816 {
1817 let host = authority.host().to_owned();
1818 if !hosts.iter().any(|h| h == &host) {
1819 hosts.push(host);
1820 }
1821
1822 let authority = authority.as_str().to_owned();
1823 if !hosts.iter().any(|h| h == &authority) {
1824 hosts.push(authority);
1825 }
1826 }
1827
1828 if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
1829 && let Some(authority) = uri.authority()
1830 {
1831 let host = authority.host().to_owned();
1832 if !hosts.iter().any(|h| h == &host) {
1833 hosts.push(host);
1834 }
1835
1836 let authority = authority.as_str().to_owned();
1837 if !hosts.iter().any(|h| h == &authority) {
1838 hosts.push(authority);
1839 }
1840 }
1841
1842 hosts
1843}
1844
1845impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
1858 for TlsConnInfo
1859{
1860 fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
1861 let addr = *target.remote_addr();
1862 let identity = target.io().identity().cloned();
1863 TlsConnInfo::new(addr, identity)
1864 }
1865}
1866
1867struct TlsListener {
1875 inner: TcpListener,
1876 acceptor: tokio_rustls::TlsAcceptor,
1877 mtls_default_role: String,
1878}
1879
1880impl TlsListener {
1881 fn new(
1882 inner: TcpListener,
1883 cert_path: &Path,
1884 key_path: &Path,
1885 mtls_config: Option<&MtlsConfig>,
1886 crl_set: Option<Arc<CrlSet>>,
1887 ) -> anyhow::Result<Self> {
1888 rustls::crypto::ring::default_provider()
1890 .install_default()
1891 .ok();
1892
1893 let certs = load_certs(cert_path)?;
1894 let key = load_key(key_path)?;
1895
1896 let mtls_default_role;
1897
1898 let tls_config = if let Some(mtls) = mtls_config {
1899 mtls_default_role = mtls.default_role.clone();
1900 let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
1901 {
1902 let Some(crl_set) = crl_set else {
1903 return Err(anyhow::anyhow!(
1904 "mTLS CRL verifier requested but CRL state was not initialized"
1905 ));
1906 };
1907 Arc::new(DynamicClientCertVerifier::new(crl_set))
1908 } else {
1909 let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
1910 if mtls.required {
1911 rustls::server::WebPkiClientVerifier::builder(root_store)
1912 .build()
1913 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1914 } else {
1915 rustls::server::WebPkiClientVerifier::builder(root_store)
1916 .allow_unauthenticated()
1917 .build()
1918 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1919 }
1920 };
1921
1922 tracing::info!(
1923 ca = %mtls.ca_cert_path.display(),
1924 required = mtls.required,
1925 crl_enabled = mtls.crl_enabled,
1926 "mTLS client auth configured"
1927 );
1928
1929 rustls::ServerConfig::builder_with_protocol_versions(&[
1930 &rustls::version::TLS12,
1931 &rustls::version::TLS13,
1932 ])
1933 .with_client_cert_verifier(verifier)
1934 .with_single_cert(certs, key)?
1935 } else {
1936 mtls_default_role = "viewer".to_owned();
1937 rustls::ServerConfig::builder_with_protocol_versions(&[
1938 &rustls::version::TLS12,
1939 &rustls::version::TLS13,
1940 ])
1941 .with_no_client_auth()
1942 .with_single_cert(certs, key)?
1943 };
1944
1945 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
1946 tracing::info!(
1947 "TLS enabled (cert: {}, key: {})",
1948 cert_path.display(),
1949 key_path.display()
1950 );
1951 Ok(Self {
1952 inner,
1953 acceptor,
1954 mtls_default_role,
1955 })
1956 }
1957
1958 fn extract_handshake_identity(
1962 tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1963 default_role: &str,
1964 addr: SocketAddr,
1965 ) -> Option<AuthIdentity> {
1966 let (_, server_conn) = tls_stream.get_ref();
1967 let cert_der = server_conn.peer_certificates()?.first()?;
1968 let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
1969 tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
1970 Some(id)
1971 }
1972}
1973
1974pub(crate) struct AuthenticatedTlsStream {
1986 inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1987 identity: Option<AuthIdentity>,
1988}
1989
1990impl AuthenticatedTlsStream {
1991 #[must_use]
1993 pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
1994 self.identity.as_ref()
1995 }
1996}
1997
1998impl std::fmt::Debug for AuthenticatedTlsStream {
1999 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2000 f.debug_struct("AuthenticatedTlsStream")
2001 .field("identity", &self.identity.as_ref().map(|id| &id.name))
2002 .finish_non_exhaustive()
2003 }
2004}
2005
2006impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2007 fn poll_read(
2008 mut self: Pin<&mut Self>,
2009 cx: &mut std::task::Context<'_>,
2010 buf: &mut tokio::io::ReadBuf<'_>,
2011 ) -> std::task::Poll<std::io::Result<()>> {
2012 Pin::new(&mut self.inner).poll_read(cx, buf)
2013 }
2014}
2015
2016impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2017 fn poll_write(
2018 mut self: Pin<&mut Self>,
2019 cx: &mut std::task::Context<'_>,
2020 buf: &[u8],
2021 ) -> std::task::Poll<std::io::Result<usize>> {
2022 Pin::new(&mut self.inner).poll_write(cx, buf)
2023 }
2024
2025 fn poll_flush(
2026 mut self: Pin<&mut Self>,
2027 cx: &mut std::task::Context<'_>,
2028 ) -> std::task::Poll<std::io::Result<()>> {
2029 Pin::new(&mut self.inner).poll_flush(cx)
2030 }
2031
2032 fn poll_shutdown(
2033 mut self: Pin<&mut Self>,
2034 cx: &mut std::task::Context<'_>,
2035 ) -> std::task::Poll<std::io::Result<()>> {
2036 Pin::new(&mut self.inner).poll_shutdown(cx)
2037 }
2038
2039 fn poll_write_vectored(
2040 mut self: Pin<&mut Self>,
2041 cx: &mut std::task::Context<'_>,
2042 bufs: &[std::io::IoSlice<'_>],
2043 ) -> std::task::Poll<std::io::Result<usize>> {
2044 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2045 }
2046
2047 fn is_write_vectored(&self) -> bool {
2048 self.inner.is_write_vectored()
2049 }
2050}
2051
2052impl axum::serve::Listener for TlsListener {
2053 type Io = AuthenticatedTlsStream;
2054 type Addr = SocketAddr;
2055
2056 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2057 loop {
2058 let (stream, addr) = match self.inner.accept().await {
2059 Ok(pair) => pair,
2060 Err(e) => {
2061 tracing::debug!("TCP accept error: {e}");
2062 continue;
2063 }
2064 };
2065 let tls_stream = match self.acceptor.accept(stream).await {
2066 Ok(s) => s,
2067 Err(e) => {
2068 tracing::debug!("TLS handshake failed from {addr}: {e}");
2069 continue;
2070 }
2071 };
2072 let identity =
2073 Self::extract_handshake_identity(&tls_stream, &self.mtls_default_role, addr);
2074 let wrapped = AuthenticatedTlsStream {
2075 inner: tls_stream,
2076 identity,
2077 };
2078 return (wrapped, addr);
2079 }
2080 }
2081
2082 fn local_addr(&self) -> std::io::Result<Self::Addr> {
2083 self.inner.local_addr()
2084 }
2085}
2086
2087fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2088 use rustls::pki_types::pem::PemObject;
2089 let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2090 .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2091 .collect::<Result<_, _>>()
2092 .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2093 anyhow::ensure!(
2094 !certs.is_empty(),
2095 "no certificates found in {}",
2096 path.display()
2097 );
2098 Ok(certs)
2099}
2100
2101fn load_client_auth_roots(
2102 path: &Path,
2103) -> anyhow::Result<(
2104 Vec<rustls::pki_types::CertificateDer<'static>>,
2105 Arc<RootCertStore>,
2106)> {
2107 let ca_certs = load_certs(path)?;
2108 let mut root_store = RootCertStore::empty();
2109 for cert in &ca_certs {
2110 root_store
2111 .add(cert.clone())
2112 .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2113 }
2114
2115 Ok((ca_certs, Arc::new(root_store)))
2116}
2117
2118fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2119 use rustls::pki_types::pem::PemObject;
2120 rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2121 .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2122}
2123
2124#[allow(clippy::unused_async)]
2125async fn healthz() -> impl IntoResponse {
2126 axum::Json(serde_json::json!({
2127 "status": "ok",
2128 }))
2129}
2130
2131fn version_payload(name: &str, version: &str) -> serde_json::Value {
2137 serde_json::json!({
2138 "name": name,
2139 "version": version,
2140 "build_git_sha": option_env!("MCPX_BUILD_SHA").unwrap_or("unknown"),
2141 "build_timestamp": option_env!("MCPX_BUILD_TIME").unwrap_or("unknown"),
2142 "rust_version": option_env!("MCPX_RUSTC_VERSION").unwrap_or("unknown"),
2143 "mcpx_version": env!("CARGO_PKG_VERSION"),
2144 })
2145}
2146
2147fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2157 let value = version_payload(name, version);
2158 serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2159}
2160
2161async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2162 let status = check().await;
2163 let ready = status
2164 .get("ready")
2165 .and_then(serde_json::Value::as_bool)
2166 .unwrap_or(false);
2167 let code = if ready {
2168 axum::http::StatusCode::OK
2169 } else {
2170 axum::http::StatusCode::SERVICE_UNAVAILABLE
2171 };
2172 (code, axum::Json(status))
2173}
2174
2175async fn shutdown_signal() {
2179 let ctrl_c = tokio::signal::ctrl_c();
2180
2181 #[cfg(unix)]
2182 {
2183 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2184 Ok(mut term) => {
2185 tokio::select! {
2186 _ = ctrl_c => {}
2187 _ = term.recv() => {}
2188 }
2189 }
2190 Err(e) => {
2191 tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2192 ctrl_c.await.ok();
2193 }
2194 }
2195 }
2196
2197 #[cfg(not(unix))]
2198 {
2199 ctrl_c.await.ok();
2200 }
2201}
2202
2203#[cfg(feature = "metrics")]
2209async fn metrics_middleware(
2210 metrics: Arc<crate::metrics::McpMetrics>,
2211 req: Request<Body>,
2212 next: Next,
2213) -> axum::response::Response {
2214 let method = req.method().to_string();
2215 let path = req.uri().path().to_owned();
2216 let start = std::time::Instant::now();
2217
2218 let response = next.run(req).await;
2219
2220 let status = response.status().as_u16().to_string();
2221 let duration = start.elapsed().as_secs_f64();
2222
2223 metrics
2224 .http_requests_total
2225 .with_label_values(&[&method, &path, &status])
2226 .inc();
2227 metrics
2228 .http_request_duration_seconds
2229 .with_label_values(&[&method, &path])
2230 .observe(duration);
2231
2232 response
2233}
2234
2235async fn security_headers_middleware(
2247 is_tls: bool,
2248 cfg: Arc<SecurityHeadersConfig>,
2249 req: Request<Body>,
2250 next: Next,
2251) -> axum::response::Response {
2252 use axum::http::{HeaderName, header};
2253
2254 let mut resp = next.run(req).await;
2255 let headers = resp.headers_mut();
2256
2257 headers.remove(header::SERVER);
2259 headers.remove(HeaderName::from_static("x-powered-by"));
2260
2261 apply_security_header(
2262 headers,
2263 header::X_CONTENT_TYPE_OPTIONS,
2264 cfg.x_content_type_options.as_deref(),
2265 "nosniff",
2266 );
2267 apply_security_header(
2268 headers,
2269 header::X_FRAME_OPTIONS,
2270 cfg.x_frame_options.as_deref(),
2271 "deny",
2272 );
2273 apply_security_header(
2274 headers,
2275 header::CACHE_CONTROL,
2276 cfg.cache_control.as_deref(),
2277 "no-store, max-age=0",
2278 );
2279 apply_security_header(
2280 headers,
2281 header::REFERRER_POLICY,
2282 cfg.referrer_policy.as_deref(),
2283 "no-referrer",
2284 );
2285 apply_security_header(
2286 headers,
2287 HeaderName::from_static("cross-origin-opener-policy"),
2288 cfg.cross_origin_opener_policy.as_deref(),
2289 "same-origin",
2290 );
2291 apply_security_header(
2292 headers,
2293 HeaderName::from_static("cross-origin-resource-policy"),
2294 cfg.cross_origin_resource_policy.as_deref(),
2295 "same-origin",
2296 );
2297 apply_security_header(
2298 headers,
2299 HeaderName::from_static("cross-origin-embedder-policy"),
2300 cfg.cross_origin_embedder_policy.as_deref(),
2301 "require-corp",
2302 );
2303 apply_security_header(
2304 headers,
2305 HeaderName::from_static("permissions-policy"),
2306 cfg.permissions_policy.as_deref(),
2307 "accelerometer=(), camera=(), geolocation=(), microphone=()",
2308 );
2309 apply_security_header(
2310 headers,
2311 HeaderName::from_static("x-permitted-cross-domain-policies"),
2312 cfg.x_permitted_cross_domain_policies.as_deref(),
2313 "none",
2314 );
2315 apply_security_header(
2316 headers,
2317 HeaderName::from_static("content-security-policy"),
2318 cfg.content_security_policy.as_deref(),
2319 "default-src 'none'; frame-ancestors 'none'",
2320 );
2321 apply_security_header(
2322 headers,
2323 HeaderName::from_static("x-dns-prefetch-control"),
2324 cfg.x_dns_prefetch_control.as_deref(),
2325 "off",
2326 );
2327
2328 if is_tls {
2329 apply_security_header(
2330 headers,
2331 header::STRICT_TRANSPORT_SECURITY,
2332 cfg.strict_transport_security.as_deref(),
2333 "max-age=63072000; includeSubDomains",
2334 );
2335 }
2336
2337 resp
2338}
2339
2340fn apply_security_header(
2351 headers: &mut axum::http::HeaderMap,
2352 name: axum::http::HeaderName,
2353 override_value: Option<&str>,
2354 default: &'static str,
2355) {
2356 use axum::http::HeaderValue;
2357
2358 match override_value {
2359 None => {
2360 headers.insert(name, HeaderValue::from_static(default));
2361 }
2362 Some("") => {
2363 }
2365 Some(v) => match HeaderValue::from_str(v) {
2366 Ok(hv) => {
2367 headers.insert(name, hv);
2368 }
2369 Err(err) => {
2370 tracing::error!(
2371 header = %name,
2372 error = %err,
2373 "invalid security header override reached middleware; using default"
2374 );
2375 headers.insert(name, HeaderValue::from_static(default));
2376 }
2377 },
2378 }
2379}
2380
2381fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2392 use axum::http::HeaderValue;
2393
2394 let fields: &[(&str, Option<&str>)] = &[
2395 (
2396 "x_content_type_options",
2397 cfg.x_content_type_options.as_deref(),
2398 ),
2399 ("x_frame_options", cfg.x_frame_options.as_deref()),
2400 ("cache_control", cfg.cache_control.as_deref()),
2401 ("referrer_policy", cfg.referrer_policy.as_deref()),
2402 (
2403 "cross_origin_opener_policy",
2404 cfg.cross_origin_opener_policy.as_deref(),
2405 ),
2406 (
2407 "cross_origin_resource_policy",
2408 cfg.cross_origin_resource_policy.as_deref(),
2409 ),
2410 (
2411 "cross_origin_embedder_policy",
2412 cfg.cross_origin_embedder_policy.as_deref(),
2413 ),
2414 ("permissions_policy", cfg.permissions_policy.as_deref()),
2415 (
2416 "x_permitted_cross_domain_policies",
2417 cfg.x_permitted_cross_domain_policies.as_deref(),
2418 ),
2419 (
2420 "content_security_policy",
2421 cfg.content_security_policy.as_deref(),
2422 ),
2423 (
2424 "x_dns_prefetch_control",
2425 cfg.x_dns_prefetch_control.as_deref(),
2426 ),
2427 (
2428 "strict_transport_security",
2429 cfg.strict_transport_security.as_deref(),
2430 ),
2431 ];
2432
2433 for (field, value) in fields {
2434 let Some(v) = value else { continue };
2435 if v.is_empty() {
2436 continue;
2437 }
2438 if let Err(err) = HeaderValue::from_str(v) {
2439 return Err(McpxError::Config(format!(
2440 "invalid security_headers.{field}: {err}"
2441 )));
2442 }
2443 }
2444
2445 if let Some(v) = cfg.strict_transport_security.as_deref()
2446 && !v.is_empty()
2447 && v.to_ascii_lowercase().contains("preload")
2448 {
2449 return Err(McpxError::Config(format!(
2450 "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2451 HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2452 )));
2453 }
2454
2455 Ok(())
2456}
2457
2458#[cfg(feature = "oauth")]
2473async fn oauth_token_cache_headers_middleware(
2474 req: Request<Body>,
2475 next: Next,
2476) -> axum::response::Response {
2477 use axum::http::{HeaderValue, header};
2478
2479 let mut resp = next.run(req).await;
2480 let headers = resp.headers_mut();
2481 headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2482 headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2483 resp
2484}
2485
2486async fn origin_check_middleware(
2490 allowed: Arc<[String]>,
2491 log_request_headers: bool,
2492 req: Request<Body>,
2493 next: Next,
2494) -> axum::response::Response {
2495 let method = req.method().clone();
2496 let path = req.uri().path().to_owned();
2497
2498 log_incoming_request(&method, &path, req.headers(), log_request_headers);
2499
2500 if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2501 let origin_str = origin.to_str().unwrap_or("");
2502 if !allowed.iter().any(|a| a == origin_str) {
2503 tracing::warn!(
2504 origin = origin_str,
2505 %method,
2506 %path,
2507 allowed = ?&*allowed,
2508 "rejected request: Origin not allowed"
2509 );
2510 return (
2511 axum::http::StatusCode::FORBIDDEN,
2512 "Forbidden: Origin not allowed",
2513 )
2514 .into_response();
2515 }
2516 }
2517 next.run(req).await
2518}
2519
2520fn log_incoming_request(
2523 method: &axum::http::Method,
2524 path: &str,
2525 headers: &axum::http::HeaderMap,
2526 log_request_headers: bool,
2527) {
2528 if log_request_headers {
2529 tracing::debug!(
2530 %method,
2531 %path,
2532 headers = %format_request_headers_for_log(headers),
2533 "incoming request"
2534 );
2535 } else {
2536 tracing::debug!(%method, %path, "incoming request");
2537 }
2538}
2539
2540fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2541 headers
2542 .iter()
2543 .map(|(k, v)| {
2544 let name = k.as_str();
2545 if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2546 format!("{name}: [REDACTED]")
2547 } else {
2548 format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2549 }
2550 })
2551 .collect::<Vec<_>>()
2552 .join(", ")
2553}
2554
2555#[allow(
2579 clippy::cognitive_complexity,
2580 reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
2581)]
2582pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2583where
2584 H: ServerHandler + 'static,
2585{
2586 use rmcp::ServiceExt as _;
2587
2588 tracing::info!("stdio transport: serving on stdin/stdout");
2589 tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2590
2591 let transport = rmcp::transport::io::stdio();
2592
2593 let service = handler
2594 .serve(transport)
2595 .await
2596 .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2597
2598 if let Err(e) = service.waiting().await {
2599 tracing::warn!(error = %e, "stdio session ended with error");
2600 }
2601 tracing::info!("stdio session ended");
2602 Ok(())
2603}
2604
2605#[cfg(test)]
2606mod tests {
2607 #![allow(
2608 clippy::unwrap_used,
2609 clippy::expect_used,
2610 clippy::panic,
2611 clippy::indexing_slicing,
2612 clippy::unwrap_in_result,
2613 clippy::print_stdout,
2614 clippy::print_stderr,
2615 deprecated,
2616 reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2617 )]
2618 use std::sync::Arc;
2619
2620 use axum::{
2621 body::Body,
2622 http::{Request, StatusCode, header},
2623 response::IntoResponse,
2624 };
2625 use http_body_util::BodyExt;
2626 use tower::ServiceExt as _;
2627
2628 use super::*;
2629
2630 #[test]
2633 fn server_config_new_defaults() {
2634 let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2635 assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2636 assert_eq!(cfg.name, "test-server");
2637 assert_eq!(cfg.version, "1.0.0");
2638 assert!(cfg.tls_cert_path.is_none());
2639 assert!(cfg.tls_key_path.is_none());
2640 assert!(cfg.auth.is_none());
2641 assert!(cfg.rbac.is_none());
2642 assert!(cfg.allowed_origins.is_empty());
2643 assert!(cfg.tool_rate_limit.is_none());
2644 assert!(cfg.readiness_check.is_none());
2645 assert_eq!(cfg.max_request_body, 1024 * 1024);
2646 assert_eq!(cfg.request_timeout, Duration::from_mins(2));
2647 assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
2648 assert!(!cfg.log_request_headers);
2649 }
2650
2651 #[test]
2652 fn validate_consumes_and_proves() {
2653 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2655 let validated = cfg.validate().expect("valid config");
2656 assert_eq!(validated.as_inner().name, "test-server");
2658 let raw = validated.into_inner();
2660 assert_eq!(raw.name, "test-server");
2661
2662 let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2664 bad.max_request_body = 0;
2665 assert!(bad.validate().is_err(), "zero body cap must fail validate");
2666 }
2667
2668 #[test]
2669 fn validate_rejects_zero_max_concurrent_requests() {
2670 let cfg =
2671 McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
2672 let err = cfg.validate().expect_err("zero concurrency cap must fail");
2673 assert!(
2674 format!("{err}").contains("max_concurrent_requests"),
2675 "error should mention max_concurrent_requests, got: {err}"
2676 );
2677 }
2678
2679 #[test]
2680 fn validate_rejects_zero_max_tracked_keys() {
2681 let rl = crate::auth::RateLimitConfig {
2682 max_tracked_keys: 0,
2683 ..Default::default()
2684 };
2685 let auth_cfg = AuthConfig {
2686 enabled: true,
2687 rate_limit: Some(rl),
2688 ..Default::default()
2689 };
2690 let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
2691 let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
2692 assert!(
2693 format!("{err}").contains("max_tracked_keys"),
2694 "error should mention max_tracked_keys, got: {err}"
2695 );
2696 }
2697
2698 #[test]
2699 fn derive_allowed_hosts_includes_public_host() {
2700 let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
2701 assert!(
2702 hosts.iter().any(|h| h == "mcp.example.com"),
2703 "public_url host must be allowed"
2704 );
2705 }
2706
2707 #[test]
2708 fn derive_allowed_hosts_includes_bind_authority() {
2709 let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
2710 assert!(
2711 hosts.iter().any(|h| h == "127.0.0.1"),
2712 "bind host must be allowed"
2713 );
2714 assert!(
2715 hosts.iter().any(|h| h == "127.0.0.1:8080"),
2716 "bind authority must be allowed"
2717 );
2718 }
2719
2720 #[tokio::test]
2723 async fn healthz_returns_ok_json() {
2724 let resp = healthz().await.into_response();
2725 assert_eq!(resp.status(), StatusCode::OK);
2726 let body = resp.into_body().collect().await.unwrap().to_bytes();
2727 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2728 assert_eq!(json["status"], "ok");
2729 assert!(
2730 json.get("name").is_none(),
2731 "healthz must not expose server name"
2732 );
2733 assert!(
2734 json.get("version").is_none(),
2735 "healthz must not expose version"
2736 );
2737 }
2738
2739 #[tokio::test]
2742 async fn readyz_returns_ok_when_ready() {
2743 let check: ReadinessCheck =
2744 Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
2745 let resp = readyz(check).await.into_response();
2746 assert_eq!(resp.status(), StatusCode::OK);
2747 let body = resp.into_body().collect().await.unwrap().to_bytes();
2748 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2749 assert_eq!(json["ready"], true);
2750 assert!(
2751 json.get("name").is_none(),
2752 "readyz must not expose server name"
2753 );
2754 assert!(
2755 json.get("version").is_none(),
2756 "readyz must not expose version"
2757 );
2758 assert_eq!(json["db"], "connected");
2759 }
2760
2761 #[tokio::test]
2762 async fn readyz_returns_503_when_not_ready() {
2763 let check: ReadinessCheck =
2764 Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
2765 let resp = readyz(check).await.into_response();
2766 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2767 }
2768
2769 #[tokio::test]
2770 async fn readyz_returns_503_when_ready_missing() {
2771 let check: ReadinessCheck =
2772 Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
2773 let resp = readyz(check).await.into_response();
2774 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2776 }
2777
2778 fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
2782 let allowed: Arc<[String]> = Arc::from(origins);
2783 axum::Router::new()
2784 .route("/test", axum::routing::get(|| async { "ok" }))
2785 .layer(axum::middleware::from_fn(move |req, next| {
2786 let a = Arc::clone(&allowed);
2787 origin_check_middleware(a, log_request_headers, req, next)
2788 }))
2789 }
2790
2791 #[tokio::test]
2792 async fn origin_allowed_passes() {
2793 let app = origin_router(vec!["http://localhost:3000".into()], false);
2794 let req = Request::builder()
2795 .uri("/test")
2796 .header(header::ORIGIN, "http://localhost:3000")
2797 .body(Body::empty())
2798 .unwrap();
2799 let resp = app.oneshot(req).await.unwrap();
2800 assert_eq!(resp.status(), StatusCode::OK);
2801 }
2802
2803 #[tokio::test]
2804 async fn origin_rejected_returns_403() {
2805 let app = origin_router(vec!["http://localhost:3000".into()], false);
2806 let req = Request::builder()
2807 .uri("/test")
2808 .header(header::ORIGIN, "http://evil.com")
2809 .body(Body::empty())
2810 .unwrap();
2811 let resp = app.oneshot(req).await.unwrap();
2812 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2813 }
2814
2815 #[tokio::test]
2816 async fn no_origin_header_passes() {
2817 let app = origin_router(vec!["http://localhost:3000".into()], false);
2818 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2819 let resp = app.oneshot(req).await.unwrap();
2820 assert_eq!(resp.status(), StatusCode::OK);
2821 }
2822
2823 #[tokio::test]
2824 async fn empty_allowlist_rejects_any_origin() {
2825 let app = origin_router(vec![], false);
2826 let req = Request::builder()
2827 .uri("/test")
2828 .header(header::ORIGIN, "http://anything.com")
2829 .body(Body::empty())
2830 .unwrap();
2831 let resp = app.oneshot(req).await.unwrap();
2832 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2833 }
2834
2835 #[tokio::test]
2836 async fn empty_allowlist_passes_without_origin() {
2837 let app = origin_router(vec![], false);
2838 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2839 let resp = app.oneshot(req).await.unwrap();
2840 assert_eq!(resp.status(), StatusCode::OK);
2841 }
2842
2843 #[test]
2844 fn format_request_headers_redacts_sensitive_values() {
2845 let mut headers = axum::http::HeaderMap::new();
2846 headers.insert("authorization", "Bearer secret-token".parse().unwrap());
2847 headers.insert("cookie", "sid=abc".parse().unwrap());
2848 headers.insert("x-request-id", "req-123".parse().unwrap());
2849
2850 let out = format_request_headers_for_log(&headers);
2851 assert!(out.contains("authorization: [REDACTED]"));
2852 assert!(out.contains("cookie: [REDACTED]"));
2853 assert!(out.contains("x-request-id: req-123"));
2854 assert!(!out.contains("secret-token"));
2855 }
2856
2857 fn security_router(is_tls: bool) -> axum::Router {
2860 security_router_with(is_tls, SecurityHeadersConfig::default())
2861 }
2862
2863 fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
2864 let cfg = Arc::new(cfg);
2865 axum::Router::new()
2866 .route("/test", axum::routing::get(|| async { "ok" }))
2867 .layer(axum::middleware::from_fn(move |req, next| {
2868 let c = Arc::clone(&cfg);
2869 security_headers_middleware(is_tls, c, req, next)
2870 }))
2871 }
2872
2873 #[tokio::test]
2874 async fn security_headers_set_on_response() {
2875 let app = security_router(false);
2876 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2877 let resp = app.oneshot(req).await.unwrap();
2878 assert_eq!(resp.status(), StatusCode::OK);
2879
2880 let h = resp.headers();
2881 assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
2882 assert_eq!(h.get("x-frame-options").unwrap(), "deny");
2883 assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
2884 assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
2885 assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
2886 assert_eq!(
2887 h.get("cross-origin-resource-policy").unwrap(),
2888 "same-origin"
2889 );
2890 assert_eq!(
2891 h.get("cross-origin-embedder-policy").unwrap(),
2892 "require-corp"
2893 );
2894 assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
2895 assert!(
2896 h.get("permissions-policy")
2897 .unwrap()
2898 .to_str()
2899 .unwrap()
2900 .contains("camera=()"),
2901 "permissions-policy must restrict browser features"
2902 );
2903 assert_eq!(
2904 h.get("content-security-policy").unwrap(),
2905 "default-src 'none'; frame-ancestors 'none'"
2906 );
2907 assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
2908 assert!(h.get("strict-transport-security").is_none());
2910 }
2911
2912 #[tokio::test]
2913 async fn hsts_set_when_tls_enabled() {
2914 let app = security_router(true);
2915 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2916 let resp = app.oneshot(req).await.unwrap();
2917
2918 let hsts = resp.headers().get("strict-transport-security").unwrap();
2919 assert!(
2920 hsts.to_str().unwrap().contains("max-age=63072000"),
2921 "HSTS must set 2-year max-age"
2922 );
2923 }
2924
2925 fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
2931 let cfg =
2932 McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
2933 cfg.check()
2934 }
2935
2936 #[test]
2937 fn security_headers_config_default_validates() {
2938 check_with_security_headers(SecurityHeadersConfig::default())
2939 .expect("default SecurityHeadersConfig must validate");
2940 }
2941
2942 #[test]
2943 fn security_headers_config_validate_accepts_empty_string() {
2944 let h = SecurityHeadersConfig {
2946 x_content_type_options: Some(String::new()),
2947 x_frame_options: Some(String::new()),
2948 cache_control: Some(String::new()),
2949 referrer_policy: Some(String::new()),
2950 cross_origin_opener_policy: Some(String::new()),
2951 cross_origin_resource_policy: Some(String::new()),
2952 cross_origin_embedder_policy: Some(String::new()),
2953 permissions_policy: Some(String::new()),
2954 x_permitted_cross_domain_policies: Some(String::new()),
2955 content_security_policy: Some(String::new()),
2956 x_dns_prefetch_control: Some(String::new()),
2957 strict_transport_security: Some(String::new()),
2958 };
2959 check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
2960 }
2961
2962 #[test]
2963 fn security_headers_config_validate_rejects_bad_value() {
2964 let h = SecurityHeadersConfig {
2966 referrer_policy: Some("\u{0007}".into()),
2967 ..SecurityHeadersConfig::default()
2968 };
2969 let err = check_with_security_headers(h)
2970 .expect_err("control char in referrer_policy must reject");
2971 let msg = err.to_string();
2972 assert!(
2973 msg.contains("referrer_policy"),
2974 "error must name the offending field, got: {msg}"
2975 );
2976 }
2977
2978 #[test]
2979 fn security_headers_config_validate_rejects_hsts_preload() {
2980 let h = SecurityHeadersConfig {
2981 strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
2982 ..SecurityHeadersConfig::default()
2983 };
2984 let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
2985 let msg = err.to_string();
2986 assert!(
2987 msg.contains("strict_transport_security"),
2988 "error must name the field, got: {msg}"
2989 );
2990 assert!(
2991 msg.to_lowercase().contains("preload"),
2992 "error must mention `preload`, got: {msg}"
2993 );
2994 }
2995
2996 #[test]
2997 fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
2998 let h = SecurityHeadersConfig {
3000 strict_transport_security: Some("max-age=600; PRELOAD".into()),
3001 ..SecurityHeadersConfig::default()
3002 };
3003 check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
3004 }
3005
3006 #[tokio::test]
3007 async fn security_headers_override_honored() {
3008 let h = SecurityHeadersConfig {
3010 x_frame_options: Some("SAMEORIGIN".into()),
3011 ..SecurityHeadersConfig::default()
3012 };
3013 let app = security_router_with(false, h);
3014 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3015 let resp = app.oneshot(req).await.unwrap();
3016 assert_eq!(resp.status(), StatusCode::OK);
3017
3018 let xfo = resp.headers().get("x-frame-options").unwrap();
3019 assert_eq!(xfo, "SAMEORIGIN");
3020 }
3021
3022 #[tokio::test]
3023 async fn security_headers_empty_string_omits() {
3024 let h = SecurityHeadersConfig {
3026 referrer_policy: Some(String::new()),
3027 ..SecurityHeadersConfig::default()
3028 };
3029 let app = security_router_with(false, h);
3030 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3031 let resp = app.oneshot(req).await.unwrap();
3032 assert_eq!(resp.status(), StatusCode::OK);
3033
3034 assert!(
3035 resp.headers().get("referrer-policy").is_none(),
3036 "Some(\"\") must omit the header"
3037 );
3038 assert_eq!(
3040 resp.headers().get("x-content-type-options").unwrap(),
3041 "nosniff"
3042 );
3043 }
3044
3045 #[tokio::test]
3046 async fn security_headers_hsts_only_when_tls() {
3047 let h = SecurityHeadersConfig {
3049 strict_transport_security: Some("max-age=600".into()),
3050 ..SecurityHeadersConfig::default()
3051 };
3052 let app = security_router_with(false, h);
3053 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3054 let resp = app.oneshot(req).await.unwrap();
3055 assert!(
3056 resp.headers().get("strict-transport-security").is_none(),
3057 "HSTS must remain absent on plaintext deployments even with override"
3058 );
3059 }
3060
3061 #[cfg(feature = "oauth")]
3064 #[tokio::test]
3065 async fn oauth_token_cache_headers_set_pragma_and_vary() {
3066 let app = axum::Router::new()
3067 .route("/token", axum::routing::post(|| async { "{}" }))
3068 .layer(axum::middleware::from_fn(
3069 oauth_token_cache_headers_middleware,
3070 ));
3071 let req = Request::builder()
3072 .method("POST")
3073 .uri("/token")
3074 .body(Body::from("{}"))
3075 .unwrap();
3076 let resp = app.oneshot(req).await.unwrap();
3077 assert_eq!(resp.status(), StatusCode::OK);
3078
3079 let h = resp.headers();
3080 assert_eq!(
3081 h.get("pragma").unwrap(),
3082 "no-cache",
3083 "RFC 6749 §5.1: token responses must set Pragma: no-cache"
3084 );
3085 let vary_values: Vec<String> = h
3086 .get_all("vary")
3087 .iter()
3088 .filter_map(|v| v.to_str().ok().map(str::to_owned))
3089 .collect();
3090 assert!(
3091 vary_values
3092 .iter()
3093 .any(|v| v.eq_ignore_ascii_case("Authorization")),
3094 "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
3095 );
3096 }
3097
3098 #[cfg(feature = "oauth")]
3099 #[tokio::test]
3100 async fn oauth_token_cache_headers_preserve_existing_vary() {
3101 let app = axum::Router::new()
3104 .route(
3105 "/token",
3106 axum::routing::post(|| async {
3107 axum::response::Response::builder()
3108 .header("vary", "Accept-Encoding")
3109 .body(axum::body::Body::from("{}"))
3110 .unwrap()
3111 }),
3112 )
3113 .layer(axum::middleware::from_fn(
3114 oauth_token_cache_headers_middleware,
3115 ));
3116 let req = Request::builder()
3117 .method("POST")
3118 .uri("/token")
3119 .body(Body::empty())
3120 .unwrap();
3121 let resp = app.oneshot(req).await.unwrap();
3122
3123 let vary: Vec<String> = resp
3124 .headers()
3125 .get_all("vary")
3126 .iter()
3127 .filter_map(|v| v.to_str().ok().map(str::to_owned))
3128 .collect();
3129 assert!(
3130 vary.iter().any(|v| v.contains("Accept-Encoding")),
3131 "must preserve pre-existing Vary value, got {vary:?}"
3132 );
3133 assert!(
3134 vary.iter().any(|v| v.contains("Authorization")),
3135 "must append Authorization to Vary, got {vary:?}"
3136 );
3137 }
3138
3139 #[test]
3142 fn version_payload_contains_expected_fields() {
3143 let v = version_payload("my-server", "1.2.3");
3144 assert_eq!(v["name"], "my-server");
3145 assert_eq!(v["version"], "1.2.3");
3146 assert!(v["build_git_sha"].is_string());
3147 assert!(v["build_timestamp"].is_string());
3148 assert!(v["rust_version"].is_string());
3149 assert!(v["mcpx_version"].is_string());
3150 }
3151
3152 #[tokio::test]
3155 async fn concurrency_limit_layer_composes_and_serves() {
3156 let app = axum::Router::new()
3160 .route("/ok", axum::routing::get(|| async { "ok" }))
3161 .layer(
3162 tower::ServiceBuilder::new()
3163 .layer(axum::error_handling::HandleErrorLayer::new(
3164 |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
3165 ))
3166 .layer(tower::load_shed::LoadShedLayer::new())
3167 .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
3168 );
3169 let resp = app
3170 .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
3171 .await
3172 .unwrap();
3173 assert_eq!(resp.status(), StatusCode::OK);
3174 }
3175
3176 #[tokio::test]
3179 async fn compression_layer_gzip_encodes_response() {
3180 use tower_http::compression::Predicate as _;
3181
3182 let big_body = "a".repeat(4096);
3183 let app = axum::Router::new()
3184 .route(
3185 "/big",
3186 axum::routing::get(move || {
3187 let body = big_body.clone();
3188 async move { body }
3189 }),
3190 )
3191 .layer(
3192 tower_http::compression::CompressionLayer::new()
3193 .gzip(true)
3194 .br(true)
3195 .compress_when(
3196 tower_http::compression::DefaultPredicate::new()
3197 .and(tower_http::compression::predicate::SizeAbove::new(1024)),
3198 ),
3199 );
3200
3201 let req = Request::builder()
3202 .uri("/big")
3203 .header(header::ACCEPT_ENCODING, "gzip")
3204 .body(Body::empty())
3205 .unwrap();
3206 let resp = app.oneshot(req).await.unwrap();
3207 assert_eq!(resp.status(), StatusCode::OK);
3208 assert_eq!(
3209 resp.headers().get(header::CONTENT_ENCODING).unwrap(),
3210 "gzip"
3211 );
3212 }
3213}