1use std::{
2 future::Future,
3 net::SocketAddr,
4 path::{Path, PathBuf},
5 pin::Pin,
6 sync::Arc,
7 time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{
12 body::Body,
13 extract::{ConnectInfo, Request},
14 middleware::Next,
15 response::IntoResponse,
16};
17use rmcp::{
18 ServerHandler,
19 transport::streamable_http_server::{
20 StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
21 },
22};
23use rustls::RootCertStore;
24use tokio::{
25 net::TcpListener,
26 sync::{Semaphore, mpsc},
27};
28use tokio_util::sync::CancellationToken;
29
30use crate::{
31 auth::{
32 AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
33 build_rate_limiter, extract_mtls_identity,
34 },
35 error::McpxError,
36 mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
37 rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
38};
39
40#[allow(
44 clippy::needless_pass_by_value,
45 reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
46)]
47fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
48 McpxError::Startup(format!("{e:#}"))
49}
50
51#[allow(
57 clippy::needless_pass_by_value,
58 reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
59)]
60fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
61 McpxError::Startup(format!("{op}: {e}"))
62}
63
64pub type ReadinessCheck =
69 Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
70
71#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
116#[non_exhaustive]
117pub struct PeerAddr {
118 pub addr: SocketAddr,
120}
121
122impl PeerAddr {
123 #[must_use]
126 pub(crate) const fn new(addr: SocketAddr) -> Self {
127 Self { addr }
128 }
129}
130
131impl<S: Send + Sync> axum::extract::FromRequestParts<S> for PeerAddr {
140 type Rejection = (axum::http::StatusCode, &'static str);
141
142 async fn from_request_parts(
143 parts: &mut axum::http::request::Parts,
144 _state: &S,
145 ) -> Result<Self, Self::Rejection> {
146 parts.extensions.get::<Self>().copied().ok_or((
147 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
148 "peer address unavailable: not running under rmcp-server-kit serve()",
149 ))
150 }
151}
152
153#[derive(Debug, Clone, Default)]
174#[non_exhaustive]
175pub struct SecurityHeadersConfig {
176 pub x_content_type_options: Option<String>,
178 pub x_frame_options: Option<String>,
180 pub cache_control: Option<String>,
182 pub referrer_policy: Option<String>,
184 pub cross_origin_opener_policy: Option<String>,
186 pub cross_origin_resource_policy: Option<String>,
188 pub cross_origin_embedder_policy: Option<String>,
190 pub permissions_policy: Option<String>,
193 pub x_permitted_cross_domain_policies: Option<String>,
195 pub content_security_policy: Option<String>,
198 pub x_dns_prefetch_control: Option<String>,
200 pub strict_transport_security: Option<String>,
205}
206
207#[allow(
209 missing_debug_implementations,
210 reason = "contains callback/trait objects that don't impl Debug"
211)]
212#[allow(
213 clippy::struct_excessive_bools,
214 reason = "server configuration naturally has many boolean feature flags"
215)]
216#[non_exhaustive]
217pub struct McpServerConfig {
218 #[deprecated(
220 since = "0.13.0",
221 note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
222 )]
223 pub bind_addr: String,
224 #[deprecated(
226 since = "0.13.0",
227 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
228 )]
229 pub name: String,
230 #[deprecated(
232 since = "0.13.0",
233 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
234 )]
235 pub version: String,
236 #[deprecated(
238 since = "0.13.0",
239 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
240 )]
241 pub tls_cert_path: Option<PathBuf>,
242 #[deprecated(
244 since = "0.13.0",
245 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
246 )]
247 pub tls_key_path: Option<PathBuf>,
248 #[deprecated(
251 since = "0.13.0",
252 note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
253 )]
254 pub auth: Option<AuthConfig>,
255 #[deprecated(
258 since = "0.13.0",
259 note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
260 )]
261 pub rbac: Option<Arc<RbacPolicy>>,
262 #[deprecated(
268 since = "0.13.0",
269 note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
270 )]
271 pub allowed_origins: Vec<String>,
272 #[deprecated(
275 since = "0.13.0",
276 note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
277 )]
278 pub tool_rate_limit: Option<u32>,
279 #[deprecated(
282 since = "0.13.0",
283 note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
284 )]
285 pub readiness_check: Option<ReadinessCheck>,
286 #[deprecated(
289 since = "0.13.0",
290 note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
291 )]
292 pub max_request_body: usize,
293 #[deprecated(
296 since = "0.13.0",
297 note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
298 )]
299 pub request_timeout: Duration,
300 #[deprecated(
303 since = "0.13.0",
304 note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
305 )]
306 pub shutdown_timeout: Duration,
307 #[deprecated(
310 since = "0.13.0",
311 note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
312 )]
313 pub session_idle_timeout: Duration,
314 #[deprecated(
317 since = "0.13.0",
318 note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
319 )]
320 pub sse_keep_alive: Duration,
321 #[deprecated(
325 since = "0.13.0",
326 note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
327 )]
328 pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
329 #[deprecated(
336 since = "0.13.0",
337 note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
338 )]
339 pub extra_router: Option<axum::Router>,
340 #[deprecated(
345 since = "0.13.0",
346 note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
347 )]
348 pub public_url: Option<String>,
349 #[deprecated(
352 since = "0.13.0",
353 note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
354 )]
355 pub log_request_headers: bool,
356 #[deprecated(
359 since = "0.13.0",
360 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
361 )]
362 pub compression_enabled: bool,
363 #[deprecated(
366 since = "0.13.0",
367 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
368 )]
369 pub compression_min_size: u16,
370 #[deprecated(
374 since = "0.13.0",
375 note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
376 )]
377 pub max_concurrent_requests: Option<usize>,
378 #[deprecated(
381 since = "0.13.0",
382 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
383 )]
384 pub admin_enabled: bool,
385 #[deprecated(
387 since = "0.13.0",
388 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
389 )]
390 pub admin_role: String,
391 #[cfg(feature = "metrics")]
394 #[deprecated(
395 since = "0.13.0",
396 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
397 )]
398 pub metrics_enabled: bool,
399 #[cfg(feature = "metrics")]
401 #[deprecated(
402 since = "0.13.0",
403 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
404 )]
405 pub metrics_bind: String,
406 #[deprecated(
410 since = "1.5.0",
411 note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
412 )]
413 pub security_headers: SecurityHeadersConfig,
414 #[deprecated(
420 since = "1.9.0",
421 note = "use McpServerConfig::with_tls_handshake_timeout(); direct field access will become pub(crate) in a future major release"
422 )]
423 pub tls_handshake_timeout: Duration,
424 #[deprecated(
431 since = "1.9.0",
432 note = "use McpServerConfig::with_max_concurrent_tls_handshakes(); direct field access will become pub(crate) in a future major release"
433 )]
434 pub max_concurrent_tls_handshakes: usize,
435}
436
437#[allow(
495 missing_debug_implementations,
496 reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
497)]
498pub struct Validated<T>(T);
499
500impl<T> std::fmt::Debug for Validated<T> {
501 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
502 f.debug_struct("Validated").finish_non_exhaustive()
503 }
504}
505
506impl<T> Validated<T> {
507 #[must_use]
509 pub fn as_inner(&self) -> &T {
510 &self.0
511 }
512
513 #[must_use]
518 pub fn into_inner(self) -> T {
519 self.0
520 }
521}
522
523#[allow(
524 deprecated,
525 reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
526)]
527impl McpServerConfig {
528 #[must_use]
536 pub fn new(
537 bind_addr: impl Into<String>,
538 name: impl Into<String>,
539 version: impl Into<String>,
540 ) -> Self {
541 Self {
542 bind_addr: bind_addr.into(),
543 name: name.into(),
544 version: version.into(),
545 tls_cert_path: None,
546 tls_key_path: None,
547 auth: None,
548 rbac: None,
549 allowed_origins: Vec::new(),
550 tool_rate_limit: None,
551 readiness_check: None,
552 max_request_body: 1024 * 1024,
553 request_timeout: Duration::from_mins(2),
554 shutdown_timeout: Duration::from_secs(30),
555 session_idle_timeout: Duration::from_mins(20),
556 sse_keep_alive: Duration::from_secs(15),
557 on_reload_ready: None,
558 extra_router: None,
559 public_url: None,
560 log_request_headers: false,
561 compression_enabled: false,
562 compression_min_size: 1024,
563 max_concurrent_requests: None,
564 admin_enabled: false,
565 admin_role: "admin".to_owned(),
566 #[cfg(feature = "metrics")]
567 metrics_enabled: false,
568 #[cfg(feature = "metrics")]
569 metrics_bind: "127.0.0.1:9090".into(),
570 security_headers: SecurityHeadersConfig::default(),
571 tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
572 max_concurrent_tls_handshakes: DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES,
573 }
574 }
575
576 #[must_use]
586 pub fn with_auth(mut self, auth: AuthConfig) -> Self {
587 self.auth = Some(auth);
588 self
589 }
590
591 #[must_use]
596 pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
597 self.security_headers = headers;
598 self
599 }
600
601 #[must_use]
605 pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
606 self.bind_addr = addr.into();
607 self
608 }
609
610 #[must_use]
613 pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
614 self.rbac = Some(rbac);
615 self
616 }
617
618 #[must_use]
622 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
623 self.tls_cert_path = Some(cert_path.into());
624 self.tls_key_path = Some(key_path.into());
625 self
626 }
627
628 #[must_use]
632 pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
633 self.public_url = Some(url.into());
634 self
635 }
636
637 #[must_use]
641 pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
642 where
643 I: IntoIterator<Item = S>,
644 S: Into<String>,
645 {
646 self.allowed_origins = origins.into_iter().map(Into::into).collect();
647 self
648 }
649
650 #[must_use]
663 pub fn with_extra_router(mut self, router: axum::Router) -> Self {
664 self.extra_router = Some(router);
665 self
666 }
667
668 #[must_use]
671 pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
672 self.readiness_check = Some(check);
673 self
674 }
675
676 #[must_use]
679 pub fn with_max_request_body(mut self, bytes: usize) -> Self {
680 self.max_request_body = bytes;
681 self
682 }
683
684 #[must_use]
686 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
687 self.request_timeout = timeout;
688 self
689 }
690
691 #[must_use]
693 pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
694 self.shutdown_timeout = timeout;
695 self
696 }
697
698 #[must_use]
700 pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
701 self.session_idle_timeout = timeout;
702 self
703 }
704
705 #[must_use]
707 pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
708 self.sse_keep_alive = interval;
709 self
710 }
711
712 #[must_use]
716 pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
717 self.max_concurrent_requests = Some(limit);
718 self
719 }
720
721 #[must_use]
729 pub fn with_tls_handshake_timeout(mut self, timeout: Duration) -> Self {
730 self.tls_handshake_timeout = timeout;
731 self
732 }
733
734 #[must_use]
743 pub fn with_max_concurrent_tls_handshakes(mut self, limit: usize) -> Self {
744 self.max_concurrent_tls_handshakes = limit;
745 self
746 }
747
748 #[must_use]
751 pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
752 self.tool_rate_limit = Some(per_minute);
753 self
754 }
755
756 #[must_use]
760 pub fn with_reload_callback<F>(mut self, callback: F) -> Self
761 where
762 F: FnOnce(ReloadHandle) + Send + 'static,
763 {
764 self.on_reload_ready = Some(Box::new(callback));
765 self
766 }
767
768 #[must_use]
772 pub fn enable_compression(mut self, min_size: u16) -> Self {
773 self.compression_enabled = true;
774 self.compression_min_size = min_size;
775 self
776 }
777
778 #[must_use]
783 pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
784 self.admin_enabled = true;
785 self.admin_role = role.into();
786 self
787 }
788
789 #[must_use]
792 pub fn enable_request_header_logging(mut self) -> Self {
793 self.log_request_headers = true;
794 self
795 }
796
797 #[cfg(feature = "metrics")]
800 #[must_use]
801 pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
802 self.metrics_enabled = true;
803 self.metrics_bind = bind.into();
804 self
805 }
806
807 pub fn validate(self) -> Result<Validated<Self>, McpxError> {
840 self.check()?;
841 Ok(Validated(self))
842 }
843
844 fn check(&self) -> Result<(), McpxError> {
848 if self.admin_enabled {
852 let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
853 if !auth_enabled {
854 return Err(McpxError::Config(
855 "admin_enabled=true requires auth to be configured and enabled".into(),
856 ));
857 }
858 }
859
860 match (&self.tls_cert_path, &self.tls_key_path) {
862 (Some(_), None) => {
863 return Err(McpxError::Config(
864 "tls_cert_path is set but tls_key_path is missing".into(),
865 ));
866 }
867 (None, Some(_)) => {
868 return Err(McpxError::Config(
869 "tls_key_path is set but tls_cert_path is missing".into(),
870 ));
871 }
872 _ => {}
873 }
874
875 if self.bind_addr.parse::<SocketAddr>().is_err() {
877 return Err(McpxError::Config(format!(
878 "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
879 self.bind_addr
880 )));
881 }
882
883 if let Some(ref url) = self.public_url
885 && !(url.starts_with("http://") || url.starts_with("https://"))
886 {
887 return Err(McpxError::Config(format!(
888 "public_url {url:?} must start with http:// or https://"
889 )));
890 }
891
892 for origin in &self.allowed_origins {
894 if !(origin.starts_with("http://") || origin.starts_with("https://")) {
895 return Err(McpxError::Config(format!(
896 "allowed_origins entry {origin:?} must start with http:// or https://"
897 )));
898 }
899 }
900
901 if self.max_request_body == 0 {
903 return Err(McpxError::Config(
904 "max_request_body must be greater than zero".into(),
905 ));
906 }
907
908 #[cfg(feature = "oauth")]
910 if let Some(auth_cfg) = &self.auth
911 && let Some(oauth_cfg) = &auth_cfg.oauth
912 {
913 oauth_cfg.validate()?;
914 }
915
916 validate_security_headers(&self.security_headers)?;
919
920 if let Some(0) = self.max_concurrent_requests {
924 return Err(McpxError::Config(
925 "max_concurrent_requests must be greater than zero when set".into(),
926 ));
927 }
928
929 if let Some(auth_cfg) = &self.auth
933 && let Some(rl) = &auth_cfg.rate_limit
934 && rl.max_tracked_keys == 0
935 {
936 return Err(McpxError::Config(
937 "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
938 ));
939 }
940
941 if self.tls_handshake_timeout == Duration::ZERO {
946 return Err(McpxError::Config(
947 "tls_handshake_timeout must be greater than zero".into(),
948 ));
949 }
950
951 if self.max_concurrent_tls_handshakes == 0 {
956 return Err(McpxError::Config(
957 "max_concurrent_tls_handshakes must be greater than zero".into(),
958 ));
959 }
960
961 Ok(())
962 }
963}
964
965#[allow(
971 missing_debug_implementations,
972 reason = "contains Arc<AuthState> with non-Debug fields"
973)]
974pub struct ReloadHandle {
975 auth: Option<Arc<AuthState>>,
976 rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
977 crl_set: Option<Arc<CrlSet>>,
978}
979
980impl ReloadHandle {
981 pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
983 if let Some(ref auth) = self.auth {
984 auth.reload_keys(keys);
985 }
986 }
987
988 pub fn reload_rbac(&self, policy: RbacPolicy) {
990 if let Some(ref rbac) = self.rbac {
991 rbac.store(Arc::new(policy));
992 tracing::info!("RBAC policy reloaded");
993 }
994 }
995
996 pub async fn refresh_crls(&self) -> Result<(), McpxError> {
1002 let Some(ref crl_set) = self.crl_set else {
1003 return Err(McpxError::Config(
1004 "CRL refresh requested but mTLS CRL support is not configured".into(),
1005 ));
1006 };
1007
1008 crl_set.force_refresh().await
1009 }
1010}
1011
1012#[allow(
1029 clippy::too_many_lines,
1030 clippy::cognitive_complexity,
1031 reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
1032)]
1033struct AppRunParams {
1037 tls_paths: Option<(PathBuf, PathBuf)>,
1039 tls_handshake_timeout: Duration,
1041 max_concurrent_tls_handshakes: usize,
1043 mtls_config: Option<MtlsConfig>,
1045 shutdown_timeout: Duration,
1047 auth_state: Option<Arc<AuthState>>,
1049 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1051 on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1053 ct: CancellationToken,
1057 scheme: &'static str,
1059 name: String,
1061}
1062
1063#[allow(
1073 clippy::cognitive_complexity,
1074 reason = "router assembly is intrinsically sequential; splitting harms readability"
1075)]
1076#[allow(
1077 deprecated,
1078 reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
1079)]
1080fn build_app_router<H, F>(
1081 mut config: McpServerConfig,
1082 handler_factory: F,
1083) -> anyhow::Result<(axum::Router, AppRunParams)>
1084where
1085 H: ServerHandler + 'static,
1086 F: Fn() -> H + Send + Sync + Clone + 'static,
1087{
1088 let ct = CancellationToken::new();
1089
1090 let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
1091 tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
1092
1093 let mcp_service = StreamableHttpService::new(
1094 move || Ok(handler_factory()),
1095 {
1096 let mut mgr = LocalSessionManager::default();
1097 mgr.session_config.keep_alive = Some(config.session_idle_timeout);
1098 mgr.into()
1099 },
1100 StreamableHttpServerConfig::default()
1101 .with_allowed_hosts(allowed_hosts)
1102 .with_sse_keep_alive(Some(config.sse_keep_alive))
1103 .with_cancellation_token(ct.child_token()),
1104 );
1105
1106 let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
1108
1109 let auth_state: Option<Arc<AuthState>> = match config.auth {
1113 Some(ref auth_config) if auth_config.enabled => {
1114 let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
1115 let pre_auth_limiter = auth_config
1116 .rate_limit
1117 .as_ref()
1118 .map(crate::auth::build_pre_auth_limiter);
1119
1120 #[cfg(feature = "oauth")]
1121 let jwks_cache = auth_config
1122 .oauth
1123 .as_ref()
1124 .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
1125 .transpose()
1126 .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
1127
1128 Some(Arc::new(AuthState {
1129 api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
1130 rate_limiter,
1131 pre_auth_limiter,
1132 #[cfg(feature = "oauth")]
1133 jwks_cache,
1134 seen_identities: crate::auth::SeenIdentitySet::new(),
1135 counters: crate::auth::AuthCounters::default(),
1136 }))
1137 }
1138 _ => None,
1139 };
1140
1141 let rbac_swap = Arc::new(ArcSwap::new(
1144 config
1145 .rbac
1146 .clone()
1147 .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
1148 ));
1149
1150 if config.admin_enabled {
1153 let Some(ref auth_state_ref) = auth_state else {
1154 return Err(anyhow::anyhow!(
1155 "admin_enabled=true requires auth to be configured and enabled"
1156 ));
1157 };
1158 let admin_state = crate::admin::AdminState {
1159 started_at: std::time::Instant::now(),
1160 name: config.name.clone(),
1161 version: config.version.clone(),
1162 auth: Some(Arc::clone(auth_state_ref)),
1163 rbac: Arc::clone(&rbac_swap),
1164 };
1165 let admin_cfg = crate::admin::AdminConfig {
1166 role: config.admin_role.clone(),
1167 };
1168 mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
1169 tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
1170 }
1171
1172 {
1205 let tool_limiter: Option<Arc<ToolRateLimiter>> =
1206 config.tool_rate_limit.map(build_tool_rate_limiter);
1207
1208 if rbac_swap.load().is_enabled() {
1209 tracing::info!("RBAC enforcement enabled on /mcp");
1210 }
1211 if let Some(limit) = config.tool_rate_limit {
1212 tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1213 }
1214
1215 let rbac_for_mw = Arc::clone(&rbac_swap);
1216 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1217 let p = rbac_for_mw.load_full();
1218 let tl = tool_limiter.clone();
1219 rbac_middleware(p, tl, req, next)
1220 }));
1221 }
1222
1223 if let Some(ref auth_config) = config.auth
1225 && auth_config.enabled
1226 {
1227 let Some(ref state) = auth_state else {
1228 return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1229 };
1230
1231 let methods: Vec<&str> = [
1232 auth_config.mtls.is_some().then_some("mTLS"),
1233 (!auth_config.api_keys.is_empty()).then_some("bearer"),
1234 #[cfg(feature = "oauth")]
1235 auth_config.oauth.is_some().then_some("oauth-jwt"),
1236 ]
1237 .into_iter()
1238 .flatten()
1239 .collect();
1240
1241 tracing::info!(
1242 methods = %methods.join(", "),
1243 api_keys = auth_config.api_keys.len(),
1244 "auth enabled on /mcp"
1245 );
1246
1247 let state_for_mw = Arc::clone(state);
1248 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1249 let s = Arc::clone(&state_for_mw);
1250 auth_middleware(s, req, next)
1251 }));
1252 }
1253
1254 mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1257 axum::http::StatusCode::REQUEST_TIMEOUT,
1258 config.request_timeout,
1259 ));
1260
1261 mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1265 config.max_request_body,
1266 ));
1267
1268 let mut effective_origins = config.allowed_origins.clone();
1275 if effective_origins.is_empty()
1276 && let Some(ref url) = config.public_url
1277 {
1278 if let Some(scheme_end) = url.find("://") {
1283 let scheme_with_sep = url.get(..scheme_end + 3).unwrap_or_default();
1284 let after_scheme = url.get(scheme_end + 3..).unwrap_or_default();
1285 let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1286 let host = after_scheme.get(..host_end).unwrap_or_default();
1287 let origin = format!("{scheme_with_sep}{host}");
1288 tracing::info!(
1289 %origin,
1290 "auto-derived allowed origin from public_url"
1291 );
1292 effective_origins.push(origin);
1293 }
1294 }
1295 let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1296 let cors_origins = Arc::clone(&allowed_origins);
1297 let log_request_headers = config.log_request_headers;
1298
1299 let readyz_route = if let Some(check) = config.readiness_check.take() {
1300 axum::routing::get(move || readyz(Arc::clone(&check)))
1301 } else {
1302 axum::routing::get(healthz)
1303 };
1304
1305 #[allow(unused_mut)] let mut router = axum::Router::new()
1307 .route("/healthz", axum::routing::get(healthz))
1308 .route("/readyz", readyz_route)
1309 .route(
1310 "/version",
1311 axum::routing::get({
1312 let payload_bytes: Arc<[u8]> =
1317 serialize_version_payload(&config.name, &config.version);
1318 move || {
1319 let p = Arc::clone(&payload_bytes);
1320 async move {
1321 (
1322 [(axum::http::header::CONTENT_TYPE, "application/json")],
1323 p.to_vec(),
1324 )
1325 }
1326 }
1327 }),
1328 )
1329 .merge(mcp_router);
1330
1331 if let Some(extra) = config.extra_router.take() {
1333 router = router.merge(extra);
1334 }
1335
1336 let server_url = if let Some(ref url) = config.public_url {
1343 url.trim_end_matches('/').to_owned()
1344 } else {
1345 let prm_scheme = if config.tls_cert_path.is_some() {
1346 "https"
1347 } else {
1348 "http"
1349 };
1350 format!("{prm_scheme}://{}", config.bind_addr)
1351 };
1352 let resource_url = format!("{server_url}/mcp");
1353
1354 #[cfg(feature = "oauth")]
1355 let prm_metadata = if let Some(ref auth_config) = config.auth
1356 && let Some(ref oauth_config) = auth_config.oauth
1357 {
1358 crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1359 } else {
1360 serde_json::json!({ "resource": resource_url })
1361 };
1362 #[cfg(not(feature = "oauth"))]
1363 let prm_metadata = serde_json::json!({ "resource": resource_url });
1364
1365 router = router.route(
1366 "/.well-known/oauth-protected-resource",
1367 axum::routing::get(move || {
1368 let m = prm_metadata.clone();
1369 async move { axum::Json(m) }
1370 }),
1371 );
1372
1373 #[cfg(feature = "oauth")]
1378 if let Some(ref auth_config) = config.auth
1379 && let Some(ref oauth_config) = auth_config.oauth
1380 && oauth_config.proxy.is_some()
1381 {
1382 router =
1383 install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1384 }
1385
1386 let is_tls = config.tls_cert_path.is_some();
1389 let security_headers_cfg = Arc::new(config.security_headers.clone());
1390 router = router.layer(axum::middleware::from_fn(move |req, next| {
1391 let cfg = Arc::clone(&security_headers_cfg);
1392 security_headers_middleware(is_tls, cfg, req, next)
1393 }));
1394
1395 if !cors_origins.is_empty() {
1399 let cors = tower_http::cors::CorsLayer::new()
1400 .allow_origin(
1401 cors_origins
1402 .iter()
1403 .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1404 .collect::<Vec<_>>(),
1405 )
1406 .allow_methods([
1407 axum::http::Method::GET,
1408 axum::http::Method::POST,
1409 axum::http::Method::OPTIONS,
1410 ])
1411 .allow_headers([
1412 axum::http::header::CONTENT_TYPE,
1413 axum::http::header::AUTHORIZATION,
1414 ]);
1415 router = router.layer(cors);
1416 }
1417
1418 if config.compression_enabled {
1422 use tower_http::compression::Predicate as _;
1423 let predicate = tower_http::compression::DefaultPredicate::new().and(
1424 tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1425 );
1426 router = router.layer(
1427 tower_http::compression::CompressionLayer::new()
1428 .gzip(true)
1429 .br(true)
1430 .compress_when(predicate),
1431 );
1432 tracing::info!(
1433 min_size = config.compression_min_size,
1434 "response compression enabled (gzip, br)"
1435 );
1436 }
1437
1438 if let Some(max) = config.max_concurrent_requests {
1441 let overload_handler = tower::ServiceBuilder::new()
1442 .layer(axum::error_handling::HandleErrorLayer::new(
1443 |_err: tower::BoxError| async {
1444 (
1445 axum::http::StatusCode::SERVICE_UNAVAILABLE,
1446 axum::Json(serde_json::json!({
1447 "error": "overloaded",
1448 "error_description": "server is at capacity, retry later"
1449 })),
1450 )
1451 },
1452 ))
1453 .layer(tower::load_shed::LoadShedLayer::new())
1454 .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1455 router = router.layer(overload_handler);
1456 tracing::info!(max, "global concurrency limit enabled");
1457 }
1458
1459 router = router.fallback(|| async {
1463 (
1464 axum::http::StatusCode::NOT_FOUND,
1465 axum::Json(serde_json::json!({
1466 "error": "not_found",
1467 "error_description": "The requested endpoint does not exist"
1468 })),
1469 )
1470 });
1471
1472 #[cfg(feature = "metrics")]
1474 if config.metrics_enabled {
1475 let metrics = Arc::new(
1476 crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1477 );
1478 let m = Arc::clone(&metrics);
1479 router = router.layer(axum::middleware::from_fn(
1480 move |req: Request<Body>, next: Next| {
1481 let m = Arc::clone(&m);
1482 metrics_middleware(m, req, next)
1483 },
1484 ));
1485 let metrics_bind = config.metrics_bind.clone();
1486 let metrics_shutdown = ct.clone();
1487 tokio::spawn(async move {
1488 if let Err(e) =
1489 crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1490 {
1491 tracing::error!("metrics listener failed: {e}");
1492 }
1493 });
1494 }
1495
1496 router = router.layer(axum::middleware::from_fn(normalize_peer_addr_middleware));
1504
1505 router = router.layer(axum::middleware::from_fn(move |req, next| {
1516 let origins = Arc::clone(&allowed_origins);
1517 origin_check_middleware(origins, log_request_headers, req, next)
1518 }));
1519
1520 let scheme = if config.tls_cert_path.is_some() {
1521 "https"
1522 } else {
1523 "http"
1524 };
1525
1526 let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1527 (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1528 _ => None,
1529 };
1530 let tls_handshake_timeout = config.tls_handshake_timeout;
1531 let max_concurrent_tls_handshakes = config.max_concurrent_tls_handshakes;
1532 let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1533
1534 Ok((
1535 router,
1536 AppRunParams {
1537 tls_paths,
1538 tls_handshake_timeout,
1539 max_concurrent_tls_handshakes,
1540 mtls_config,
1541 shutdown_timeout: config.shutdown_timeout,
1542 auth_state,
1543 rbac_swap,
1544 on_reload_ready: config.on_reload_ready.take(),
1545 ct,
1546 scheme,
1547 name: config.name.clone(),
1548 },
1549 ))
1550}
1551
1552pub async fn serve<H, F>(
1569 config: Validated<McpServerConfig>,
1570 handler_factory: F,
1571) -> Result<(), McpxError>
1572where
1573 H: ServerHandler + 'static,
1574 F: Fn() -> H + Send + Sync + Clone + 'static,
1575{
1576 let config = config.into_inner();
1577 #[allow(
1578 deprecated,
1579 reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1580 )]
1581 let bind_addr = config.bind_addr.clone();
1582 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1583
1584 let listener = TcpListener::bind(&bind_addr)
1585 .await
1586 .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1587 log_listening(¶ms.name, params.scheme, &bind_addr);
1588
1589 run_server(
1590 router,
1591 listener,
1592 params.tls_paths,
1593 params.tls_handshake_timeout,
1594 params.max_concurrent_tls_handshakes,
1595 params.mtls_config,
1596 params.shutdown_timeout,
1597 params.auth_state,
1598 params.rbac_swap,
1599 params.on_reload_ready,
1600 params.ct,
1601 )
1602 .await
1603 .map_err(anyhow_to_startup)
1604}
1605
1606pub async fn serve_with_listener<H, F>(
1636 listener: TcpListener,
1637 config: Validated<McpServerConfig>,
1638 handler_factory: F,
1639 ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1640 shutdown: Option<CancellationToken>,
1641) -> Result<(), McpxError>
1642where
1643 H: ServerHandler + 'static,
1644 F: Fn() -> H + Send + Sync + Clone + 'static,
1645{
1646 let config = config.into_inner();
1647 let local_addr = listener
1648 .local_addr()
1649 .map_err(|e| io_to_startup("listener.local_addr", e))?;
1650 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1651
1652 log_listening(¶ms.name, params.scheme, &local_addr.to_string());
1653
1654 if let Some(external) = shutdown {
1658 let internal = params.ct.clone();
1659 tokio::spawn(async move {
1660 external.cancelled().await;
1661 internal.cancel();
1662 });
1663 }
1664
1665 if let Some(tx) = ready_tx {
1669 let _ = tx.send(local_addr);
1671 }
1672
1673 run_server(
1674 router,
1675 listener,
1676 params.tls_paths,
1677 params.tls_handshake_timeout,
1678 params.max_concurrent_tls_handshakes,
1679 params.mtls_config,
1680 params.shutdown_timeout,
1681 params.auth_state,
1682 params.rbac_swap,
1683 params.on_reload_ready,
1684 params.ct,
1685 )
1686 .await
1687 .map_err(anyhow_to_startup)
1688}
1689
1690#[allow(
1693 clippy::cognitive_complexity,
1694 reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1695)]
1696fn log_listening(name: &str, scheme: &str, addr: &str) {
1697 tracing::info!("{name} listening on {addr}");
1698 tracing::info!(" MCP endpoint: {scheme}://{addr}/mcp");
1699 tracing::info!(" Health check: {scheme}://{addr}/healthz");
1700 tracing::info!(" Readiness: {scheme}://{addr}/readyz");
1701}
1702
1703#[allow(
1726 clippy::too_many_arguments,
1727 clippy::cognitive_complexity,
1728 reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1729)]
1730async fn run_server(
1731 router: axum::Router,
1732 listener: TcpListener,
1733 tls_paths: Option<(PathBuf, PathBuf)>,
1734 tls_handshake_timeout: Duration,
1735 max_concurrent_tls_handshakes: usize,
1736 mtls_config: Option<MtlsConfig>,
1737 shutdown_timeout: Duration,
1738 auth_state: Option<Arc<AuthState>>,
1739 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1740 mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1741 ct: CancellationToken,
1742) -> anyhow::Result<()> {
1743 let shutdown_trigger = CancellationToken::new();
1747 {
1748 let trigger = shutdown_trigger.clone();
1749 let parent = ct.clone();
1750 tokio::spawn(async move {
1751 tokio::select! {
1752 () = shutdown_signal() => {}
1753 () = parent.cancelled() => {}
1754 }
1755 trigger.cancel();
1756 });
1757 }
1758
1759 let graceful = {
1760 let trigger = shutdown_trigger.clone();
1761 let ct = ct.clone();
1762 async move {
1763 trigger.cancelled().await;
1764 tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1765 ct.cancel();
1766 }
1767 };
1768
1769 let force_exit_timer = {
1770 let trigger = shutdown_trigger.clone();
1771 async move {
1772 trigger.cancelled().await;
1773 tokio::time::sleep(shutdown_timeout).await;
1774 }
1775 };
1776
1777 if let Some((cert_path, key_path)) = tls_paths {
1778 let crl_set = if let Some(mtls) = mtls_config.as_ref()
1779 && mtls.crl_enabled
1780 {
1781 let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1782 let (crl_set, discover_rx) =
1783 mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1784 .await
1785 .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1786 tokio::spawn(mtls_revocation::run_crl_refresher(
1787 Arc::clone(&crl_set),
1788 discover_rx,
1789 ct.clone(),
1790 ));
1791 Some(crl_set)
1792 } else {
1793 None
1794 };
1795
1796 if let Some(cb) = on_reload_ready.take() {
1797 cb(ReloadHandle {
1798 auth: auth_state.clone(),
1799 rbac: Some(Arc::clone(&rbac_swap)),
1800 crl_set: crl_set.clone(),
1801 });
1802 }
1803
1804 let tls_listener = TlsListener::new(
1805 listener,
1806 &cert_path,
1807 &key_path,
1808 mtls_config.as_ref(),
1809 crl_set,
1810 tls_handshake_timeout,
1811 max_concurrent_tls_handshakes,
1812 )?;
1813 let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1814 tokio::select! {
1815 result = axum::serve(tls_listener, make_svc)
1816 .with_graceful_shutdown(graceful) => { result?; }
1817 () = force_exit_timer => {
1818 tracing::warn!("shutdown timeout exceeded, forcing exit");
1819 }
1820 }
1821 } else {
1822 if let Some(cb) = on_reload_ready.take() {
1823 cb(ReloadHandle {
1824 auth: auth_state,
1825 rbac: Some(rbac_swap),
1826 crl_set: None,
1827 });
1828 }
1829
1830 let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1831 tokio::select! {
1832 result = axum::serve(listener, make_svc)
1833 .with_graceful_shutdown(graceful) => { result?; }
1834 () = force_exit_timer => {
1835 tracing::warn!("shutdown timeout exceeded, forcing exit");
1836 }
1837 }
1838 }
1839
1840 Ok(())
1841}
1842
1843#[cfg(feature = "oauth")]
1852fn install_oauth_proxy_routes(
1853 router: axum::Router,
1854 server_url: &str,
1855 oauth_config: &crate::oauth::OAuthConfig,
1856 auth_state: Option<&Arc<AuthState>>,
1857) -> Result<axum::Router, McpxError> {
1858 let Some(ref proxy) = oauth_config.proxy else {
1859 return Ok(router);
1860 };
1861
1862 let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1865
1866 let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1867 let router = router.route(
1868 "/.well-known/oauth-authorization-server",
1869 axum::routing::get(move || {
1870 let m = asm.clone();
1871 async move { axum::Json(m) }
1872 }),
1873 );
1874
1875 let proxy_authorize = proxy.clone();
1876 let router = router.route(
1877 "/authorize",
1878 axum::routing::get(
1879 move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1880 let p = proxy_authorize.clone();
1881 async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1882 },
1883 ),
1884 );
1885
1886 let proxy_token = proxy.clone();
1887 let token_http = http.clone();
1888 let router = router.route(
1889 "/token",
1890 axum::routing::post(move |body: String| {
1891 let p = proxy_token.clone();
1892 let h = token_http.clone();
1893 async move { crate::oauth::handle_token(&h, &p, &body).await }
1894 })
1895 .layer(axum::middleware::from_fn(
1896 oauth_token_cache_headers_middleware,
1897 )),
1898 );
1899
1900 let proxy_register = proxy.clone();
1901 let router = router.route(
1902 "/register",
1903 axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1904 let p = proxy_register;
1905 async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1906 })
1907 .layer(axum::middleware::from_fn(
1908 oauth_token_cache_headers_middleware,
1909 )),
1910 );
1911
1912 let admin_routes_enabled = proxy.expose_admin_endpoints
1913 && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1914 if proxy.expose_admin_endpoints
1915 && !proxy.require_auth_on_admin_endpoints
1916 && proxy.allow_unauthenticated_admin_endpoints
1917 {
1918 tracing::warn!(
1922 "OAuth introspect/revoke endpoints are unauthenticated by explicit \
1923 allow_unauthenticated_admin_endpoints opt-out; ensure an \
1924 authenticated reverse proxy fronts these routes"
1925 );
1926 }
1927
1928 let admin_router = if admin_routes_enabled {
1929 build_oauth_admin_router(proxy, http, auth_state)?
1930 } else {
1931 axum::Router::new()
1932 };
1933
1934 let router = router.merge(admin_router);
1935
1936 tracing::info!(
1937 introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1938 revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1939 "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1940 );
1941 Ok(router)
1942}
1943
1944#[cfg(feature = "oauth")]
1950fn build_oauth_admin_router(
1951 proxy: &crate::oauth::OAuthProxyConfig,
1952 http: crate::oauth::OauthHttpClient,
1953 auth_state: Option<&Arc<AuthState>>,
1954) -> Result<axum::Router, McpxError> {
1955 let mut admin_router = axum::Router::new();
1956 if proxy.introspection_url.is_some() {
1957 let proxy_introspect = proxy.clone();
1958 let introspect_http = http.clone();
1959 admin_router = admin_router.route(
1960 "/introspect",
1961 axum::routing::post(move |body: String| {
1962 let p = proxy_introspect.clone();
1963 let h = introspect_http.clone();
1964 async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1965 }),
1966 );
1967 }
1968 if proxy.revocation_url.is_some() {
1969 let proxy_revoke = proxy.clone();
1970 let revoke_http = http;
1971 admin_router = admin_router.route(
1972 "/revoke",
1973 axum::routing::post(move |body: String| {
1974 let p = proxy_revoke.clone();
1975 let h = revoke_http.clone();
1976 async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1977 }),
1978 );
1979 }
1980
1981 let admin_router = admin_router.layer(axum::middleware::from_fn(
1982 oauth_token_cache_headers_middleware,
1983 ));
1984
1985 if proxy.require_auth_on_admin_endpoints {
1986 let Some(state) = auth_state else {
1987 return Err(McpxError::Startup(
1988 "oauth proxy admin endpoints require auth state".into(),
1989 ));
1990 };
1991 let state_for_mw = Arc::clone(state);
1992 Ok(
1993 admin_router.layer(axum::middleware::from_fn(move |req, next| {
1994 let s = Arc::clone(&state_for_mw);
1995 auth_middleware(s, req, next)
1996 })),
1997 )
1998 } else {
1999 Ok(admin_router)
2000 }
2001}
2002
2003fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
2008 let mut hosts = vec![
2009 "localhost".to_owned(),
2010 "127.0.0.1".to_owned(),
2011 "::1".to_owned(),
2012 ];
2013
2014 if let Some(url) = public_url
2015 && let Ok(uri) = url.parse::<axum::http::Uri>()
2016 && let Some(authority) = uri.authority()
2017 {
2018 let host = authority.host().to_owned();
2019 if !hosts.iter().any(|h| h == &host) {
2020 hosts.push(host);
2021 }
2022
2023 let authority = authority.as_str().to_owned();
2024 if !hosts.iter().any(|h| h == &authority) {
2025 hosts.push(authority);
2026 }
2027 }
2028
2029 if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
2030 && let Some(authority) = uri.authority()
2031 {
2032 let host = authority.host().to_owned();
2033 if !hosts.iter().any(|h| h == &host) {
2034 hosts.push(host);
2035 }
2036
2037 let authority = authority.as_str().to_owned();
2038 if !hosts.iter().any(|h| h == &authority) {
2039 hosts.push(authority);
2040 }
2041 }
2042
2043 hosts
2044}
2045
2046impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
2059 for TlsConnInfo
2060{
2061 fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
2062 let addr = *target.remote_addr();
2063 let identity = target.io().identity().cloned();
2064 TlsConnInfo::new(addr, identity)
2065 }
2066}
2067
2068const DEFAULT_TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
2075
2076const DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES: usize = 256;
2084
2085const TLS_ACCEPT_CHANNEL_CAPACITY: usize = 32;
2090
2091struct TlsListener {
2107 local_addr: SocketAddr,
2110 rx: mpsc::Receiver<(AuthenticatedTlsStream, SocketAddr)>,
2112 acceptor_task: tokio::task::JoinHandle<()>,
2115}
2116
2117impl TlsListener {
2118 fn new(
2119 inner: TcpListener,
2120 cert_path: &Path,
2121 key_path: &Path,
2122 mtls_config: Option<&MtlsConfig>,
2123 crl_set: Option<Arc<CrlSet>>,
2124 handshake_timeout: Duration,
2125 max_concurrent_handshakes: usize,
2126 ) -> anyhow::Result<Self> {
2127 rustls::crypto::ring::default_provider()
2129 .install_default()
2130 .ok();
2131
2132 let certs = load_certs(cert_path)?;
2133 let key = load_key(key_path)?;
2134
2135 let mtls_default_role;
2136
2137 let tls_config = if let Some(mtls) = mtls_config {
2138 mtls_default_role = mtls.default_role.clone();
2139 let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
2140 {
2141 let Some(crl_set) = crl_set else {
2142 return Err(anyhow::anyhow!(
2143 "mTLS CRL verifier requested but CRL state was not initialized"
2144 ));
2145 };
2146 Arc::new(DynamicClientCertVerifier::new(crl_set))
2147 } else {
2148 let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
2149 if mtls.required {
2150 rustls::server::WebPkiClientVerifier::builder(root_store)
2151 .build()
2152 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2153 } else {
2154 rustls::server::WebPkiClientVerifier::builder(root_store)
2155 .allow_unauthenticated()
2156 .build()
2157 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2158 }
2159 };
2160
2161 tracing::info!(
2162 ca = %mtls.ca_cert_path.display(),
2163 required = mtls.required,
2164 crl_enabled = mtls.crl_enabled,
2165 "mTLS client auth configured"
2166 );
2167
2168 rustls::ServerConfig::builder_with_protocol_versions(&[
2169 &rustls::version::TLS12,
2170 &rustls::version::TLS13,
2171 ])
2172 .with_client_cert_verifier(verifier)
2173 .with_single_cert(certs, key)?
2174 } else {
2175 mtls_default_role = "viewer".to_owned();
2176 rustls::ServerConfig::builder_with_protocol_versions(&[
2177 &rustls::version::TLS12,
2178 &rustls::version::TLS13,
2179 ])
2180 .with_no_client_auth()
2181 .with_single_cert(certs, key)?
2182 };
2183
2184 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
2185 tracing::info!(
2186 "TLS enabled (cert: {}, key: {})",
2187 cert_path.display(),
2188 key_path.display()
2189 );
2190 let local_addr = inner.local_addr()?;
2191 let (tx, rx) = mpsc::channel(TLS_ACCEPT_CHANNEL_CAPACITY);
2192 let acceptor_task = tokio::spawn(run_tls_acceptor(
2193 inner,
2194 acceptor,
2195 mtls_default_role,
2196 tx,
2197 handshake_timeout,
2198 max_concurrent_handshakes,
2199 ));
2200 Ok(Self {
2201 local_addr,
2202 rx,
2203 acceptor_task,
2204 })
2205 }
2206
2207 fn extract_handshake_identity(
2211 tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2212 default_role: &str,
2213 addr: SocketAddr,
2214 ) -> Option<AuthIdentity> {
2215 let (_, server_conn) = tls_stream.get_ref();
2216 let cert_der = server_conn.peer_certificates()?.first()?;
2217 let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
2218 tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
2219 Some(id)
2220 }
2221}
2222
2223async fn run_tls_acceptor(
2231 listener: TcpListener,
2232 acceptor: tokio_rustls::TlsAcceptor,
2233 default_role: String,
2234 tx: mpsc::Sender<(AuthenticatedTlsStream, SocketAddr)>,
2235 handshake_timeout: Duration,
2236 max_concurrent_handshakes: usize,
2237) {
2238 let inflight = Arc::new(Semaphore::new(max_concurrent_handshakes));
2239 loop {
2240 let Ok(permit) = Arc::clone(&inflight).acquire_owned().await else {
2244 return;
2246 };
2247 let (stream, addr) = match listener.accept().await {
2248 Ok(pair) => pair,
2249 Err(e) => {
2250 tracing::debug!("TCP accept error: {e}");
2251 continue;
2252 }
2253 };
2254 if tx.is_closed() {
2255 return;
2257 }
2258 let acceptor = acceptor.clone();
2259 let default_role = default_role.clone();
2260 let tx = tx.clone();
2261 tokio::spawn(async move {
2262 let _permit = permit;
2263 match tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
2264 Ok(Ok(tls_stream)) => {
2265 let identity =
2266 TlsListener::extract_handshake_identity(&tls_stream, &default_role, addr);
2267 let wrapped = AuthenticatedTlsStream {
2268 inner: tls_stream,
2269 identity,
2270 };
2271 let _ = tx.send((wrapped, addr)).await;
2274 }
2275 Ok(Err(e)) => {
2276 tracing::debug!("TLS handshake failed from {addr}: {e}");
2277 }
2278 Err(_elapsed) => {
2279 tracing::debug!(
2280 "TLS handshake timed out from {addr} after {handshake_timeout:?}"
2281 );
2282 }
2283 }
2284 });
2285 }
2286}
2287
2288pub(crate) struct AuthenticatedTlsStream {
2300 inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2301 identity: Option<AuthIdentity>,
2302}
2303
2304impl AuthenticatedTlsStream {
2305 #[must_use]
2307 pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
2308 self.identity.as_ref()
2309 }
2310}
2311
2312impl std::fmt::Debug for AuthenticatedTlsStream {
2313 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2314 f.debug_struct("AuthenticatedTlsStream")
2315 .field("identity", &self.identity.as_ref().map(|id| &id.name))
2316 .finish_non_exhaustive()
2317 }
2318}
2319
2320impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2321 fn poll_read(
2322 mut self: Pin<&mut Self>,
2323 cx: &mut std::task::Context<'_>,
2324 buf: &mut tokio::io::ReadBuf<'_>,
2325 ) -> std::task::Poll<std::io::Result<()>> {
2326 Pin::new(&mut self.inner).poll_read(cx, buf)
2327 }
2328}
2329
2330impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2331 fn poll_write(
2332 mut self: Pin<&mut Self>,
2333 cx: &mut std::task::Context<'_>,
2334 buf: &[u8],
2335 ) -> std::task::Poll<std::io::Result<usize>> {
2336 Pin::new(&mut self.inner).poll_write(cx, buf)
2337 }
2338
2339 fn poll_flush(
2340 mut self: Pin<&mut Self>,
2341 cx: &mut std::task::Context<'_>,
2342 ) -> std::task::Poll<std::io::Result<()>> {
2343 Pin::new(&mut self.inner).poll_flush(cx)
2344 }
2345
2346 fn poll_shutdown(
2347 mut self: Pin<&mut Self>,
2348 cx: &mut std::task::Context<'_>,
2349 ) -> std::task::Poll<std::io::Result<()>> {
2350 Pin::new(&mut self.inner).poll_shutdown(cx)
2351 }
2352
2353 fn poll_write_vectored(
2354 mut self: Pin<&mut Self>,
2355 cx: &mut std::task::Context<'_>,
2356 bufs: &[std::io::IoSlice<'_>],
2357 ) -> std::task::Poll<std::io::Result<usize>> {
2358 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2359 }
2360
2361 fn is_write_vectored(&self) -> bool {
2362 self.inner.is_write_vectored()
2363 }
2364}
2365
2366impl axum::serve::Listener for TlsListener {
2367 type Io = AuthenticatedTlsStream;
2368 type Addr = SocketAddr;
2369
2370 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2376 if let Some(pair) = self.rx.recv().await {
2377 return pair;
2378 }
2379 tracing::error!("TLS acceptor task terminated; no further connections will be accepted");
2385 std::future::pending().await
2386 }
2387
2388 fn local_addr(&self) -> std::io::Result<Self::Addr> {
2389 Ok(self.local_addr)
2390 }
2391}
2392
2393impl Drop for TlsListener {
2394 fn drop(&mut self) {
2395 self.acceptor_task.abort();
2398 }
2399}
2400
2401fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2402 use rustls::pki_types::pem::PemObject;
2403 let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2404 .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2405 .collect::<Result<_, _>>()
2406 .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2407 anyhow::ensure!(
2408 !certs.is_empty(),
2409 "no certificates found in {}",
2410 path.display()
2411 );
2412 Ok(certs)
2413}
2414
2415fn load_client_auth_roots(
2416 path: &Path,
2417) -> anyhow::Result<(
2418 Vec<rustls::pki_types::CertificateDer<'static>>,
2419 Arc<RootCertStore>,
2420)> {
2421 let ca_certs = load_certs(path)?;
2422 let mut root_store = RootCertStore::empty();
2423 for cert in &ca_certs {
2424 root_store
2425 .add(cert.clone())
2426 .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2427 }
2428
2429 Ok((ca_certs, Arc::new(root_store)))
2430}
2431
2432fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2433 use rustls::pki_types::pem::PemObject;
2434 rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2435 .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2436}
2437
2438#[allow(
2439 clippy::unused_async,
2440 reason = "axum route handler signature requires `async fn` even when the body is synchronous"
2441)]
2442async fn healthz() -> impl IntoResponse {
2443 axum::Json(serde_json::json!({
2444 "status": "ok",
2445 }))
2446}
2447
2448fn version_payload(name: &str, version: &str) -> serde_json::Value {
2455 serde_json::json!({
2456 "name": name,
2457 "version": version,
2458 "build_git_sha": option_env!("RMCP_SERVER_KIT_BUILD_SHA").unwrap_or("unknown"),
2459 "build_timestamp": option_env!("RMCP_SERVER_KIT_BUILD_TIME").unwrap_or("unknown"),
2460 "rust_version": option_env!("RMCP_SERVER_KIT_RUSTC_VERSION").unwrap_or("unknown"),
2461 "mcpx_version": env!("CARGO_PKG_VERSION"),
2462 })
2463}
2464
2465fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2475 let value = version_payload(name, version);
2476 serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2477}
2478
2479async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2480 let status = check().await;
2481 let ready = status
2482 .get("ready")
2483 .and_then(serde_json::Value::as_bool)
2484 .unwrap_or(false);
2485 let code = if ready {
2486 axum::http::StatusCode::OK
2487 } else {
2488 axum::http::StatusCode::SERVICE_UNAVAILABLE
2489 };
2490 (code, axum::Json(status))
2491}
2492
2493async fn shutdown_signal() {
2497 let ctrl_c = tokio::signal::ctrl_c();
2498
2499 #[cfg(unix)]
2500 {
2501 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2502 Ok(mut term) => {
2503 tokio::select! {
2504 _ = ctrl_c => {}
2505 _ = term.recv() => {}
2506 }
2507 }
2508 Err(e) => {
2509 tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2510 ctrl_c.await.ok();
2511 }
2512 }
2513 }
2514
2515 #[cfg(not(unix))]
2516 {
2517 ctrl_c.await.ok();
2518 }
2519}
2520
2521#[cfg(feature = "metrics")]
2527async fn metrics_middleware(
2528 metrics: Arc<crate::metrics::McpMetrics>,
2529 req: Request<Body>,
2530 next: Next,
2531) -> axum::response::Response {
2532 let method = req.method().to_string();
2533 let path = req.uri().path().to_owned();
2534 let start = std::time::Instant::now();
2535
2536 let response = next.run(req).await;
2537
2538 let status = response.status().as_u16().to_string();
2539 let duration = start.elapsed().as_secs_f64();
2540
2541 metrics
2542 .http_requests_total
2543 .with_label_values(&[&method, &path, &status])
2544 .inc();
2545 metrics
2546 .http_request_duration_seconds
2547 .with_label_values(&[&method, &path])
2548 .observe(duration);
2549
2550 response
2551}
2552
2553async fn security_headers_middleware(
2565 is_tls: bool,
2566 cfg: Arc<SecurityHeadersConfig>,
2567 req: Request<Body>,
2568 next: Next,
2569) -> axum::response::Response {
2570 use axum::http::{HeaderName, header};
2571
2572 let mut resp = next.run(req).await;
2573 let headers = resp.headers_mut();
2574
2575 headers.remove(header::SERVER);
2577 headers.remove(HeaderName::from_static("x-powered-by"));
2578
2579 apply_security_header(
2580 headers,
2581 header::X_CONTENT_TYPE_OPTIONS,
2582 cfg.x_content_type_options.as_deref(),
2583 "nosniff",
2584 );
2585 apply_security_header(
2586 headers,
2587 header::X_FRAME_OPTIONS,
2588 cfg.x_frame_options.as_deref(),
2589 "deny",
2590 );
2591 apply_security_header(
2592 headers,
2593 header::CACHE_CONTROL,
2594 cfg.cache_control.as_deref(),
2595 "no-store, max-age=0",
2596 );
2597 apply_security_header(
2598 headers,
2599 header::REFERRER_POLICY,
2600 cfg.referrer_policy.as_deref(),
2601 "no-referrer",
2602 );
2603 apply_security_header(
2604 headers,
2605 HeaderName::from_static("cross-origin-opener-policy"),
2606 cfg.cross_origin_opener_policy.as_deref(),
2607 "same-origin",
2608 );
2609 apply_security_header(
2610 headers,
2611 HeaderName::from_static("cross-origin-resource-policy"),
2612 cfg.cross_origin_resource_policy.as_deref(),
2613 "same-origin",
2614 );
2615 apply_security_header(
2616 headers,
2617 HeaderName::from_static("cross-origin-embedder-policy"),
2618 cfg.cross_origin_embedder_policy.as_deref(),
2619 "require-corp",
2620 );
2621 apply_security_header(
2622 headers,
2623 HeaderName::from_static("permissions-policy"),
2624 cfg.permissions_policy.as_deref(),
2625 "accelerometer=(), camera=(), geolocation=(), microphone=()",
2626 );
2627 apply_security_header(
2628 headers,
2629 HeaderName::from_static("x-permitted-cross-domain-policies"),
2630 cfg.x_permitted_cross_domain_policies.as_deref(),
2631 "none",
2632 );
2633 apply_security_header(
2634 headers,
2635 HeaderName::from_static("content-security-policy"),
2636 cfg.content_security_policy.as_deref(),
2637 "default-src 'none'; frame-ancestors 'none'",
2638 );
2639 apply_security_header(
2640 headers,
2641 HeaderName::from_static("x-dns-prefetch-control"),
2642 cfg.x_dns_prefetch_control.as_deref(),
2643 "off",
2644 );
2645
2646 if is_tls {
2647 apply_security_header(
2648 headers,
2649 header::STRICT_TRANSPORT_SECURITY,
2650 cfg.strict_transport_security.as_deref(),
2651 "max-age=63072000; includeSubDomains",
2652 );
2653 }
2654
2655 resp
2656}
2657
2658fn apply_security_header(
2669 headers: &mut axum::http::HeaderMap,
2670 name: axum::http::HeaderName,
2671 override_value: Option<&str>,
2672 default: &'static str,
2673) {
2674 use axum::http::HeaderValue;
2675
2676 match override_value {
2677 None => {
2678 headers.insert(name, HeaderValue::from_static(default));
2679 }
2680 Some("") => {
2681 }
2683 Some(v) => match HeaderValue::from_str(v) {
2684 Ok(hv) => {
2685 headers.insert(name, hv);
2686 }
2687 Err(err) => {
2688 tracing::error!(
2689 header = %name,
2690 error = %err,
2691 "invalid security header override reached middleware; using default"
2692 );
2693 headers.insert(name, HeaderValue::from_static(default));
2694 }
2695 },
2696 }
2697}
2698
2699fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2710 use axum::http::HeaderValue;
2711
2712 let fields: &[(&str, Option<&str>)] = &[
2713 (
2714 "x_content_type_options",
2715 cfg.x_content_type_options.as_deref(),
2716 ),
2717 ("x_frame_options", cfg.x_frame_options.as_deref()),
2718 ("cache_control", cfg.cache_control.as_deref()),
2719 ("referrer_policy", cfg.referrer_policy.as_deref()),
2720 (
2721 "cross_origin_opener_policy",
2722 cfg.cross_origin_opener_policy.as_deref(),
2723 ),
2724 (
2725 "cross_origin_resource_policy",
2726 cfg.cross_origin_resource_policy.as_deref(),
2727 ),
2728 (
2729 "cross_origin_embedder_policy",
2730 cfg.cross_origin_embedder_policy.as_deref(),
2731 ),
2732 ("permissions_policy", cfg.permissions_policy.as_deref()),
2733 (
2734 "x_permitted_cross_domain_policies",
2735 cfg.x_permitted_cross_domain_policies.as_deref(),
2736 ),
2737 (
2738 "content_security_policy",
2739 cfg.content_security_policy.as_deref(),
2740 ),
2741 (
2742 "x_dns_prefetch_control",
2743 cfg.x_dns_prefetch_control.as_deref(),
2744 ),
2745 (
2746 "strict_transport_security",
2747 cfg.strict_transport_security.as_deref(),
2748 ),
2749 ];
2750
2751 for (field, value) in fields {
2752 let Some(v) = value else { continue };
2753 if v.is_empty() {
2754 continue;
2755 }
2756 if let Err(err) = HeaderValue::from_str(v) {
2757 return Err(McpxError::Config(format!(
2758 "invalid security_headers.{field}: {err}"
2759 )));
2760 }
2761 }
2762
2763 if let Some(v) = cfg.strict_transport_security.as_deref()
2764 && !v.is_empty()
2765 && v.to_ascii_lowercase().contains("preload")
2766 {
2767 return Err(McpxError::Config(format!(
2768 "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2769 HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2770 )));
2771 }
2772
2773 Ok(())
2774}
2775
2776#[cfg(feature = "oauth")]
2791async fn oauth_token_cache_headers_middleware(
2792 req: Request<Body>,
2793 next: Next,
2794) -> axum::response::Response {
2795 use axum::http::{HeaderValue, header};
2796
2797 let mut resp = next.run(req).await;
2798 let headers = resp.headers_mut();
2799 headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2800 headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2801 resp
2802}
2803
2804async fn normalize_peer_addr_middleware(
2828 mut req: Request<Body>,
2829 next: Next,
2830) -> axum::response::Response {
2831 let direct = req
2832 .extensions()
2833 .get::<ConnectInfo<SocketAddr>>()
2834 .map(|ci| ci.0);
2835 let from_tls = req
2836 .extensions()
2837 .get::<ConnectInfo<TlsConnInfo>>()
2838 .map(|ci| ci.0.addr);
2839 if let Some(addr) = direct.or(from_tls) {
2840 if direct.is_none() {
2841 req.extensions_mut().insert(ConnectInfo(addr));
2842 }
2843 req.extensions_mut().insert(PeerAddr::new(addr));
2844 }
2845 next.run(req).await
2846}
2847
2848async fn origin_check_middleware(
2852 allowed: Arc<[String]>,
2853 log_request_headers: bool,
2854 req: Request<Body>,
2855 next: Next,
2856) -> axum::response::Response {
2857 let method = req.method().clone();
2858 let path = req.uri().path().to_owned();
2859
2860 log_incoming_request(&method, &path, req.headers(), log_request_headers);
2861
2862 if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2863 let origin_str = origin.to_str().unwrap_or("");
2864 if !allowed.iter().any(|a| a == origin_str) {
2865 tracing::warn!(
2866 origin = origin_str,
2867 %method,
2868 %path,
2869 allowed = ?&*allowed,
2870 "rejected request: Origin not allowed"
2871 );
2872 return (
2873 axum::http::StatusCode::FORBIDDEN,
2874 "Forbidden: Origin not allowed",
2875 )
2876 .into_response();
2877 }
2878 }
2879 next.run(req).await
2880}
2881
2882fn log_incoming_request(
2885 method: &axum::http::Method,
2886 path: &str,
2887 headers: &axum::http::HeaderMap,
2888 log_request_headers: bool,
2889) {
2890 if log_request_headers {
2891 tracing::debug!(
2892 %method,
2893 %path,
2894 headers = %format_request_headers_for_log(headers),
2895 "incoming request"
2896 );
2897 } else {
2898 tracing::debug!(%method, %path, "incoming request");
2899 }
2900}
2901
2902fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2903 headers
2904 .iter()
2905 .map(|(k, v)| {
2906 let name = k.as_str();
2907 if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2908 format!("{name}: [REDACTED]")
2909 } else {
2910 format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2911 }
2912 })
2913 .collect::<Vec<_>>()
2914 .join(", ")
2915}
2916
2917#[allow(
2941 clippy::cognitive_complexity,
2942 reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
2943)]
2944pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2945where
2946 H: ServerHandler + 'static,
2947{
2948 use rmcp::ServiceExt as _;
2949
2950 tracing::info!("stdio transport: serving on stdin/stdout");
2951 tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2952
2953 let transport = rmcp::transport::io::stdio();
2954
2955 let service = handler
2956 .serve(transport)
2957 .await
2958 .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2959
2960 if let Err(e) = service.waiting().await {
2961 tracing::warn!(error = %e, "stdio session ended with error");
2962 }
2963 tracing::info!("stdio session ended");
2964 Ok(())
2965}
2966
2967#[cfg(test)]
2968mod tests {
2969 #![allow(
2970 clippy::unwrap_used,
2971 clippy::expect_used,
2972 clippy::panic,
2973 clippy::indexing_slicing,
2974 clippy::unwrap_in_result,
2975 clippy::print_stdout,
2976 clippy::print_stderr,
2977 deprecated,
2978 reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2979 )]
2980 use std::{sync::Arc, time::Duration};
2981
2982 use axum::{
2983 body::Body,
2984 http::{Request, StatusCode, header},
2985 response::IntoResponse,
2986 };
2987 use http_body_util::BodyExt;
2988 use tower::ServiceExt as _;
2989
2990 use super::*;
2991
2992 #[test]
2995 fn server_config_new_defaults() {
2996 let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2997 assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2998 assert_eq!(cfg.name, "test-server");
2999 assert_eq!(cfg.version, "1.0.0");
3000 assert!(cfg.tls_cert_path.is_none());
3001 assert!(cfg.tls_key_path.is_none());
3002 assert!(cfg.auth.is_none());
3003 assert!(cfg.rbac.is_none());
3004 assert!(cfg.allowed_origins.is_empty());
3005 assert!(cfg.tool_rate_limit.is_none());
3006 assert!(cfg.readiness_check.is_none());
3007 assert_eq!(cfg.max_request_body, 1024 * 1024);
3008 assert_eq!(cfg.request_timeout, Duration::from_mins(2));
3009 assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
3010 assert!(!cfg.log_request_headers);
3011 assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(10));
3012 assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
3013 }
3014
3015 #[test]
3016 fn tls_handshake_builders_set_fields() {
3017 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3018 .with_tls_handshake_timeout(Duration::from_secs(3))
3019 .with_max_concurrent_tls_handshakes(64);
3020 assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(3));
3021 assert_eq!(cfg.max_concurrent_tls_handshakes, 64);
3022 }
3023
3024 #[test]
3025 fn validate_rejects_zero_tls_handshake_timeout() {
3026 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3027 .with_tls_handshake_timeout(Duration::ZERO);
3028 let err = cfg.validate().expect_err("zero handshake timeout");
3029 assert!(err.to_string().contains("tls_handshake_timeout"));
3030 }
3031
3032 #[test]
3033 fn validate_rejects_zero_max_concurrent_tls_handshakes() {
3034 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3035 .with_max_concurrent_tls_handshakes(0);
3036 let err = cfg.validate().expect_err("zero handshake concurrency");
3037 assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
3038 }
3039
3040 #[test]
3041 fn validate_consumes_and_proves() {
3042 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3044 let validated = cfg.validate().expect("valid config");
3045 assert_eq!(validated.as_inner().name, "test-server");
3047 let raw = validated.into_inner();
3049 assert_eq!(raw.name, "test-server");
3050
3051 let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3053 bad.max_request_body = 0;
3054 assert!(bad.validate().is_err(), "zero body cap must fail validate");
3055 }
3056
3057 #[test]
3058 fn validate_rejects_zero_max_concurrent_requests() {
3059 let cfg =
3060 McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
3061 let err = cfg.validate().expect_err("zero concurrency cap must fail");
3062 assert!(
3063 format!("{err}").contains("max_concurrent_requests"),
3064 "error should mention max_concurrent_requests, got: {err}"
3065 );
3066 }
3067
3068 #[test]
3069 fn validate_rejects_zero_max_tracked_keys() {
3070 let rl = crate::auth::RateLimitConfig {
3073 max_attempts_per_minute: 30,
3074 pre_auth_max_per_minute: None,
3075 max_tracked_keys: 0,
3076 idle_eviction: Duration::from_secs(15 * 60),
3077 };
3078 let auth_cfg = AuthConfig {
3079 enabled: true,
3080 api_keys: Vec::new(),
3081 mtls: None,
3082 rate_limit: Some(rl),
3083 #[cfg(feature = "oauth")]
3084 oauth: None,
3085 };
3086 let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
3087 let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
3088 assert!(
3089 format!("{err}").contains("max_tracked_keys"),
3090 "error should mention max_tracked_keys, got: {err}"
3091 );
3092 }
3093
3094 #[test]
3095 fn derive_allowed_hosts_includes_public_host() {
3096 let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
3097 assert!(
3098 hosts.iter().any(|h| h == "mcp.example.com"),
3099 "public_url host must be allowed"
3100 );
3101 }
3102
3103 #[test]
3104 fn derive_allowed_hosts_includes_bind_authority() {
3105 let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
3106 assert!(
3107 hosts.iter().any(|h| h == "127.0.0.1"),
3108 "bind host must be allowed"
3109 );
3110 assert!(
3111 hosts.iter().any(|h| h == "127.0.0.1:8080"),
3112 "bind authority must be allowed"
3113 );
3114 }
3115
3116 #[tokio::test]
3119 async fn healthz_returns_ok_json() {
3120 let resp = healthz().await.into_response();
3121 assert_eq!(resp.status(), StatusCode::OK);
3122 let body = resp.into_body().collect().await.unwrap().to_bytes();
3123 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3124 assert_eq!(json["status"], "ok");
3125 assert!(
3126 json.get("name").is_none(),
3127 "healthz must not expose server name"
3128 );
3129 assert!(
3130 json.get("version").is_none(),
3131 "healthz must not expose version"
3132 );
3133 }
3134
3135 #[tokio::test]
3138 async fn readyz_returns_ok_when_ready() {
3139 let check: ReadinessCheck =
3140 Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
3141 let resp = readyz(check).await.into_response();
3142 assert_eq!(resp.status(), StatusCode::OK);
3143 let body = resp.into_body().collect().await.unwrap().to_bytes();
3144 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3145 assert_eq!(json["ready"], true);
3146 assert!(
3147 json.get("name").is_none(),
3148 "readyz must not expose server name"
3149 );
3150 assert!(
3151 json.get("version").is_none(),
3152 "readyz must not expose version"
3153 );
3154 assert_eq!(json["db"], "connected");
3155 }
3156
3157 #[tokio::test]
3158 async fn readyz_returns_503_when_not_ready() {
3159 let check: ReadinessCheck =
3160 Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
3161 let resp = readyz(check).await.into_response();
3162 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3163 }
3164
3165 #[tokio::test]
3166 async fn readyz_returns_503_when_ready_missing() {
3167 let check: ReadinessCheck =
3168 Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
3169 let resp = readyz(check).await.into_response();
3170 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3172 }
3173
3174 fn peer_probe_router() -> axum::Router {
3179 async fn probe(req: Request<Body>) -> String {
3180 let ci = req
3181 .extensions()
3182 .get::<ConnectInfo<SocketAddr>>()
3183 .map(|c| c.0.to_string())
3184 .unwrap_or_default();
3185 let pa = req
3186 .extensions()
3187 .get::<PeerAddr>()
3188 .map(|p| p.addr.to_string())
3189 .unwrap_or_default();
3190 format!("{ci}|{pa}")
3191 }
3192 axum::Router::new()
3193 .route("/probe", axum::routing::get(probe))
3194 .layer(axum::middleware::from_fn(normalize_peer_addr_middleware))
3195 }
3196
3197 async fn body_string(resp: axum::response::Response) -> String {
3198 let bytes = resp.into_body().collect().await.unwrap().to_bytes();
3199 String::from_utf8(bytes.to_vec()).unwrap()
3200 }
3201
3202 #[tokio::test]
3203 async fn normalize_preserves_existing_connect_info_and_mirrors_peer_addr() {
3204 let plain: SocketAddr = "10.0.0.1:1111".parse().unwrap();
3207 let tls: SocketAddr = "10.0.0.2:2222".parse().unwrap();
3208 let req = Request::builder()
3209 .uri("/probe")
3210 .extension(ConnectInfo(plain))
3211 .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3212 .body(Body::empty())
3213 .unwrap();
3214 let resp = peer_probe_router().oneshot(req).await.unwrap();
3215 assert_eq!(resp.status(), StatusCode::OK);
3216 assert_eq!(body_string(resp).await, format!("{plain}|{plain}"));
3217 }
3218
3219 #[tokio::test]
3220 async fn normalize_inserts_connect_info_and_peer_addr_from_tls() {
3221 let tls: SocketAddr = "192.168.1.7:50443".parse().unwrap();
3222 let req = Request::builder()
3223 .uri("/probe")
3224 .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3225 .body(Body::empty())
3226 .unwrap();
3227 let resp = peer_probe_router().oneshot(req).await.unwrap();
3228 assert_eq!(resp.status(), StatusCode::OK);
3229 assert_eq!(body_string(resp).await, format!("{tls}|{tls}"));
3230 }
3231
3232 #[tokio::test]
3233 async fn normalize_no_op_without_any_connect_info() {
3234 let req = Request::builder()
3235 .uri("/probe")
3236 .body(Body::empty())
3237 .unwrap();
3238 let resp = peer_probe_router().oneshot(req).await.unwrap();
3239 assert_eq!(resp.status(), StatusCode::OK);
3240 assert_eq!(body_string(resp).await, "|");
3241 }
3242
3243 #[tokio::test]
3244 async fn peer_addr_extractor_rejects_when_absent() {
3245 async fn h(peer: PeerAddr) -> String {
3246 peer.addr.to_string()
3247 }
3248 let app = axum::Router::new().route("/p", axum::routing::get(h));
3249 let req = Request::builder().uri("/p").body(Body::empty()).unwrap();
3250 let resp = app.oneshot(req).await.unwrap();
3251 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
3252 }
3253
3254 #[tokio::test]
3255 async fn peer_addr_extractor_returns_value_when_present() {
3256 async fn h(peer: PeerAddr) -> String {
3257 peer.addr.to_string()
3258 }
3259 let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap();
3260 let app = axum::Router::new().route("/p", axum::routing::get(h));
3261 let req = Request::builder()
3262 .uri("/p")
3263 .extension(PeerAddr::new(addr))
3264 .body(Body::empty())
3265 .unwrap();
3266 let resp = app.oneshot(req).await.unwrap();
3267 assert_eq!(resp.status(), StatusCode::OK);
3268 assert_eq!(body_string(resp).await, addr.to_string());
3269 }
3270
3271 #[tokio::test]
3272 async fn peer_addr_via_extension_extractor() {
3273 async fn h(axum::Extension(peer): axum::Extension<PeerAddr>) -> String {
3274 peer.addr.to_string()
3275 }
3276 let addr: SocketAddr = "127.0.0.1:4242".parse().unwrap();
3277 let app = axum::Router::new().route("/p", axum::routing::get(h));
3278 let req = Request::builder()
3279 .uri("/p")
3280 .extension(PeerAddr::new(addr))
3281 .body(Body::empty())
3282 .unwrap();
3283 let resp = app.oneshot(req).await.unwrap();
3284 assert_eq!(resp.status(), StatusCode::OK);
3285 assert_eq!(body_string(resp).await, addr.to_string());
3286 }
3287
3288 fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
3292 let allowed: Arc<[String]> = Arc::from(origins);
3293 axum::Router::new()
3294 .route("/test", axum::routing::get(|| async { "ok" }))
3295 .layer(axum::middleware::from_fn(move |req, next| {
3296 let a = Arc::clone(&allowed);
3297 origin_check_middleware(a, log_request_headers, req, next)
3298 }))
3299 }
3300
3301 #[tokio::test]
3302 async fn origin_allowed_passes() {
3303 let app = origin_router(vec!["http://localhost:3000".into()], false);
3304 let req = Request::builder()
3305 .uri("/test")
3306 .header(header::ORIGIN, "http://localhost:3000")
3307 .body(Body::empty())
3308 .unwrap();
3309 let resp = app.oneshot(req).await.unwrap();
3310 assert_eq!(resp.status(), StatusCode::OK);
3311 }
3312
3313 #[tokio::test]
3314 async fn origin_rejected_returns_403() {
3315 let app = origin_router(vec!["http://localhost:3000".into()], false);
3316 let req = Request::builder()
3317 .uri("/test")
3318 .header(header::ORIGIN, "http://evil.com")
3319 .body(Body::empty())
3320 .unwrap();
3321 let resp = app.oneshot(req).await.unwrap();
3322 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
3323 }
3324
3325 #[tokio::test]
3326 async fn no_origin_header_passes() {
3327 let app = origin_router(vec!["http://localhost:3000".into()], false);
3328 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3329 let resp = app.oneshot(req).await.unwrap();
3330 assert_eq!(resp.status(), StatusCode::OK);
3331 }
3332
3333 #[tokio::test]
3334 async fn empty_allowlist_rejects_any_origin() {
3335 let app = origin_router(vec![], false);
3336 let req = Request::builder()
3337 .uri("/test")
3338 .header(header::ORIGIN, "http://anything.com")
3339 .body(Body::empty())
3340 .unwrap();
3341 let resp = app.oneshot(req).await.unwrap();
3342 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
3343 }
3344
3345 #[tokio::test]
3346 async fn empty_allowlist_passes_without_origin() {
3347 let app = origin_router(vec![], false);
3348 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3349 let resp = app.oneshot(req).await.unwrap();
3350 assert_eq!(resp.status(), StatusCode::OK);
3351 }
3352
3353 #[test]
3354 fn format_request_headers_redacts_sensitive_values() {
3355 let mut headers = axum::http::HeaderMap::new();
3356 headers.insert("authorization", "Bearer secret-token".parse().unwrap());
3357 headers.insert("cookie", "sid=abc".parse().unwrap());
3358 headers.insert("x-request-id", "req-123".parse().unwrap());
3359
3360 let out = format_request_headers_for_log(&headers);
3361 assert!(out.contains("authorization: [REDACTED]"));
3362 assert!(out.contains("cookie: [REDACTED]"));
3363 assert!(out.contains("x-request-id: req-123"));
3364 assert!(!out.contains("secret-token"));
3365 }
3366
3367 fn security_router(is_tls: bool) -> axum::Router {
3370 security_router_with(is_tls, SecurityHeadersConfig::default())
3371 }
3372
3373 fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
3374 let cfg = Arc::new(cfg);
3375 axum::Router::new()
3376 .route("/test", axum::routing::get(|| async { "ok" }))
3377 .layer(axum::middleware::from_fn(move |req, next| {
3378 let c = Arc::clone(&cfg);
3379 security_headers_middleware(is_tls, c, req, next)
3380 }))
3381 }
3382
3383 #[tokio::test]
3384 async fn security_headers_set_on_response() {
3385 let app = security_router(false);
3386 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3387 let resp = app.oneshot(req).await.unwrap();
3388 assert_eq!(resp.status(), StatusCode::OK);
3389
3390 let h = resp.headers();
3391 assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
3392 assert_eq!(h.get("x-frame-options").unwrap(), "deny");
3393 assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
3394 assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
3395 assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
3396 assert_eq!(
3397 h.get("cross-origin-resource-policy").unwrap(),
3398 "same-origin"
3399 );
3400 assert_eq!(
3401 h.get("cross-origin-embedder-policy").unwrap(),
3402 "require-corp"
3403 );
3404 assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
3405 assert!(
3406 h.get("permissions-policy")
3407 .unwrap()
3408 .to_str()
3409 .unwrap()
3410 .contains("camera=()"),
3411 "permissions-policy must restrict browser features"
3412 );
3413 assert_eq!(
3414 h.get("content-security-policy").unwrap(),
3415 "default-src 'none'; frame-ancestors 'none'"
3416 );
3417 assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
3418 assert!(h.get("strict-transport-security").is_none());
3420 }
3421
3422 #[tokio::test]
3423 async fn hsts_set_when_tls_enabled() {
3424 let app = security_router(true);
3425 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3426 let resp = app.oneshot(req).await.unwrap();
3427
3428 let hsts = resp.headers().get("strict-transport-security").unwrap();
3429 assert!(
3430 hsts.to_str().unwrap().contains("max-age=63072000"),
3431 "HSTS must set 2-year max-age"
3432 );
3433 }
3434
3435 fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
3441 let cfg =
3442 McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
3443 cfg.check()
3444 }
3445
3446 #[test]
3447 fn security_headers_config_default_validates() {
3448 check_with_security_headers(SecurityHeadersConfig::default())
3449 .expect("default SecurityHeadersConfig must validate");
3450 }
3451
3452 #[test]
3453 fn security_headers_config_validate_accepts_empty_string() {
3454 let h = SecurityHeadersConfig {
3456 x_content_type_options: Some(String::new()),
3457 x_frame_options: Some(String::new()),
3458 cache_control: Some(String::new()),
3459 referrer_policy: Some(String::new()),
3460 cross_origin_opener_policy: Some(String::new()),
3461 cross_origin_resource_policy: Some(String::new()),
3462 cross_origin_embedder_policy: Some(String::new()),
3463 permissions_policy: Some(String::new()),
3464 x_permitted_cross_domain_policies: Some(String::new()),
3465 content_security_policy: Some(String::new()),
3466 x_dns_prefetch_control: Some(String::new()),
3467 strict_transport_security: Some(String::new()),
3468 };
3469 check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
3470 }
3471
3472 #[test]
3473 fn security_headers_config_validate_rejects_bad_value() {
3474 let h = SecurityHeadersConfig {
3476 referrer_policy: Some("\u{0007}".into()),
3477 ..SecurityHeadersConfig::default()
3478 };
3479 let err = check_with_security_headers(h)
3480 .expect_err("control char in referrer_policy must reject");
3481 let msg = err.to_string();
3482 assert!(
3483 msg.contains("referrer_policy"),
3484 "error must name the offending field, got: {msg}"
3485 );
3486 }
3487
3488 #[test]
3489 fn security_headers_config_validate_rejects_hsts_preload() {
3490 let h = SecurityHeadersConfig {
3491 strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
3492 ..SecurityHeadersConfig::default()
3493 };
3494 let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
3495 let msg = err.to_string();
3496 assert!(
3497 msg.contains("strict_transport_security"),
3498 "error must name the field, got: {msg}"
3499 );
3500 assert!(
3501 msg.to_lowercase().contains("preload"),
3502 "error must mention `preload`, got: {msg}"
3503 );
3504 }
3505
3506 #[test]
3507 fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
3508 let h = SecurityHeadersConfig {
3510 strict_transport_security: Some("max-age=600; PRELOAD".into()),
3511 ..SecurityHeadersConfig::default()
3512 };
3513 check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
3514 }
3515
3516 #[tokio::test]
3517 async fn security_headers_override_honored() {
3518 let h = SecurityHeadersConfig {
3520 x_frame_options: Some("SAMEORIGIN".into()),
3521 ..SecurityHeadersConfig::default()
3522 };
3523 let app = security_router_with(false, h);
3524 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3525 let resp = app.oneshot(req).await.unwrap();
3526 assert_eq!(resp.status(), StatusCode::OK);
3527
3528 let xfo = resp.headers().get("x-frame-options").unwrap();
3529 assert_eq!(xfo, "SAMEORIGIN");
3530 }
3531
3532 #[tokio::test]
3533 async fn security_headers_empty_string_omits() {
3534 let h = SecurityHeadersConfig {
3536 referrer_policy: Some(String::new()),
3537 ..SecurityHeadersConfig::default()
3538 };
3539 let app = security_router_with(false, h);
3540 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3541 let resp = app.oneshot(req).await.unwrap();
3542 assert_eq!(resp.status(), StatusCode::OK);
3543
3544 assert!(
3545 resp.headers().get("referrer-policy").is_none(),
3546 "Some(\"\") must omit the header"
3547 );
3548 assert_eq!(
3550 resp.headers().get("x-content-type-options").unwrap(),
3551 "nosniff"
3552 );
3553 }
3554
3555 #[tokio::test]
3556 async fn security_headers_hsts_only_when_tls() {
3557 let h = SecurityHeadersConfig {
3559 strict_transport_security: Some("max-age=600".into()),
3560 ..SecurityHeadersConfig::default()
3561 };
3562 let app = security_router_with(false, h);
3563 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3564 let resp = app.oneshot(req).await.unwrap();
3565 assert!(
3566 resp.headers().get("strict-transport-security").is_none(),
3567 "HSTS must remain absent on plaintext deployments even with override"
3568 );
3569 }
3570
3571 #[cfg(feature = "oauth")]
3574 #[tokio::test]
3575 async fn oauth_token_cache_headers_set_pragma_and_vary() {
3576 let app = axum::Router::new()
3577 .route("/token", axum::routing::post(|| async { "{}" }))
3578 .layer(axum::middleware::from_fn(
3579 oauth_token_cache_headers_middleware,
3580 ));
3581 let req = Request::builder()
3582 .method("POST")
3583 .uri("/token")
3584 .body(Body::from("{}"))
3585 .unwrap();
3586 let resp = app.oneshot(req).await.unwrap();
3587 assert_eq!(resp.status(), StatusCode::OK);
3588
3589 let h = resp.headers();
3590 assert_eq!(
3591 h.get("pragma").unwrap(),
3592 "no-cache",
3593 "RFC 6749 §5.1: token responses must set Pragma: no-cache"
3594 );
3595 let vary_values: Vec<String> = h
3596 .get_all("vary")
3597 .iter()
3598 .filter_map(|v| v.to_str().ok().map(str::to_owned))
3599 .collect();
3600 assert!(
3601 vary_values
3602 .iter()
3603 .any(|v| v.eq_ignore_ascii_case("Authorization")),
3604 "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
3605 );
3606 }
3607
3608 #[cfg(feature = "oauth")]
3609 #[tokio::test]
3610 async fn oauth_token_cache_headers_preserve_existing_vary() {
3611 let app = axum::Router::new()
3614 .route(
3615 "/token",
3616 axum::routing::post(|| async {
3617 axum::response::Response::builder()
3618 .header("vary", "Accept-Encoding")
3619 .body(axum::body::Body::from("{}"))
3620 .unwrap()
3621 }),
3622 )
3623 .layer(axum::middleware::from_fn(
3624 oauth_token_cache_headers_middleware,
3625 ));
3626 let req = Request::builder()
3627 .method("POST")
3628 .uri("/token")
3629 .body(Body::empty())
3630 .unwrap();
3631 let resp = app.oneshot(req).await.unwrap();
3632
3633 let vary: Vec<String> = resp
3634 .headers()
3635 .get_all("vary")
3636 .iter()
3637 .filter_map(|v| v.to_str().ok().map(str::to_owned))
3638 .collect();
3639 assert!(
3640 vary.iter().any(|v| v.contains("Accept-Encoding")),
3641 "must preserve pre-existing Vary value, got {vary:?}"
3642 );
3643 assert!(
3644 vary.iter().any(|v| v.contains("Authorization")),
3645 "must append Authorization to Vary, got {vary:?}"
3646 );
3647 }
3648
3649 #[test]
3652 fn version_payload_contains_expected_fields() {
3653 let v = version_payload("my-server", "1.2.3");
3654 assert_eq!(v["name"], "my-server");
3655 assert_eq!(v["version"], "1.2.3");
3656 assert!(v["build_git_sha"].is_string());
3657 assert!(v["build_timestamp"].is_string());
3658 assert!(v["rust_version"].is_string());
3659 assert!(v["mcpx_version"].is_string());
3660 }
3661
3662 #[tokio::test]
3665 async fn concurrency_limit_layer_composes_and_serves() {
3666 let app = axum::Router::new()
3670 .route("/ok", axum::routing::get(|| async { "ok" }))
3671 .layer(
3672 tower::ServiceBuilder::new()
3673 .layer(axum::error_handling::HandleErrorLayer::new(
3674 |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
3675 ))
3676 .layer(tower::load_shed::LoadShedLayer::new())
3677 .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
3678 );
3679 let resp = app
3680 .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
3681 .await
3682 .unwrap();
3683 assert_eq!(resp.status(), StatusCode::OK);
3684 }
3685
3686 #[tokio::test]
3689 async fn compression_layer_gzip_encodes_response() {
3690 use tower_http::compression::Predicate as _;
3691
3692 let big_body = "a".repeat(4096);
3693 let app = axum::Router::new()
3694 .route(
3695 "/big",
3696 axum::routing::get(move || {
3697 let body = big_body.clone();
3698 async move { body }
3699 }),
3700 )
3701 .layer(
3702 tower_http::compression::CompressionLayer::new()
3703 .gzip(true)
3704 .br(true)
3705 .compress_when(
3706 tower_http::compression::DefaultPredicate::new()
3707 .and(tower_http::compression::predicate::SizeAbove::new(1024)),
3708 ),
3709 );
3710
3711 let req = Request::builder()
3712 .uri("/big")
3713 .header(header::ACCEPT_ENCODING, "gzip")
3714 .body(Body::empty())
3715 .unwrap();
3716 let resp = app.oneshot(req).await.unwrap();
3717 assert_eq!(resp.status(), StatusCode::OK);
3718 assert_eq!(
3719 resp.headers().get(header::CONTENT_ENCODING).unwrap(),
3720 "gzip"
3721 );
3722 }
3723
3724 #[tokio::test]
3727 async fn tls_handshake_timeout_reaps_idle_connections() {
3728 use tokio::io::AsyncReadExt as _;
3729
3730 let _ = rustls::crypto::ring::default_provider().install_default();
3731
3732 let key = rcgen::KeyPair::generate().expect("generate key");
3734 let cert = rcgen::CertificateParams::new(vec!["localhost".to_owned()])
3735 .expect("cert params")
3736 .self_signed(&key)
3737 .expect("self-signed cert");
3738 let dir = std::env::temp_dir().join(format!(
3739 "rmcp-server-kit-hs-timeout-{}",
3740 std::time::SystemTime::now()
3741 .duration_since(std::time::UNIX_EPOCH)
3742 .expect("clock after epoch")
3743 .as_nanos()
3744 ));
3745 tokio::fs::create_dir_all(&dir).await.expect("temp dir");
3746 let cert_path = dir.join("server.crt");
3747 let key_path = dir.join("server.key");
3748 tokio::fs::write(&cert_path, cert.pem())
3749 .await
3750 .expect("write cert");
3751 tokio::fs::write(&key_path, key.serialize_pem())
3752 .await
3753 .expect("write key");
3754
3755 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
3756 let tls = TlsListener::new(
3757 listener,
3758 &cert_path,
3759 &key_path,
3760 None,
3761 None,
3762 Duration::from_millis(200),
3763 8, )
3765 .expect("tls listener");
3766 let addr = axum::serve::Listener::local_addr(&tls).expect("local addr");
3767
3768 let mut idle = tokio::net::TcpStream::connect(addr).await.expect("connect");
3772 let mut buf = [0_u8; 16];
3773 let read = tokio::time::timeout(Duration::from_secs(2), idle.read(&mut buf))
3774 .await
3775 .expect("server must reap the idle handshake within its timeout");
3776 match read {
3777 Ok(0) | Err(_) => {} Ok(n) => panic!("unexpected {n} bytes from server during reaped handshake"),
3779 }
3780
3781 drop(tls);
3782 }
3783}