1use std::{
2 future::Future,
3 net::SocketAddr,
4 path::{Path, PathBuf},
5 pin::Pin,
6 sync::Arc,
7 time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{body::Body, extract::Request, middleware::Next, response::IntoResponse};
12use rmcp::{
13 ServerHandler,
14 transport::streamable_http_server::{
15 StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16 },
17};
18use rustls::RootCertStore;
19use tokio::net::TcpListener;
20use tokio_util::sync::CancellationToken;
21
22use crate::{
23 auth::{
24 AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
25 build_rate_limiter, extract_mtls_identity,
26 },
27 error::McpxError,
28 mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
29 rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
30};
31
32#[allow(
36 clippy::needless_pass_by_value,
37 reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
38)]
39fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
40 McpxError::Startup(format!("{e:#}"))
41}
42
43#[allow(
49 clippy::needless_pass_by_value,
50 reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
51)]
52fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
53 McpxError::Startup(format!("{op}: {e}"))
54}
55
56pub type ReadinessCheck =
61 Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
62
63#[derive(Debug, Clone, Default)]
84#[non_exhaustive]
85pub struct SecurityHeadersConfig {
86 pub x_content_type_options: Option<String>,
88 pub x_frame_options: Option<String>,
90 pub cache_control: Option<String>,
92 pub referrer_policy: Option<String>,
94 pub cross_origin_opener_policy: Option<String>,
96 pub cross_origin_resource_policy: Option<String>,
98 pub cross_origin_embedder_policy: Option<String>,
100 pub permissions_policy: Option<String>,
103 pub x_permitted_cross_domain_policies: Option<String>,
105 pub content_security_policy: Option<String>,
108 pub x_dns_prefetch_control: Option<String>,
110 pub strict_transport_security: Option<String>,
115}
116
117#[allow(
119 missing_debug_implementations,
120 reason = "contains callback/trait objects that don't impl Debug"
121)]
122#[allow(
123 clippy::struct_excessive_bools,
124 reason = "server configuration naturally has many boolean feature flags"
125)]
126#[non_exhaustive]
127pub struct McpServerConfig {
128 #[deprecated(
130 since = "0.13.0",
131 note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
132 )]
133 pub bind_addr: String,
134 #[deprecated(
136 since = "0.13.0",
137 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
138 )]
139 pub name: String,
140 #[deprecated(
142 since = "0.13.0",
143 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
144 )]
145 pub version: String,
146 #[deprecated(
148 since = "0.13.0",
149 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
150 )]
151 pub tls_cert_path: Option<PathBuf>,
152 #[deprecated(
154 since = "0.13.0",
155 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
156 )]
157 pub tls_key_path: Option<PathBuf>,
158 #[deprecated(
161 since = "0.13.0",
162 note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
163 )]
164 pub auth: Option<AuthConfig>,
165 #[deprecated(
168 since = "0.13.0",
169 note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
170 )]
171 pub rbac: Option<Arc<RbacPolicy>>,
172 #[deprecated(
178 since = "0.13.0",
179 note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
180 )]
181 pub allowed_origins: Vec<String>,
182 #[deprecated(
185 since = "0.13.0",
186 note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
187 )]
188 pub tool_rate_limit: Option<u32>,
189 #[deprecated(
192 since = "0.13.0",
193 note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
194 )]
195 pub readiness_check: Option<ReadinessCheck>,
196 #[deprecated(
199 since = "0.13.0",
200 note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
201 )]
202 pub max_request_body: usize,
203 #[deprecated(
206 since = "0.13.0",
207 note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
208 )]
209 pub request_timeout: Duration,
210 #[deprecated(
213 since = "0.13.0",
214 note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
215 )]
216 pub shutdown_timeout: Duration,
217 #[deprecated(
220 since = "0.13.0",
221 note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
222 )]
223 pub session_idle_timeout: Duration,
224 #[deprecated(
227 since = "0.13.0",
228 note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
229 )]
230 pub sse_keep_alive: Duration,
231 #[deprecated(
235 since = "0.13.0",
236 note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
237 )]
238 pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
239 #[deprecated(
243 since = "0.13.0",
244 note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
245 )]
246 pub extra_router: Option<axum::Router>,
247 #[deprecated(
252 since = "0.13.0",
253 note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
254 )]
255 pub public_url: Option<String>,
256 #[deprecated(
259 since = "0.13.0",
260 note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
261 )]
262 pub log_request_headers: bool,
263 #[deprecated(
266 since = "0.13.0",
267 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
268 )]
269 pub compression_enabled: bool,
270 #[deprecated(
273 since = "0.13.0",
274 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
275 )]
276 pub compression_min_size: u16,
277 #[deprecated(
281 since = "0.13.0",
282 note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
283 )]
284 pub max_concurrent_requests: Option<usize>,
285 #[deprecated(
288 since = "0.13.0",
289 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
290 )]
291 pub admin_enabled: bool,
292 #[deprecated(
294 since = "0.13.0",
295 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
296 )]
297 pub admin_role: String,
298 #[cfg(feature = "metrics")]
301 #[deprecated(
302 since = "0.13.0",
303 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
304 )]
305 pub metrics_enabled: bool,
306 #[cfg(feature = "metrics")]
308 #[deprecated(
309 since = "0.13.0",
310 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
311 )]
312 pub metrics_bind: String,
313 #[deprecated(
317 since = "1.5.0",
318 note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
319 )]
320 pub security_headers: SecurityHeadersConfig,
321}
322
323#[allow(
381 missing_debug_implementations,
382 reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
383)]
384pub struct Validated<T>(T);
385
386impl<T> std::fmt::Debug for Validated<T> {
387 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388 f.debug_struct("Validated").finish_non_exhaustive()
389 }
390}
391
392impl<T> Validated<T> {
393 #[must_use]
395 pub fn as_inner(&self) -> &T {
396 &self.0
397 }
398
399 #[must_use]
404 pub fn into_inner(self) -> T {
405 self.0
406 }
407}
408
409impl<T> std::ops::Deref for Validated<T> {
410 type Target = T;
411
412 fn deref(&self) -> &T {
413 &self.0
414 }
415}
416
417#[allow(
418 deprecated,
419 reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
420)]
421impl McpServerConfig {
422 #[must_use]
430 pub fn new(
431 bind_addr: impl Into<String>,
432 name: impl Into<String>,
433 version: impl Into<String>,
434 ) -> Self {
435 Self {
436 bind_addr: bind_addr.into(),
437 name: name.into(),
438 version: version.into(),
439 tls_cert_path: None,
440 tls_key_path: None,
441 auth: None,
442 rbac: None,
443 allowed_origins: Vec::new(),
444 tool_rate_limit: None,
445 readiness_check: None,
446 max_request_body: 1024 * 1024,
447 request_timeout: Duration::from_mins(2),
448 shutdown_timeout: Duration::from_secs(30),
449 session_idle_timeout: Duration::from_mins(20),
450 sse_keep_alive: Duration::from_secs(15),
451 on_reload_ready: None,
452 extra_router: None,
453 public_url: None,
454 log_request_headers: false,
455 compression_enabled: false,
456 compression_min_size: 1024,
457 max_concurrent_requests: None,
458 admin_enabled: false,
459 admin_role: "admin".to_owned(),
460 #[cfg(feature = "metrics")]
461 metrics_enabled: false,
462 #[cfg(feature = "metrics")]
463 metrics_bind: "127.0.0.1:9090".into(),
464 security_headers: SecurityHeadersConfig::default(),
465 }
466 }
467
468 #[must_use]
478 pub fn with_auth(mut self, auth: AuthConfig) -> Self {
479 self.auth = Some(auth);
480 self
481 }
482
483 #[must_use]
488 pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
489 self.security_headers = headers;
490 self
491 }
492
493 #[must_use]
497 pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
498 self.bind_addr = addr.into();
499 self
500 }
501
502 #[must_use]
505 pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
506 self.rbac = Some(rbac);
507 self
508 }
509
510 #[must_use]
514 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
515 self.tls_cert_path = Some(cert_path.into());
516 self.tls_key_path = Some(key_path.into());
517 self
518 }
519
520 #[must_use]
524 pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
525 self.public_url = Some(url.into());
526 self
527 }
528
529 #[must_use]
533 pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
534 where
535 I: IntoIterator<Item = S>,
536 S: Into<String>,
537 {
538 self.allowed_origins = origins.into_iter().map(Into::into).collect();
539 self
540 }
541
542 #[must_use]
546 pub fn with_extra_router(mut self, router: axum::Router) -> Self {
547 self.extra_router = Some(router);
548 self
549 }
550
551 #[must_use]
554 pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
555 self.readiness_check = Some(check);
556 self
557 }
558
559 #[must_use]
562 pub fn with_max_request_body(mut self, bytes: usize) -> Self {
563 self.max_request_body = bytes;
564 self
565 }
566
567 #[must_use]
569 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
570 self.request_timeout = timeout;
571 self
572 }
573
574 #[must_use]
576 pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
577 self.shutdown_timeout = timeout;
578 self
579 }
580
581 #[must_use]
583 pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
584 self.session_idle_timeout = timeout;
585 self
586 }
587
588 #[must_use]
590 pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
591 self.sse_keep_alive = interval;
592 self
593 }
594
595 #[must_use]
599 pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
600 self.max_concurrent_requests = Some(limit);
601 self
602 }
603
604 #[must_use]
607 pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
608 self.tool_rate_limit = Some(per_minute);
609 self
610 }
611
612 #[must_use]
616 pub fn with_reload_callback<F>(mut self, callback: F) -> Self
617 where
618 F: FnOnce(ReloadHandle) + Send + 'static,
619 {
620 self.on_reload_ready = Some(Box::new(callback));
621 self
622 }
623
624 #[must_use]
628 pub fn enable_compression(mut self, min_size: u16) -> Self {
629 self.compression_enabled = true;
630 self.compression_min_size = min_size;
631 self
632 }
633
634 #[must_use]
639 pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
640 self.admin_enabled = true;
641 self.admin_role = role.into();
642 self
643 }
644
645 #[must_use]
648 pub fn enable_request_header_logging(mut self) -> Self {
649 self.log_request_headers = true;
650 self
651 }
652
653 #[cfg(feature = "metrics")]
656 #[must_use]
657 pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
658 self.metrics_enabled = true;
659 self.metrics_bind = bind.into();
660 self
661 }
662
663 pub fn validate(self) -> Result<Validated<Self>, McpxError> {
696 self.check()?;
697 Ok(Validated(self))
698 }
699
700 fn check(&self) -> Result<(), McpxError> {
704 if self.admin_enabled {
708 let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
709 if !auth_enabled {
710 return Err(McpxError::Config(
711 "admin_enabled=true requires auth to be configured and enabled".into(),
712 ));
713 }
714 }
715
716 match (&self.tls_cert_path, &self.tls_key_path) {
718 (Some(_), None) => {
719 return Err(McpxError::Config(
720 "tls_cert_path is set but tls_key_path is missing".into(),
721 ));
722 }
723 (None, Some(_)) => {
724 return Err(McpxError::Config(
725 "tls_key_path is set but tls_cert_path is missing".into(),
726 ));
727 }
728 _ => {}
729 }
730
731 if self.bind_addr.parse::<SocketAddr>().is_err() {
733 return Err(McpxError::Config(format!(
734 "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
735 self.bind_addr
736 )));
737 }
738
739 if let Some(ref url) = self.public_url
741 && !(url.starts_with("http://") || url.starts_with("https://"))
742 {
743 return Err(McpxError::Config(format!(
744 "public_url {url:?} must start with http:// or https://"
745 )));
746 }
747
748 for origin in &self.allowed_origins {
750 if !(origin.starts_with("http://") || origin.starts_with("https://")) {
751 return Err(McpxError::Config(format!(
752 "allowed_origins entry {origin:?} must start with http:// or https://"
753 )));
754 }
755 }
756
757 if self.max_request_body == 0 {
759 return Err(McpxError::Config(
760 "max_request_body must be greater than zero".into(),
761 ));
762 }
763
764 #[cfg(feature = "oauth")]
766 if let Some(auth_cfg) = &self.auth
767 && let Some(oauth_cfg) = &auth_cfg.oauth
768 {
769 oauth_cfg.validate()?;
770 }
771
772 validate_security_headers(&self.security_headers)?;
775
776 if let Some(0) = self.max_concurrent_requests {
780 return Err(McpxError::Config(
781 "max_concurrent_requests must be greater than zero when set".into(),
782 ));
783 }
784
785 if let Some(auth_cfg) = &self.auth
789 && let Some(rl) = &auth_cfg.rate_limit
790 && rl.max_tracked_keys == 0
791 {
792 return Err(McpxError::Config(
793 "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
794 ));
795 }
796
797 Ok(())
798 }
799}
800
801#[allow(
807 missing_debug_implementations,
808 reason = "contains Arc<AuthState> with non-Debug fields"
809)]
810pub struct ReloadHandle {
811 auth: Option<Arc<AuthState>>,
812 rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
813 crl_set: Option<Arc<CrlSet>>,
814}
815
816impl ReloadHandle {
817 pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
819 if let Some(ref auth) = self.auth {
820 auth.reload_keys(keys);
821 }
822 }
823
824 pub fn reload_rbac(&self, policy: RbacPolicy) {
826 if let Some(ref rbac) = self.rbac {
827 rbac.store(Arc::new(policy));
828 tracing::info!("RBAC policy reloaded");
829 }
830 }
831
832 pub async fn refresh_crls(&self) -> Result<(), McpxError> {
838 let Some(ref crl_set) = self.crl_set else {
839 return Err(McpxError::Config(
840 "CRL refresh requested but mTLS CRL support is not configured".into(),
841 ));
842 };
843
844 crl_set.force_refresh().await
845 }
846}
847
848#[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
865struct AppRunParams {
869 tls_paths: Option<(PathBuf, PathBuf)>,
871 mtls_config: Option<MtlsConfig>,
873 shutdown_timeout: Duration,
875 auth_state: Option<Arc<AuthState>>,
877 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
879 on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
881 ct: CancellationToken,
885 scheme: &'static str,
887 name: String,
889}
890
891#[allow(
901 clippy::cognitive_complexity,
902 reason = "router assembly is intrinsically sequential; splitting harms readability"
903)]
904#[allow(
905 deprecated,
906 reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
907)]
908fn build_app_router<H, F>(
909 mut config: McpServerConfig,
910 handler_factory: F,
911) -> anyhow::Result<(axum::Router, AppRunParams)>
912where
913 H: ServerHandler + 'static,
914 F: Fn() -> H + Send + Sync + Clone + 'static,
915{
916 let ct = CancellationToken::new();
917
918 let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
919 tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
920
921 let mcp_service = StreamableHttpService::new(
922 move || Ok(handler_factory()),
923 {
924 let mut mgr = LocalSessionManager::default();
925 mgr.session_config.keep_alive = Some(config.session_idle_timeout);
926 mgr.into()
927 },
928 StreamableHttpServerConfig::default()
929 .with_allowed_hosts(allowed_hosts)
930 .with_sse_keep_alive(Some(config.sse_keep_alive))
931 .with_cancellation_token(ct.child_token()),
932 );
933
934 let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
936
937 let auth_state: Option<Arc<AuthState>> = match config.auth {
941 Some(ref auth_config) if auth_config.enabled => {
942 let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
943 let pre_auth_limiter = auth_config
944 .rate_limit
945 .as_ref()
946 .map(crate::auth::build_pre_auth_limiter);
947
948 #[cfg(feature = "oauth")]
949 let jwks_cache = auth_config
950 .oauth
951 .as_ref()
952 .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
953 .transpose()
954 .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
955
956 Some(Arc::new(AuthState {
957 api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
958 rate_limiter,
959 pre_auth_limiter,
960 #[cfg(feature = "oauth")]
961 jwks_cache,
962 seen_identities: std::sync::Mutex::new(std::collections::HashSet::new()),
963 counters: crate::auth::AuthCounters::default(),
964 }))
965 }
966 _ => None,
967 };
968
969 let rbac_swap = Arc::new(ArcSwap::new(
972 config
973 .rbac
974 .clone()
975 .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
976 ));
977
978 if config.admin_enabled {
981 let Some(ref auth_state_ref) = auth_state else {
982 return Err(anyhow::anyhow!(
983 "admin_enabled=true requires auth to be configured and enabled"
984 ));
985 };
986 let admin_state = crate::admin::AdminState {
987 started_at: std::time::Instant::now(),
988 name: config.name.clone(),
989 version: config.version.clone(),
990 auth: Some(Arc::clone(auth_state_ref)),
991 rbac: Arc::clone(&rbac_swap),
992 };
993 let admin_cfg = crate::admin::AdminConfig {
994 role: config.admin_role.clone(),
995 };
996 mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
997 tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
998 }
999
1000 {
1033 let tool_limiter: Option<Arc<ToolRateLimiter>> =
1034 config.tool_rate_limit.map(build_tool_rate_limiter);
1035
1036 if rbac_swap.load().is_enabled() {
1037 tracing::info!("RBAC enforcement enabled on /mcp");
1038 }
1039 if let Some(limit) = config.tool_rate_limit {
1040 tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1041 }
1042
1043 let rbac_for_mw = Arc::clone(&rbac_swap);
1044 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1045 let p = rbac_for_mw.load_full();
1046 let tl = tool_limiter.clone();
1047 rbac_middleware(p, tl, req, next)
1048 }));
1049 }
1050
1051 if let Some(ref auth_config) = config.auth
1053 && auth_config.enabled
1054 {
1055 let Some(ref state) = auth_state else {
1056 return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1057 };
1058
1059 let methods: Vec<&str> = [
1060 auth_config.mtls.is_some().then_some("mTLS"),
1061 (!auth_config.api_keys.is_empty()).then_some("bearer"),
1062 #[cfg(feature = "oauth")]
1063 auth_config.oauth.is_some().then_some("oauth-jwt"),
1064 ]
1065 .into_iter()
1066 .flatten()
1067 .collect();
1068
1069 tracing::info!(
1070 methods = %methods.join(", "),
1071 api_keys = auth_config.api_keys.len(),
1072 "auth enabled on /mcp"
1073 );
1074
1075 let state_for_mw = Arc::clone(state);
1076 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1077 let s = Arc::clone(&state_for_mw);
1078 auth_middleware(s, req, next)
1079 }));
1080 }
1081
1082 mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1085 axum::http::StatusCode::REQUEST_TIMEOUT,
1086 config.request_timeout,
1087 ));
1088
1089 mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1093 config.max_request_body,
1094 ));
1095
1096 let mut effective_origins = config.allowed_origins.clone();
1103 if effective_origins.is_empty()
1104 && let Some(ref url) = config.public_url
1105 {
1106 if let Some(scheme_end) = url.find("://") {
1109 let after_scheme = &url[scheme_end + 3..];
1110 let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1111 let origin = format!("{}{}", &url[..scheme_end + 3], &after_scheme[..host_end]);
1112 tracing::info!(
1113 %origin,
1114 "auto-derived allowed origin from public_url"
1115 );
1116 effective_origins.push(origin);
1117 }
1118 }
1119 let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1120 let cors_origins = Arc::clone(&allowed_origins);
1121 let log_request_headers = config.log_request_headers;
1122
1123 let readyz_route = if let Some(check) = config.readiness_check.take() {
1124 axum::routing::get(move || readyz(Arc::clone(&check)))
1125 } else {
1126 axum::routing::get(healthz)
1127 };
1128
1129 #[allow(unused_mut)] let mut router = axum::Router::new()
1131 .route("/healthz", axum::routing::get(healthz))
1132 .route("/readyz", readyz_route)
1133 .route(
1134 "/version",
1135 axum::routing::get({
1136 let payload_bytes: Arc<[u8]> =
1141 serialize_version_payload(&config.name, &config.version);
1142 move || {
1143 let p = Arc::clone(&payload_bytes);
1144 async move {
1145 (
1146 [(axum::http::header::CONTENT_TYPE, "application/json")],
1147 p.to_vec(),
1148 )
1149 }
1150 }
1151 }),
1152 )
1153 .merge(mcp_router);
1154
1155 if let Some(extra) = config.extra_router.take() {
1157 router = router.merge(extra);
1158 }
1159
1160 let server_url = if let Some(ref url) = config.public_url {
1167 url.trim_end_matches('/').to_owned()
1168 } else {
1169 let prm_scheme = if config.tls_cert_path.is_some() {
1170 "https"
1171 } else {
1172 "http"
1173 };
1174 format!("{prm_scheme}://{}", config.bind_addr)
1175 };
1176 let resource_url = format!("{server_url}/mcp");
1177
1178 #[cfg(feature = "oauth")]
1179 let prm_metadata = if let Some(ref auth_config) = config.auth
1180 && let Some(ref oauth_config) = auth_config.oauth
1181 {
1182 crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1183 } else {
1184 serde_json::json!({ "resource": resource_url })
1185 };
1186 #[cfg(not(feature = "oauth"))]
1187 let prm_metadata = serde_json::json!({ "resource": resource_url });
1188
1189 router = router.route(
1190 "/.well-known/oauth-protected-resource",
1191 axum::routing::get(move || {
1192 let m = prm_metadata.clone();
1193 async move { axum::Json(m) }
1194 }),
1195 );
1196
1197 #[cfg(feature = "oauth")]
1202 if let Some(ref auth_config) = config.auth
1203 && let Some(ref oauth_config) = auth_config.oauth
1204 && oauth_config.proxy.is_some()
1205 {
1206 router =
1207 install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1208 }
1209
1210 let is_tls = config.tls_cert_path.is_some();
1213 let security_headers_cfg = Arc::new(config.security_headers.clone());
1214 router = router.layer(axum::middleware::from_fn(move |req, next| {
1215 let cfg = Arc::clone(&security_headers_cfg);
1216 security_headers_middleware(is_tls, cfg, req, next)
1217 }));
1218
1219 if !cors_origins.is_empty() {
1223 let cors = tower_http::cors::CorsLayer::new()
1224 .allow_origin(
1225 cors_origins
1226 .iter()
1227 .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1228 .collect::<Vec<_>>(),
1229 )
1230 .allow_methods([
1231 axum::http::Method::GET,
1232 axum::http::Method::POST,
1233 axum::http::Method::OPTIONS,
1234 ])
1235 .allow_headers([
1236 axum::http::header::CONTENT_TYPE,
1237 axum::http::header::AUTHORIZATION,
1238 ]);
1239 router = router.layer(cors);
1240 }
1241
1242 if config.compression_enabled {
1246 use tower_http::compression::Predicate as _;
1247 let predicate = tower_http::compression::DefaultPredicate::new().and(
1248 tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1249 );
1250 router = router.layer(
1251 tower_http::compression::CompressionLayer::new()
1252 .gzip(true)
1253 .br(true)
1254 .compress_when(predicate),
1255 );
1256 tracing::info!(
1257 min_size = config.compression_min_size,
1258 "response compression enabled (gzip, br)"
1259 );
1260 }
1261
1262 if let Some(max) = config.max_concurrent_requests {
1265 let overload_handler = tower::ServiceBuilder::new()
1266 .layer(axum::error_handling::HandleErrorLayer::new(
1267 |_err: tower::BoxError| async {
1268 (
1269 axum::http::StatusCode::SERVICE_UNAVAILABLE,
1270 axum::Json(serde_json::json!({
1271 "error": "overloaded",
1272 "error_description": "server is at capacity, retry later"
1273 })),
1274 )
1275 },
1276 ))
1277 .layer(tower::load_shed::LoadShedLayer::new())
1278 .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1279 router = router.layer(overload_handler);
1280 tracing::info!(max, "global concurrency limit enabled");
1281 }
1282
1283 router = router.fallback(|| async {
1287 (
1288 axum::http::StatusCode::NOT_FOUND,
1289 axum::Json(serde_json::json!({
1290 "error": "not_found",
1291 "error_description": "The requested endpoint does not exist"
1292 })),
1293 )
1294 });
1295
1296 #[cfg(feature = "metrics")]
1298 if config.metrics_enabled {
1299 let metrics = Arc::new(
1300 crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1301 );
1302 let m = Arc::clone(&metrics);
1303 router = router.layer(axum::middleware::from_fn(
1304 move |req: Request<Body>, next: Next| {
1305 let m = Arc::clone(&m);
1306 metrics_middleware(m, req, next)
1307 },
1308 ));
1309 let metrics_bind = config.metrics_bind.clone();
1310 let metrics_shutdown = ct.clone();
1311 tokio::spawn(async move {
1312 if let Err(e) =
1313 crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1314 {
1315 tracing::error!("metrics listener failed: {e}");
1316 }
1317 });
1318 }
1319
1320 router = router.layer(axum::middleware::from_fn(move |req, next| {
1331 let origins = Arc::clone(&allowed_origins);
1332 origin_check_middleware(origins, log_request_headers, req, next)
1333 }));
1334
1335 let scheme = if config.tls_cert_path.is_some() {
1336 "https"
1337 } else {
1338 "http"
1339 };
1340
1341 let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1342 (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1343 _ => None,
1344 };
1345 let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1346
1347 Ok((
1348 router,
1349 AppRunParams {
1350 tls_paths,
1351 mtls_config,
1352 shutdown_timeout: config.shutdown_timeout,
1353 auth_state,
1354 rbac_swap,
1355 on_reload_ready: config.on_reload_ready.take(),
1356 ct,
1357 scheme,
1358 name: config.name.clone(),
1359 },
1360 ))
1361}
1362
1363pub async fn serve<H, F>(
1380 config: Validated<McpServerConfig>,
1381 handler_factory: F,
1382) -> Result<(), McpxError>
1383where
1384 H: ServerHandler + 'static,
1385 F: Fn() -> H + Send + Sync + Clone + 'static,
1386{
1387 let config = config.into_inner();
1388 #[allow(
1389 deprecated,
1390 reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1391 )]
1392 let bind_addr = config.bind_addr.clone();
1393 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1394
1395 let listener = TcpListener::bind(&bind_addr)
1396 .await
1397 .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1398 log_listening(¶ms.name, params.scheme, &bind_addr);
1399
1400 run_server(
1401 router,
1402 listener,
1403 params.tls_paths,
1404 params.mtls_config,
1405 params.shutdown_timeout,
1406 params.auth_state,
1407 params.rbac_swap,
1408 params.on_reload_ready,
1409 params.ct,
1410 )
1411 .await
1412 .map_err(anyhow_to_startup)
1413}
1414
1415pub async fn serve_with_listener<H, F>(
1445 listener: TcpListener,
1446 config: Validated<McpServerConfig>,
1447 handler_factory: F,
1448 ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1449 shutdown: Option<CancellationToken>,
1450) -> Result<(), McpxError>
1451where
1452 H: ServerHandler + 'static,
1453 F: Fn() -> H + Send + Sync + Clone + 'static,
1454{
1455 let config = config.into_inner();
1456 let local_addr = listener
1457 .local_addr()
1458 .map_err(|e| io_to_startup("listener.local_addr", e))?;
1459 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1460
1461 log_listening(¶ms.name, params.scheme, &local_addr.to_string());
1462
1463 if let Some(external) = shutdown {
1467 let internal = params.ct.clone();
1468 tokio::spawn(async move {
1469 external.cancelled().await;
1470 internal.cancel();
1471 });
1472 }
1473
1474 if let Some(tx) = ready_tx {
1478 let _ = tx.send(local_addr);
1480 }
1481
1482 run_server(
1483 router,
1484 listener,
1485 params.tls_paths,
1486 params.mtls_config,
1487 params.shutdown_timeout,
1488 params.auth_state,
1489 params.rbac_swap,
1490 params.on_reload_ready,
1491 params.ct,
1492 )
1493 .await
1494 .map_err(anyhow_to_startup)
1495}
1496
1497#[allow(
1500 clippy::cognitive_complexity,
1501 reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1502)]
1503fn log_listening(name: &str, scheme: &str, addr: &str) {
1504 tracing::info!("{name} listening on {addr}");
1505 tracing::info!(" MCP endpoint: {scheme}://{addr}/mcp");
1506 tracing::info!(" Health check: {scheme}://{addr}/healthz");
1507 tracing::info!(" Readiness: {scheme}://{addr}/readyz");
1508}
1509
1510#[allow(
1533 clippy::too_many_arguments,
1534 clippy::cognitive_complexity,
1535 reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1536)]
1537async fn run_server(
1538 router: axum::Router,
1539 listener: TcpListener,
1540 tls_paths: Option<(PathBuf, PathBuf)>,
1541 mtls_config: Option<MtlsConfig>,
1542 shutdown_timeout: Duration,
1543 auth_state: Option<Arc<AuthState>>,
1544 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1545 mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1546 ct: CancellationToken,
1547) -> anyhow::Result<()> {
1548 let shutdown_trigger = CancellationToken::new();
1552 {
1553 let trigger = shutdown_trigger.clone();
1554 let parent = ct.clone();
1555 tokio::spawn(async move {
1556 tokio::select! {
1557 () = shutdown_signal() => {}
1558 () = parent.cancelled() => {}
1559 }
1560 trigger.cancel();
1561 });
1562 }
1563
1564 let graceful = {
1565 let trigger = shutdown_trigger.clone();
1566 let ct = ct.clone();
1567 async move {
1568 trigger.cancelled().await;
1569 tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1570 ct.cancel();
1571 }
1572 };
1573
1574 let force_exit_timer = {
1575 let trigger = shutdown_trigger.clone();
1576 async move {
1577 trigger.cancelled().await;
1578 tokio::time::sleep(shutdown_timeout).await;
1579 }
1580 };
1581
1582 if let Some((cert_path, key_path)) = tls_paths {
1583 let crl_set = if let Some(mtls) = mtls_config.as_ref()
1584 && mtls.crl_enabled
1585 {
1586 let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1587 let (crl_set, discover_rx) =
1588 mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1589 .await
1590 .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1591 tokio::spawn(mtls_revocation::run_crl_refresher(
1592 Arc::clone(&crl_set),
1593 discover_rx,
1594 ct.clone(),
1595 ));
1596 Some(crl_set)
1597 } else {
1598 None
1599 };
1600
1601 if let Some(cb) = on_reload_ready.take() {
1602 cb(ReloadHandle {
1603 auth: auth_state.clone(),
1604 rbac: Some(Arc::clone(&rbac_swap)),
1605 crl_set: crl_set.clone(),
1606 });
1607 }
1608
1609 let tls_listener = TlsListener::new(
1610 listener,
1611 &cert_path,
1612 &key_path,
1613 mtls_config.as_ref(),
1614 crl_set,
1615 )?;
1616 let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1617 tokio::select! {
1618 result = axum::serve(tls_listener, make_svc)
1619 .with_graceful_shutdown(graceful) => { result?; }
1620 () = force_exit_timer => {
1621 tracing::warn!("shutdown timeout exceeded, forcing exit");
1622 }
1623 }
1624 } else {
1625 if let Some(cb) = on_reload_ready.take() {
1626 cb(ReloadHandle {
1627 auth: auth_state,
1628 rbac: Some(rbac_swap),
1629 crl_set: None,
1630 });
1631 }
1632
1633 let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1634 tokio::select! {
1635 result = axum::serve(listener, make_svc)
1636 .with_graceful_shutdown(graceful) => { result?; }
1637 () = force_exit_timer => {
1638 tracing::warn!("shutdown timeout exceeded, forcing exit");
1639 }
1640 }
1641 }
1642
1643 Ok(())
1644}
1645
1646#[cfg(feature = "oauth")]
1655fn install_oauth_proxy_routes(
1656 router: axum::Router,
1657 server_url: &str,
1658 oauth_config: &crate::oauth::OAuthConfig,
1659 auth_state: Option<&Arc<AuthState>>,
1660) -> Result<axum::Router, McpxError> {
1661 let Some(ref proxy) = oauth_config.proxy else {
1662 return Ok(router);
1663 };
1664
1665 let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1668
1669 let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1670 let router = router.route(
1671 "/.well-known/oauth-authorization-server",
1672 axum::routing::get(move || {
1673 let m = asm.clone();
1674 async move { axum::Json(m) }
1675 }),
1676 );
1677
1678 let proxy_authorize = proxy.clone();
1679 let router = router.route(
1680 "/authorize",
1681 axum::routing::get(
1682 move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1683 let p = proxy_authorize.clone();
1684 async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1685 },
1686 ),
1687 );
1688
1689 let proxy_token = proxy.clone();
1690 let token_http = http.clone();
1691 let router = router.route(
1692 "/token",
1693 axum::routing::post(move |body: String| {
1694 let p = proxy_token.clone();
1695 let h = token_http.clone();
1696 async move { crate::oauth::handle_token(&h, &p, &body).await }
1697 })
1698 .layer(axum::middleware::from_fn(
1699 oauth_token_cache_headers_middleware,
1700 )),
1701 );
1702
1703 let proxy_register = proxy.clone();
1704 let router = router.route(
1705 "/register",
1706 axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1707 let p = proxy_register;
1708 async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1709 })
1710 .layer(axum::middleware::from_fn(
1711 oauth_token_cache_headers_middleware,
1712 )),
1713 );
1714
1715 let admin_routes_enabled = proxy.expose_admin_endpoints
1716 && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1717 if proxy.expose_admin_endpoints
1718 && !proxy.require_auth_on_admin_endpoints
1719 && proxy.allow_unauthenticated_admin_endpoints
1720 {
1721 tracing::warn!(
1725 "OAuth introspect/revoke endpoints are unauthenticated by explicit \
1726 allow_unauthenticated_admin_endpoints opt-out; ensure an \
1727 authenticated reverse proxy fronts these routes"
1728 );
1729 }
1730
1731 let admin_router = if admin_routes_enabled {
1732 build_oauth_admin_router(proxy, http, auth_state)?
1733 } else {
1734 axum::Router::new()
1735 };
1736
1737 let router = router.merge(admin_router);
1738
1739 tracing::info!(
1740 introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1741 revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1742 "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1743 );
1744 Ok(router)
1745}
1746
1747#[cfg(feature = "oauth")]
1753fn build_oauth_admin_router(
1754 proxy: &crate::oauth::OAuthProxyConfig,
1755 http: crate::oauth::OauthHttpClient,
1756 auth_state: Option<&Arc<AuthState>>,
1757) -> Result<axum::Router, McpxError> {
1758 let mut admin_router = axum::Router::new();
1759 if proxy.introspection_url.is_some() {
1760 let proxy_introspect = proxy.clone();
1761 let introspect_http = http.clone();
1762 admin_router = admin_router.route(
1763 "/introspect",
1764 axum::routing::post(move |body: String| {
1765 let p = proxy_introspect.clone();
1766 let h = introspect_http.clone();
1767 async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1768 }),
1769 );
1770 }
1771 if proxy.revocation_url.is_some() {
1772 let proxy_revoke = proxy.clone();
1773 let revoke_http = http;
1774 admin_router = admin_router.route(
1775 "/revoke",
1776 axum::routing::post(move |body: String| {
1777 let p = proxy_revoke.clone();
1778 let h = revoke_http.clone();
1779 async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1780 }),
1781 );
1782 }
1783
1784 let admin_router = admin_router.layer(axum::middleware::from_fn(
1785 oauth_token_cache_headers_middleware,
1786 ));
1787
1788 if proxy.require_auth_on_admin_endpoints {
1789 let Some(state) = auth_state else {
1790 return Err(McpxError::Startup(
1791 "oauth proxy admin endpoints require auth state".into(),
1792 ));
1793 };
1794 let state_for_mw = Arc::clone(state);
1795 Ok(
1796 admin_router.layer(axum::middleware::from_fn(move |req, next| {
1797 let s = Arc::clone(&state_for_mw);
1798 auth_middleware(s, req, next)
1799 })),
1800 )
1801 } else {
1802 Ok(admin_router)
1803 }
1804}
1805
1806fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
1811 let mut hosts = vec![
1812 "localhost".to_owned(),
1813 "127.0.0.1".to_owned(),
1814 "::1".to_owned(),
1815 ];
1816
1817 if let Some(url) = public_url
1818 && let Ok(uri) = url.parse::<axum::http::Uri>()
1819 && let Some(authority) = uri.authority()
1820 {
1821 let host = authority.host().to_owned();
1822 if !hosts.iter().any(|h| h == &host) {
1823 hosts.push(host);
1824 }
1825
1826 let authority = authority.as_str().to_owned();
1827 if !hosts.iter().any(|h| h == &authority) {
1828 hosts.push(authority);
1829 }
1830 }
1831
1832 if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
1833 && let Some(authority) = uri.authority()
1834 {
1835 let host = authority.host().to_owned();
1836 if !hosts.iter().any(|h| h == &host) {
1837 hosts.push(host);
1838 }
1839
1840 let authority = authority.as_str().to_owned();
1841 if !hosts.iter().any(|h| h == &authority) {
1842 hosts.push(authority);
1843 }
1844 }
1845
1846 hosts
1847}
1848
1849impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
1862 for TlsConnInfo
1863{
1864 fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
1865 let addr = *target.remote_addr();
1866 let identity = target.io().identity().cloned();
1867 TlsConnInfo::new(addr, identity)
1868 }
1869}
1870
1871struct TlsListener {
1879 inner: TcpListener,
1880 acceptor: tokio_rustls::TlsAcceptor,
1881 mtls_default_role: String,
1882}
1883
1884impl TlsListener {
1885 fn new(
1886 inner: TcpListener,
1887 cert_path: &Path,
1888 key_path: &Path,
1889 mtls_config: Option<&MtlsConfig>,
1890 crl_set: Option<Arc<CrlSet>>,
1891 ) -> anyhow::Result<Self> {
1892 rustls::crypto::ring::default_provider()
1894 .install_default()
1895 .ok();
1896
1897 let certs = load_certs(cert_path)?;
1898 let key = load_key(key_path)?;
1899
1900 let mtls_default_role;
1901
1902 let tls_config = if let Some(mtls) = mtls_config {
1903 mtls_default_role = mtls.default_role.clone();
1904 let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
1905 {
1906 let Some(crl_set) = crl_set else {
1907 return Err(anyhow::anyhow!(
1908 "mTLS CRL verifier requested but CRL state was not initialized"
1909 ));
1910 };
1911 Arc::new(DynamicClientCertVerifier::new(crl_set))
1912 } else {
1913 let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
1914 if mtls.required {
1915 rustls::server::WebPkiClientVerifier::builder(root_store)
1916 .build()
1917 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1918 } else {
1919 rustls::server::WebPkiClientVerifier::builder(root_store)
1920 .allow_unauthenticated()
1921 .build()
1922 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1923 }
1924 };
1925
1926 tracing::info!(
1927 ca = %mtls.ca_cert_path.display(),
1928 required = mtls.required,
1929 crl_enabled = mtls.crl_enabled,
1930 "mTLS client auth configured"
1931 );
1932
1933 rustls::ServerConfig::builder_with_protocol_versions(&[
1934 &rustls::version::TLS12,
1935 &rustls::version::TLS13,
1936 ])
1937 .with_client_cert_verifier(verifier)
1938 .with_single_cert(certs, key)?
1939 } else {
1940 mtls_default_role = "viewer".to_owned();
1941 rustls::ServerConfig::builder_with_protocol_versions(&[
1942 &rustls::version::TLS12,
1943 &rustls::version::TLS13,
1944 ])
1945 .with_no_client_auth()
1946 .with_single_cert(certs, key)?
1947 };
1948
1949 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
1950 tracing::info!(
1951 "TLS enabled (cert: {}, key: {})",
1952 cert_path.display(),
1953 key_path.display()
1954 );
1955 Ok(Self {
1956 inner,
1957 acceptor,
1958 mtls_default_role,
1959 })
1960 }
1961
1962 fn extract_handshake_identity(
1966 tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1967 default_role: &str,
1968 addr: SocketAddr,
1969 ) -> Option<AuthIdentity> {
1970 let (_, server_conn) = tls_stream.get_ref();
1971 let cert_der = server_conn.peer_certificates()?.first()?;
1972 let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
1973 tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
1974 Some(id)
1975 }
1976}
1977
1978pub(crate) struct AuthenticatedTlsStream {
1990 inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1991 identity: Option<AuthIdentity>,
1992}
1993
1994impl AuthenticatedTlsStream {
1995 #[must_use]
1997 pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
1998 self.identity.as_ref()
1999 }
2000}
2001
2002impl std::fmt::Debug for AuthenticatedTlsStream {
2003 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2004 f.debug_struct("AuthenticatedTlsStream")
2005 .field("identity", &self.identity.as_ref().map(|id| &id.name))
2006 .finish_non_exhaustive()
2007 }
2008}
2009
2010impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2011 fn poll_read(
2012 mut self: Pin<&mut Self>,
2013 cx: &mut std::task::Context<'_>,
2014 buf: &mut tokio::io::ReadBuf<'_>,
2015 ) -> std::task::Poll<std::io::Result<()>> {
2016 Pin::new(&mut self.inner).poll_read(cx, buf)
2017 }
2018}
2019
2020impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2021 fn poll_write(
2022 mut self: Pin<&mut Self>,
2023 cx: &mut std::task::Context<'_>,
2024 buf: &[u8],
2025 ) -> std::task::Poll<std::io::Result<usize>> {
2026 Pin::new(&mut self.inner).poll_write(cx, buf)
2027 }
2028
2029 fn poll_flush(
2030 mut self: Pin<&mut Self>,
2031 cx: &mut std::task::Context<'_>,
2032 ) -> std::task::Poll<std::io::Result<()>> {
2033 Pin::new(&mut self.inner).poll_flush(cx)
2034 }
2035
2036 fn poll_shutdown(
2037 mut self: Pin<&mut Self>,
2038 cx: &mut std::task::Context<'_>,
2039 ) -> std::task::Poll<std::io::Result<()>> {
2040 Pin::new(&mut self.inner).poll_shutdown(cx)
2041 }
2042
2043 fn poll_write_vectored(
2044 mut self: Pin<&mut Self>,
2045 cx: &mut std::task::Context<'_>,
2046 bufs: &[std::io::IoSlice<'_>],
2047 ) -> std::task::Poll<std::io::Result<usize>> {
2048 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2049 }
2050
2051 fn is_write_vectored(&self) -> bool {
2052 self.inner.is_write_vectored()
2053 }
2054}
2055
2056impl axum::serve::Listener for TlsListener {
2057 type Io = AuthenticatedTlsStream;
2058 type Addr = SocketAddr;
2059
2060 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2061 loop {
2062 let (stream, addr) = match self.inner.accept().await {
2063 Ok(pair) => pair,
2064 Err(e) => {
2065 tracing::debug!("TCP accept error: {e}");
2066 continue;
2067 }
2068 };
2069 let tls_stream = match self.acceptor.accept(stream).await {
2070 Ok(s) => s,
2071 Err(e) => {
2072 tracing::debug!("TLS handshake failed from {addr}: {e}");
2073 continue;
2074 }
2075 };
2076 let identity =
2077 Self::extract_handshake_identity(&tls_stream, &self.mtls_default_role, addr);
2078 let wrapped = AuthenticatedTlsStream {
2079 inner: tls_stream,
2080 identity,
2081 };
2082 return (wrapped, addr);
2083 }
2084 }
2085
2086 fn local_addr(&self) -> std::io::Result<Self::Addr> {
2087 self.inner.local_addr()
2088 }
2089}
2090
2091fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2092 use rustls::pki_types::pem::PemObject;
2093 let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2094 .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2095 .collect::<Result<_, _>>()
2096 .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2097 anyhow::ensure!(
2098 !certs.is_empty(),
2099 "no certificates found in {}",
2100 path.display()
2101 );
2102 Ok(certs)
2103}
2104
2105fn load_client_auth_roots(
2106 path: &Path,
2107) -> anyhow::Result<(
2108 Vec<rustls::pki_types::CertificateDer<'static>>,
2109 Arc<RootCertStore>,
2110)> {
2111 let ca_certs = load_certs(path)?;
2112 let mut root_store = RootCertStore::empty();
2113 for cert in &ca_certs {
2114 root_store
2115 .add(cert.clone())
2116 .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2117 }
2118
2119 Ok((ca_certs, Arc::new(root_store)))
2120}
2121
2122fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2123 use rustls::pki_types::pem::PemObject;
2124 rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2125 .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2126}
2127
2128#[allow(clippy::unused_async)]
2129async fn healthz() -> impl IntoResponse {
2130 axum::Json(serde_json::json!({
2131 "status": "ok",
2132 }))
2133}
2134
2135fn version_payload(name: &str, version: &str) -> serde_json::Value {
2141 serde_json::json!({
2142 "name": name,
2143 "version": version,
2144 "build_git_sha": option_env!("MCPX_BUILD_SHA").unwrap_or("unknown"),
2145 "build_timestamp": option_env!("MCPX_BUILD_TIME").unwrap_or("unknown"),
2146 "rust_version": option_env!("MCPX_RUSTC_VERSION").unwrap_or("unknown"),
2147 "mcpx_version": env!("CARGO_PKG_VERSION"),
2148 })
2149}
2150
2151fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2161 let value = version_payload(name, version);
2162 serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2163}
2164
2165async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2166 let status = check().await;
2167 let ready = status
2168 .get("ready")
2169 .and_then(serde_json::Value::as_bool)
2170 .unwrap_or(false);
2171 let code = if ready {
2172 axum::http::StatusCode::OK
2173 } else {
2174 axum::http::StatusCode::SERVICE_UNAVAILABLE
2175 };
2176 (code, axum::Json(status))
2177}
2178
2179async fn shutdown_signal() {
2183 let ctrl_c = tokio::signal::ctrl_c();
2184
2185 #[cfg(unix)]
2186 {
2187 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2188 Ok(mut term) => {
2189 tokio::select! {
2190 _ = ctrl_c => {}
2191 _ = term.recv() => {}
2192 }
2193 }
2194 Err(e) => {
2195 tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2196 ctrl_c.await.ok();
2197 }
2198 }
2199 }
2200
2201 #[cfg(not(unix))]
2202 {
2203 ctrl_c.await.ok();
2204 }
2205}
2206
2207#[cfg(feature = "metrics")]
2213async fn metrics_middleware(
2214 metrics: Arc<crate::metrics::McpMetrics>,
2215 req: Request<Body>,
2216 next: Next,
2217) -> axum::response::Response {
2218 let method = req.method().to_string();
2219 let path = req.uri().path().to_owned();
2220 let start = std::time::Instant::now();
2221
2222 let response = next.run(req).await;
2223
2224 let status = response.status().as_u16().to_string();
2225 let duration = start.elapsed().as_secs_f64();
2226
2227 metrics
2228 .http_requests_total
2229 .with_label_values(&[&method, &path, &status])
2230 .inc();
2231 metrics
2232 .http_request_duration_seconds
2233 .with_label_values(&[&method, &path])
2234 .observe(duration);
2235
2236 response
2237}
2238
2239async fn security_headers_middleware(
2251 is_tls: bool,
2252 cfg: Arc<SecurityHeadersConfig>,
2253 req: Request<Body>,
2254 next: Next,
2255) -> axum::response::Response {
2256 use axum::http::{HeaderName, header};
2257
2258 let mut resp = next.run(req).await;
2259 let headers = resp.headers_mut();
2260
2261 headers.remove(header::SERVER);
2263 headers.remove(HeaderName::from_static("x-powered-by"));
2264
2265 apply_security_header(
2266 headers,
2267 header::X_CONTENT_TYPE_OPTIONS,
2268 cfg.x_content_type_options.as_deref(),
2269 "nosniff",
2270 );
2271 apply_security_header(
2272 headers,
2273 header::X_FRAME_OPTIONS,
2274 cfg.x_frame_options.as_deref(),
2275 "deny",
2276 );
2277 apply_security_header(
2278 headers,
2279 header::CACHE_CONTROL,
2280 cfg.cache_control.as_deref(),
2281 "no-store, max-age=0",
2282 );
2283 apply_security_header(
2284 headers,
2285 header::REFERRER_POLICY,
2286 cfg.referrer_policy.as_deref(),
2287 "no-referrer",
2288 );
2289 apply_security_header(
2290 headers,
2291 HeaderName::from_static("cross-origin-opener-policy"),
2292 cfg.cross_origin_opener_policy.as_deref(),
2293 "same-origin",
2294 );
2295 apply_security_header(
2296 headers,
2297 HeaderName::from_static("cross-origin-resource-policy"),
2298 cfg.cross_origin_resource_policy.as_deref(),
2299 "same-origin",
2300 );
2301 apply_security_header(
2302 headers,
2303 HeaderName::from_static("cross-origin-embedder-policy"),
2304 cfg.cross_origin_embedder_policy.as_deref(),
2305 "require-corp",
2306 );
2307 apply_security_header(
2308 headers,
2309 HeaderName::from_static("permissions-policy"),
2310 cfg.permissions_policy.as_deref(),
2311 "accelerometer=(), camera=(), geolocation=(), microphone=()",
2312 );
2313 apply_security_header(
2314 headers,
2315 HeaderName::from_static("x-permitted-cross-domain-policies"),
2316 cfg.x_permitted_cross_domain_policies.as_deref(),
2317 "none",
2318 );
2319 apply_security_header(
2320 headers,
2321 HeaderName::from_static("content-security-policy"),
2322 cfg.content_security_policy.as_deref(),
2323 "default-src 'none'; frame-ancestors 'none'",
2324 );
2325 apply_security_header(
2326 headers,
2327 HeaderName::from_static("x-dns-prefetch-control"),
2328 cfg.x_dns_prefetch_control.as_deref(),
2329 "off",
2330 );
2331
2332 if is_tls {
2333 apply_security_header(
2334 headers,
2335 header::STRICT_TRANSPORT_SECURITY,
2336 cfg.strict_transport_security.as_deref(),
2337 "max-age=63072000; includeSubDomains",
2338 );
2339 }
2340
2341 resp
2342}
2343
2344fn apply_security_header(
2355 headers: &mut axum::http::HeaderMap,
2356 name: axum::http::HeaderName,
2357 override_value: Option<&str>,
2358 default: &'static str,
2359) {
2360 use axum::http::HeaderValue;
2361
2362 match override_value {
2363 None => {
2364 headers.insert(name, HeaderValue::from_static(default));
2365 }
2366 Some("") => {
2367 }
2369 Some(v) => match HeaderValue::from_str(v) {
2370 Ok(hv) => {
2371 headers.insert(name, hv);
2372 }
2373 Err(err) => {
2374 tracing::error!(
2375 header = %name,
2376 error = %err,
2377 "invalid security header override reached middleware; using default"
2378 );
2379 headers.insert(name, HeaderValue::from_static(default));
2380 }
2381 },
2382 }
2383}
2384
2385fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2396 use axum::http::HeaderValue;
2397
2398 let fields: &[(&str, Option<&str>)] = &[
2399 (
2400 "x_content_type_options",
2401 cfg.x_content_type_options.as_deref(),
2402 ),
2403 ("x_frame_options", cfg.x_frame_options.as_deref()),
2404 ("cache_control", cfg.cache_control.as_deref()),
2405 ("referrer_policy", cfg.referrer_policy.as_deref()),
2406 (
2407 "cross_origin_opener_policy",
2408 cfg.cross_origin_opener_policy.as_deref(),
2409 ),
2410 (
2411 "cross_origin_resource_policy",
2412 cfg.cross_origin_resource_policy.as_deref(),
2413 ),
2414 (
2415 "cross_origin_embedder_policy",
2416 cfg.cross_origin_embedder_policy.as_deref(),
2417 ),
2418 ("permissions_policy", cfg.permissions_policy.as_deref()),
2419 (
2420 "x_permitted_cross_domain_policies",
2421 cfg.x_permitted_cross_domain_policies.as_deref(),
2422 ),
2423 (
2424 "content_security_policy",
2425 cfg.content_security_policy.as_deref(),
2426 ),
2427 (
2428 "x_dns_prefetch_control",
2429 cfg.x_dns_prefetch_control.as_deref(),
2430 ),
2431 (
2432 "strict_transport_security",
2433 cfg.strict_transport_security.as_deref(),
2434 ),
2435 ];
2436
2437 for (field, value) in fields {
2438 let Some(v) = value else { continue };
2439 if v.is_empty() {
2440 continue;
2441 }
2442 if let Err(err) = HeaderValue::from_str(v) {
2443 return Err(McpxError::Config(format!(
2444 "invalid security_headers.{field}: {err}"
2445 )));
2446 }
2447 }
2448
2449 if let Some(v) = cfg.strict_transport_security.as_deref()
2450 && !v.is_empty()
2451 && v.to_ascii_lowercase().contains("preload")
2452 {
2453 return Err(McpxError::Config(format!(
2454 "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2455 HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2456 )));
2457 }
2458
2459 Ok(())
2460}
2461
2462#[cfg(feature = "oauth")]
2477async fn oauth_token_cache_headers_middleware(
2478 req: Request<Body>,
2479 next: Next,
2480) -> axum::response::Response {
2481 use axum::http::{HeaderValue, header};
2482
2483 let mut resp = next.run(req).await;
2484 let headers = resp.headers_mut();
2485 headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2486 headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2487 resp
2488}
2489
2490async fn origin_check_middleware(
2494 allowed: Arc<[String]>,
2495 log_request_headers: bool,
2496 req: Request<Body>,
2497 next: Next,
2498) -> axum::response::Response {
2499 let method = req.method().clone();
2500 let path = req.uri().path().to_owned();
2501
2502 log_incoming_request(&method, &path, req.headers(), log_request_headers);
2503
2504 if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2505 let origin_str = origin.to_str().unwrap_or("");
2506 if !allowed.iter().any(|a| a == origin_str) {
2507 tracing::warn!(
2508 origin = origin_str,
2509 %method,
2510 %path,
2511 allowed = ?&*allowed,
2512 "rejected request: Origin not allowed"
2513 );
2514 return (
2515 axum::http::StatusCode::FORBIDDEN,
2516 "Forbidden: Origin not allowed",
2517 )
2518 .into_response();
2519 }
2520 }
2521 next.run(req).await
2522}
2523
2524fn log_incoming_request(
2527 method: &axum::http::Method,
2528 path: &str,
2529 headers: &axum::http::HeaderMap,
2530 log_request_headers: bool,
2531) {
2532 if log_request_headers {
2533 tracing::debug!(
2534 %method,
2535 %path,
2536 headers = %format_request_headers_for_log(headers),
2537 "incoming request"
2538 );
2539 } else {
2540 tracing::debug!(%method, %path, "incoming request");
2541 }
2542}
2543
2544fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2545 headers
2546 .iter()
2547 .map(|(k, v)| {
2548 let name = k.as_str();
2549 if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2550 format!("{name}: [REDACTED]")
2551 } else {
2552 format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2553 }
2554 })
2555 .collect::<Vec<_>>()
2556 .join(", ")
2557}
2558
2559#[allow(clippy::cognitive_complexity)]
2583pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2584where
2585 H: ServerHandler + 'static,
2586{
2587 use rmcp::ServiceExt as _;
2588
2589 tracing::info!("stdio transport: serving on stdin/stdout");
2590 tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2591
2592 let transport = rmcp::transport::io::stdio();
2593
2594 let service = handler
2595 .serve(transport)
2596 .await
2597 .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2598
2599 if let Err(e) = service.waiting().await {
2600 tracing::warn!(error = %e, "stdio session ended with error");
2601 }
2602 tracing::info!("stdio session ended");
2603 Ok(())
2604}
2605
2606#[cfg(test)]
2607mod tests {
2608 #![allow(
2609 clippy::unwrap_used,
2610 clippy::expect_used,
2611 clippy::panic,
2612 clippy::indexing_slicing,
2613 clippy::unwrap_in_result,
2614 clippy::print_stdout,
2615 clippy::print_stderr,
2616 deprecated,
2617 reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2618 )]
2619 use std::sync::Arc;
2620
2621 use axum::{
2622 body::Body,
2623 http::{Request, StatusCode, header},
2624 response::IntoResponse,
2625 };
2626 use http_body_util::BodyExt;
2627 use tower::ServiceExt as _;
2628
2629 use super::*;
2630
2631 #[test]
2634 fn server_config_new_defaults() {
2635 let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2636 assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2637 assert_eq!(cfg.name, "test-server");
2638 assert_eq!(cfg.version, "1.0.0");
2639 assert!(cfg.tls_cert_path.is_none());
2640 assert!(cfg.tls_key_path.is_none());
2641 assert!(cfg.auth.is_none());
2642 assert!(cfg.rbac.is_none());
2643 assert!(cfg.allowed_origins.is_empty());
2644 assert!(cfg.tool_rate_limit.is_none());
2645 assert!(cfg.readiness_check.is_none());
2646 assert_eq!(cfg.max_request_body, 1024 * 1024);
2647 assert_eq!(cfg.request_timeout, Duration::from_mins(2));
2648 assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
2649 assert!(!cfg.log_request_headers);
2650 }
2651
2652 #[test]
2653 fn validate_consumes_and_proves() {
2654 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2656 let validated = cfg.validate().expect("valid config");
2657 assert_eq!(validated.name, "test-server");
2659 let raw = validated.into_inner();
2661 assert_eq!(raw.name, "test-server");
2662
2663 let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2665 bad.max_request_body = 0;
2666 assert!(bad.validate().is_err(), "zero body cap must fail validate");
2667 }
2668
2669 #[test]
2670 fn validate_rejects_zero_max_concurrent_requests() {
2671 let cfg =
2672 McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
2673 let err = cfg.validate().expect_err("zero concurrency cap must fail");
2674 assert!(
2675 format!("{err}").contains("max_concurrent_requests"),
2676 "error should mention max_concurrent_requests, got: {err}"
2677 );
2678 }
2679
2680 #[test]
2681 fn validate_rejects_zero_max_tracked_keys() {
2682 let rl = crate::auth::RateLimitConfig {
2683 max_tracked_keys: 0,
2684 ..Default::default()
2685 };
2686 let auth_cfg = AuthConfig {
2687 enabled: true,
2688 rate_limit: Some(rl),
2689 ..Default::default()
2690 };
2691 let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
2692 let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
2693 assert!(
2694 format!("{err}").contains("max_tracked_keys"),
2695 "error should mention max_tracked_keys, got: {err}"
2696 );
2697 }
2698
2699 #[test]
2700 fn derive_allowed_hosts_includes_public_host() {
2701 let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
2702 assert!(
2703 hosts.iter().any(|h| h == "mcp.example.com"),
2704 "public_url host must be allowed"
2705 );
2706 }
2707
2708 #[test]
2709 fn derive_allowed_hosts_includes_bind_authority() {
2710 let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
2711 assert!(
2712 hosts.iter().any(|h| h == "127.0.0.1"),
2713 "bind host must be allowed"
2714 );
2715 assert!(
2716 hosts.iter().any(|h| h == "127.0.0.1:8080"),
2717 "bind authority must be allowed"
2718 );
2719 }
2720
2721 #[tokio::test]
2724 async fn healthz_returns_ok_json() {
2725 let resp = healthz().await.into_response();
2726 assert_eq!(resp.status(), StatusCode::OK);
2727 let body = resp.into_body().collect().await.unwrap().to_bytes();
2728 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2729 assert_eq!(json["status"], "ok");
2730 assert!(
2731 json.get("name").is_none(),
2732 "healthz must not expose server name"
2733 );
2734 assert!(
2735 json.get("version").is_none(),
2736 "healthz must not expose version"
2737 );
2738 }
2739
2740 #[tokio::test]
2743 async fn readyz_returns_ok_when_ready() {
2744 let check: ReadinessCheck =
2745 Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
2746 let resp = readyz(check).await.into_response();
2747 assert_eq!(resp.status(), StatusCode::OK);
2748 let body = resp.into_body().collect().await.unwrap().to_bytes();
2749 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2750 assert_eq!(json["ready"], true);
2751 assert!(
2752 json.get("name").is_none(),
2753 "readyz must not expose server name"
2754 );
2755 assert!(
2756 json.get("version").is_none(),
2757 "readyz must not expose version"
2758 );
2759 assert_eq!(json["db"], "connected");
2760 }
2761
2762 #[tokio::test]
2763 async fn readyz_returns_503_when_not_ready() {
2764 let check: ReadinessCheck =
2765 Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
2766 let resp = readyz(check).await.into_response();
2767 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2768 }
2769
2770 #[tokio::test]
2771 async fn readyz_returns_503_when_ready_missing() {
2772 let check: ReadinessCheck =
2773 Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
2774 let resp = readyz(check).await.into_response();
2775 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2777 }
2778
2779 fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
2783 let allowed: Arc<[String]> = Arc::from(origins);
2784 axum::Router::new()
2785 .route("/test", axum::routing::get(|| async { "ok" }))
2786 .layer(axum::middleware::from_fn(move |req, next| {
2787 let a = Arc::clone(&allowed);
2788 origin_check_middleware(a, log_request_headers, req, next)
2789 }))
2790 }
2791
2792 #[tokio::test]
2793 async fn origin_allowed_passes() {
2794 let app = origin_router(vec!["http://localhost:3000".into()], false);
2795 let req = Request::builder()
2796 .uri("/test")
2797 .header(header::ORIGIN, "http://localhost:3000")
2798 .body(Body::empty())
2799 .unwrap();
2800 let resp = app.oneshot(req).await.unwrap();
2801 assert_eq!(resp.status(), StatusCode::OK);
2802 }
2803
2804 #[tokio::test]
2805 async fn origin_rejected_returns_403() {
2806 let app = origin_router(vec!["http://localhost:3000".into()], false);
2807 let req = Request::builder()
2808 .uri("/test")
2809 .header(header::ORIGIN, "http://evil.com")
2810 .body(Body::empty())
2811 .unwrap();
2812 let resp = app.oneshot(req).await.unwrap();
2813 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2814 }
2815
2816 #[tokio::test]
2817 async fn no_origin_header_passes() {
2818 let app = origin_router(vec!["http://localhost:3000".into()], false);
2819 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2820 let resp = app.oneshot(req).await.unwrap();
2821 assert_eq!(resp.status(), StatusCode::OK);
2822 }
2823
2824 #[tokio::test]
2825 async fn empty_allowlist_rejects_any_origin() {
2826 let app = origin_router(vec![], false);
2827 let req = Request::builder()
2828 .uri("/test")
2829 .header(header::ORIGIN, "http://anything.com")
2830 .body(Body::empty())
2831 .unwrap();
2832 let resp = app.oneshot(req).await.unwrap();
2833 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2834 }
2835
2836 #[tokio::test]
2837 async fn empty_allowlist_passes_without_origin() {
2838 let app = origin_router(vec![], false);
2839 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2840 let resp = app.oneshot(req).await.unwrap();
2841 assert_eq!(resp.status(), StatusCode::OK);
2842 }
2843
2844 #[test]
2845 fn format_request_headers_redacts_sensitive_values() {
2846 let mut headers = axum::http::HeaderMap::new();
2847 headers.insert("authorization", "Bearer secret-token".parse().unwrap());
2848 headers.insert("cookie", "sid=abc".parse().unwrap());
2849 headers.insert("x-request-id", "req-123".parse().unwrap());
2850
2851 let out = format_request_headers_for_log(&headers);
2852 assert!(out.contains("authorization: [REDACTED]"));
2853 assert!(out.contains("cookie: [REDACTED]"));
2854 assert!(out.contains("x-request-id: req-123"));
2855 assert!(!out.contains("secret-token"));
2856 }
2857
2858 fn security_router(is_tls: bool) -> axum::Router {
2861 security_router_with(is_tls, SecurityHeadersConfig::default())
2862 }
2863
2864 fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
2865 let cfg = Arc::new(cfg);
2866 axum::Router::new()
2867 .route("/test", axum::routing::get(|| async { "ok" }))
2868 .layer(axum::middleware::from_fn(move |req, next| {
2869 let c = Arc::clone(&cfg);
2870 security_headers_middleware(is_tls, c, req, next)
2871 }))
2872 }
2873
2874 #[tokio::test]
2875 async fn security_headers_set_on_response() {
2876 let app = security_router(false);
2877 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2878 let resp = app.oneshot(req).await.unwrap();
2879 assert_eq!(resp.status(), StatusCode::OK);
2880
2881 let h = resp.headers();
2882 assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
2883 assert_eq!(h.get("x-frame-options").unwrap(), "deny");
2884 assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
2885 assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
2886 assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
2887 assert_eq!(
2888 h.get("cross-origin-resource-policy").unwrap(),
2889 "same-origin"
2890 );
2891 assert_eq!(
2892 h.get("cross-origin-embedder-policy").unwrap(),
2893 "require-corp"
2894 );
2895 assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
2896 assert!(
2897 h.get("permissions-policy")
2898 .unwrap()
2899 .to_str()
2900 .unwrap()
2901 .contains("camera=()"),
2902 "permissions-policy must restrict browser features"
2903 );
2904 assert_eq!(
2905 h.get("content-security-policy").unwrap(),
2906 "default-src 'none'; frame-ancestors 'none'"
2907 );
2908 assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
2909 assert!(h.get("strict-transport-security").is_none());
2911 }
2912
2913 #[tokio::test]
2914 async fn hsts_set_when_tls_enabled() {
2915 let app = security_router(true);
2916 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2917 let resp = app.oneshot(req).await.unwrap();
2918
2919 let hsts = resp.headers().get("strict-transport-security").unwrap();
2920 assert!(
2921 hsts.to_str().unwrap().contains("max-age=63072000"),
2922 "HSTS must set 2-year max-age"
2923 );
2924 }
2925
2926 fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
2932 let cfg =
2933 McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
2934 cfg.check()
2935 }
2936
2937 #[test]
2938 fn security_headers_config_default_validates() {
2939 check_with_security_headers(SecurityHeadersConfig::default())
2940 .expect("default SecurityHeadersConfig must validate");
2941 }
2942
2943 #[test]
2944 fn security_headers_config_validate_accepts_empty_string() {
2945 let h = SecurityHeadersConfig {
2947 x_content_type_options: Some(String::new()),
2948 x_frame_options: Some(String::new()),
2949 cache_control: Some(String::new()),
2950 referrer_policy: Some(String::new()),
2951 cross_origin_opener_policy: Some(String::new()),
2952 cross_origin_resource_policy: Some(String::new()),
2953 cross_origin_embedder_policy: Some(String::new()),
2954 permissions_policy: Some(String::new()),
2955 x_permitted_cross_domain_policies: Some(String::new()),
2956 content_security_policy: Some(String::new()),
2957 x_dns_prefetch_control: Some(String::new()),
2958 strict_transport_security: Some(String::new()),
2959 };
2960 check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
2961 }
2962
2963 #[test]
2964 fn security_headers_config_validate_rejects_bad_value() {
2965 let h = SecurityHeadersConfig {
2967 referrer_policy: Some("\u{0007}".into()),
2968 ..SecurityHeadersConfig::default()
2969 };
2970 let err = check_with_security_headers(h)
2971 .expect_err("control char in referrer_policy must reject");
2972 let msg = err.to_string();
2973 assert!(
2974 msg.contains("referrer_policy"),
2975 "error must name the offending field, got: {msg}"
2976 );
2977 }
2978
2979 #[test]
2980 fn security_headers_config_validate_rejects_hsts_preload() {
2981 let h = SecurityHeadersConfig {
2982 strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
2983 ..SecurityHeadersConfig::default()
2984 };
2985 let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
2986 let msg = err.to_string();
2987 assert!(
2988 msg.contains("strict_transport_security"),
2989 "error must name the field, got: {msg}"
2990 );
2991 assert!(
2992 msg.to_lowercase().contains("preload"),
2993 "error must mention `preload`, got: {msg}"
2994 );
2995 }
2996
2997 #[test]
2998 fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
2999 let h = SecurityHeadersConfig {
3001 strict_transport_security: Some("max-age=600; PRELOAD".into()),
3002 ..SecurityHeadersConfig::default()
3003 };
3004 check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
3005 }
3006
3007 #[tokio::test]
3008 async fn security_headers_override_honored() {
3009 let h = SecurityHeadersConfig {
3011 x_frame_options: Some("SAMEORIGIN".into()),
3012 ..SecurityHeadersConfig::default()
3013 };
3014 let app = security_router_with(false, h);
3015 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3016 let resp = app.oneshot(req).await.unwrap();
3017 assert_eq!(resp.status(), StatusCode::OK);
3018
3019 let xfo = resp.headers().get("x-frame-options").unwrap();
3020 assert_eq!(xfo, "SAMEORIGIN");
3021 }
3022
3023 #[tokio::test]
3024 async fn security_headers_empty_string_omits() {
3025 let h = SecurityHeadersConfig {
3027 referrer_policy: Some(String::new()),
3028 ..SecurityHeadersConfig::default()
3029 };
3030 let app = security_router_with(false, h);
3031 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3032 let resp = app.oneshot(req).await.unwrap();
3033 assert_eq!(resp.status(), StatusCode::OK);
3034
3035 assert!(
3036 resp.headers().get("referrer-policy").is_none(),
3037 "Some(\"\") must omit the header"
3038 );
3039 assert_eq!(
3041 resp.headers().get("x-content-type-options").unwrap(),
3042 "nosniff"
3043 );
3044 }
3045
3046 #[tokio::test]
3047 async fn security_headers_hsts_only_when_tls() {
3048 let h = SecurityHeadersConfig {
3050 strict_transport_security: Some("max-age=600".into()),
3051 ..SecurityHeadersConfig::default()
3052 };
3053 let app = security_router_with(false, h);
3054 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3055 let resp = app.oneshot(req).await.unwrap();
3056 assert!(
3057 resp.headers().get("strict-transport-security").is_none(),
3058 "HSTS must remain absent on plaintext deployments even with override"
3059 );
3060 }
3061
3062 #[cfg(feature = "oauth")]
3065 #[tokio::test]
3066 async fn oauth_token_cache_headers_set_pragma_and_vary() {
3067 let app = axum::Router::new()
3068 .route("/token", axum::routing::post(|| async { "{}" }))
3069 .layer(axum::middleware::from_fn(
3070 oauth_token_cache_headers_middleware,
3071 ));
3072 let req = Request::builder()
3073 .method("POST")
3074 .uri("/token")
3075 .body(Body::from("{}"))
3076 .unwrap();
3077 let resp = app.oneshot(req).await.unwrap();
3078 assert_eq!(resp.status(), StatusCode::OK);
3079
3080 let h = resp.headers();
3081 assert_eq!(
3082 h.get("pragma").unwrap(),
3083 "no-cache",
3084 "RFC 6749 §5.1: token responses must set Pragma: no-cache"
3085 );
3086 let vary_values: Vec<String> = h
3087 .get_all("vary")
3088 .iter()
3089 .filter_map(|v| v.to_str().ok().map(str::to_owned))
3090 .collect();
3091 assert!(
3092 vary_values
3093 .iter()
3094 .any(|v| v.eq_ignore_ascii_case("Authorization")),
3095 "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
3096 );
3097 }
3098
3099 #[cfg(feature = "oauth")]
3100 #[tokio::test]
3101 async fn oauth_token_cache_headers_preserve_existing_vary() {
3102 let app = axum::Router::new()
3105 .route(
3106 "/token",
3107 axum::routing::post(|| async {
3108 axum::response::Response::builder()
3109 .header("vary", "Accept-Encoding")
3110 .body(axum::body::Body::from("{}"))
3111 .unwrap()
3112 }),
3113 )
3114 .layer(axum::middleware::from_fn(
3115 oauth_token_cache_headers_middleware,
3116 ));
3117 let req = Request::builder()
3118 .method("POST")
3119 .uri("/token")
3120 .body(Body::empty())
3121 .unwrap();
3122 let resp = app.oneshot(req).await.unwrap();
3123
3124 let vary: Vec<String> = resp
3125 .headers()
3126 .get_all("vary")
3127 .iter()
3128 .filter_map(|v| v.to_str().ok().map(str::to_owned))
3129 .collect();
3130 assert!(
3131 vary.iter().any(|v| v.contains("Accept-Encoding")),
3132 "must preserve pre-existing Vary value, got {vary:?}"
3133 );
3134 assert!(
3135 vary.iter().any(|v| v.contains("Authorization")),
3136 "must append Authorization to Vary, got {vary:?}"
3137 );
3138 }
3139
3140 #[test]
3143 fn version_payload_contains_expected_fields() {
3144 let v = version_payload("my-server", "1.2.3");
3145 assert_eq!(v["name"], "my-server");
3146 assert_eq!(v["version"], "1.2.3");
3147 assert!(v["build_git_sha"].is_string());
3148 assert!(v["build_timestamp"].is_string());
3149 assert!(v["rust_version"].is_string());
3150 assert!(v["mcpx_version"].is_string());
3151 }
3152
3153 #[tokio::test]
3156 async fn concurrency_limit_layer_composes_and_serves() {
3157 let app = axum::Router::new()
3161 .route("/ok", axum::routing::get(|| async { "ok" }))
3162 .layer(
3163 tower::ServiceBuilder::new()
3164 .layer(axum::error_handling::HandleErrorLayer::new(
3165 |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
3166 ))
3167 .layer(tower::load_shed::LoadShedLayer::new())
3168 .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
3169 );
3170 let resp = app
3171 .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
3172 .await
3173 .unwrap();
3174 assert_eq!(resp.status(), StatusCode::OK);
3175 }
3176
3177 #[tokio::test]
3180 async fn compression_layer_gzip_encodes_response() {
3181 use tower_http::compression::Predicate as _;
3182
3183 let big_body = "a".repeat(4096);
3184 let app = axum::Router::new()
3185 .route(
3186 "/big",
3187 axum::routing::get(move || {
3188 let body = big_body.clone();
3189 async move { body }
3190 }),
3191 )
3192 .layer(
3193 tower_http::compression::CompressionLayer::new()
3194 .gzip(true)
3195 .br(true)
3196 .compress_when(
3197 tower_http::compression::DefaultPredicate::new()
3198 .and(tower_http::compression::predicate::SizeAbove::new(1024)),
3199 ),
3200 );
3201
3202 let req = Request::builder()
3203 .uri("/big")
3204 .header(header::ACCEPT_ENCODING, "gzip")
3205 .body(Body::empty())
3206 .unwrap();
3207 let resp = app.oneshot(req).await.unwrap();
3208 assert_eq!(resp.status(), StatusCode::OK);
3209 assert_eq!(
3210 resp.headers().get(header::CONTENT_ENCODING).unwrap(),
3211 "gzip"
3212 );
3213 }
3214}