1use std::{
2 future::Future,
3 net::SocketAddr,
4 path::{Path, PathBuf},
5 pin::Pin,
6 sync::Arc,
7 time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{body::Body, extract::Request, middleware::Next, response::IntoResponse};
12use rmcp::{
13 ServerHandler,
14 transport::streamable_http_server::{
15 StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16 },
17};
18use rustls::RootCertStore;
19use tokio::{
20 net::TcpListener,
21 sync::{Semaphore, mpsc},
22};
23use tokio_util::sync::CancellationToken;
24
25use crate::{
26 auth::{
27 AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
28 build_rate_limiter, extract_mtls_identity,
29 },
30 error::McpxError,
31 mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
32 rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
33};
34
35#[allow(
39 clippy::needless_pass_by_value,
40 reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
41)]
42fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
43 McpxError::Startup(format!("{e:#}"))
44}
45
46#[allow(
52 clippy::needless_pass_by_value,
53 reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
54)]
55fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
56 McpxError::Startup(format!("{op}: {e}"))
57}
58
59pub type ReadinessCheck =
64 Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
65
66#[derive(Debug, Clone, Default)]
87#[non_exhaustive]
88pub struct SecurityHeadersConfig {
89 pub x_content_type_options: Option<String>,
91 pub x_frame_options: Option<String>,
93 pub cache_control: Option<String>,
95 pub referrer_policy: Option<String>,
97 pub cross_origin_opener_policy: Option<String>,
99 pub cross_origin_resource_policy: Option<String>,
101 pub cross_origin_embedder_policy: Option<String>,
103 pub permissions_policy: Option<String>,
106 pub x_permitted_cross_domain_policies: Option<String>,
108 pub content_security_policy: Option<String>,
111 pub x_dns_prefetch_control: Option<String>,
113 pub strict_transport_security: Option<String>,
118}
119
120#[allow(
122 missing_debug_implementations,
123 reason = "contains callback/trait objects that don't impl Debug"
124)]
125#[allow(
126 clippy::struct_excessive_bools,
127 reason = "server configuration naturally has many boolean feature flags"
128)]
129#[non_exhaustive]
130pub struct McpServerConfig {
131 #[deprecated(
133 since = "0.13.0",
134 note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
135 )]
136 pub bind_addr: String,
137 #[deprecated(
139 since = "0.13.0",
140 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
141 )]
142 pub name: String,
143 #[deprecated(
145 since = "0.13.0",
146 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
147 )]
148 pub version: String,
149 #[deprecated(
151 since = "0.13.0",
152 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
153 )]
154 pub tls_cert_path: Option<PathBuf>,
155 #[deprecated(
157 since = "0.13.0",
158 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
159 )]
160 pub tls_key_path: Option<PathBuf>,
161 #[deprecated(
164 since = "0.13.0",
165 note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
166 )]
167 pub auth: Option<AuthConfig>,
168 #[deprecated(
171 since = "0.13.0",
172 note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
173 )]
174 pub rbac: Option<Arc<RbacPolicy>>,
175 #[deprecated(
181 since = "0.13.0",
182 note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
183 )]
184 pub allowed_origins: Vec<String>,
185 #[deprecated(
188 since = "0.13.0",
189 note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
190 )]
191 pub tool_rate_limit: Option<u32>,
192 #[deprecated(
195 since = "0.13.0",
196 note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
197 )]
198 pub readiness_check: Option<ReadinessCheck>,
199 #[deprecated(
202 since = "0.13.0",
203 note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
204 )]
205 pub max_request_body: usize,
206 #[deprecated(
209 since = "0.13.0",
210 note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
211 )]
212 pub request_timeout: Duration,
213 #[deprecated(
216 since = "0.13.0",
217 note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
218 )]
219 pub shutdown_timeout: Duration,
220 #[deprecated(
223 since = "0.13.0",
224 note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
225 )]
226 pub session_idle_timeout: Duration,
227 #[deprecated(
230 since = "0.13.0",
231 note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
232 )]
233 pub sse_keep_alive: Duration,
234 #[deprecated(
238 since = "0.13.0",
239 note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
240 )]
241 pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
242 #[deprecated(
246 since = "0.13.0",
247 note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
248 )]
249 pub extra_router: Option<axum::Router>,
250 #[deprecated(
255 since = "0.13.0",
256 note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
257 )]
258 pub public_url: Option<String>,
259 #[deprecated(
262 since = "0.13.0",
263 note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
264 )]
265 pub log_request_headers: bool,
266 #[deprecated(
269 since = "0.13.0",
270 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
271 )]
272 pub compression_enabled: bool,
273 #[deprecated(
276 since = "0.13.0",
277 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
278 )]
279 pub compression_min_size: u16,
280 #[deprecated(
284 since = "0.13.0",
285 note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
286 )]
287 pub max_concurrent_requests: Option<usize>,
288 #[deprecated(
291 since = "0.13.0",
292 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
293 )]
294 pub admin_enabled: bool,
295 #[deprecated(
297 since = "0.13.0",
298 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
299 )]
300 pub admin_role: String,
301 #[cfg(feature = "metrics")]
304 #[deprecated(
305 since = "0.13.0",
306 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
307 )]
308 pub metrics_enabled: bool,
309 #[cfg(feature = "metrics")]
311 #[deprecated(
312 since = "0.13.0",
313 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
314 )]
315 pub metrics_bind: String,
316 #[deprecated(
320 since = "1.5.0",
321 note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
322 )]
323 pub security_headers: SecurityHeadersConfig,
324}
325
326#[allow(
384 missing_debug_implementations,
385 reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
386)]
387pub struct Validated<T>(T);
388
389impl<T> std::fmt::Debug for Validated<T> {
390 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391 f.debug_struct("Validated").finish_non_exhaustive()
392 }
393}
394
395impl<T> Validated<T> {
396 #[must_use]
398 pub fn as_inner(&self) -> &T {
399 &self.0
400 }
401
402 #[must_use]
407 pub fn into_inner(self) -> T {
408 self.0
409 }
410}
411
412#[allow(
413 deprecated,
414 reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
415)]
416impl McpServerConfig {
417 #[must_use]
425 pub fn new(
426 bind_addr: impl Into<String>,
427 name: impl Into<String>,
428 version: impl Into<String>,
429 ) -> Self {
430 Self {
431 bind_addr: bind_addr.into(),
432 name: name.into(),
433 version: version.into(),
434 tls_cert_path: None,
435 tls_key_path: None,
436 auth: None,
437 rbac: None,
438 allowed_origins: Vec::new(),
439 tool_rate_limit: None,
440 readiness_check: None,
441 max_request_body: 1024 * 1024,
442 request_timeout: Duration::from_mins(2),
443 shutdown_timeout: Duration::from_secs(30),
444 session_idle_timeout: Duration::from_mins(20),
445 sse_keep_alive: Duration::from_secs(15),
446 on_reload_ready: None,
447 extra_router: None,
448 public_url: None,
449 log_request_headers: false,
450 compression_enabled: false,
451 compression_min_size: 1024,
452 max_concurrent_requests: None,
453 admin_enabled: false,
454 admin_role: "admin".to_owned(),
455 #[cfg(feature = "metrics")]
456 metrics_enabled: false,
457 #[cfg(feature = "metrics")]
458 metrics_bind: "127.0.0.1:9090".into(),
459 security_headers: SecurityHeadersConfig::default(),
460 }
461 }
462
463 #[must_use]
473 pub fn with_auth(mut self, auth: AuthConfig) -> Self {
474 self.auth = Some(auth);
475 self
476 }
477
478 #[must_use]
483 pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
484 self.security_headers = headers;
485 self
486 }
487
488 #[must_use]
492 pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
493 self.bind_addr = addr.into();
494 self
495 }
496
497 #[must_use]
500 pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
501 self.rbac = Some(rbac);
502 self
503 }
504
505 #[must_use]
509 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
510 self.tls_cert_path = Some(cert_path.into());
511 self.tls_key_path = Some(key_path.into());
512 self
513 }
514
515 #[must_use]
519 pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
520 self.public_url = Some(url.into());
521 self
522 }
523
524 #[must_use]
528 pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
529 where
530 I: IntoIterator<Item = S>,
531 S: Into<String>,
532 {
533 self.allowed_origins = origins.into_iter().map(Into::into).collect();
534 self
535 }
536
537 #[must_use]
541 pub fn with_extra_router(mut self, router: axum::Router) -> Self {
542 self.extra_router = Some(router);
543 self
544 }
545
546 #[must_use]
549 pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
550 self.readiness_check = Some(check);
551 self
552 }
553
554 #[must_use]
557 pub fn with_max_request_body(mut self, bytes: usize) -> Self {
558 self.max_request_body = bytes;
559 self
560 }
561
562 #[must_use]
564 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
565 self.request_timeout = timeout;
566 self
567 }
568
569 #[must_use]
571 pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
572 self.shutdown_timeout = timeout;
573 self
574 }
575
576 #[must_use]
578 pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
579 self.session_idle_timeout = timeout;
580 self
581 }
582
583 #[must_use]
585 pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
586 self.sse_keep_alive = interval;
587 self
588 }
589
590 #[must_use]
594 pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
595 self.max_concurrent_requests = Some(limit);
596 self
597 }
598
599 #[must_use]
602 pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
603 self.tool_rate_limit = Some(per_minute);
604 self
605 }
606
607 #[must_use]
611 pub fn with_reload_callback<F>(mut self, callback: F) -> Self
612 where
613 F: FnOnce(ReloadHandle) + Send + 'static,
614 {
615 self.on_reload_ready = Some(Box::new(callback));
616 self
617 }
618
619 #[must_use]
623 pub fn enable_compression(mut self, min_size: u16) -> Self {
624 self.compression_enabled = true;
625 self.compression_min_size = min_size;
626 self
627 }
628
629 #[must_use]
634 pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
635 self.admin_enabled = true;
636 self.admin_role = role.into();
637 self
638 }
639
640 #[must_use]
643 pub fn enable_request_header_logging(mut self) -> Self {
644 self.log_request_headers = true;
645 self
646 }
647
648 #[cfg(feature = "metrics")]
651 #[must_use]
652 pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
653 self.metrics_enabled = true;
654 self.metrics_bind = bind.into();
655 self
656 }
657
658 pub fn validate(self) -> Result<Validated<Self>, McpxError> {
691 self.check()?;
692 Ok(Validated(self))
693 }
694
695 fn check(&self) -> Result<(), McpxError> {
699 if self.admin_enabled {
703 let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
704 if !auth_enabled {
705 return Err(McpxError::Config(
706 "admin_enabled=true requires auth to be configured and enabled".into(),
707 ));
708 }
709 }
710
711 match (&self.tls_cert_path, &self.tls_key_path) {
713 (Some(_), None) => {
714 return Err(McpxError::Config(
715 "tls_cert_path is set but tls_key_path is missing".into(),
716 ));
717 }
718 (None, Some(_)) => {
719 return Err(McpxError::Config(
720 "tls_key_path is set but tls_cert_path is missing".into(),
721 ));
722 }
723 _ => {}
724 }
725
726 if self.bind_addr.parse::<SocketAddr>().is_err() {
728 return Err(McpxError::Config(format!(
729 "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
730 self.bind_addr
731 )));
732 }
733
734 if let Some(ref url) = self.public_url
736 && !(url.starts_with("http://") || url.starts_with("https://"))
737 {
738 return Err(McpxError::Config(format!(
739 "public_url {url:?} must start with http:// or https://"
740 )));
741 }
742
743 for origin in &self.allowed_origins {
745 if !(origin.starts_with("http://") || origin.starts_with("https://")) {
746 return Err(McpxError::Config(format!(
747 "allowed_origins entry {origin:?} must start with http:// or https://"
748 )));
749 }
750 }
751
752 if self.max_request_body == 0 {
754 return Err(McpxError::Config(
755 "max_request_body must be greater than zero".into(),
756 ));
757 }
758
759 #[cfg(feature = "oauth")]
761 if let Some(auth_cfg) = &self.auth
762 && let Some(oauth_cfg) = &auth_cfg.oauth
763 {
764 oauth_cfg.validate()?;
765 }
766
767 validate_security_headers(&self.security_headers)?;
770
771 if let Some(0) = self.max_concurrent_requests {
775 return Err(McpxError::Config(
776 "max_concurrent_requests must be greater than zero when set".into(),
777 ));
778 }
779
780 if let Some(auth_cfg) = &self.auth
784 && let Some(rl) = &auth_cfg.rate_limit
785 && rl.max_tracked_keys == 0
786 {
787 return Err(McpxError::Config(
788 "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
789 ));
790 }
791
792 Ok(())
793 }
794}
795
796#[allow(
802 missing_debug_implementations,
803 reason = "contains Arc<AuthState> with non-Debug fields"
804)]
805pub struct ReloadHandle {
806 auth: Option<Arc<AuthState>>,
807 rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
808 crl_set: Option<Arc<CrlSet>>,
809}
810
811impl ReloadHandle {
812 pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
814 if let Some(ref auth) = self.auth {
815 auth.reload_keys(keys);
816 }
817 }
818
819 pub fn reload_rbac(&self, policy: RbacPolicy) {
821 if let Some(ref rbac) = self.rbac {
822 rbac.store(Arc::new(policy));
823 tracing::info!("RBAC policy reloaded");
824 }
825 }
826
827 pub async fn refresh_crls(&self) -> Result<(), McpxError> {
833 let Some(ref crl_set) = self.crl_set else {
834 return Err(McpxError::Config(
835 "CRL refresh requested but mTLS CRL support is not configured".into(),
836 ));
837 };
838
839 crl_set.force_refresh().await
840 }
841}
842
843#[allow(
860 clippy::too_many_lines,
861 clippy::cognitive_complexity,
862 reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
863)]
864struct AppRunParams {
868 tls_paths: Option<(PathBuf, PathBuf)>,
870 mtls_config: Option<MtlsConfig>,
872 shutdown_timeout: Duration,
874 auth_state: Option<Arc<AuthState>>,
876 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
878 on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
880 ct: CancellationToken,
884 scheme: &'static str,
886 name: String,
888}
889
890#[allow(
900 clippy::cognitive_complexity,
901 reason = "router assembly is intrinsically sequential; splitting harms readability"
902)]
903#[allow(
904 deprecated,
905 reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
906)]
907fn build_app_router<H, F>(
908 mut config: McpServerConfig,
909 handler_factory: F,
910) -> anyhow::Result<(axum::Router, AppRunParams)>
911where
912 H: ServerHandler + 'static,
913 F: Fn() -> H + Send + Sync + Clone + 'static,
914{
915 let ct = CancellationToken::new();
916
917 let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
918 tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
919
920 let mcp_service = StreamableHttpService::new(
921 move || Ok(handler_factory()),
922 {
923 let mut mgr = LocalSessionManager::default();
924 mgr.session_config.keep_alive = Some(config.session_idle_timeout);
925 mgr.into()
926 },
927 StreamableHttpServerConfig::default()
928 .with_allowed_hosts(allowed_hosts)
929 .with_sse_keep_alive(Some(config.sse_keep_alive))
930 .with_cancellation_token(ct.child_token()),
931 );
932
933 let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
935
936 let auth_state: Option<Arc<AuthState>> = match config.auth {
940 Some(ref auth_config) if auth_config.enabled => {
941 let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
942 let pre_auth_limiter = auth_config
943 .rate_limit
944 .as_ref()
945 .map(crate::auth::build_pre_auth_limiter);
946
947 #[cfg(feature = "oauth")]
948 let jwks_cache = auth_config
949 .oauth
950 .as_ref()
951 .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
952 .transpose()
953 .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
954
955 Some(Arc::new(AuthState {
956 api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
957 rate_limiter,
958 pre_auth_limiter,
959 #[cfg(feature = "oauth")]
960 jwks_cache,
961 seen_identities: crate::auth::SeenIdentitySet::new(),
962 counters: crate::auth::AuthCounters::default(),
963 }))
964 }
965 _ => None,
966 };
967
968 let rbac_swap = Arc::new(ArcSwap::new(
971 config
972 .rbac
973 .clone()
974 .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
975 ));
976
977 if config.admin_enabled {
980 let Some(ref auth_state_ref) = auth_state else {
981 return Err(anyhow::anyhow!(
982 "admin_enabled=true requires auth to be configured and enabled"
983 ));
984 };
985 let admin_state = crate::admin::AdminState {
986 started_at: std::time::Instant::now(),
987 name: config.name.clone(),
988 version: config.version.clone(),
989 auth: Some(Arc::clone(auth_state_ref)),
990 rbac: Arc::clone(&rbac_swap),
991 };
992 let admin_cfg = crate::admin::AdminConfig {
993 role: config.admin_role.clone(),
994 };
995 mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
996 tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
997 }
998
999 {
1032 let tool_limiter: Option<Arc<ToolRateLimiter>> =
1033 config.tool_rate_limit.map(build_tool_rate_limiter);
1034
1035 if rbac_swap.load().is_enabled() {
1036 tracing::info!("RBAC enforcement enabled on /mcp");
1037 }
1038 if let Some(limit) = config.tool_rate_limit {
1039 tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1040 }
1041
1042 let rbac_for_mw = Arc::clone(&rbac_swap);
1043 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1044 let p = rbac_for_mw.load_full();
1045 let tl = tool_limiter.clone();
1046 rbac_middleware(p, tl, req, next)
1047 }));
1048 }
1049
1050 if let Some(ref auth_config) = config.auth
1052 && auth_config.enabled
1053 {
1054 let Some(ref state) = auth_state else {
1055 return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1056 };
1057
1058 let methods: Vec<&str> = [
1059 auth_config.mtls.is_some().then_some("mTLS"),
1060 (!auth_config.api_keys.is_empty()).then_some("bearer"),
1061 #[cfg(feature = "oauth")]
1062 auth_config.oauth.is_some().then_some("oauth-jwt"),
1063 ]
1064 .into_iter()
1065 .flatten()
1066 .collect();
1067
1068 tracing::info!(
1069 methods = %methods.join(", "),
1070 api_keys = auth_config.api_keys.len(),
1071 "auth enabled on /mcp"
1072 );
1073
1074 let state_for_mw = Arc::clone(state);
1075 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1076 let s = Arc::clone(&state_for_mw);
1077 auth_middleware(s, req, next)
1078 }));
1079 }
1080
1081 mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1084 axum::http::StatusCode::REQUEST_TIMEOUT,
1085 config.request_timeout,
1086 ));
1087
1088 mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1092 config.max_request_body,
1093 ));
1094
1095 let mut effective_origins = config.allowed_origins.clone();
1102 if effective_origins.is_empty()
1103 && let Some(ref url) = config.public_url
1104 {
1105 if let Some(scheme_end) = url.find("://") {
1110 let scheme_with_sep = url.get(..scheme_end + 3).unwrap_or_default();
1111 let after_scheme = url.get(scheme_end + 3..).unwrap_or_default();
1112 let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1113 let host = after_scheme.get(..host_end).unwrap_or_default();
1114 let origin = format!("{scheme_with_sep}{host}");
1115 tracing::info!(
1116 %origin,
1117 "auto-derived allowed origin from public_url"
1118 );
1119 effective_origins.push(origin);
1120 }
1121 }
1122 let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1123 let cors_origins = Arc::clone(&allowed_origins);
1124 let log_request_headers = config.log_request_headers;
1125
1126 let readyz_route = if let Some(check) = config.readiness_check.take() {
1127 axum::routing::get(move || readyz(Arc::clone(&check)))
1128 } else {
1129 axum::routing::get(healthz)
1130 };
1131
1132 #[allow(unused_mut)] let mut router = axum::Router::new()
1134 .route("/healthz", axum::routing::get(healthz))
1135 .route("/readyz", readyz_route)
1136 .route(
1137 "/version",
1138 axum::routing::get({
1139 let payload_bytes: Arc<[u8]> =
1144 serialize_version_payload(&config.name, &config.version);
1145 move || {
1146 let p = Arc::clone(&payload_bytes);
1147 async move {
1148 (
1149 [(axum::http::header::CONTENT_TYPE, "application/json")],
1150 p.to_vec(),
1151 )
1152 }
1153 }
1154 }),
1155 )
1156 .merge(mcp_router);
1157
1158 if let Some(extra) = config.extra_router.take() {
1160 router = router.merge(extra);
1161 }
1162
1163 let server_url = if let Some(ref url) = config.public_url {
1170 url.trim_end_matches('/').to_owned()
1171 } else {
1172 let prm_scheme = if config.tls_cert_path.is_some() {
1173 "https"
1174 } else {
1175 "http"
1176 };
1177 format!("{prm_scheme}://{}", config.bind_addr)
1178 };
1179 let resource_url = format!("{server_url}/mcp");
1180
1181 #[cfg(feature = "oauth")]
1182 let prm_metadata = if let Some(ref auth_config) = config.auth
1183 && let Some(ref oauth_config) = auth_config.oauth
1184 {
1185 crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1186 } else {
1187 serde_json::json!({ "resource": resource_url })
1188 };
1189 #[cfg(not(feature = "oauth"))]
1190 let prm_metadata = serde_json::json!({ "resource": resource_url });
1191
1192 router = router.route(
1193 "/.well-known/oauth-protected-resource",
1194 axum::routing::get(move || {
1195 let m = prm_metadata.clone();
1196 async move { axum::Json(m) }
1197 }),
1198 );
1199
1200 #[cfg(feature = "oauth")]
1205 if let Some(ref auth_config) = config.auth
1206 && let Some(ref oauth_config) = auth_config.oauth
1207 && oauth_config.proxy.is_some()
1208 {
1209 router =
1210 install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1211 }
1212
1213 let is_tls = config.tls_cert_path.is_some();
1216 let security_headers_cfg = Arc::new(config.security_headers.clone());
1217 router = router.layer(axum::middleware::from_fn(move |req, next| {
1218 let cfg = Arc::clone(&security_headers_cfg);
1219 security_headers_middleware(is_tls, cfg, req, next)
1220 }));
1221
1222 if !cors_origins.is_empty() {
1226 let cors = tower_http::cors::CorsLayer::new()
1227 .allow_origin(
1228 cors_origins
1229 .iter()
1230 .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1231 .collect::<Vec<_>>(),
1232 )
1233 .allow_methods([
1234 axum::http::Method::GET,
1235 axum::http::Method::POST,
1236 axum::http::Method::OPTIONS,
1237 ])
1238 .allow_headers([
1239 axum::http::header::CONTENT_TYPE,
1240 axum::http::header::AUTHORIZATION,
1241 ]);
1242 router = router.layer(cors);
1243 }
1244
1245 if config.compression_enabled {
1249 use tower_http::compression::Predicate as _;
1250 let predicate = tower_http::compression::DefaultPredicate::new().and(
1251 tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1252 );
1253 router = router.layer(
1254 tower_http::compression::CompressionLayer::new()
1255 .gzip(true)
1256 .br(true)
1257 .compress_when(predicate),
1258 );
1259 tracing::info!(
1260 min_size = config.compression_min_size,
1261 "response compression enabled (gzip, br)"
1262 );
1263 }
1264
1265 if let Some(max) = config.max_concurrent_requests {
1268 let overload_handler = tower::ServiceBuilder::new()
1269 .layer(axum::error_handling::HandleErrorLayer::new(
1270 |_err: tower::BoxError| async {
1271 (
1272 axum::http::StatusCode::SERVICE_UNAVAILABLE,
1273 axum::Json(serde_json::json!({
1274 "error": "overloaded",
1275 "error_description": "server is at capacity, retry later"
1276 })),
1277 )
1278 },
1279 ))
1280 .layer(tower::load_shed::LoadShedLayer::new())
1281 .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1282 router = router.layer(overload_handler);
1283 tracing::info!(max, "global concurrency limit enabled");
1284 }
1285
1286 router = router.fallback(|| async {
1290 (
1291 axum::http::StatusCode::NOT_FOUND,
1292 axum::Json(serde_json::json!({
1293 "error": "not_found",
1294 "error_description": "The requested endpoint does not exist"
1295 })),
1296 )
1297 });
1298
1299 #[cfg(feature = "metrics")]
1301 if config.metrics_enabled {
1302 let metrics = Arc::new(
1303 crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1304 );
1305 let m = Arc::clone(&metrics);
1306 router = router.layer(axum::middleware::from_fn(
1307 move |req: Request<Body>, next: Next| {
1308 let m = Arc::clone(&m);
1309 metrics_middleware(m, req, next)
1310 },
1311 ));
1312 let metrics_bind = config.metrics_bind.clone();
1313 let metrics_shutdown = ct.clone();
1314 tokio::spawn(async move {
1315 if let Err(e) =
1316 crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1317 {
1318 tracing::error!("metrics listener failed: {e}");
1319 }
1320 });
1321 }
1322
1323 router = router.layer(axum::middleware::from_fn(move |req, next| {
1334 let origins = Arc::clone(&allowed_origins);
1335 origin_check_middleware(origins, log_request_headers, req, next)
1336 }));
1337
1338 let scheme = if config.tls_cert_path.is_some() {
1339 "https"
1340 } else {
1341 "http"
1342 };
1343
1344 let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1345 (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1346 _ => None,
1347 };
1348 let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1349
1350 Ok((
1351 router,
1352 AppRunParams {
1353 tls_paths,
1354 mtls_config,
1355 shutdown_timeout: config.shutdown_timeout,
1356 auth_state,
1357 rbac_swap,
1358 on_reload_ready: config.on_reload_ready.take(),
1359 ct,
1360 scheme,
1361 name: config.name.clone(),
1362 },
1363 ))
1364}
1365
1366pub async fn serve<H, F>(
1383 config: Validated<McpServerConfig>,
1384 handler_factory: F,
1385) -> Result<(), McpxError>
1386where
1387 H: ServerHandler + 'static,
1388 F: Fn() -> H + Send + Sync + Clone + 'static,
1389{
1390 let config = config.into_inner();
1391 #[allow(
1392 deprecated,
1393 reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1394 )]
1395 let bind_addr = config.bind_addr.clone();
1396 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1397
1398 let listener = TcpListener::bind(&bind_addr)
1399 .await
1400 .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1401 log_listening(¶ms.name, params.scheme, &bind_addr);
1402
1403 run_server(
1404 router,
1405 listener,
1406 params.tls_paths,
1407 params.mtls_config,
1408 params.shutdown_timeout,
1409 params.auth_state,
1410 params.rbac_swap,
1411 params.on_reload_ready,
1412 params.ct,
1413 )
1414 .await
1415 .map_err(anyhow_to_startup)
1416}
1417
1418pub async fn serve_with_listener<H, F>(
1448 listener: TcpListener,
1449 config: Validated<McpServerConfig>,
1450 handler_factory: F,
1451 ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1452 shutdown: Option<CancellationToken>,
1453) -> Result<(), McpxError>
1454where
1455 H: ServerHandler + 'static,
1456 F: Fn() -> H + Send + Sync + Clone + 'static,
1457{
1458 let config = config.into_inner();
1459 let local_addr = listener
1460 .local_addr()
1461 .map_err(|e| io_to_startup("listener.local_addr", e))?;
1462 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1463
1464 log_listening(¶ms.name, params.scheme, &local_addr.to_string());
1465
1466 if let Some(external) = shutdown {
1470 let internal = params.ct.clone();
1471 tokio::spawn(async move {
1472 external.cancelled().await;
1473 internal.cancel();
1474 });
1475 }
1476
1477 if let Some(tx) = ready_tx {
1481 let _ = tx.send(local_addr);
1483 }
1484
1485 run_server(
1486 router,
1487 listener,
1488 params.tls_paths,
1489 params.mtls_config,
1490 params.shutdown_timeout,
1491 params.auth_state,
1492 params.rbac_swap,
1493 params.on_reload_ready,
1494 params.ct,
1495 )
1496 .await
1497 .map_err(anyhow_to_startup)
1498}
1499
1500#[allow(
1503 clippy::cognitive_complexity,
1504 reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1505)]
1506fn log_listening(name: &str, scheme: &str, addr: &str) {
1507 tracing::info!("{name} listening on {addr}");
1508 tracing::info!(" MCP endpoint: {scheme}://{addr}/mcp");
1509 tracing::info!(" Health check: {scheme}://{addr}/healthz");
1510 tracing::info!(" Readiness: {scheme}://{addr}/readyz");
1511}
1512
1513#[allow(
1536 clippy::too_many_arguments,
1537 clippy::cognitive_complexity,
1538 reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1539)]
1540async fn run_server(
1541 router: axum::Router,
1542 listener: TcpListener,
1543 tls_paths: Option<(PathBuf, PathBuf)>,
1544 mtls_config: Option<MtlsConfig>,
1545 shutdown_timeout: Duration,
1546 auth_state: Option<Arc<AuthState>>,
1547 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1548 mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1549 ct: CancellationToken,
1550) -> anyhow::Result<()> {
1551 let shutdown_trigger = CancellationToken::new();
1555 {
1556 let trigger = shutdown_trigger.clone();
1557 let parent = ct.clone();
1558 tokio::spawn(async move {
1559 tokio::select! {
1560 () = shutdown_signal() => {}
1561 () = parent.cancelled() => {}
1562 }
1563 trigger.cancel();
1564 });
1565 }
1566
1567 let graceful = {
1568 let trigger = shutdown_trigger.clone();
1569 let ct = ct.clone();
1570 async move {
1571 trigger.cancelled().await;
1572 tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1573 ct.cancel();
1574 }
1575 };
1576
1577 let force_exit_timer = {
1578 let trigger = shutdown_trigger.clone();
1579 async move {
1580 trigger.cancelled().await;
1581 tokio::time::sleep(shutdown_timeout).await;
1582 }
1583 };
1584
1585 if let Some((cert_path, key_path)) = tls_paths {
1586 let crl_set = if let Some(mtls) = mtls_config.as_ref()
1587 && mtls.crl_enabled
1588 {
1589 let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1590 let (crl_set, discover_rx) =
1591 mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1592 .await
1593 .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1594 tokio::spawn(mtls_revocation::run_crl_refresher(
1595 Arc::clone(&crl_set),
1596 discover_rx,
1597 ct.clone(),
1598 ));
1599 Some(crl_set)
1600 } else {
1601 None
1602 };
1603
1604 if let Some(cb) = on_reload_ready.take() {
1605 cb(ReloadHandle {
1606 auth: auth_state.clone(),
1607 rbac: Some(Arc::clone(&rbac_swap)),
1608 crl_set: crl_set.clone(),
1609 });
1610 }
1611
1612 let tls_listener = TlsListener::new(
1613 listener,
1614 &cert_path,
1615 &key_path,
1616 mtls_config.as_ref(),
1617 crl_set,
1618 TLS_HANDSHAKE_TIMEOUT,
1619 )?;
1620 let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1621 tokio::select! {
1622 result = axum::serve(tls_listener, make_svc)
1623 .with_graceful_shutdown(graceful) => { result?; }
1624 () = force_exit_timer => {
1625 tracing::warn!("shutdown timeout exceeded, forcing exit");
1626 }
1627 }
1628 } else {
1629 if let Some(cb) = on_reload_ready.take() {
1630 cb(ReloadHandle {
1631 auth: auth_state,
1632 rbac: Some(rbac_swap),
1633 crl_set: None,
1634 });
1635 }
1636
1637 let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1638 tokio::select! {
1639 result = axum::serve(listener, make_svc)
1640 .with_graceful_shutdown(graceful) => { result?; }
1641 () = force_exit_timer => {
1642 tracing::warn!("shutdown timeout exceeded, forcing exit");
1643 }
1644 }
1645 }
1646
1647 Ok(())
1648}
1649
1650#[cfg(feature = "oauth")]
1659fn install_oauth_proxy_routes(
1660 router: axum::Router,
1661 server_url: &str,
1662 oauth_config: &crate::oauth::OAuthConfig,
1663 auth_state: Option<&Arc<AuthState>>,
1664) -> Result<axum::Router, McpxError> {
1665 let Some(ref proxy) = oauth_config.proxy else {
1666 return Ok(router);
1667 };
1668
1669 let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1672
1673 let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1674 let router = router.route(
1675 "/.well-known/oauth-authorization-server",
1676 axum::routing::get(move || {
1677 let m = asm.clone();
1678 async move { axum::Json(m) }
1679 }),
1680 );
1681
1682 let proxy_authorize = proxy.clone();
1683 let router = router.route(
1684 "/authorize",
1685 axum::routing::get(
1686 move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1687 let p = proxy_authorize.clone();
1688 async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1689 },
1690 ),
1691 );
1692
1693 let proxy_token = proxy.clone();
1694 let token_http = http.clone();
1695 let router = router.route(
1696 "/token",
1697 axum::routing::post(move |body: String| {
1698 let p = proxy_token.clone();
1699 let h = token_http.clone();
1700 async move { crate::oauth::handle_token(&h, &p, &body).await }
1701 })
1702 .layer(axum::middleware::from_fn(
1703 oauth_token_cache_headers_middleware,
1704 )),
1705 );
1706
1707 let proxy_register = proxy.clone();
1708 let router = router.route(
1709 "/register",
1710 axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1711 let p = proxy_register;
1712 async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1713 })
1714 .layer(axum::middleware::from_fn(
1715 oauth_token_cache_headers_middleware,
1716 )),
1717 );
1718
1719 let admin_routes_enabled = proxy.expose_admin_endpoints
1720 && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1721 if proxy.expose_admin_endpoints
1722 && !proxy.require_auth_on_admin_endpoints
1723 && proxy.allow_unauthenticated_admin_endpoints
1724 {
1725 tracing::warn!(
1729 "OAuth introspect/revoke endpoints are unauthenticated by explicit \
1730 allow_unauthenticated_admin_endpoints opt-out; ensure an \
1731 authenticated reverse proxy fronts these routes"
1732 );
1733 }
1734
1735 let admin_router = if admin_routes_enabled {
1736 build_oauth_admin_router(proxy, http, auth_state)?
1737 } else {
1738 axum::Router::new()
1739 };
1740
1741 let router = router.merge(admin_router);
1742
1743 tracing::info!(
1744 introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1745 revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1746 "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1747 );
1748 Ok(router)
1749}
1750
1751#[cfg(feature = "oauth")]
1757fn build_oauth_admin_router(
1758 proxy: &crate::oauth::OAuthProxyConfig,
1759 http: crate::oauth::OauthHttpClient,
1760 auth_state: Option<&Arc<AuthState>>,
1761) -> Result<axum::Router, McpxError> {
1762 let mut admin_router = axum::Router::new();
1763 if proxy.introspection_url.is_some() {
1764 let proxy_introspect = proxy.clone();
1765 let introspect_http = http.clone();
1766 admin_router = admin_router.route(
1767 "/introspect",
1768 axum::routing::post(move |body: String| {
1769 let p = proxy_introspect.clone();
1770 let h = introspect_http.clone();
1771 async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1772 }),
1773 );
1774 }
1775 if proxy.revocation_url.is_some() {
1776 let proxy_revoke = proxy.clone();
1777 let revoke_http = http;
1778 admin_router = admin_router.route(
1779 "/revoke",
1780 axum::routing::post(move |body: String| {
1781 let p = proxy_revoke.clone();
1782 let h = revoke_http.clone();
1783 async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1784 }),
1785 );
1786 }
1787
1788 let admin_router = admin_router.layer(axum::middleware::from_fn(
1789 oauth_token_cache_headers_middleware,
1790 ));
1791
1792 if proxy.require_auth_on_admin_endpoints {
1793 let Some(state) = auth_state else {
1794 return Err(McpxError::Startup(
1795 "oauth proxy admin endpoints require auth state".into(),
1796 ));
1797 };
1798 let state_for_mw = Arc::clone(state);
1799 Ok(
1800 admin_router.layer(axum::middleware::from_fn(move |req, next| {
1801 let s = Arc::clone(&state_for_mw);
1802 auth_middleware(s, req, next)
1803 })),
1804 )
1805 } else {
1806 Ok(admin_router)
1807 }
1808}
1809
1810fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
1815 let mut hosts = vec![
1816 "localhost".to_owned(),
1817 "127.0.0.1".to_owned(),
1818 "::1".to_owned(),
1819 ];
1820
1821 if let Some(url) = public_url
1822 && let Ok(uri) = url.parse::<axum::http::Uri>()
1823 && let Some(authority) = uri.authority()
1824 {
1825 let host = authority.host().to_owned();
1826 if !hosts.iter().any(|h| h == &host) {
1827 hosts.push(host);
1828 }
1829
1830 let authority = authority.as_str().to_owned();
1831 if !hosts.iter().any(|h| h == &authority) {
1832 hosts.push(authority);
1833 }
1834 }
1835
1836 if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
1837 && let Some(authority) = uri.authority()
1838 {
1839 let host = authority.host().to_owned();
1840 if !hosts.iter().any(|h| h == &host) {
1841 hosts.push(host);
1842 }
1843
1844 let authority = authority.as_str().to_owned();
1845 if !hosts.iter().any(|h| h == &authority) {
1846 hosts.push(authority);
1847 }
1848 }
1849
1850 hosts
1851}
1852
1853impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
1866 for TlsConnInfo
1867{
1868 fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
1869 let addr = *target.remote_addr();
1870 let identity = target.io().identity().cloned();
1871 TlsConnInfo::new(addr, identity)
1872 }
1873}
1874
1875const TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
1882
1883const MAX_INFLIGHT_TLS_HANDSHAKES: usize = 256;
1890
1891const TLS_ACCEPT_CHANNEL_CAPACITY: usize = 32;
1896
1897struct TlsListener {
1911 local_addr: SocketAddr,
1914 rx: mpsc::Receiver<(AuthenticatedTlsStream, SocketAddr)>,
1916 acceptor_task: tokio::task::JoinHandle<()>,
1919}
1920
1921impl TlsListener {
1922 fn new(
1923 inner: TcpListener,
1924 cert_path: &Path,
1925 key_path: &Path,
1926 mtls_config: Option<&MtlsConfig>,
1927 crl_set: Option<Arc<CrlSet>>,
1928 handshake_timeout: Duration,
1929 ) -> anyhow::Result<Self> {
1930 rustls::crypto::ring::default_provider()
1932 .install_default()
1933 .ok();
1934
1935 let certs = load_certs(cert_path)?;
1936 let key = load_key(key_path)?;
1937
1938 let mtls_default_role;
1939
1940 let tls_config = if let Some(mtls) = mtls_config {
1941 mtls_default_role = mtls.default_role.clone();
1942 let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
1943 {
1944 let Some(crl_set) = crl_set else {
1945 return Err(anyhow::anyhow!(
1946 "mTLS CRL verifier requested but CRL state was not initialized"
1947 ));
1948 };
1949 Arc::new(DynamicClientCertVerifier::new(crl_set))
1950 } else {
1951 let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
1952 if mtls.required {
1953 rustls::server::WebPkiClientVerifier::builder(root_store)
1954 .build()
1955 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1956 } else {
1957 rustls::server::WebPkiClientVerifier::builder(root_store)
1958 .allow_unauthenticated()
1959 .build()
1960 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1961 }
1962 };
1963
1964 tracing::info!(
1965 ca = %mtls.ca_cert_path.display(),
1966 required = mtls.required,
1967 crl_enabled = mtls.crl_enabled,
1968 "mTLS client auth configured"
1969 );
1970
1971 rustls::ServerConfig::builder_with_protocol_versions(&[
1972 &rustls::version::TLS12,
1973 &rustls::version::TLS13,
1974 ])
1975 .with_client_cert_verifier(verifier)
1976 .with_single_cert(certs, key)?
1977 } else {
1978 mtls_default_role = "viewer".to_owned();
1979 rustls::ServerConfig::builder_with_protocol_versions(&[
1980 &rustls::version::TLS12,
1981 &rustls::version::TLS13,
1982 ])
1983 .with_no_client_auth()
1984 .with_single_cert(certs, key)?
1985 };
1986
1987 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
1988 tracing::info!(
1989 "TLS enabled (cert: {}, key: {})",
1990 cert_path.display(),
1991 key_path.display()
1992 );
1993 let local_addr = inner.local_addr()?;
1994 let (tx, rx) = mpsc::channel(TLS_ACCEPT_CHANNEL_CAPACITY);
1995 let acceptor_task = tokio::spawn(run_tls_acceptor(
1996 inner,
1997 acceptor,
1998 mtls_default_role,
1999 tx,
2000 handshake_timeout,
2001 ));
2002 Ok(Self {
2003 local_addr,
2004 rx,
2005 acceptor_task,
2006 })
2007 }
2008
2009 fn extract_handshake_identity(
2013 tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2014 default_role: &str,
2015 addr: SocketAddr,
2016 ) -> Option<AuthIdentity> {
2017 let (_, server_conn) = tls_stream.get_ref();
2018 let cert_der = server_conn.peer_certificates()?.first()?;
2019 let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
2020 tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
2021 Some(id)
2022 }
2023}
2024
2025async fn run_tls_acceptor(
2033 listener: TcpListener,
2034 acceptor: tokio_rustls::TlsAcceptor,
2035 default_role: String,
2036 tx: mpsc::Sender<(AuthenticatedTlsStream, SocketAddr)>,
2037 handshake_timeout: Duration,
2038) {
2039 let inflight = Arc::new(Semaphore::new(MAX_INFLIGHT_TLS_HANDSHAKES));
2040 loop {
2041 let Ok(permit) = Arc::clone(&inflight).acquire_owned().await else {
2045 return;
2047 };
2048 let (stream, addr) = match listener.accept().await {
2049 Ok(pair) => pair,
2050 Err(e) => {
2051 tracing::debug!("TCP accept error: {e}");
2052 continue;
2053 }
2054 };
2055 if tx.is_closed() {
2056 return;
2058 }
2059 let acceptor = acceptor.clone();
2060 let default_role = default_role.clone();
2061 let tx = tx.clone();
2062 tokio::spawn(async move {
2063 let _permit = permit;
2064 match tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
2065 Ok(Ok(tls_stream)) => {
2066 let identity =
2067 TlsListener::extract_handshake_identity(&tls_stream, &default_role, addr);
2068 let wrapped = AuthenticatedTlsStream {
2069 inner: tls_stream,
2070 identity,
2071 };
2072 let _ = tx.send((wrapped, addr)).await;
2075 }
2076 Ok(Err(e)) => {
2077 tracing::debug!("TLS handshake failed from {addr}: {e}");
2078 }
2079 Err(_elapsed) => {
2080 tracing::debug!(
2081 "TLS handshake timed out from {addr} after {handshake_timeout:?}"
2082 );
2083 }
2084 }
2085 });
2086 }
2087}
2088
2089pub(crate) struct AuthenticatedTlsStream {
2101 inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2102 identity: Option<AuthIdentity>,
2103}
2104
2105impl AuthenticatedTlsStream {
2106 #[must_use]
2108 pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
2109 self.identity.as_ref()
2110 }
2111}
2112
2113impl std::fmt::Debug for AuthenticatedTlsStream {
2114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2115 f.debug_struct("AuthenticatedTlsStream")
2116 .field("identity", &self.identity.as_ref().map(|id| &id.name))
2117 .finish_non_exhaustive()
2118 }
2119}
2120
2121impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2122 fn poll_read(
2123 mut self: Pin<&mut Self>,
2124 cx: &mut std::task::Context<'_>,
2125 buf: &mut tokio::io::ReadBuf<'_>,
2126 ) -> std::task::Poll<std::io::Result<()>> {
2127 Pin::new(&mut self.inner).poll_read(cx, buf)
2128 }
2129}
2130
2131impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2132 fn poll_write(
2133 mut self: Pin<&mut Self>,
2134 cx: &mut std::task::Context<'_>,
2135 buf: &[u8],
2136 ) -> std::task::Poll<std::io::Result<usize>> {
2137 Pin::new(&mut self.inner).poll_write(cx, buf)
2138 }
2139
2140 fn poll_flush(
2141 mut self: Pin<&mut Self>,
2142 cx: &mut std::task::Context<'_>,
2143 ) -> std::task::Poll<std::io::Result<()>> {
2144 Pin::new(&mut self.inner).poll_flush(cx)
2145 }
2146
2147 fn poll_shutdown(
2148 mut self: Pin<&mut Self>,
2149 cx: &mut std::task::Context<'_>,
2150 ) -> std::task::Poll<std::io::Result<()>> {
2151 Pin::new(&mut self.inner).poll_shutdown(cx)
2152 }
2153
2154 fn poll_write_vectored(
2155 mut self: Pin<&mut Self>,
2156 cx: &mut std::task::Context<'_>,
2157 bufs: &[std::io::IoSlice<'_>],
2158 ) -> std::task::Poll<std::io::Result<usize>> {
2159 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2160 }
2161
2162 fn is_write_vectored(&self) -> bool {
2163 self.inner.is_write_vectored()
2164 }
2165}
2166
2167impl axum::serve::Listener for TlsListener {
2168 type Io = AuthenticatedTlsStream;
2169 type Addr = SocketAddr;
2170
2171 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2177 if let Some(pair) = self.rx.recv().await {
2178 return pair;
2179 }
2180 tracing::error!("TLS acceptor task terminated; no further connections will be accepted");
2186 std::future::pending().await
2187 }
2188
2189 fn local_addr(&self) -> std::io::Result<Self::Addr> {
2190 Ok(self.local_addr)
2191 }
2192}
2193
2194impl Drop for TlsListener {
2195 fn drop(&mut self) {
2196 self.acceptor_task.abort();
2199 }
2200}
2201
2202fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2203 use rustls::pki_types::pem::PemObject;
2204 let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2205 .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2206 .collect::<Result<_, _>>()
2207 .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2208 anyhow::ensure!(
2209 !certs.is_empty(),
2210 "no certificates found in {}",
2211 path.display()
2212 );
2213 Ok(certs)
2214}
2215
2216fn load_client_auth_roots(
2217 path: &Path,
2218) -> anyhow::Result<(
2219 Vec<rustls::pki_types::CertificateDer<'static>>,
2220 Arc<RootCertStore>,
2221)> {
2222 let ca_certs = load_certs(path)?;
2223 let mut root_store = RootCertStore::empty();
2224 for cert in &ca_certs {
2225 root_store
2226 .add(cert.clone())
2227 .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2228 }
2229
2230 Ok((ca_certs, Arc::new(root_store)))
2231}
2232
2233fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2234 use rustls::pki_types::pem::PemObject;
2235 rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2236 .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2237}
2238
2239#[allow(
2240 clippy::unused_async,
2241 reason = "axum route handler signature requires `async fn` even when the body is synchronous"
2242)]
2243async fn healthz() -> impl IntoResponse {
2244 axum::Json(serde_json::json!({
2245 "status": "ok",
2246 }))
2247}
2248
2249fn version_payload(name: &str, version: &str) -> serde_json::Value {
2256 serde_json::json!({
2257 "name": name,
2258 "version": version,
2259 "build_git_sha": option_env!("RMCP_SERVER_KIT_BUILD_SHA").unwrap_or("unknown"),
2260 "build_timestamp": option_env!("RMCP_SERVER_KIT_BUILD_TIME").unwrap_or("unknown"),
2261 "rust_version": option_env!("RMCP_SERVER_KIT_RUSTC_VERSION").unwrap_or("unknown"),
2262 "mcpx_version": env!("CARGO_PKG_VERSION"),
2263 })
2264}
2265
2266fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2276 let value = version_payload(name, version);
2277 serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2278}
2279
2280async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2281 let status = check().await;
2282 let ready = status
2283 .get("ready")
2284 .and_then(serde_json::Value::as_bool)
2285 .unwrap_or(false);
2286 let code = if ready {
2287 axum::http::StatusCode::OK
2288 } else {
2289 axum::http::StatusCode::SERVICE_UNAVAILABLE
2290 };
2291 (code, axum::Json(status))
2292}
2293
2294async fn shutdown_signal() {
2298 let ctrl_c = tokio::signal::ctrl_c();
2299
2300 #[cfg(unix)]
2301 {
2302 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2303 Ok(mut term) => {
2304 tokio::select! {
2305 _ = ctrl_c => {}
2306 _ = term.recv() => {}
2307 }
2308 }
2309 Err(e) => {
2310 tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2311 ctrl_c.await.ok();
2312 }
2313 }
2314 }
2315
2316 #[cfg(not(unix))]
2317 {
2318 ctrl_c.await.ok();
2319 }
2320}
2321
2322#[cfg(feature = "metrics")]
2328async fn metrics_middleware(
2329 metrics: Arc<crate::metrics::McpMetrics>,
2330 req: Request<Body>,
2331 next: Next,
2332) -> axum::response::Response {
2333 let method = req.method().to_string();
2334 let path = req.uri().path().to_owned();
2335 let start = std::time::Instant::now();
2336
2337 let response = next.run(req).await;
2338
2339 let status = response.status().as_u16().to_string();
2340 let duration = start.elapsed().as_secs_f64();
2341
2342 metrics
2343 .http_requests_total
2344 .with_label_values(&[&method, &path, &status])
2345 .inc();
2346 metrics
2347 .http_request_duration_seconds
2348 .with_label_values(&[&method, &path])
2349 .observe(duration);
2350
2351 response
2352}
2353
2354async fn security_headers_middleware(
2366 is_tls: bool,
2367 cfg: Arc<SecurityHeadersConfig>,
2368 req: Request<Body>,
2369 next: Next,
2370) -> axum::response::Response {
2371 use axum::http::{HeaderName, header};
2372
2373 let mut resp = next.run(req).await;
2374 let headers = resp.headers_mut();
2375
2376 headers.remove(header::SERVER);
2378 headers.remove(HeaderName::from_static("x-powered-by"));
2379
2380 apply_security_header(
2381 headers,
2382 header::X_CONTENT_TYPE_OPTIONS,
2383 cfg.x_content_type_options.as_deref(),
2384 "nosniff",
2385 );
2386 apply_security_header(
2387 headers,
2388 header::X_FRAME_OPTIONS,
2389 cfg.x_frame_options.as_deref(),
2390 "deny",
2391 );
2392 apply_security_header(
2393 headers,
2394 header::CACHE_CONTROL,
2395 cfg.cache_control.as_deref(),
2396 "no-store, max-age=0",
2397 );
2398 apply_security_header(
2399 headers,
2400 header::REFERRER_POLICY,
2401 cfg.referrer_policy.as_deref(),
2402 "no-referrer",
2403 );
2404 apply_security_header(
2405 headers,
2406 HeaderName::from_static("cross-origin-opener-policy"),
2407 cfg.cross_origin_opener_policy.as_deref(),
2408 "same-origin",
2409 );
2410 apply_security_header(
2411 headers,
2412 HeaderName::from_static("cross-origin-resource-policy"),
2413 cfg.cross_origin_resource_policy.as_deref(),
2414 "same-origin",
2415 );
2416 apply_security_header(
2417 headers,
2418 HeaderName::from_static("cross-origin-embedder-policy"),
2419 cfg.cross_origin_embedder_policy.as_deref(),
2420 "require-corp",
2421 );
2422 apply_security_header(
2423 headers,
2424 HeaderName::from_static("permissions-policy"),
2425 cfg.permissions_policy.as_deref(),
2426 "accelerometer=(), camera=(), geolocation=(), microphone=()",
2427 );
2428 apply_security_header(
2429 headers,
2430 HeaderName::from_static("x-permitted-cross-domain-policies"),
2431 cfg.x_permitted_cross_domain_policies.as_deref(),
2432 "none",
2433 );
2434 apply_security_header(
2435 headers,
2436 HeaderName::from_static("content-security-policy"),
2437 cfg.content_security_policy.as_deref(),
2438 "default-src 'none'; frame-ancestors 'none'",
2439 );
2440 apply_security_header(
2441 headers,
2442 HeaderName::from_static("x-dns-prefetch-control"),
2443 cfg.x_dns_prefetch_control.as_deref(),
2444 "off",
2445 );
2446
2447 if is_tls {
2448 apply_security_header(
2449 headers,
2450 header::STRICT_TRANSPORT_SECURITY,
2451 cfg.strict_transport_security.as_deref(),
2452 "max-age=63072000; includeSubDomains",
2453 );
2454 }
2455
2456 resp
2457}
2458
2459fn apply_security_header(
2470 headers: &mut axum::http::HeaderMap,
2471 name: axum::http::HeaderName,
2472 override_value: Option<&str>,
2473 default: &'static str,
2474) {
2475 use axum::http::HeaderValue;
2476
2477 match override_value {
2478 None => {
2479 headers.insert(name, HeaderValue::from_static(default));
2480 }
2481 Some("") => {
2482 }
2484 Some(v) => match HeaderValue::from_str(v) {
2485 Ok(hv) => {
2486 headers.insert(name, hv);
2487 }
2488 Err(err) => {
2489 tracing::error!(
2490 header = %name,
2491 error = %err,
2492 "invalid security header override reached middleware; using default"
2493 );
2494 headers.insert(name, HeaderValue::from_static(default));
2495 }
2496 },
2497 }
2498}
2499
2500fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2511 use axum::http::HeaderValue;
2512
2513 let fields: &[(&str, Option<&str>)] = &[
2514 (
2515 "x_content_type_options",
2516 cfg.x_content_type_options.as_deref(),
2517 ),
2518 ("x_frame_options", cfg.x_frame_options.as_deref()),
2519 ("cache_control", cfg.cache_control.as_deref()),
2520 ("referrer_policy", cfg.referrer_policy.as_deref()),
2521 (
2522 "cross_origin_opener_policy",
2523 cfg.cross_origin_opener_policy.as_deref(),
2524 ),
2525 (
2526 "cross_origin_resource_policy",
2527 cfg.cross_origin_resource_policy.as_deref(),
2528 ),
2529 (
2530 "cross_origin_embedder_policy",
2531 cfg.cross_origin_embedder_policy.as_deref(),
2532 ),
2533 ("permissions_policy", cfg.permissions_policy.as_deref()),
2534 (
2535 "x_permitted_cross_domain_policies",
2536 cfg.x_permitted_cross_domain_policies.as_deref(),
2537 ),
2538 (
2539 "content_security_policy",
2540 cfg.content_security_policy.as_deref(),
2541 ),
2542 (
2543 "x_dns_prefetch_control",
2544 cfg.x_dns_prefetch_control.as_deref(),
2545 ),
2546 (
2547 "strict_transport_security",
2548 cfg.strict_transport_security.as_deref(),
2549 ),
2550 ];
2551
2552 for (field, value) in fields {
2553 let Some(v) = value else { continue };
2554 if v.is_empty() {
2555 continue;
2556 }
2557 if let Err(err) = HeaderValue::from_str(v) {
2558 return Err(McpxError::Config(format!(
2559 "invalid security_headers.{field}: {err}"
2560 )));
2561 }
2562 }
2563
2564 if let Some(v) = cfg.strict_transport_security.as_deref()
2565 && !v.is_empty()
2566 && v.to_ascii_lowercase().contains("preload")
2567 {
2568 return Err(McpxError::Config(format!(
2569 "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2570 HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2571 )));
2572 }
2573
2574 Ok(())
2575}
2576
2577#[cfg(feature = "oauth")]
2592async fn oauth_token_cache_headers_middleware(
2593 req: Request<Body>,
2594 next: Next,
2595) -> axum::response::Response {
2596 use axum::http::{HeaderValue, header};
2597
2598 let mut resp = next.run(req).await;
2599 let headers = resp.headers_mut();
2600 headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2601 headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2602 resp
2603}
2604
2605async fn origin_check_middleware(
2609 allowed: Arc<[String]>,
2610 log_request_headers: bool,
2611 req: Request<Body>,
2612 next: Next,
2613) -> axum::response::Response {
2614 let method = req.method().clone();
2615 let path = req.uri().path().to_owned();
2616
2617 log_incoming_request(&method, &path, req.headers(), log_request_headers);
2618
2619 if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2620 let origin_str = origin.to_str().unwrap_or("");
2621 if !allowed.iter().any(|a| a == origin_str) {
2622 tracing::warn!(
2623 origin = origin_str,
2624 %method,
2625 %path,
2626 allowed = ?&*allowed,
2627 "rejected request: Origin not allowed"
2628 );
2629 return (
2630 axum::http::StatusCode::FORBIDDEN,
2631 "Forbidden: Origin not allowed",
2632 )
2633 .into_response();
2634 }
2635 }
2636 next.run(req).await
2637}
2638
2639fn log_incoming_request(
2642 method: &axum::http::Method,
2643 path: &str,
2644 headers: &axum::http::HeaderMap,
2645 log_request_headers: bool,
2646) {
2647 if log_request_headers {
2648 tracing::debug!(
2649 %method,
2650 %path,
2651 headers = %format_request_headers_for_log(headers),
2652 "incoming request"
2653 );
2654 } else {
2655 tracing::debug!(%method, %path, "incoming request");
2656 }
2657}
2658
2659fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2660 headers
2661 .iter()
2662 .map(|(k, v)| {
2663 let name = k.as_str();
2664 if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2665 format!("{name}: [REDACTED]")
2666 } else {
2667 format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2668 }
2669 })
2670 .collect::<Vec<_>>()
2671 .join(", ")
2672}
2673
2674#[allow(
2698 clippy::cognitive_complexity,
2699 reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
2700)]
2701pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2702where
2703 H: ServerHandler + 'static,
2704{
2705 use rmcp::ServiceExt as _;
2706
2707 tracing::info!("stdio transport: serving on stdin/stdout");
2708 tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2709
2710 let transport = rmcp::transport::io::stdio();
2711
2712 let service = handler
2713 .serve(transport)
2714 .await
2715 .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2716
2717 if let Err(e) = service.waiting().await {
2718 tracing::warn!(error = %e, "stdio session ended with error");
2719 }
2720 tracing::info!("stdio session ended");
2721 Ok(())
2722}
2723
2724#[cfg(test)]
2725mod tests {
2726 #![allow(
2727 clippy::unwrap_used,
2728 clippy::expect_used,
2729 clippy::panic,
2730 clippy::indexing_slicing,
2731 clippy::unwrap_in_result,
2732 clippy::print_stdout,
2733 clippy::print_stderr,
2734 deprecated,
2735 reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2736 )]
2737 use std::{sync::Arc, time::Duration};
2738
2739 use axum::{
2740 body::Body,
2741 http::{Request, StatusCode, header},
2742 response::IntoResponse,
2743 };
2744 use http_body_util::BodyExt;
2745 use tower::ServiceExt as _;
2746
2747 use super::*;
2748
2749 #[test]
2752 fn server_config_new_defaults() {
2753 let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2754 assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2755 assert_eq!(cfg.name, "test-server");
2756 assert_eq!(cfg.version, "1.0.0");
2757 assert!(cfg.tls_cert_path.is_none());
2758 assert!(cfg.tls_key_path.is_none());
2759 assert!(cfg.auth.is_none());
2760 assert!(cfg.rbac.is_none());
2761 assert!(cfg.allowed_origins.is_empty());
2762 assert!(cfg.tool_rate_limit.is_none());
2763 assert!(cfg.readiness_check.is_none());
2764 assert_eq!(cfg.max_request_body, 1024 * 1024);
2765 assert_eq!(cfg.request_timeout, Duration::from_mins(2));
2766 assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
2767 assert!(!cfg.log_request_headers);
2768 }
2769
2770 #[test]
2771 fn validate_consumes_and_proves() {
2772 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2774 let validated = cfg.validate().expect("valid config");
2775 assert_eq!(validated.as_inner().name, "test-server");
2777 let raw = validated.into_inner();
2779 assert_eq!(raw.name, "test-server");
2780
2781 let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2783 bad.max_request_body = 0;
2784 assert!(bad.validate().is_err(), "zero body cap must fail validate");
2785 }
2786
2787 #[test]
2788 fn validate_rejects_zero_max_concurrent_requests() {
2789 let cfg =
2790 McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
2791 let err = cfg.validate().expect_err("zero concurrency cap must fail");
2792 assert!(
2793 format!("{err}").contains("max_concurrent_requests"),
2794 "error should mention max_concurrent_requests, got: {err}"
2795 );
2796 }
2797
2798 #[test]
2799 fn validate_rejects_zero_max_tracked_keys() {
2800 let rl = crate::auth::RateLimitConfig {
2803 max_attempts_per_minute: 30,
2804 pre_auth_max_per_minute: None,
2805 max_tracked_keys: 0,
2806 idle_eviction: Duration::from_secs(15 * 60),
2807 };
2808 let auth_cfg = AuthConfig {
2809 enabled: true,
2810 api_keys: Vec::new(),
2811 mtls: None,
2812 rate_limit: Some(rl),
2813 #[cfg(feature = "oauth")]
2814 oauth: None,
2815 };
2816 let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
2817 let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
2818 assert!(
2819 format!("{err}").contains("max_tracked_keys"),
2820 "error should mention max_tracked_keys, got: {err}"
2821 );
2822 }
2823
2824 #[test]
2825 fn derive_allowed_hosts_includes_public_host() {
2826 let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
2827 assert!(
2828 hosts.iter().any(|h| h == "mcp.example.com"),
2829 "public_url host must be allowed"
2830 );
2831 }
2832
2833 #[test]
2834 fn derive_allowed_hosts_includes_bind_authority() {
2835 let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
2836 assert!(
2837 hosts.iter().any(|h| h == "127.0.0.1"),
2838 "bind host must be allowed"
2839 );
2840 assert!(
2841 hosts.iter().any(|h| h == "127.0.0.1:8080"),
2842 "bind authority must be allowed"
2843 );
2844 }
2845
2846 #[tokio::test]
2849 async fn healthz_returns_ok_json() {
2850 let resp = healthz().await.into_response();
2851 assert_eq!(resp.status(), StatusCode::OK);
2852 let body = resp.into_body().collect().await.unwrap().to_bytes();
2853 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2854 assert_eq!(json["status"], "ok");
2855 assert!(
2856 json.get("name").is_none(),
2857 "healthz must not expose server name"
2858 );
2859 assert!(
2860 json.get("version").is_none(),
2861 "healthz must not expose version"
2862 );
2863 }
2864
2865 #[tokio::test]
2868 async fn readyz_returns_ok_when_ready() {
2869 let check: ReadinessCheck =
2870 Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
2871 let resp = readyz(check).await.into_response();
2872 assert_eq!(resp.status(), StatusCode::OK);
2873 let body = resp.into_body().collect().await.unwrap().to_bytes();
2874 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2875 assert_eq!(json["ready"], true);
2876 assert!(
2877 json.get("name").is_none(),
2878 "readyz must not expose server name"
2879 );
2880 assert!(
2881 json.get("version").is_none(),
2882 "readyz must not expose version"
2883 );
2884 assert_eq!(json["db"], "connected");
2885 }
2886
2887 #[tokio::test]
2888 async fn readyz_returns_503_when_not_ready() {
2889 let check: ReadinessCheck =
2890 Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
2891 let resp = readyz(check).await.into_response();
2892 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2893 }
2894
2895 #[tokio::test]
2896 async fn readyz_returns_503_when_ready_missing() {
2897 let check: ReadinessCheck =
2898 Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
2899 let resp = readyz(check).await.into_response();
2900 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2902 }
2903
2904 fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
2908 let allowed: Arc<[String]> = Arc::from(origins);
2909 axum::Router::new()
2910 .route("/test", axum::routing::get(|| async { "ok" }))
2911 .layer(axum::middleware::from_fn(move |req, next| {
2912 let a = Arc::clone(&allowed);
2913 origin_check_middleware(a, log_request_headers, req, next)
2914 }))
2915 }
2916
2917 #[tokio::test]
2918 async fn origin_allowed_passes() {
2919 let app = origin_router(vec!["http://localhost:3000".into()], false);
2920 let req = Request::builder()
2921 .uri("/test")
2922 .header(header::ORIGIN, "http://localhost:3000")
2923 .body(Body::empty())
2924 .unwrap();
2925 let resp = app.oneshot(req).await.unwrap();
2926 assert_eq!(resp.status(), StatusCode::OK);
2927 }
2928
2929 #[tokio::test]
2930 async fn origin_rejected_returns_403() {
2931 let app = origin_router(vec!["http://localhost:3000".into()], false);
2932 let req = Request::builder()
2933 .uri("/test")
2934 .header(header::ORIGIN, "http://evil.com")
2935 .body(Body::empty())
2936 .unwrap();
2937 let resp = app.oneshot(req).await.unwrap();
2938 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2939 }
2940
2941 #[tokio::test]
2942 async fn no_origin_header_passes() {
2943 let app = origin_router(vec!["http://localhost:3000".into()], false);
2944 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2945 let resp = app.oneshot(req).await.unwrap();
2946 assert_eq!(resp.status(), StatusCode::OK);
2947 }
2948
2949 #[tokio::test]
2950 async fn empty_allowlist_rejects_any_origin() {
2951 let app = origin_router(vec![], false);
2952 let req = Request::builder()
2953 .uri("/test")
2954 .header(header::ORIGIN, "http://anything.com")
2955 .body(Body::empty())
2956 .unwrap();
2957 let resp = app.oneshot(req).await.unwrap();
2958 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2959 }
2960
2961 #[tokio::test]
2962 async fn empty_allowlist_passes_without_origin() {
2963 let app = origin_router(vec![], false);
2964 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2965 let resp = app.oneshot(req).await.unwrap();
2966 assert_eq!(resp.status(), StatusCode::OK);
2967 }
2968
2969 #[test]
2970 fn format_request_headers_redacts_sensitive_values() {
2971 let mut headers = axum::http::HeaderMap::new();
2972 headers.insert("authorization", "Bearer secret-token".parse().unwrap());
2973 headers.insert("cookie", "sid=abc".parse().unwrap());
2974 headers.insert("x-request-id", "req-123".parse().unwrap());
2975
2976 let out = format_request_headers_for_log(&headers);
2977 assert!(out.contains("authorization: [REDACTED]"));
2978 assert!(out.contains("cookie: [REDACTED]"));
2979 assert!(out.contains("x-request-id: req-123"));
2980 assert!(!out.contains("secret-token"));
2981 }
2982
2983 fn security_router(is_tls: bool) -> axum::Router {
2986 security_router_with(is_tls, SecurityHeadersConfig::default())
2987 }
2988
2989 fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
2990 let cfg = Arc::new(cfg);
2991 axum::Router::new()
2992 .route("/test", axum::routing::get(|| async { "ok" }))
2993 .layer(axum::middleware::from_fn(move |req, next| {
2994 let c = Arc::clone(&cfg);
2995 security_headers_middleware(is_tls, c, req, next)
2996 }))
2997 }
2998
2999 #[tokio::test]
3000 async fn security_headers_set_on_response() {
3001 let app = security_router(false);
3002 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3003 let resp = app.oneshot(req).await.unwrap();
3004 assert_eq!(resp.status(), StatusCode::OK);
3005
3006 let h = resp.headers();
3007 assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
3008 assert_eq!(h.get("x-frame-options").unwrap(), "deny");
3009 assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
3010 assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
3011 assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
3012 assert_eq!(
3013 h.get("cross-origin-resource-policy").unwrap(),
3014 "same-origin"
3015 );
3016 assert_eq!(
3017 h.get("cross-origin-embedder-policy").unwrap(),
3018 "require-corp"
3019 );
3020 assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
3021 assert!(
3022 h.get("permissions-policy")
3023 .unwrap()
3024 .to_str()
3025 .unwrap()
3026 .contains("camera=()"),
3027 "permissions-policy must restrict browser features"
3028 );
3029 assert_eq!(
3030 h.get("content-security-policy").unwrap(),
3031 "default-src 'none'; frame-ancestors 'none'"
3032 );
3033 assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
3034 assert!(h.get("strict-transport-security").is_none());
3036 }
3037
3038 #[tokio::test]
3039 async fn hsts_set_when_tls_enabled() {
3040 let app = security_router(true);
3041 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3042 let resp = app.oneshot(req).await.unwrap();
3043
3044 let hsts = resp.headers().get("strict-transport-security").unwrap();
3045 assert!(
3046 hsts.to_str().unwrap().contains("max-age=63072000"),
3047 "HSTS must set 2-year max-age"
3048 );
3049 }
3050
3051 fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
3057 let cfg =
3058 McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
3059 cfg.check()
3060 }
3061
3062 #[test]
3063 fn security_headers_config_default_validates() {
3064 check_with_security_headers(SecurityHeadersConfig::default())
3065 .expect("default SecurityHeadersConfig must validate");
3066 }
3067
3068 #[test]
3069 fn security_headers_config_validate_accepts_empty_string() {
3070 let h = SecurityHeadersConfig {
3072 x_content_type_options: Some(String::new()),
3073 x_frame_options: Some(String::new()),
3074 cache_control: Some(String::new()),
3075 referrer_policy: Some(String::new()),
3076 cross_origin_opener_policy: Some(String::new()),
3077 cross_origin_resource_policy: Some(String::new()),
3078 cross_origin_embedder_policy: Some(String::new()),
3079 permissions_policy: Some(String::new()),
3080 x_permitted_cross_domain_policies: Some(String::new()),
3081 content_security_policy: Some(String::new()),
3082 x_dns_prefetch_control: Some(String::new()),
3083 strict_transport_security: Some(String::new()),
3084 };
3085 check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
3086 }
3087
3088 #[test]
3089 fn security_headers_config_validate_rejects_bad_value() {
3090 let h = SecurityHeadersConfig {
3092 referrer_policy: Some("\u{0007}".into()),
3093 ..SecurityHeadersConfig::default()
3094 };
3095 let err = check_with_security_headers(h)
3096 .expect_err("control char in referrer_policy must reject");
3097 let msg = err.to_string();
3098 assert!(
3099 msg.contains("referrer_policy"),
3100 "error must name the offending field, got: {msg}"
3101 );
3102 }
3103
3104 #[test]
3105 fn security_headers_config_validate_rejects_hsts_preload() {
3106 let h = SecurityHeadersConfig {
3107 strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
3108 ..SecurityHeadersConfig::default()
3109 };
3110 let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
3111 let msg = err.to_string();
3112 assert!(
3113 msg.contains("strict_transport_security"),
3114 "error must name the field, got: {msg}"
3115 );
3116 assert!(
3117 msg.to_lowercase().contains("preload"),
3118 "error must mention `preload`, got: {msg}"
3119 );
3120 }
3121
3122 #[test]
3123 fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
3124 let h = SecurityHeadersConfig {
3126 strict_transport_security: Some("max-age=600; PRELOAD".into()),
3127 ..SecurityHeadersConfig::default()
3128 };
3129 check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
3130 }
3131
3132 #[tokio::test]
3133 async fn security_headers_override_honored() {
3134 let h = SecurityHeadersConfig {
3136 x_frame_options: Some("SAMEORIGIN".into()),
3137 ..SecurityHeadersConfig::default()
3138 };
3139 let app = security_router_with(false, h);
3140 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3141 let resp = app.oneshot(req).await.unwrap();
3142 assert_eq!(resp.status(), StatusCode::OK);
3143
3144 let xfo = resp.headers().get("x-frame-options").unwrap();
3145 assert_eq!(xfo, "SAMEORIGIN");
3146 }
3147
3148 #[tokio::test]
3149 async fn security_headers_empty_string_omits() {
3150 let h = SecurityHeadersConfig {
3152 referrer_policy: Some(String::new()),
3153 ..SecurityHeadersConfig::default()
3154 };
3155 let app = security_router_with(false, h);
3156 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3157 let resp = app.oneshot(req).await.unwrap();
3158 assert_eq!(resp.status(), StatusCode::OK);
3159
3160 assert!(
3161 resp.headers().get("referrer-policy").is_none(),
3162 "Some(\"\") must omit the header"
3163 );
3164 assert_eq!(
3166 resp.headers().get("x-content-type-options").unwrap(),
3167 "nosniff"
3168 );
3169 }
3170
3171 #[tokio::test]
3172 async fn security_headers_hsts_only_when_tls() {
3173 let h = SecurityHeadersConfig {
3175 strict_transport_security: Some("max-age=600".into()),
3176 ..SecurityHeadersConfig::default()
3177 };
3178 let app = security_router_with(false, h);
3179 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3180 let resp = app.oneshot(req).await.unwrap();
3181 assert!(
3182 resp.headers().get("strict-transport-security").is_none(),
3183 "HSTS must remain absent on plaintext deployments even with override"
3184 );
3185 }
3186
3187 #[cfg(feature = "oauth")]
3190 #[tokio::test]
3191 async fn oauth_token_cache_headers_set_pragma_and_vary() {
3192 let app = axum::Router::new()
3193 .route("/token", axum::routing::post(|| async { "{}" }))
3194 .layer(axum::middleware::from_fn(
3195 oauth_token_cache_headers_middleware,
3196 ));
3197 let req = Request::builder()
3198 .method("POST")
3199 .uri("/token")
3200 .body(Body::from("{}"))
3201 .unwrap();
3202 let resp = app.oneshot(req).await.unwrap();
3203 assert_eq!(resp.status(), StatusCode::OK);
3204
3205 let h = resp.headers();
3206 assert_eq!(
3207 h.get("pragma").unwrap(),
3208 "no-cache",
3209 "RFC 6749 §5.1: token responses must set Pragma: no-cache"
3210 );
3211 let vary_values: Vec<String> = h
3212 .get_all("vary")
3213 .iter()
3214 .filter_map(|v| v.to_str().ok().map(str::to_owned))
3215 .collect();
3216 assert!(
3217 vary_values
3218 .iter()
3219 .any(|v| v.eq_ignore_ascii_case("Authorization")),
3220 "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
3221 );
3222 }
3223
3224 #[cfg(feature = "oauth")]
3225 #[tokio::test]
3226 async fn oauth_token_cache_headers_preserve_existing_vary() {
3227 let app = axum::Router::new()
3230 .route(
3231 "/token",
3232 axum::routing::post(|| async {
3233 axum::response::Response::builder()
3234 .header("vary", "Accept-Encoding")
3235 .body(axum::body::Body::from("{}"))
3236 .unwrap()
3237 }),
3238 )
3239 .layer(axum::middleware::from_fn(
3240 oauth_token_cache_headers_middleware,
3241 ));
3242 let req = Request::builder()
3243 .method("POST")
3244 .uri("/token")
3245 .body(Body::empty())
3246 .unwrap();
3247 let resp = app.oneshot(req).await.unwrap();
3248
3249 let vary: Vec<String> = resp
3250 .headers()
3251 .get_all("vary")
3252 .iter()
3253 .filter_map(|v| v.to_str().ok().map(str::to_owned))
3254 .collect();
3255 assert!(
3256 vary.iter().any(|v| v.contains("Accept-Encoding")),
3257 "must preserve pre-existing Vary value, got {vary:?}"
3258 );
3259 assert!(
3260 vary.iter().any(|v| v.contains("Authorization")),
3261 "must append Authorization to Vary, got {vary:?}"
3262 );
3263 }
3264
3265 #[test]
3268 fn version_payload_contains_expected_fields() {
3269 let v = version_payload("my-server", "1.2.3");
3270 assert_eq!(v["name"], "my-server");
3271 assert_eq!(v["version"], "1.2.3");
3272 assert!(v["build_git_sha"].is_string());
3273 assert!(v["build_timestamp"].is_string());
3274 assert!(v["rust_version"].is_string());
3275 assert!(v["mcpx_version"].is_string());
3276 }
3277
3278 #[tokio::test]
3281 async fn concurrency_limit_layer_composes_and_serves() {
3282 let app = axum::Router::new()
3286 .route("/ok", axum::routing::get(|| async { "ok" }))
3287 .layer(
3288 tower::ServiceBuilder::new()
3289 .layer(axum::error_handling::HandleErrorLayer::new(
3290 |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
3291 ))
3292 .layer(tower::load_shed::LoadShedLayer::new())
3293 .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
3294 );
3295 let resp = app
3296 .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
3297 .await
3298 .unwrap();
3299 assert_eq!(resp.status(), StatusCode::OK);
3300 }
3301
3302 #[tokio::test]
3305 async fn compression_layer_gzip_encodes_response() {
3306 use tower_http::compression::Predicate as _;
3307
3308 let big_body = "a".repeat(4096);
3309 let app = axum::Router::new()
3310 .route(
3311 "/big",
3312 axum::routing::get(move || {
3313 let body = big_body.clone();
3314 async move { body }
3315 }),
3316 )
3317 .layer(
3318 tower_http::compression::CompressionLayer::new()
3319 .gzip(true)
3320 .br(true)
3321 .compress_when(
3322 tower_http::compression::DefaultPredicate::new()
3323 .and(tower_http::compression::predicate::SizeAbove::new(1024)),
3324 ),
3325 );
3326
3327 let req = Request::builder()
3328 .uri("/big")
3329 .header(header::ACCEPT_ENCODING, "gzip")
3330 .body(Body::empty())
3331 .unwrap();
3332 let resp = app.oneshot(req).await.unwrap();
3333 assert_eq!(resp.status(), StatusCode::OK);
3334 assert_eq!(
3335 resp.headers().get(header::CONTENT_ENCODING).unwrap(),
3336 "gzip"
3337 );
3338 }
3339
3340 #[tokio::test]
3343 async fn tls_handshake_timeout_reaps_idle_connections() {
3344 use tokio::io::AsyncReadExt as _;
3345
3346 let _ = rustls::crypto::ring::default_provider().install_default();
3347
3348 let key = rcgen::KeyPair::generate().expect("generate key");
3350 let cert = rcgen::CertificateParams::new(vec!["localhost".to_owned()])
3351 .expect("cert params")
3352 .self_signed(&key)
3353 .expect("self-signed cert");
3354 let dir = std::env::temp_dir().join(format!(
3355 "rmcp-server-kit-hs-timeout-{}",
3356 std::time::SystemTime::now()
3357 .duration_since(std::time::UNIX_EPOCH)
3358 .expect("clock after epoch")
3359 .as_nanos()
3360 ));
3361 tokio::fs::create_dir_all(&dir).await.expect("temp dir");
3362 let cert_path = dir.join("server.crt");
3363 let key_path = dir.join("server.key");
3364 tokio::fs::write(&cert_path, cert.pem())
3365 .await
3366 .expect("write cert");
3367 tokio::fs::write(&key_path, key.serialize_pem())
3368 .await
3369 .expect("write key");
3370
3371 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
3372 let tls = TlsListener::new(
3373 listener,
3374 &cert_path,
3375 &key_path,
3376 None,
3377 None,
3378 Duration::from_millis(200),
3379 )
3380 .expect("tls listener");
3381 let addr = axum::serve::Listener::local_addr(&tls).expect("local addr");
3382
3383 let mut idle = tokio::net::TcpStream::connect(addr).await.expect("connect");
3387 let mut buf = [0_u8; 16];
3388 let read = tokio::time::timeout(Duration::from_secs(2), idle.read(&mut buf))
3389 .await
3390 .expect("server must reap the idle handshake within its timeout");
3391 match read {
3392 Ok(0) | Err(_) => {} Ok(n) => panic!("unexpected {n} bytes from server during reaped handshake"),
3394 }
3395
3396 drop(tls);
3397 }
3398}