1use std::{
2 future::Future,
3 net::SocketAddr,
4 path::{Path, PathBuf},
5 pin::Pin,
6 sync::Arc,
7 time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{body::Body, extract::Request, middleware::Next, response::IntoResponse};
12use rmcp::{
13 ServerHandler,
14 transport::streamable_http_server::{
15 StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16 },
17};
18use rustls::RootCertStore;
19use tokio::net::TcpListener;
20use tokio_util::sync::CancellationToken;
21
22use crate::{
23 auth::{
24 AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
25 build_rate_limiter, extract_mtls_identity,
26 },
27 error::McpxError,
28 mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
29 rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
30};
31
32#[allow(
36 clippy::needless_pass_by_value,
37 reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
38)]
39fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
40 McpxError::Startup(format!("{e:#}"))
41}
42
43#[allow(
49 clippy::needless_pass_by_value,
50 reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
51)]
52fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
53 McpxError::Startup(format!("{op}: {e}"))
54}
55
56pub type ReadinessCheck =
61 Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
62
63#[derive(Debug, Clone, Default)]
84#[non_exhaustive]
85pub struct SecurityHeadersConfig {
86 pub x_content_type_options: Option<String>,
88 pub x_frame_options: Option<String>,
90 pub cache_control: Option<String>,
92 pub referrer_policy: Option<String>,
94 pub cross_origin_opener_policy: Option<String>,
96 pub cross_origin_resource_policy: Option<String>,
98 pub cross_origin_embedder_policy: Option<String>,
100 pub permissions_policy: Option<String>,
103 pub x_permitted_cross_domain_policies: Option<String>,
105 pub content_security_policy: Option<String>,
108 pub x_dns_prefetch_control: Option<String>,
110 pub strict_transport_security: Option<String>,
115}
116
117#[allow(
119 missing_debug_implementations,
120 reason = "contains callback/trait objects that don't impl Debug"
121)]
122#[allow(
123 clippy::struct_excessive_bools,
124 reason = "server configuration naturally has many boolean feature flags"
125)]
126#[non_exhaustive]
127pub struct McpServerConfig {
128 #[deprecated(
130 since = "0.13.0",
131 note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in 1.0"
132 )]
133 pub bind_addr: String,
134 #[deprecated(
136 since = "0.13.0",
137 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in 1.0"
138 )]
139 pub name: String,
140 #[deprecated(
142 since = "0.13.0",
143 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in 1.0"
144 )]
145 pub version: String,
146 #[deprecated(
148 since = "0.13.0",
149 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in 1.0"
150 )]
151 pub tls_cert_path: Option<PathBuf>,
152 #[deprecated(
154 since = "0.13.0",
155 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in 1.0"
156 )]
157 pub tls_key_path: Option<PathBuf>,
158 #[deprecated(
161 since = "0.13.0",
162 note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in 1.0"
163 )]
164 pub auth: Option<AuthConfig>,
165 #[deprecated(
168 since = "0.13.0",
169 note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in 1.0"
170 )]
171 pub rbac: Option<Arc<RbacPolicy>>,
172 #[deprecated(
178 since = "0.13.0",
179 note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in 1.0"
180 )]
181 pub allowed_origins: Vec<String>,
182 #[deprecated(
185 since = "0.13.0",
186 note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in 1.0"
187 )]
188 pub tool_rate_limit: Option<u32>,
189 #[deprecated(
192 since = "0.13.0",
193 note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in 1.0"
194 )]
195 pub readiness_check: Option<ReadinessCheck>,
196 #[deprecated(
199 since = "0.13.0",
200 note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in 1.0"
201 )]
202 pub max_request_body: usize,
203 #[deprecated(
206 since = "0.13.0",
207 note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in 1.0"
208 )]
209 pub request_timeout: Duration,
210 #[deprecated(
213 since = "0.13.0",
214 note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in 1.0"
215 )]
216 pub shutdown_timeout: Duration,
217 #[deprecated(
220 since = "0.13.0",
221 note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in 1.0"
222 )]
223 pub session_idle_timeout: Duration,
224 #[deprecated(
227 since = "0.13.0",
228 note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in 1.0"
229 )]
230 pub sse_keep_alive: Duration,
231 #[deprecated(
235 since = "0.13.0",
236 note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in 1.0"
237 )]
238 pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
239 #[deprecated(
243 since = "0.13.0",
244 note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in 1.0"
245 )]
246 pub extra_router: Option<axum::Router>,
247 #[deprecated(
252 since = "0.13.0",
253 note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in 1.0"
254 )]
255 pub public_url: Option<String>,
256 #[deprecated(
259 since = "0.13.0",
260 note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in 1.0"
261 )]
262 pub log_request_headers: bool,
263 #[deprecated(
266 since = "0.13.0",
267 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in 1.0"
268 )]
269 pub compression_enabled: bool,
270 #[deprecated(
273 since = "0.13.0",
274 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in 1.0"
275 )]
276 pub compression_min_size: u16,
277 #[deprecated(
281 since = "0.13.0",
282 note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in 1.0"
283 )]
284 pub max_concurrent_requests: Option<usize>,
285 #[deprecated(
288 since = "0.13.0",
289 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in 1.0"
290 )]
291 pub admin_enabled: bool,
292 #[deprecated(
294 since = "0.13.0",
295 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in 1.0"
296 )]
297 pub admin_role: String,
298 #[cfg(feature = "metrics")]
301 #[deprecated(
302 since = "0.13.0",
303 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in 1.0"
304 )]
305 pub metrics_enabled: bool,
306 #[cfg(feature = "metrics")]
308 #[deprecated(
309 since = "0.13.0",
310 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in 1.0"
311 )]
312 pub metrics_bind: String,
313 #[deprecated(
317 since = "1.5.0",
318 note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in 1.0"
319 )]
320 pub security_headers: SecurityHeadersConfig,
321}
322
323#[allow(
381 missing_debug_implementations,
382 reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
383)]
384pub struct Validated<T>(T);
385
386impl<T> std::fmt::Debug for Validated<T> {
387 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388 f.debug_struct("Validated").finish_non_exhaustive()
389 }
390}
391
392impl<T> Validated<T> {
393 #[must_use]
395 pub fn as_inner(&self) -> &T {
396 &self.0
397 }
398
399 #[must_use]
404 pub fn into_inner(self) -> T {
405 self.0
406 }
407}
408
409impl<T> std::ops::Deref for Validated<T> {
410 type Target = T;
411
412 fn deref(&self) -> &T {
413 &self.0
414 }
415}
416
417#[allow(
418 deprecated,
419 reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
420)]
421impl McpServerConfig {
422 #[must_use]
430 pub fn new(
431 bind_addr: impl Into<String>,
432 name: impl Into<String>,
433 version: impl Into<String>,
434 ) -> Self {
435 Self {
436 bind_addr: bind_addr.into(),
437 name: name.into(),
438 version: version.into(),
439 tls_cert_path: None,
440 tls_key_path: None,
441 auth: None,
442 rbac: None,
443 allowed_origins: Vec::new(),
444 tool_rate_limit: None,
445 readiness_check: None,
446 max_request_body: 1024 * 1024,
447 request_timeout: Duration::from_mins(2),
448 shutdown_timeout: Duration::from_secs(30),
449 session_idle_timeout: Duration::from_mins(20),
450 sse_keep_alive: Duration::from_secs(15),
451 on_reload_ready: None,
452 extra_router: None,
453 public_url: None,
454 log_request_headers: false,
455 compression_enabled: false,
456 compression_min_size: 1024,
457 max_concurrent_requests: None,
458 admin_enabled: false,
459 admin_role: "admin".to_owned(),
460 #[cfg(feature = "metrics")]
461 metrics_enabled: false,
462 #[cfg(feature = "metrics")]
463 metrics_bind: "127.0.0.1:9090".into(),
464 security_headers: SecurityHeadersConfig::default(),
465 }
466 }
467
468 #[must_use]
478 pub fn with_auth(mut self, auth: AuthConfig) -> Self {
479 self.auth = Some(auth);
480 self
481 }
482
483 #[must_use]
488 pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
489 self.security_headers = headers;
490 self
491 }
492
493 #[must_use]
497 pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
498 self.bind_addr = addr.into();
499 self
500 }
501
502 #[must_use]
505 pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
506 self.rbac = Some(rbac);
507 self
508 }
509
510 #[must_use]
514 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
515 self.tls_cert_path = Some(cert_path.into());
516 self.tls_key_path = Some(key_path.into());
517 self
518 }
519
520 #[must_use]
524 pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
525 self.public_url = Some(url.into());
526 self
527 }
528
529 #[must_use]
533 pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
534 where
535 I: IntoIterator<Item = S>,
536 S: Into<String>,
537 {
538 self.allowed_origins = origins.into_iter().map(Into::into).collect();
539 self
540 }
541
542 #[must_use]
546 pub fn with_extra_router(mut self, router: axum::Router) -> Self {
547 self.extra_router = Some(router);
548 self
549 }
550
551 #[must_use]
554 pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
555 self.readiness_check = Some(check);
556 self
557 }
558
559 #[must_use]
562 pub fn with_max_request_body(mut self, bytes: usize) -> Self {
563 self.max_request_body = bytes;
564 self
565 }
566
567 #[must_use]
569 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
570 self.request_timeout = timeout;
571 self
572 }
573
574 #[must_use]
576 pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
577 self.shutdown_timeout = timeout;
578 self
579 }
580
581 #[must_use]
583 pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
584 self.session_idle_timeout = timeout;
585 self
586 }
587
588 #[must_use]
590 pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
591 self.sse_keep_alive = interval;
592 self
593 }
594
595 #[must_use]
599 pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
600 self.max_concurrent_requests = Some(limit);
601 self
602 }
603
604 #[must_use]
607 pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
608 self.tool_rate_limit = Some(per_minute);
609 self
610 }
611
612 #[must_use]
616 pub fn with_reload_callback<F>(mut self, callback: F) -> Self
617 where
618 F: FnOnce(ReloadHandle) + Send + 'static,
619 {
620 self.on_reload_ready = Some(Box::new(callback));
621 self
622 }
623
624 #[must_use]
628 pub fn enable_compression(mut self, min_size: u16) -> Self {
629 self.compression_enabled = true;
630 self.compression_min_size = min_size;
631 self
632 }
633
634 #[must_use]
639 pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
640 self.admin_enabled = true;
641 self.admin_role = role.into();
642 self
643 }
644
645 #[must_use]
648 pub fn enable_request_header_logging(mut self) -> Self {
649 self.log_request_headers = true;
650 self
651 }
652
653 #[cfg(feature = "metrics")]
656 #[must_use]
657 pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
658 self.metrics_enabled = true;
659 self.metrics_bind = bind.into();
660 self
661 }
662
663 pub fn validate(self) -> Result<Validated<Self>, McpxError> {
696 self.check()?;
697 Ok(Validated(self))
698 }
699
700 fn check(&self) -> Result<(), McpxError> {
704 if self.admin_enabled {
708 let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
709 if !auth_enabled {
710 return Err(McpxError::Config(
711 "admin_enabled=true requires auth to be configured and enabled".into(),
712 ));
713 }
714 }
715
716 match (&self.tls_cert_path, &self.tls_key_path) {
718 (Some(_), None) => {
719 return Err(McpxError::Config(
720 "tls_cert_path is set but tls_key_path is missing".into(),
721 ));
722 }
723 (None, Some(_)) => {
724 return Err(McpxError::Config(
725 "tls_key_path is set but tls_cert_path is missing".into(),
726 ));
727 }
728 _ => {}
729 }
730
731 if self.bind_addr.parse::<SocketAddr>().is_err() {
733 return Err(McpxError::Config(format!(
734 "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
735 self.bind_addr
736 )));
737 }
738
739 if let Some(ref url) = self.public_url
741 && !(url.starts_with("http://") || url.starts_with("https://"))
742 {
743 return Err(McpxError::Config(format!(
744 "public_url {url:?} must start with http:// or https://"
745 )));
746 }
747
748 for origin in &self.allowed_origins {
750 if !(origin.starts_with("http://") || origin.starts_with("https://")) {
751 return Err(McpxError::Config(format!(
752 "allowed_origins entry {origin:?} must start with http:// or https://"
753 )));
754 }
755 }
756
757 if self.max_request_body == 0 {
759 return Err(McpxError::Config(
760 "max_request_body must be greater than zero".into(),
761 ));
762 }
763
764 #[cfg(feature = "oauth")]
766 if let Some(auth_cfg) = &self.auth
767 && let Some(oauth_cfg) = &auth_cfg.oauth
768 {
769 oauth_cfg.validate()?;
770 }
771
772 validate_security_headers(&self.security_headers)?;
775
776 Ok(())
777 }
778}
779
780#[allow(
786 missing_debug_implementations,
787 reason = "contains Arc<AuthState> with non-Debug fields"
788)]
789pub struct ReloadHandle {
790 auth: Option<Arc<AuthState>>,
791 rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
792 crl_set: Option<Arc<CrlSet>>,
793}
794
795impl ReloadHandle {
796 pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
798 if let Some(ref auth) = self.auth {
799 auth.reload_keys(keys);
800 }
801 }
802
803 pub fn reload_rbac(&self, policy: RbacPolicy) {
805 if let Some(ref rbac) = self.rbac {
806 rbac.store(Arc::new(policy));
807 tracing::info!("RBAC policy reloaded");
808 }
809 }
810
811 pub async fn refresh_crls(&self) -> Result<(), McpxError> {
817 let Some(ref crl_set) = self.crl_set else {
818 return Err(McpxError::Config(
819 "CRL refresh requested but mTLS CRL support is not configured".into(),
820 ));
821 };
822
823 crl_set.force_refresh().await
824 }
825}
826
827#[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
844struct AppRunParams {
848 tls_paths: Option<(PathBuf, PathBuf)>,
850 mtls_config: Option<MtlsConfig>,
852 shutdown_timeout: Duration,
854 auth_state: Option<Arc<AuthState>>,
856 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
858 on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
860 ct: CancellationToken,
864 scheme: &'static str,
866 name: String,
868}
869
870#[allow(
880 clippy::cognitive_complexity,
881 reason = "router assembly is intrinsically sequential; splitting harms readability"
882)]
883#[allow(
884 deprecated,
885 reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
886)]
887fn build_app_router<H, F>(
888 mut config: McpServerConfig,
889 handler_factory: F,
890) -> anyhow::Result<(axum::Router, AppRunParams)>
891where
892 H: ServerHandler + 'static,
893 F: Fn() -> H + Send + Sync + Clone + 'static,
894{
895 let ct = CancellationToken::new();
896
897 let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
898 tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
899
900 let mcp_service = StreamableHttpService::new(
901 move || Ok(handler_factory()),
902 {
903 let mut mgr = LocalSessionManager::default();
904 mgr.session_config.keep_alive = Some(config.session_idle_timeout);
905 mgr.into()
906 },
907 StreamableHttpServerConfig::default()
908 .with_allowed_hosts(allowed_hosts)
909 .with_sse_keep_alive(Some(config.sse_keep_alive))
910 .with_cancellation_token(ct.child_token()),
911 );
912
913 let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
915
916 let auth_state: Option<Arc<AuthState>> = match config.auth {
920 Some(ref auth_config) if auth_config.enabled => {
921 let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
922 let pre_auth_limiter = auth_config
923 .rate_limit
924 .as_ref()
925 .map(crate::auth::build_pre_auth_limiter);
926
927 #[cfg(feature = "oauth")]
928 let jwks_cache = auth_config
929 .oauth
930 .as_ref()
931 .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
932 .transpose()
933 .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
934
935 Some(Arc::new(AuthState {
936 api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
937 rate_limiter,
938 pre_auth_limiter,
939 #[cfg(feature = "oauth")]
940 jwks_cache,
941 seen_identities: std::sync::Mutex::new(std::collections::HashSet::new()),
942 counters: crate::auth::AuthCounters::default(),
943 }))
944 }
945 _ => None,
946 };
947
948 let rbac_swap = Arc::new(ArcSwap::new(
951 config
952 .rbac
953 .clone()
954 .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
955 ));
956
957 if config.admin_enabled {
960 let Some(ref auth_state_ref) = auth_state else {
961 return Err(anyhow::anyhow!(
962 "admin_enabled=true requires auth to be configured and enabled"
963 ));
964 };
965 let admin_state = crate::admin::AdminState {
966 started_at: std::time::Instant::now(),
967 name: config.name.clone(),
968 version: config.version.clone(),
969 auth: Some(Arc::clone(auth_state_ref)),
970 rbac: Arc::clone(&rbac_swap),
971 };
972 let admin_cfg = crate::admin::AdminConfig {
973 role: config.admin_role.clone(),
974 };
975 mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
976 tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
977 }
978
979 {
1012 let tool_limiter: Option<Arc<ToolRateLimiter>> =
1013 config.tool_rate_limit.map(build_tool_rate_limiter);
1014
1015 if rbac_swap.load().is_enabled() {
1016 tracing::info!("RBAC enforcement enabled on /mcp");
1017 }
1018 if let Some(limit) = config.tool_rate_limit {
1019 tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1020 }
1021
1022 let rbac_for_mw = Arc::clone(&rbac_swap);
1023 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1024 let p = rbac_for_mw.load_full();
1025 let tl = tool_limiter.clone();
1026 rbac_middleware(p, tl, req, next)
1027 }));
1028 }
1029
1030 if let Some(ref auth_config) = config.auth
1032 && auth_config.enabled
1033 {
1034 let Some(ref state) = auth_state else {
1035 return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1036 };
1037
1038 let methods: Vec<&str> = [
1039 auth_config.mtls.is_some().then_some("mTLS"),
1040 (!auth_config.api_keys.is_empty()).then_some("bearer"),
1041 #[cfg(feature = "oauth")]
1042 auth_config.oauth.is_some().then_some("oauth-jwt"),
1043 ]
1044 .into_iter()
1045 .flatten()
1046 .collect();
1047
1048 tracing::info!(
1049 methods = %methods.join(", "),
1050 api_keys = auth_config.api_keys.len(),
1051 "auth enabled on /mcp"
1052 );
1053
1054 let state_for_mw = Arc::clone(state);
1055 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1056 let s = Arc::clone(&state_for_mw);
1057 auth_middleware(s, req, next)
1058 }));
1059 }
1060
1061 mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1064 axum::http::StatusCode::REQUEST_TIMEOUT,
1065 config.request_timeout,
1066 ));
1067
1068 mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1072 config.max_request_body,
1073 ));
1074
1075 let mut effective_origins = config.allowed_origins.clone();
1082 if effective_origins.is_empty()
1083 && let Some(ref url) = config.public_url
1084 {
1085 if let Some(scheme_end) = url.find("://") {
1088 let after_scheme = &url[scheme_end + 3..];
1089 let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1090 let origin = format!("{}{}", &url[..scheme_end + 3], &after_scheme[..host_end]);
1091 tracing::info!(
1092 %origin,
1093 "auto-derived allowed origin from public_url"
1094 );
1095 effective_origins.push(origin);
1096 }
1097 }
1098 let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1099 let cors_origins = Arc::clone(&allowed_origins);
1100 let log_request_headers = config.log_request_headers;
1101
1102 let readyz_route = if let Some(check) = config.readiness_check.take() {
1103 axum::routing::get(move || readyz(Arc::clone(&check)))
1104 } else {
1105 axum::routing::get(healthz)
1106 };
1107
1108 #[allow(unused_mut)] let mut router = axum::Router::new()
1110 .route("/healthz", axum::routing::get(healthz))
1111 .route("/readyz", readyz_route)
1112 .route(
1113 "/version",
1114 axum::routing::get({
1115 let payload_bytes: Arc<[u8]> =
1120 serialize_version_payload(&config.name, &config.version);
1121 move || {
1122 let p = Arc::clone(&payload_bytes);
1123 async move {
1124 (
1125 [(axum::http::header::CONTENT_TYPE, "application/json")],
1126 p.to_vec(),
1127 )
1128 }
1129 }
1130 }),
1131 )
1132 .merge(mcp_router);
1133
1134 if let Some(extra) = config.extra_router.take() {
1136 router = router.merge(extra);
1137 }
1138
1139 let server_url = if let Some(ref url) = config.public_url {
1146 url.trim_end_matches('/').to_owned()
1147 } else {
1148 let prm_scheme = if config.tls_cert_path.is_some() {
1149 "https"
1150 } else {
1151 "http"
1152 };
1153 format!("{prm_scheme}://{}", config.bind_addr)
1154 };
1155 let resource_url = format!("{server_url}/mcp");
1156
1157 #[cfg(feature = "oauth")]
1158 let prm_metadata = if let Some(ref auth_config) = config.auth
1159 && let Some(ref oauth_config) = auth_config.oauth
1160 {
1161 crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1162 } else {
1163 serde_json::json!({ "resource": resource_url })
1164 };
1165 #[cfg(not(feature = "oauth"))]
1166 let prm_metadata = serde_json::json!({ "resource": resource_url });
1167
1168 router = router.route(
1169 "/.well-known/oauth-protected-resource",
1170 axum::routing::get(move || {
1171 let m = prm_metadata.clone();
1172 async move { axum::Json(m) }
1173 }),
1174 );
1175
1176 #[cfg(feature = "oauth")]
1181 if let Some(ref auth_config) = config.auth
1182 && let Some(ref oauth_config) = auth_config.oauth
1183 && oauth_config.proxy.is_some()
1184 {
1185 router =
1186 install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1187 }
1188
1189 let is_tls = config.tls_cert_path.is_some();
1192 let security_headers_cfg = Arc::new(config.security_headers.clone());
1193 router = router.layer(axum::middleware::from_fn(move |req, next| {
1194 let cfg = Arc::clone(&security_headers_cfg);
1195 security_headers_middleware(is_tls, cfg, req, next)
1196 }));
1197
1198 if !cors_origins.is_empty() {
1202 let cors = tower_http::cors::CorsLayer::new()
1203 .allow_origin(
1204 cors_origins
1205 .iter()
1206 .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1207 .collect::<Vec<_>>(),
1208 )
1209 .allow_methods([
1210 axum::http::Method::GET,
1211 axum::http::Method::POST,
1212 axum::http::Method::OPTIONS,
1213 ])
1214 .allow_headers([
1215 axum::http::header::CONTENT_TYPE,
1216 axum::http::header::AUTHORIZATION,
1217 ]);
1218 router = router.layer(cors);
1219 }
1220
1221 if config.compression_enabled {
1225 use tower_http::compression::Predicate as _;
1226 let predicate = tower_http::compression::DefaultPredicate::new().and(
1227 tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1228 );
1229 router = router.layer(
1230 tower_http::compression::CompressionLayer::new()
1231 .gzip(true)
1232 .br(true)
1233 .compress_when(predicate),
1234 );
1235 tracing::info!(
1236 min_size = config.compression_min_size,
1237 "response compression enabled (gzip, br)"
1238 );
1239 }
1240
1241 if let Some(max) = config.max_concurrent_requests {
1244 let overload_handler = tower::ServiceBuilder::new()
1245 .layer(axum::error_handling::HandleErrorLayer::new(
1246 |_err: tower::BoxError| async {
1247 (
1248 axum::http::StatusCode::SERVICE_UNAVAILABLE,
1249 axum::Json(serde_json::json!({
1250 "error": "overloaded",
1251 "error_description": "server is at capacity, retry later"
1252 })),
1253 )
1254 },
1255 ))
1256 .layer(tower::load_shed::LoadShedLayer::new())
1257 .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1258 router = router.layer(overload_handler);
1259 tracing::info!(max, "global concurrency limit enabled");
1260 }
1261
1262 router = router.fallback(|| async {
1266 (
1267 axum::http::StatusCode::NOT_FOUND,
1268 axum::Json(serde_json::json!({
1269 "error": "not_found",
1270 "error_description": "The requested endpoint does not exist"
1271 })),
1272 )
1273 });
1274
1275 #[cfg(feature = "metrics")]
1277 if config.metrics_enabled {
1278 let metrics = Arc::new(
1279 crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1280 );
1281 let m = Arc::clone(&metrics);
1282 router = router.layer(axum::middleware::from_fn(
1283 move |req: Request<Body>, next: Next| {
1284 let m = Arc::clone(&m);
1285 metrics_middleware(m, req, next)
1286 },
1287 ));
1288 let metrics_bind = config.metrics_bind.clone();
1289 tokio::spawn(async move {
1290 if let Err(e) = crate::metrics::serve_metrics(metrics_bind, metrics).await {
1291 tracing::error!("metrics listener failed: {e}");
1292 }
1293 });
1294 }
1295
1296 router = router.layer(axum::middleware::from_fn(move |req, next| {
1307 let origins = Arc::clone(&allowed_origins);
1308 origin_check_middleware(origins, log_request_headers, req, next)
1309 }));
1310
1311 let scheme = if config.tls_cert_path.is_some() {
1312 "https"
1313 } else {
1314 "http"
1315 };
1316
1317 let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1318 (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1319 _ => None,
1320 };
1321 let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1322
1323 Ok((
1324 router,
1325 AppRunParams {
1326 tls_paths,
1327 mtls_config,
1328 shutdown_timeout: config.shutdown_timeout,
1329 auth_state,
1330 rbac_swap,
1331 on_reload_ready: config.on_reload_ready.take(),
1332 ct,
1333 scheme,
1334 name: config.name.clone(),
1335 },
1336 ))
1337}
1338
1339pub async fn serve<H, F>(
1356 config: Validated<McpServerConfig>,
1357 handler_factory: F,
1358) -> Result<(), McpxError>
1359where
1360 H: ServerHandler + 'static,
1361 F: Fn() -> H + Send + Sync + Clone + 'static,
1362{
1363 let config = config.into_inner();
1364 #[allow(
1365 deprecated,
1366 reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1367 )]
1368 let bind_addr = config.bind_addr.clone();
1369 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1370
1371 let listener = TcpListener::bind(&bind_addr)
1372 .await
1373 .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1374 log_listening(¶ms.name, params.scheme, &bind_addr);
1375
1376 run_server(
1377 router,
1378 listener,
1379 params.tls_paths,
1380 params.mtls_config,
1381 params.shutdown_timeout,
1382 params.auth_state,
1383 params.rbac_swap,
1384 params.on_reload_ready,
1385 params.ct,
1386 )
1387 .await
1388 .map_err(anyhow_to_startup)
1389}
1390
1391pub async fn serve_with_listener<H, F>(
1421 listener: TcpListener,
1422 config: Validated<McpServerConfig>,
1423 handler_factory: F,
1424 ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1425 shutdown: Option<CancellationToken>,
1426) -> Result<(), McpxError>
1427where
1428 H: ServerHandler + 'static,
1429 F: Fn() -> H + Send + Sync + Clone + 'static,
1430{
1431 let config = config.into_inner();
1432 let local_addr = listener
1433 .local_addr()
1434 .map_err(|e| io_to_startup("listener.local_addr", e))?;
1435 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1436
1437 log_listening(¶ms.name, params.scheme, &local_addr.to_string());
1438
1439 if let Some(external) = shutdown {
1443 let internal = params.ct.clone();
1444 tokio::spawn(async move {
1445 external.cancelled().await;
1446 internal.cancel();
1447 });
1448 }
1449
1450 if let Some(tx) = ready_tx {
1454 let _ = tx.send(local_addr);
1456 }
1457
1458 run_server(
1459 router,
1460 listener,
1461 params.tls_paths,
1462 params.mtls_config,
1463 params.shutdown_timeout,
1464 params.auth_state,
1465 params.rbac_swap,
1466 params.on_reload_ready,
1467 params.ct,
1468 )
1469 .await
1470 .map_err(anyhow_to_startup)
1471}
1472
1473#[allow(
1476 clippy::cognitive_complexity,
1477 reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1478)]
1479fn log_listening(name: &str, scheme: &str, addr: &str) {
1480 tracing::info!("{name} listening on {addr}");
1481 tracing::info!(" MCP endpoint: {scheme}://{addr}/mcp");
1482 tracing::info!(" Health check: {scheme}://{addr}/healthz");
1483 tracing::info!(" Readiness: {scheme}://{addr}/readyz");
1484}
1485
1486#[allow(
1509 clippy::too_many_arguments,
1510 clippy::cognitive_complexity,
1511 reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1512)]
1513async fn run_server(
1514 router: axum::Router,
1515 listener: TcpListener,
1516 tls_paths: Option<(PathBuf, PathBuf)>,
1517 mtls_config: Option<MtlsConfig>,
1518 shutdown_timeout: Duration,
1519 auth_state: Option<Arc<AuthState>>,
1520 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1521 mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1522 ct: CancellationToken,
1523) -> anyhow::Result<()> {
1524 let shutdown_trigger = CancellationToken::new();
1528 {
1529 let trigger = shutdown_trigger.clone();
1530 let parent = ct.clone();
1531 tokio::spawn(async move {
1532 tokio::select! {
1533 () = shutdown_signal() => {}
1534 () = parent.cancelled() => {}
1535 }
1536 trigger.cancel();
1537 });
1538 }
1539
1540 let graceful = {
1541 let trigger = shutdown_trigger.clone();
1542 let ct = ct.clone();
1543 async move {
1544 trigger.cancelled().await;
1545 tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1546 ct.cancel();
1547 }
1548 };
1549
1550 let force_exit_timer = {
1551 let trigger = shutdown_trigger.clone();
1552 async move {
1553 trigger.cancelled().await;
1554 tokio::time::sleep(shutdown_timeout).await;
1555 }
1556 };
1557
1558 if let Some((cert_path, key_path)) = tls_paths {
1559 let crl_set = if let Some(mtls) = mtls_config.as_ref()
1560 && mtls.crl_enabled
1561 {
1562 let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1563 let (crl_set, discover_rx) =
1564 mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1565 .await
1566 .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1567 tokio::spawn(mtls_revocation::run_crl_refresher(
1568 Arc::clone(&crl_set),
1569 discover_rx,
1570 ct.clone(),
1571 ));
1572 Some(crl_set)
1573 } else {
1574 None
1575 };
1576
1577 if let Some(cb) = on_reload_ready.take() {
1578 cb(ReloadHandle {
1579 auth: auth_state.clone(),
1580 rbac: Some(Arc::clone(&rbac_swap)),
1581 crl_set: crl_set.clone(),
1582 });
1583 }
1584
1585 let tls_listener = TlsListener::new(
1586 listener,
1587 &cert_path,
1588 &key_path,
1589 mtls_config.as_ref(),
1590 crl_set,
1591 )?;
1592 let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1593 tokio::select! {
1594 result = axum::serve(tls_listener, make_svc)
1595 .with_graceful_shutdown(graceful) => { result?; }
1596 () = force_exit_timer => {
1597 tracing::warn!("shutdown timeout exceeded, forcing exit");
1598 }
1599 }
1600 } else {
1601 if let Some(cb) = on_reload_ready.take() {
1602 cb(ReloadHandle {
1603 auth: auth_state,
1604 rbac: Some(rbac_swap),
1605 crl_set: None,
1606 });
1607 }
1608
1609 let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1610 tokio::select! {
1611 result = axum::serve(listener, make_svc)
1612 .with_graceful_shutdown(graceful) => { result?; }
1613 () = force_exit_timer => {
1614 tracing::warn!("shutdown timeout exceeded, forcing exit");
1615 }
1616 }
1617 }
1618
1619 Ok(())
1620}
1621
1622#[cfg(feature = "oauth")]
1631fn install_oauth_proxy_routes(
1632 router: axum::Router,
1633 server_url: &str,
1634 oauth_config: &crate::oauth::OAuthConfig,
1635 auth_state: Option<&Arc<AuthState>>,
1636) -> Result<axum::Router, McpxError> {
1637 let Some(ref proxy) = oauth_config.proxy else {
1638 return Ok(router);
1639 };
1640
1641 let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1644
1645 let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1646 let router = router.route(
1647 "/.well-known/oauth-authorization-server",
1648 axum::routing::get(move || {
1649 let m = asm.clone();
1650 async move { axum::Json(m) }
1651 }),
1652 );
1653
1654 let proxy_authorize = proxy.clone();
1655 let router = router.route(
1656 "/authorize",
1657 axum::routing::get(
1658 move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1659 let p = proxy_authorize.clone();
1660 async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1661 },
1662 ),
1663 );
1664
1665 let proxy_token = proxy.clone();
1666 let token_http = http.clone();
1667 let router = router.route(
1668 "/token",
1669 axum::routing::post(move |body: String| {
1670 let p = proxy_token.clone();
1671 let h = token_http.clone();
1672 async move { crate::oauth::handle_token(&h, &p, &body).await }
1673 })
1674 .layer(axum::middleware::from_fn(
1675 oauth_token_cache_headers_middleware,
1676 )),
1677 );
1678
1679 let proxy_register = proxy.clone();
1680 let router = router.route(
1681 "/register",
1682 axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1683 let p = proxy_register;
1684 async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1685 })
1686 .layer(axum::middleware::from_fn(
1687 oauth_token_cache_headers_middleware,
1688 )),
1689 );
1690
1691 let admin_routes_enabled = proxy.expose_admin_endpoints
1692 && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1693 if proxy.expose_admin_endpoints && !proxy.require_auth_on_admin_endpoints {
1694 tracing::warn!(
1695 "OAuth introspect/revoke endpoints are unauthenticated; consider setting require_auth_on_admin_endpoints = true"
1696 );
1697 }
1698
1699 let admin_router = if admin_routes_enabled {
1700 build_oauth_admin_router(proxy, http, auth_state)?
1701 } else {
1702 axum::Router::new()
1703 };
1704
1705 let router = router.merge(admin_router);
1706
1707 tracing::info!(
1708 introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1709 revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1710 "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1711 );
1712 Ok(router)
1713}
1714
1715#[cfg(feature = "oauth")]
1721fn build_oauth_admin_router(
1722 proxy: &crate::oauth::OAuthProxyConfig,
1723 http: crate::oauth::OauthHttpClient,
1724 auth_state: Option<&Arc<AuthState>>,
1725) -> Result<axum::Router, McpxError> {
1726 let mut admin_router = axum::Router::new();
1727 if proxy.introspection_url.is_some() {
1728 let proxy_introspect = proxy.clone();
1729 let introspect_http = http.clone();
1730 admin_router = admin_router.route(
1731 "/introspect",
1732 axum::routing::post(move |body: String| {
1733 let p = proxy_introspect.clone();
1734 let h = introspect_http.clone();
1735 async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1736 }),
1737 );
1738 }
1739 if proxy.revocation_url.is_some() {
1740 let proxy_revoke = proxy.clone();
1741 let revoke_http = http;
1742 admin_router = admin_router.route(
1743 "/revoke",
1744 axum::routing::post(move |body: String| {
1745 let p = proxy_revoke.clone();
1746 let h = revoke_http.clone();
1747 async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1748 }),
1749 );
1750 }
1751
1752 let admin_router = admin_router.layer(axum::middleware::from_fn(
1753 oauth_token_cache_headers_middleware,
1754 ));
1755
1756 if proxy.require_auth_on_admin_endpoints {
1757 let Some(state) = auth_state else {
1758 return Err(McpxError::Startup(
1759 "oauth proxy admin endpoints require auth state".into(),
1760 ));
1761 };
1762 let state_for_mw = Arc::clone(state);
1763 Ok(
1764 admin_router.layer(axum::middleware::from_fn(move |req, next| {
1765 let s = Arc::clone(&state_for_mw);
1766 auth_middleware(s, req, next)
1767 })),
1768 )
1769 } else {
1770 Ok(admin_router)
1771 }
1772}
1773
1774fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
1779 let mut hosts = vec![
1780 "localhost".to_owned(),
1781 "127.0.0.1".to_owned(),
1782 "::1".to_owned(),
1783 ];
1784
1785 if let Some(url) = public_url
1786 && let Ok(uri) = url.parse::<axum::http::Uri>()
1787 && let Some(authority) = uri.authority()
1788 {
1789 let host = authority.host().to_owned();
1790 if !hosts.iter().any(|h| h == &host) {
1791 hosts.push(host);
1792 }
1793
1794 let authority = authority.as_str().to_owned();
1795 if !hosts.iter().any(|h| h == &authority) {
1796 hosts.push(authority);
1797 }
1798 }
1799
1800 if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
1801 && let Some(authority) = uri.authority()
1802 {
1803 let host = authority.host().to_owned();
1804 if !hosts.iter().any(|h| h == &host) {
1805 hosts.push(host);
1806 }
1807
1808 let authority = authority.as_str().to_owned();
1809 if !hosts.iter().any(|h| h == &authority) {
1810 hosts.push(authority);
1811 }
1812 }
1813
1814 hosts
1815}
1816
1817impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
1830 for TlsConnInfo
1831{
1832 fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
1833 let addr = *target.remote_addr();
1834 let identity = target.io().identity().cloned();
1835 TlsConnInfo::new(addr, identity)
1836 }
1837}
1838
1839struct TlsListener {
1847 inner: TcpListener,
1848 acceptor: tokio_rustls::TlsAcceptor,
1849 mtls_default_role: String,
1850}
1851
1852impl TlsListener {
1853 fn new(
1854 inner: TcpListener,
1855 cert_path: &Path,
1856 key_path: &Path,
1857 mtls_config: Option<&MtlsConfig>,
1858 crl_set: Option<Arc<CrlSet>>,
1859 ) -> anyhow::Result<Self> {
1860 rustls::crypto::ring::default_provider()
1862 .install_default()
1863 .ok();
1864
1865 let certs = load_certs(cert_path)?;
1866 let key = load_key(key_path)?;
1867
1868 let mtls_default_role;
1869
1870 let tls_config = if let Some(mtls) = mtls_config {
1871 mtls_default_role = mtls.default_role.clone();
1872 let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
1873 {
1874 let Some(crl_set) = crl_set else {
1875 return Err(anyhow::anyhow!(
1876 "mTLS CRL verifier requested but CRL state was not initialized"
1877 ));
1878 };
1879 Arc::new(DynamicClientCertVerifier::new(crl_set))
1880 } else {
1881 let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
1882 if mtls.required {
1883 rustls::server::WebPkiClientVerifier::builder(root_store)
1884 .build()
1885 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1886 } else {
1887 rustls::server::WebPkiClientVerifier::builder(root_store)
1888 .allow_unauthenticated()
1889 .build()
1890 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1891 }
1892 };
1893
1894 tracing::info!(
1895 ca = %mtls.ca_cert_path.display(),
1896 required = mtls.required,
1897 crl_enabled = mtls.crl_enabled,
1898 "mTLS client auth configured"
1899 );
1900
1901 rustls::ServerConfig::builder_with_protocol_versions(&[
1902 &rustls::version::TLS12,
1903 &rustls::version::TLS13,
1904 ])
1905 .with_client_cert_verifier(verifier)
1906 .with_single_cert(certs, key)?
1907 } else {
1908 mtls_default_role = "viewer".to_owned();
1909 rustls::ServerConfig::builder_with_protocol_versions(&[
1910 &rustls::version::TLS12,
1911 &rustls::version::TLS13,
1912 ])
1913 .with_no_client_auth()
1914 .with_single_cert(certs, key)?
1915 };
1916
1917 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
1918 tracing::info!(
1919 "TLS enabled (cert: {}, key: {})",
1920 cert_path.display(),
1921 key_path.display()
1922 );
1923 Ok(Self {
1924 inner,
1925 acceptor,
1926 mtls_default_role,
1927 })
1928 }
1929
1930 fn extract_handshake_identity(
1934 tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1935 default_role: &str,
1936 addr: SocketAddr,
1937 ) -> Option<AuthIdentity> {
1938 let (_, server_conn) = tls_stream.get_ref();
1939 let cert_der = server_conn.peer_certificates()?.first()?;
1940 let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
1941 tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
1942 Some(id)
1943 }
1944}
1945
1946pub(crate) struct AuthenticatedTlsStream {
1958 inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1959 identity: Option<AuthIdentity>,
1960}
1961
1962impl AuthenticatedTlsStream {
1963 #[must_use]
1965 pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
1966 self.identity.as_ref()
1967 }
1968}
1969
1970impl std::fmt::Debug for AuthenticatedTlsStream {
1971 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1972 f.debug_struct("AuthenticatedTlsStream")
1973 .field("identity", &self.identity.as_ref().map(|id| &id.name))
1974 .finish_non_exhaustive()
1975 }
1976}
1977
1978impl tokio::io::AsyncRead for AuthenticatedTlsStream {
1979 fn poll_read(
1980 mut self: Pin<&mut Self>,
1981 cx: &mut std::task::Context<'_>,
1982 buf: &mut tokio::io::ReadBuf<'_>,
1983 ) -> std::task::Poll<std::io::Result<()>> {
1984 Pin::new(&mut self.inner).poll_read(cx, buf)
1985 }
1986}
1987
1988impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
1989 fn poll_write(
1990 mut self: Pin<&mut Self>,
1991 cx: &mut std::task::Context<'_>,
1992 buf: &[u8],
1993 ) -> std::task::Poll<std::io::Result<usize>> {
1994 Pin::new(&mut self.inner).poll_write(cx, buf)
1995 }
1996
1997 fn poll_flush(
1998 mut self: Pin<&mut Self>,
1999 cx: &mut std::task::Context<'_>,
2000 ) -> std::task::Poll<std::io::Result<()>> {
2001 Pin::new(&mut self.inner).poll_flush(cx)
2002 }
2003
2004 fn poll_shutdown(
2005 mut self: Pin<&mut Self>,
2006 cx: &mut std::task::Context<'_>,
2007 ) -> std::task::Poll<std::io::Result<()>> {
2008 Pin::new(&mut self.inner).poll_shutdown(cx)
2009 }
2010
2011 fn poll_write_vectored(
2012 mut self: Pin<&mut Self>,
2013 cx: &mut std::task::Context<'_>,
2014 bufs: &[std::io::IoSlice<'_>],
2015 ) -> std::task::Poll<std::io::Result<usize>> {
2016 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2017 }
2018
2019 fn is_write_vectored(&self) -> bool {
2020 self.inner.is_write_vectored()
2021 }
2022}
2023
2024impl axum::serve::Listener for TlsListener {
2025 type Io = AuthenticatedTlsStream;
2026 type Addr = SocketAddr;
2027
2028 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2029 loop {
2030 let (stream, addr) = match self.inner.accept().await {
2031 Ok(pair) => pair,
2032 Err(e) => {
2033 tracing::debug!("TCP accept error: {e}");
2034 continue;
2035 }
2036 };
2037 let tls_stream = match self.acceptor.accept(stream).await {
2038 Ok(s) => s,
2039 Err(e) => {
2040 tracing::debug!("TLS handshake failed from {addr}: {e}");
2041 continue;
2042 }
2043 };
2044 let identity =
2045 Self::extract_handshake_identity(&tls_stream, &self.mtls_default_role, addr);
2046 let wrapped = AuthenticatedTlsStream {
2047 inner: tls_stream,
2048 identity,
2049 };
2050 return (wrapped, addr);
2051 }
2052 }
2053
2054 fn local_addr(&self) -> std::io::Result<Self::Addr> {
2055 self.inner.local_addr()
2056 }
2057}
2058
2059fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2060 use rustls::pki_types::pem::PemObject;
2061 let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2062 .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2063 .collect::<Result<_, _>>()
2064 .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2065 anyhow::ensure!(
2066 !certs.is_empty(),
2067 "no certificates found in {}",
2068 path.display()
2069 );
2070 Ok(certs)
2071}
2072
2073fn load_client_auth_roots(
2074 path: &Path,
2075) -> anyhow::Result<(
2076 Vec<rustls::pki_types::CertificateDer<'static>>,
2077 Arc<RootCertStore>,
2078)> {
2079 let ca_certs = load_certs(path)?;
2080 let mut root_store = RootCertStore::empty();
2081 for cert in &ca_certs {
2082 root_store
2083 .add(cert.clone())
2084 .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2085 }
2086
2087 Ok((ca_certs, Arc::new(root_store)))
2088}
2089
2090fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2091 use rustls::pki_types::pem::PemObject;
2092 rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2093 .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2094}
2095
2096#[allow(clippy::unused_async)]
2097async fn healthz() -> impl IntoResponse {
2098 axum::Json(serde_json::json!({
2099 "status": "ok",
2100 }))
2101}
2102
2103fn version_payload(name: &str, version: &str) -> serde_json::Value {
2109 serde_json::json!({
2110 "name": name,
2111 "version": version,
2112 "build_git_sha": option_env!("MCPX_BUILD_SHA").unwrap_or("unknown"),
2113 "build_timestamp": option_env!("MCPX_BUILD_TIME").unwrap_or("unknown"),
2114 "rust_version": option_env!("MCPX_RUSTC_VERSION").unwrap_or("unknown"),
2115 "mcpx_version": env!("CARGO_PKG_VERSION"),
2116 })
2117}
2118
2119fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2129 let value = version_payload(name, version);
2130 serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2131}
2132
2133async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2134 let status = check().await;
2135 let ready = status
2136 .get("ready")
2137 .and_then(serde_json::Value::as_bool)
2138 .unwrap_or(false);
2139 let code = if ready {
2140 axum::http::StatusCode::OK
2141 } else {
2142 axum::http::StatusCode::SERVICE_UNAVAILABLE
2143 };
2144 (code, axum::Json(status))
2145}
2146
2147async fn shutdown_signal() {
2151 let ctrl_c = tokio::signal::ctrl_c();
2152
2153 #[cfg(unix)]
2154 {
2155 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2156 Ok(mut term) => {
2157 tokio::select! {
2158 _ = ctrl_c => {}
2159 _ = term.recv() => {}
2160 }
2161 }
2162 Err(e) => {
2163 tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2164 ctrl_c.await.ok();
2165 }
2166 }
2167 }
2168
2169 #[cfg(not(unix))]
2170 {
2171 ctrl_c.await.ok();
2172 }
2173}
2174
2175#[cfg(feature = "metrics")]
2181async fn metrics_middleware(
2182 metrics: Arc<crate::metrics::McpMetrics>,
2183 req: Request<Body>,
2184 next: Next,
2185) -> axum::response::Response {
2186 let method = req.method().to_string();
2187 let path = req.uri().path().to_owned();
2188 let start = std::time::Instant::now();
2189
2190 let response = next.run(req).await;
2191
2192 let status = response.status().as_u16().to_string();
2193 let duration = start.elapsed().as_secs_f64();
2194
2195 metrics
2196 .http_requests_total
2197 .with_label_values(&[&method, &path, &status])
2198 .inc();
2199 metrics
2200 .http_request_duration_seconds
2201 .with_label_values(&[&method, &path])
2202 .observe(duration);
2203
2204 response
2205}
2206
2207async fn security_headers_middleware(
2219 is_tls: bool,
2220 cfg: Arc<SecurityHeadersConfig>,
2221 req: Request<Body>,
2222 next: Next,
2223) -> axum::response::Response {
2224 use axum::http::{HeaderName, header};
2225
2226 let mut resp = next.run(req).await;
2227 let headers = resp.headers_mut();
2228
2229 headers.remove(header::SERVER);
2231 headers.remove(HeaderName::from_static("x-powered-by"));
2232
2233 apply_security_header(
2234 headers,
2235 header::X_CONTENT_TYPE_OPTIONS,
2236 cfg.x_content_type_options.as_deref(),
2237 "nosniff",
2238 );
2239 apply_security_header(
2240 headers,
2241 header::X_FRAME_OPTIONS,
2242 cfg.x_frame_options.as_deref(),
2243 "deny",
2244 );
2245 apply_security_header(
2246 headers,
2247 header::CACHE_CONTROL,
2248 cfg.cache_control.as_deref(),
2249 "no-store, max-age=0",
2250 );
2251 apply_security_header(
2252 headers,
2253 header::REFERRER_POLICY,
2254 cfg.referrer_policy.as_deref(),
2255 "no-referrer",
2256 );
2257 apply_security_header(
2258 headers,
2259 HeaderName::from_static("cross-origin-opener-policy"),
2260 cfg.cross_origin_opener_policy.as_deref(),
2261 "same-origin",
2262 );
2263 apply_security_header(
2264 headers,
2265 HeaderName::from_static("cross-origin-resource-policy"),
2266 cfg.cross_origin_resource_policy.as_deref(),
2267 "same-origin",
2268 );
2269 apply_security_header(
2270 headers,
2271 HeaderName::from_static("cross-origin-embedder-policy"),
2272 cfg.cross_origin_embedder_policy.as_deref(),
2273 "require-corp",
2274 );
2275 apply_security_header(
2276 headers,
2277 HeaderName::from_static("permissions-policy"),
2278 cfg.permissions_policy.as_deref(),
2279 "accelerometer=(), camera=(), geolocation=(), microphone=()",
2280 );
2281 apply_security_header(
2282 headers,
2283 HeaderName::from_static("x-permitted-cross-domain-policies"),
2284 cfg.x_permitted_cross_domain_policies.as_deref(),
2285 "none",
2286 );
2287 apply_security_header(
2288 headers,
2289 HeaderName::from_static("content-security-policy"),
2290 cfg.content_security_policy.as_deref(),
2291 "default-src 'none'; frame-ancestors 'none'",
2292 );
2293 apply_security_header(
2294 headers,
2295 HeaderName::from_static("x-dns-prefetch-control"),
2296 cfg.x_dns_prefetch_control.as_deref(),
2297 "off",
2298 );
2299
2300 if is_tls {
2301 apply_security_header(
2302 headers,
2303 header::STRICT_TRANSPORT_SECURITY,
2304 cfg.strict_transport_security.as_deref(),
2305 "max-age=63072000; includeSubDomains",
2306 );
2307 }
2308
2309 resp
2310}
2311
2312fn apply_security_header(
2323 headers: &mut axum::http::HeaderMap,
2324 name: axum::http::HeaderName,
2325 override_value: Option<&str>,
2326 default: &'static str,
2327) {
2328 use axum::http::HeaderValue;
2329
2330 match override_value {
2331 None => {
2332 headers.insert(name, HeaderValue::from_static(default));
2333 }
2334 Some("") => {
2335 }
2337 Some(v) => match HeaderValue::from_str(v) {
2338 Ok(hv) => {
2339 headers.insert(name, hv);
2340 }
2341 Err(err) => {
2342 tracing::error!(
2343 header = %name,
2344 error = %err,
2345 "invalid security header override reached middleware; using default"
2346 );
2347 headers.insert(name, HeaderValue::from_static(default));
2348 }
2349 },
2350 }
2351}
2352
2353fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2364 use axum::http::HeaderValue;
2365
2366 let fields: &[(&str, Option<&str>)] = &[
2367 (
2368 "x_content_type_options",
2369 cfg.x_content_type_options.as_deref(),
2370 ),
2371 ("x_frame_options", cfg.x_frame_options.as_deref()),
2372 ("cache_control", cfg.cache_control.as_deref()),
2373 ("referrer_policy", cfg.referrer_policy.as_deref()),
2374 (
2375 "cross_origin_opener_policy",
2376 cfg.cross_origin_opener_policy.as_deref(),
2377 ),
2378 (
2379 "cross_origin_resource_policy",
2380 cfg.cross_origin_resource_policy.as_deref(),
2381 ),
2382 (
2383 "cross_origin_embedder_policy",
2384 cfg.cross_origin_embedder_policy.as_deref(),
2385 ),
2386 ("permissions_policy", cfg.permissions_policy.as_deref()),
2387 (
2388 "x_permitted_cross_domain_policies",
2389 cfg.x_permitted_cross_domain_policies.as_deref(),
2390 ),
2391 (
2392 "content_security_policy",
2393 cfg.content_security_policy.as_deref(),
2394 ),
2395 (
2396 "x_dns_prefetch_control",
2397 cfg.x_dns_prefetch_control.as_deref(),
2398 ),
2399 (
2400 "strict_transport_security",
2401 cfg.strict_transport_security.as_deref(),
2402 ),
2403 ];
2404
2405 for (field, value) in fields {
2406 let Some(v) = value else { continue };
2407 if v.is_empty() {
2408 continue;
2409 }
2410 if let Err(err) = HeaderValue::from_str(v) {
2411 return Err(McpxError::Config(format!(
2412 "invalid security_headers.{field}: {err}"
2413 )));
2414 }
2415 }
2416
2417 if let Some(v) = cfg.strict_transport_security.as_deref()
2418 && !v.is_empty()
2419 && v.to_ascii_lowercase().contains("preload")
2420 {
2421 return Err(McpxError::Config(format!(
2422 "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2423 HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2424 )));
2425 }
2426
2427 Ok(())
2428}
2429
2430#[cfg(feature = "oauth")]
2445async fn oauth_token_cache_headers_middleware(
2446 req: Request<Body>,
2447 next: Next,
2448) -> axum::response::Response {
2449 use axum::http::{HeaderValue, header};
2450
2451 let mut resp = next.run(req).await;
2452 let headers = resp.headers_mut();
2453 headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2454 headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2455 resp
2456}
2457
2458async fn origin_check_middleware(
2462 allowed: Arc<[String]>,
2463 log_request_headers: bool,
2464 req: Request<Body>,
2465 next: Next,
2466) -> axum::response::Response {
2467 let method = req.method().clone();
2468 let path = req.uri().path().to_owned();
2469
2470 log_incoming_request(&method, &path, req.headers(), log_request_headers);
2471
2472 if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2473 let origin_str = origin.to_str().unwrap_or("");
2474 if !allowed.iter().any(|a| a == origin_str) {
2475 tracing::warn!(
2476 origin = origin_str,
2477 %method,
2478 %path,
2479 allowed = ?&*allowed,
2480 "rejected request: Origin not allowed"
2481 );
2482 return (
2483 axum::http::StatusCode::FORBIDDEN,
2484 "Forbidden: Origin not allowed",
2485 )
2486 .into_response();
2487 }
2488 }
2489 next.run(req).await
2490}
2491
2492fn log_incoming_request(
2495 method: &axum::http::Method,
2496 path: &str,
2497 headers: &axum::http::HeaderMap,
2498 log_request_headers: bool,
2499) {
2500 if log_request_headers {
2501 tracing::debug!(
2502 %method,
2503 %path,
2504 headers = %format_request_headers_for_log(headers),
2505 "incoming request"
2506 );
2507 } else {
2508 tracing::debug!(%method, %path, "incoming request");
2509 }
2510}
2511
2512fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2513 headers
2514 .iter()
2515 .map(|(k, v)| {
2516 let name = k.as_str();
2517 if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2518 format!("{name}: [REDACTED]")
2519 } else {
2520 format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2521 }
2522 })
2523 .collect::<Vec<_>>()
2524 .join(", ")
2525}
2526
2527#[allow(clippy::cognitive_complexity)]
2551pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2552where
2553 H: ServerHandler + 'static,
2554{
2555 use rmcp::ServiceExt as _;
2556
2557 tracing::info!("stdio transport: serving on stdin/stdout");
2558 tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2559
2560 let transport = rmcp::transport::io::stdio();
2561
2562 let service = handler
2563 .serve(transport)
2564 .await
2565 .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2566
2567 if let Err(e) = service.waiting().await {
2568 tracing::warn!(error = %e, "stdio session ended with error");
2569 }
2570 tracing::info!("stdio session ended");
2571 Ok(())
2572}
2573
2574#[cfg(test)]
2575mod tests {
2576 #![allow(
2577 clippy::unwrap_used,
2578 clippy::expect_used,
2579 clippy::panic,
2580 clippy::indexing_slicing,
2581 clippy::unwrap_in_result,
2582 clippy::print_stdout,
2583 clippy::print_stderr,
2584 deprecated,
2585 reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2586 )]
2587 use std::sync::Arc;
2588
2589 use axum::{
2590 body::Body,
2591 http::{Request, StatusCode, header},
2592 response::IntoResponse,
2593 };
2594 use http_body_util::BodyExt;
2595 use tower::ServiceExt as _;
2596
2597 use super::*;
2598
2599 #[test]
2602 fn server_config_new_defaults() {
2603 let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2604 assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2605 assert_eq!(cfg.name, "test-server");
2606 assert_eq!(cfg.version, "1.0.0");
2607 assert!(cfg.tls_cert_path.is_none());
2608 assert!(cfg.tls_key_path.is_none());
2609 assert!(cfg.auth.is_none());
2610 assert!(cfg.rbac.is_none());
2611 assert!(cfg.allowed_origins.is_empty());
2612 assert!(cfg.tool_rate_limit.is_none());
2613 assert!(cfg.readiness_check.is_none());
2614 assert_eq!(cfg.max_request_body, 1024 * 1024);
2615 assert_eq!(cfg.request_timeout, Duration::from_mins(2));
2616 assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
2617 assert!(!cfg.log_request_headers);
2618 }
2619
2620 #[test]
2621 fn validate_consumes_and_proves() {
2622 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2624 let validated = cfg.validate().expect("valid config");
2625 assert_eq!(validated.name, "test-server");
2627 let raw = validated.into_inner();
2629 assert_eq!(raw.name, "test-server");
2630
2631 let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2633 bad.max_request_body = 0;
2634 assert!(bad.validate().is_err(), "zero body cap must fail validate");
2635 }
2636
2637 #[test]
2638 fn derive_allowed_hosts_includes_public_host() {
2639 let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
2640 assert!(
2641 hosts.iter().any(|h| h == "mcp.example.com"),
2642 "public_url host must be allowed"
2643 );
2644 }
2645
2646 #[test]
2647 fn derive_allowed_hosts_includes_bind_authority() {
2648 let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
2649 assert!(
2650 hosts.iter().any(|h| h == "127.0.0.1"),
2651 "bind host must be allowed"
2652 );
2653 assert!(
2654 hosts.iter().any(|h| h == "127.0.0.1:8080"),
2655 "bind authority must be allowed"
2656 );
2657 }
2658
2659 #[tokio::test]
2662 async fn healthz_returns_ok_json() {
2663 let resp = healthz().await.into_response();
2664 assert_eq!(resp.status(), StatusCode::OK);
2665 let body = resp.into_body().collect().await.unwrap().to_bytes();
2666 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2667 assert_eq!(json["status"], "ok");
2668 assert!(
2669 json.get("name").is_none(),
2670 "healthz must not expose server name"
2671 );
2672 assert!(
2673 json.get("version").is_none(),
2674 "healthz must not expose version"
2675 );
2676 }
2677
2678 #[tokio::test]
2681 async fn readyz_returns_ok_when_ready() {
2682 let check: ReadinessCheck =
2683 Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
2684 let resp = readyz(check).await.into_response();
2685 assert_eq!(resp.status(), StatusCode::OK);
2686 let body = resp.into_body().collect().await.unwrap().to_bytes();
2687 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2688 assert_eq!(json["ready"], true);
2689 assert!(
2690 json.get("name").is_none(),
2691 "readyz must not expose server name"
2692 );
2693 assert!(
2694 json.get("version").is_none(),
2695 "readyz must not expose version"
2696 );
2697 assert_eq!(json["db"], "connected");
2698 }
2699
2700 #[tokio::test]
2701 async fn readyz_returns_503_when_not_ready() {
2702 let check: ReadinessCheck =
2703 Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
2704 let resp = readyz(check).await.into_response();
2705 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2706 }
2707
2708 #[tokio::test]
2709 async fn readyz_returns_503_when_ready_missing() {
2710 let check: ReadinessCheck =
2711 Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
2712 let resp = readyz(check).await.into_response();
2713 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2715 }
2716
2717 fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
2721 let allowed: Arc<[String]> = Arc::from(origins);
2722 axum::Router::new()
2723 .route("/test", axum::routing::get(|| async { "ok" }))
2724 .layer(axum::middleware::from_fn(move |req, next| {
2725 let a = Arc::clone(&allowed);
2726 origin_check_middleware(a, log_request_headers, req, next)
2727 }))
2728 }
2729
2730 #[tokio::test]
2731 async fn origin_allowed_passes() {
2732 let app = origin_router(vec!["http://localhost:3000".into()], false);
2733 let req = Request::builder()
2734 .uri("/test")
2735 .header(header::ORIGIN, "http://localhost:3000")
2736 .body(Body::empty())
2737 .unwrap();
2738 let resp = app.oneshot(req).await.unwrap();
2739 assert_eq!(resp.status(), StatusCode::OK);
2740 }
2741
2742 #[tokio::test]
2743 async fn origin_rejected_returns_403() {
2744 let app = origin_router(vec!["http://localhost:3000".into()], false);
2745 let req = Request::builder()
2746 .uri("/test")
2747 .header(header::ORIGIN, "http://evil.com")
2748 .body(Body::empty())
2749 .unwrap();
2750 let resp = app.oneshot(req).await.unwrap();
2751 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2752 }
2753
2754 #[tokio::test]
2755 async fn no_origin_header_passes() {
2756 let app = origin_router(vec!["http://localhost:3000".into()], false);
2757 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2758 let resp = app.oneshot(req).await.unwrap();
2759 assert_eq!(resp.status(), StatusCode::OK);
2760 }
2761
2762 #[tokio::test]
2763 async fn empty_allowlist_rejects_any_origin() {
2764 let app = origin_router(vec![], false);
2765 let req = Request::builder()
2766 .uri("/test")
2767 .header(header::ORIGIN, "http://anything.com")
2768 .body(Body::empty())
2769 .unwrap();
2770 let resp = app.oneshot(req).await.unwrap();
2771 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2772 }
2773
2774 #[tokio::test]
2775 async fn empty_allowlist_passes_without_origin() {
2776 let app = origin_router(vec![], false);
2777 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2778 let resp = app.oneshot(req).await.unwrap();
2779 assert_eq!(resp.status(), StatusCode::OK);
2780 }
2781
2782 #[test]
2783 fn format_request_headers_redacts_sensitive_values() {
2784 let mut headers = axum::http::HeaderMap::new();
2785 headers.insert("authorization", "Bearer secret-token".parse().unwrap());
2786 headers.insert("cookie", "sid=abc".parse().unwrap());
2787 headers.insert("x-request-id", "req-123".parse().unwrap());
2788
2789 let out = format_request_headers_for_log(&headers);
2790 assert!(out.contains("authorization: [REDACTED]"));
2791 assert!(out.contains("cookie: [REDACTED]"));
2792 assert!(out.contains("x-request-id: req-123"));
2793 assert!(!out.contains("secret-token"));
2794 }
2795
2796 fn security_router(is_tls: bool) -> axum::Router {
2799 security_router_with(is_tls, SecurityHeadersConfig::default())
2800 }
2801
2802 fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
2803 let cfg = Arc::new(cfg);
2804 axum::Router::new()
2805 .route("/test", axum::routing::get(|| async { "ok" }))
2806 .layer(axum::middleware::from_fn(move |req, next| {
2807 let c = Arc::clone(&cfg);
2808 security_headers_middleware(is_tls, c, req, next)
2809 }))
2810 }
2811
2812 #[tokio::test]
2813 async fn security_headers_set_on_response() {
2814 let app = security_router(false);
2815 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2816 let resp = app.oneshot(req).await.unwrap();
2817 assert_eq!(resp.status(), StatusCode::OK);
2818
2819 let h = resp.headers();
2820 assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
2821 assert_eq!(h.get("x-frame-options").unwrap(), "deny");
2822 assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
2823 assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
2824 assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
2825 assert_eq!(
2826 h.get("cross-origin-resource-policy").unwrap(),
2827 "same-origin"
2828 );
2829 assert_eq!(
2830 h.get("cross-origin-embedder-policy").unwrap(),
2831 "require-corp"
2832 );
2833 assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
2834 assert!(
2835 h.get("permissions-policy")
2836 .unwrap()
2837 .to_str()
2838 .unwrap()
2839 .contains("camera=()"),
2840 "permissions-policy must restrict browser features"
2841 );
2842 assert_eq!(
2843 h.get("content-security-policy").unwrap(),
2844 "default-src 'none'; frame-ancestors 'none'"
2845 );
2846 assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
2847 assert!(h.get("strict-transport-security").is_none());
2849 }
2850
2851 #[tokio::test]
2852 async fn hsts_set_when_tls_enabled() {
2853 let app = security_router(true);
2854 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2855 let resp = app.oneshot(req).await.unwrap();
2856
2857 let hsts = resp.headers().get("strict-transport-security").unwrap();
2858 assert!(
2859 hsts.to_str().unwrap().contains("max-age=63072000"),
2860 "HSTS must set 2-year max-age"
2861 );
2862 }
2863
2864 fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
2870 let cfg =
2871 McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
2872 cfg.check()
2873 }
2874
2875 #[test]
2876 fn security_headers_config_default_validates() {
2877 check_with_security_headers(SecurityHeadersConfig::default())
2878 .expect("default SecurityHeadersConfig must validate");
2879 }
2880
2881 #[test]
2882 fn security_headers_config_validate_accepts_empty_string() {
2883 let h = SecurityHeadersConfig {
2885 x_content_type_options: Some(String::new()),
2886 x_frame_options: Some(String::new()),
2887 cache_control: Some(String::new()),
2888 referrer_policy: Some(String::new()),
2889 cross_origin_opener_policy: Some(String::new()),
2890 cross_origin_resource_policy: Some(String::new()),
2891 cross_origin_embedder_policy: Some(String::new()),
2892 permissions_policy: Some(String::new()),
2893 x_permitted_cross_domain_policies: Some(String::new()),
2894 content_security_policy: Some(String::new()),
2895 x_dns_prefetch_control: Some(String::new()),
2896 strict_transport_security: Some(String::new()),
2897 };
2898 check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
2899 }
2900
2901 #[test]
2902 fn security_headers_config_validate_rejects_bad_value() {
2903 let h = SecurityHeadersConfig {
2905 referrer_policy: Some("\u{0007}".into()),
2906 ..SecurityHeadersConfig::default()
2907 };
2908 let err = check_with_security_headers(h)
2909 .expect_err("control char in referrer_policy must reject");
2910 let msg = err.to_string();
2911 assert!(
2912 msg.contains("referrer_policy"),
2913 "error must name the offending field, got: {msg}"
2914 );
2915 }
2916
2917 #[test]
2918 fn security_headers_config_validate_rejects_hsts_preload() {
2919 let h = SecurityHeadersConfig {
2920 strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
2921 ..SecurityHeadersConfig::default()
2922 };
2923 let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
2924 let msg = err.to_string();
2925 assert!(
2926 msg.contains("strict_transport_security"),
2927 "error must name the field, got: {msg}"
2928 );
2929 assert!(
2930 msg.to_lowercase().contains("preload"),
2931 "error must mention `preload`, got: {msg}"
2932 );
2933 }
2934
2935 #[test]
2936 fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
2937 let h = SecurityHeadersConfig {
2939 strict_transport_security: Some("max-age=600; PRELOAD".into()),
2940 ..SecurityHeadersConfig::default()
2941 };
2942 check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
2943 }
2944
2945 #[tokio::test]
2946 async fn security_headers_override_honored() {
2947 let h = SecurityHeadersConfig {
2949 x_frame_options: Some("SAMEORIGIN".into()),
2950 ..SecurityHeadersConfig::default()
2951 };
2952 let app = security_router_with(false, h);
2953 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2954 let resp = app.oneshot(req).await.unwrap();
2955 assert_eq!(resp.status(), StatusCode::OK);
2956
2957 let xfo = resp.headers().get("x-frame-options").unwrap();
2958 assert_eq!(xfo, "SAMEORIGIN");
2959 }
2960
2961 #[tokio::test]
2962 async fn security_headers_empty_string_omits() {
2963 let h = SecurityHeadersConfig {
2965 referrer_policy: Some(String::new()),
2966 ..SecurityHeadersConfig::default()
2967 };
2968 let app = security_router_with(false, h);
2969 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2970 let resp = app.oneshot(req).await.unwrap();
2971 assert_eq!(resp.status(), StatusCode::OK);
2972
2973 assert!(
2974 resp.headers().get("referrer-policy").is_none(),
2975 "Some(\"\") must omit the header"
2976 );
2977 assert_eq!(
2979 resp.headers().get("x-content-type-options").unwrap(),
2980 "nosniff"
2981 );
2982 }
2983
2984 #[tokio::test]
2985 async fn security_headers_hsts_only_when_tls() {
2986 let h = SecurityHeadersConfig {
2988 strict_transport_security: Some("max-age=600".into()),
2989 ..SecurityHeadersConfig::default()
2990 };
2991 let app = security_router_with(false, h);
2992 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2993 let resp = app.oneshot(req).await.unwrap();
2994 assert!(
2995 resp.headers().get("strict-transport-security").is_none(),
2996 "HSTS must remain absent on plaintext deployments even with override"
2997 );
2998 }
2999
3000 #[cfg(feature = "oauth")]
3003 #[tokio::test]
3004 async fn oauth_token_cache_headers_set_pragma_and_vary() {
3005 let app = axum::Router::new()
3006 .route("/token", axum::routing::post(|| async { "{}" }))
3007 .layer(axum::middleware::from_fn(
3008 oauth_token_cache_headers_middleware,
3009 ));
3010 let req = Request::builder()
3011 .method("POST")
3012 .uri("/token")
3013 .body(Body::from("{}"))
3014 .unwrap();
3015 let resp = app.oneshot(req).await.unwrap();
3016 assert_eq!(resp.status(), StatusCode::OK);
3017
3018 let h = resp.headers();
3019 assert_eq!(
3020 h.get("pragma").unwrap(),
3021 "no-cache",
3022 "RFC 6749 §5.1: token responses must set Pragma: no-cache"
3023 );
3024 let vary_values: Vec<String> = h
3025 .get_all("vary")
3026 .iter()
3027 .filter_map(|v| v.to_str().ok().map(str::to_owned))
3028 .collect();
3029 assert!(
3030 vary_values
3031 .iter()
3032 .any(|v| v.eq_ignore_ascii_case("Authorization")),
3033 "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
3034 );
3035 }
3036
3037 #[cfg(feature = "oauth")]
3038 #[tokio::test]
3039 async fn oauth_token_cache_headers_preserve_existing_vary() {
3040 let app = axum::Router::new()
3043 .route(
3044 "/token",
3045 axum::routing::post(|| async {
3046 axum::response::Response::builder()
3047 .header("vary", "Accept-Encoding")
3048 .body(axum::body::Body::from("{}"))
3049 .unwrap()
3050 }),
3051 )
3052 .layer(axum::middleware::from_fn(
3053 oauth_token_cache_headers_middleware,
3054 ));
3055 let req = Request::builder()
3056 .method("POST")
3057 .uri("/token")
3058 .body(Body::empty())
3059 .unwrap();
3060 let resp = app.oneshot(req).await.unwrap();
3061
3062 let vary: Vec<String> = resp
3063 .headers()
3064 .get_all("vary")
3065 .iter()
3066 .filter_map(|v| v.to_str().ok().map(str::to_owned))
3067 .collect();
3068 assert!(
3069 vary.iter().any(|v| v.contains("Accept-Encoding")),
3070 "must preserve pre-existing Vary value, got {vary:?}"
3071 );
3072 assert!(
3073 vary.iter().any(|v| v.contains("Authorization")),
3074 "must append Authorization to Vary, got {vary:?}"
3075 );
3076 }
3077
3078 #[test]
3081 fn version_payload_contains_expected_fields() {
3082 let v = version_payload("my-server", "1.2.3");
3083 assert_eq!(v["name"], "my-server");
3084 assert_eq!(v["version"], "1.2.3");
3085 assert!(v["build_git_sha"].is_string());
3086 assert!(v["build_timestamp"].is_string());
3087 assert!(v["rust_version"].is_string());
3088 assert!(v["mcpx_version"].is_string());
3089 }
3090
3091 #[tokio::test]
3094 async fn concurrency_limit_layer_composes_and_serves() {
3095 let app = axum::Router::new()
3099 .route("/ok", axum::routing::get(|| async { "ok" }))
3100 .layer(
3101 tower::ServiceBuilder::new()
3102 .layer(axum::error_handling::HandleErrorLayer::new(
3103 |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
3104 ))
3105 .layer(tower::load_shed::LoadShedLayer::new())
3106 .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
3107 );
3108 let resp = app
3109 .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
3110 .await
3111 .unwrap();
3112 assert_eq!(resp.status(), StatusCode::OK);
3113 }
3114
3115 #[tokio::test]
3118 async fn compression_layer_gzip_encodes_response() {
3119 use tower_http::compression::Predicate as _;
3120
3121 let big_body = "a".repeat(4096);
3122 let app = axum::Router::new()
3123 .route(
3124 "/big",
3125 axum::routing::get(move || {
3126 let body = big_body.clone();
3127 async move { body }
3128 }),
3129 )
3130 .layer(
3131 tower_http::compression::CompressionLayer::new()
3132 .gzip(true)
3133 .br(true)
3134 .compress_when(
3135 tower_http::compression::DefaultPredicate::new()
3136 .and(tower_http::compression::predicate::SizeAbove::new(1024)),
3137 ),
3138 );
3139
3140 let req = Request::builder()
3141 .uri("/big")
3142 .header(header::ACCEPT_ENCODING, "gzip")
3143 .body(Body::empty())
3144 .unwrap();
3145 let resp = app.oneshot(req).await.unwrap();
3146 assert_eq!(resp.status(), StatusCode::OK);
3147 assert_eq!(
3148 resp.headers().get(header::CONTENT_ENCODING).unwrap(),
3149 "gzip"
3150 );
3151 }
3152}