1use std::{
2 future::Future,
3 net::{IpAddr, SocketAddr},
4 path::{Path, PathBuf},
5 pin::Pin,
6 sync::Arc,
7 time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{
12 body::Body,
13 extract::{ConnectInfo, Request},
14 middleware::Next,
15 response::IntoResponse,
16};
17use rmcp::{
18 ServerHandler,
19 transport::streamable_http_server::{
20 StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
21 },
22};
23use rustls::RootCertStore;
24use tokio::{
25 net::TcpListener,
26 sync::{Semaphore, mpsc},
27};
28use tokio_util::sync::CancellationToken;
29
30use crate::{
31 auth::{
32 AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
33 build_rate_limiter, extract_mtls_identity,
34 },
35 bounded_limiter::BoundedKeyedLimiter,
36 error::McpxError,
37 mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
38 rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
39};
40
41#[allow(
45 clippy::needless_pass_by_value,
46 reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
47)]
48fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
49 McpxError::Startup(format!("{e:#}"))
50}
51
52#[allow(
58 clippy::needless_pass_by_value,
59 reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
60)]
61fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
62 McpxError::Startup(format!("{op}: {e}"))
63}
64
65pub type ReadinessCheck =
70 Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
71
72#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
117#[non_exhaustive]
118pub struct PeerAddr {
119 pub addr: SocketAddr,
121}
122
123impl PeerAddr {
124 #[must_use]
127 pub(crate) const fn new(addr: SocketAddr) -> Self {
128 Self { addr }
129 }
130}
131
132impl<S: Send + Sync> axum::extract::FromRequestParts<S> for PeerAddr {
141 type Rejection = (axum::http::StatusCode, &'static str);
142
143 async fn from_request_parts(
144 parts: &mut axum::http::request::Parts,
145 _state: &S,
146 ) -> Result<Self, Self::Rejection> {
147 parts.extensions.get::<Self>().copied().ok_or((
148 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
149 "peer address unavailable: not running under rmcp-server-kit serve()",
150 ))
151 }
152}
153
154#[derive(Debug, Clone, Default)]
175#[non_exhaustive]
176pub struct SecurityHeadersConfig {
177 pub x_content_type_options: Option<String>,
179 pub x_frame_options: Option<String>,
181 pub cache_control: Option<String>,
183 pub referrer_policy: Option<String>,
185 pub cross_origin_opener_policy: Option<String>,
187 pub cross_origin_resource_policy: Option<String>,
189 pub cross_origin_embedder_policy: Option<String>,
191 pub permissions_policy: Option<String>,
194 pub x_permitted_cross_domain_policies: Option<String>,
196 pub content_security_policy: Option<String>,
199 pub x_dns_prefetch_control: Option<String>,
201 pub strict_transport_security: Option<String>,
206}
207
208#[allow(
210 missing_debug_implementations,
211 reason = "contains callback/trait objects that don't impl Debug"
212)]
213#[allow(
214 clippy::struct_excessive_bools,
215 reason = "server configuration naturally has many boolean feature flags"
216)]
217#[non_exhaustive]
218pub struct McpServerConfig {
219 #[deprecated(
221 since = "0.13.0",
222 note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
223 )]
224 pub bind_addr: String,
225 #[deprecated(
227 since = "0.13.0",
228 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
229 )]
230 pub name: String,
231 #[deprecated(
233 since = "0.13.0",
234 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
235 )]
236 pub version: String,
237 #[deprecated(
239 since = "0.13.0",
240 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
241 )]
242 pub tls_cert_path: Option<PathBuf>,
243 #[deprecated(
245 since = "0.13.0",
246 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
247 )]
248 pub tls_key_path: Option<PathBuf>,
249 #[deprecated(
252 since = "0.13.0",
253 note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
254 )]
255 pub auth: Option<AuthConfig>,
256 #[deprecated(
259 since = "0.13.0",
260 note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
261 )]
262 pub rbac: Option<Arc<RbacPolicy>>,
263 #[deprecated(
269 since = "0.13.0",
270 note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
271 )]
272 pub allowed_origins: Vec<String>,
273 #[deprecated(
276 since = "0.13.0",
277 note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
278 )]
279 pub tool_rate_limit: Option<u32>,
280 #[deprecated(
293 since = "1.11.0",
294 note = "use McpServerConfig::with_extra_route_rate_limit(); direct field access will become pub(crate) in a future major release"
295 )]
296 pub extra_route_rate_limit: Option<u32>,
297 #[deprecated(
300 since = "0.13.0",
301 note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
302 )]
303 pub readiness_check: Option<ReadinessCheck>,
304 #[deprecated(
307 since = "0.13.0",
308 note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
309 )]
310 pub max_request_body: usize,
311 #[deprecated(
314 since = "0.13.0",
315 note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
316 )]
317 pub request_timeout: Duration,
318 #[deprecated(
321 since = "0.13.0",
322 note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
323 )]
324 pub shutdown_timeout: Duration,
325 #[deprecated(
328 since = "0.13.0",
329 note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
330 )]
331 pub session_idle_timeout: Duration,
332 #[deprecated(
335 since = "0.13.0",
336 note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
337 )]
338 pub sse_keep_alive: Duration,
339 #[deprecated(
343 since = "0.13.0",
344 note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
345 )]
346 pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
347 #[deprecated(
354 since = "0.13.0",
355 note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
356 )]
357 pub extra_router: Option<axum::Router>,
358 #[deprecated(
363 since = "0.13.0",
364 note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
365 )]
366 pub public_url: Option<String>,
367 #[deprecated(
370 since = "0.13.0",
371 note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
372 )]
373 pub log_request_headers: bool,
374 #[deprecated(
377 since = "0.13.0",
378 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
379 )]
380 pub compression_enabled: bool,
381 #[deprecated(
384 since = "0.13.0",
385 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
386 )]
387 pub compression_min_size: u16,
388 #[deprecated(
392 since = "0.13.0",
393 note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
394 )]
395 pub max_concurrent_requests: Option<usize>,
396 #[deprecated(
399 since = "0.13.0",
400 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
401 )]
402 pub admin_enabled: bool,
403 #[deprecated(
405 since = "0.13.0",
406 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
407 )]
408 pub admin_role: String,
409 #[cfg(feature = "metrics")]
412 #[deprecated(
413 since = "0.13.0",
414 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
415 )]
416 pub metrics_enabled: bool,
417 #[cfg(feature = "metrics")]
419 #[deprecated(
420 since = "0.13.0",
421 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
422 )]
423 pub metrics_bind: String,
424 #[deprecated(
428 since = "1.5.0",
429 note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
430 )]
431 pub security_headers: SecurityHeadersConfig,
432 #[deprecated(
438 since = "1.9.0",
439 note = "use McpServerConfig::with_tls_handshake_timeout(); direct field access will become pub(crate) in a future major release"
440 )]
441 pub tls_handshake_timeout: Duration,
442 #[deprecated(
449 since = "1.9.0",
450 note = "use McpServerConfig::with_max_concurrent_tls_handshakes(); direct field access will become pub(crate) in a future major release"
451 )]
452 pub max_concurrent_tls_handshakes: usize,
453}
454
455#[allow(
513 missing_debug_implementations,
514 reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
515)]
516pub struct Validated<T>(T);
517
518impl<T> std::fmt::Debug for Validated<T> {
519 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
520 f.debug_struct("Validated").finish_non_exhaustive()
521 }
522}
523
524impl<T> Validated<T> {
525 #[must_use]
527 pub fn as_inner(&self) -> &T {
528 &self.0
529 }
530
531 #[must_use]
536 pub fn into_inner(self) -> T {
537 self.0
538 }
539}
540
541#[allow(
542 deprecated,
543 reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
544)]
545impl McpServerConfig {
546 #[must_use]
554 pub fn new(
555 bind_addr: impl Into<String>,
556 name: impl Into<String>,
557 version: impl Into<String>,
558 ) -> Self {
559 Self {
560 bind_addr: bind_addr.into(),
561 name: name.into(),
562 version: version.into(),
563 tls_cert_path: None,
564 tls_key_path: None,
565 auth: None,
566 rbac: None,
567 allowed_origins: Vec::new(),
568 tool_rate_limit: None,
569 readiness_check: None,
570 max_request_body: 1024 * 1024,
571 request_timeout: Duration::from_mins(2),
572 shutdown_timeout: Duration::from_secs(30),
573 session_idle_timeout: Duration::from_mins(20),
574 sse_keep_alive: Duration::from_secs(15),
575 on_reload_ready: None,
576 extra_router: None,
577 public_url: None,
578 log_request_headers: false,
579 compression_enabled: false,
580 compression_min_size: 1024,
581 max_concurrent_requests: None,
582 admin_enabled: false,
583 admin_role: "admin".to_owned(),
584 #[cfg(feature = "metrics")]
585 metrics_enabled: false,
586 #[cfg(feature = "metrics")]
587 metrics_bind: "127.0.0.1:9090".into(),
588 security_headers: SecurityHeadersConfig::default(),
589 tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
590 max_concurrent_tls_handshakes: DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES,
591 extra_route_rate_limit: None,
592 }
593 }
594
595 #[must_use]
605 pub fn with_auth(mut self, auth: AuthConfig) -> Self {
606 self.auth = Some(auth);
607 self
608 }
609
610 #[must_use]
615 pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
616 self.security_headers = headers;
617 self
618 }
619
620 #[must_use]
624 pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
625 self.bind_addr = addr.into();
626 self
627 }
628
629 #[must_use]
632 pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
633 self.rbac = Some(rbac);
634 self
635 }
636
637 #[must_use]
641 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
642 self.tls_cert_path = Some(cert_path.into());
643 self.tls_key_path = Some(key_path.into());
644 self
645 }
646
647 #[must_use]
651 pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
652 self.public_url = Some(url.into());
653 self
654 }
655
656 #[must_use]
660 pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
661 where
662 I: IntoIterator<Item = S>,
663 S: Into<String>,
664 {
665 self.allowed_origins = origins.into_iter().map(Into::into).collect();
666 self
667 }
668
669 #[must_use]
682 pub fn with_extra_router(mut self, router: axum::Router) -> Self {
683 self.extra_router = Some(router);
684 self
685 }
686
687 #[must_use]
690 pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
691 self.readiness_check = Some(check);
692 self
693 }
694
695 #[must_use]
698 pub fn with_max_request_body(mut self, bytes: usize) -> Self {
699 self.max_request_body = bytes;
700 self
701 }
702
703 #[must_use]
705 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
706 self.request_timeout = timeout;
707 self
708 }
709
710 #[must_use]
712 pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
713 self.shutdown_timeout = timeout;
714 self
715 }
716
717 #[must_use]
719 pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
720 self.session_idle_timeout = timeout;
721 self
722 }
723
724 #[must_use]
726 pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
727 self.sse_keep_alive = interval;
728 self
729 }
730
731 #[must_use]
735 pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
736 self.max_concurrent_requests = Some(limit);
737 self
738 }
739
740 #[must_use]
748 pub fn with_tls_handshake_timeout(mut self, timeout: Duration) -> Self {
749 self.tls_handshake_timeout = timeout;
750 self
751 }
752
753 #[must_use]
762 pub fn with_max_concurrent_tls_handshakes(mut self, limit: usize) -> Self {
763 self.max_concurrent_tls_handshakes = limit;
764 self
765 }
766
767 #[must_use]
770 pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
771 self.tool_rate_limit = Some(per_minute);
772 self
773 }
774
775 #[must_use]
786 pub fn with_extra_route_rate_limit(mut self, per_minute: u32) -> Self {
787 self.extra_route_rate_limit = Some(per_minute);
788 self
789 }
790
791 #[must_use]
795 pub fn with_reload_callback<F>(mut self, callback: F) -> Self
796 where
797 F: FnOnce(ReloadHandle) + Send + 'static,
798 {
799 self.on_reload_ready = Some(Box::new(callback));
800 self
801 }
802
803 #[must_use]
807 pub fn enable_compression(mut self, min_size: u16) -> Self {
808 self.compression_enabled = true;
809 self.compression_min_size = min_size;
810 self
811 }
812
813 #[must_use]
818 pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
819 self.admin_enabled = true;
820 self.admin_role = role.into();
821 self
822 }
823
824 #[must_use]
827 pub fn enable_request_header_logging(mut self) -> Self {
828 self.log_request_headers = true;
829 self
830 }
831
832 #[cfg(feature = "metrics")]
835 #[must_use]
836 pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
837 self.metrics_enabled = true;
838 self.metrics_bind = bind.into();
839 self
840 }
841
842 pub fn validate(self) -> Result<Validated<Self>, McpxError> {
875 self.check()?;
876 Ok(Validated(self))
877 }
878
879 fn check(&self) -> Result<(), McpxError> {
883 if self.admin_enabled {
887 let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
888 if !auth_enabled {
889 return Err(McpxError::Config(
890 "admin_enabled=true requires auth to be configured and enabled".into(),
891 ));
892 }
893 }
894
895 match (&self.tls_cert_path, &self.tls_key_path) {
897 (Some(_), None) => {
898 return Err(McpxError::Config(
899 "tls_cert_path is set but tls_key_path is missing".into(),
900 ));
901 }
902 (None, Some(_)) => {
903 return Err(McpxError::Config(
904 "tls_key_path is set but tls_cert_path is missing".into(),
905 ));
906 }
907 _ => {}
908 }
909
910 if self.bind_addr.parse::<SocketAddr>().is_err() {
912 return Err(McpxError::Config(format!(
913 "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
914 self.bind_addr
915 )));
916 }
917
918 if let Some(ref url) = self.public_url
920 && !(url.starts_with("http://") || url.starts_with("https://"))
921 {
922 return Err(McpxError::Config(format!(
923 "public_url {url:?} must start with http:// or https://"
924 )));
925 }
926
927 for origin in &self.allowed_origins {
929 if !(origin.starts_with("http://") || origin.starts_with("https://")) {
930 return Err(McpxError::Config(format!(
931 "allowed_origins entry {origin:?} must start with http:// or https://"
932 )));
933 }
934 }
935
936 if self.max_request_body == 0 {
938 return Err(McpxError::Config(
939 "max_request_body must be greater than zero".into(),
940 ));
941 }
942
943 if self.extra_route_rate_limit == Some(0) {
947 return Err(McpxError::Config(
948 "extra_route_rate_limit must be greater than zero".into(),
949 ));
950 }
951
952 #[cfg(feature = "oauth")]
954 if let Some(auth_cfg) = &self.auth
955 && let Some(oauth_cfg) = &auth_cfg.oauth
956 {
957 oauth_cfg.validate()?;
958 }
959
960 validate_security_headers(&self.security_headers)?;
963
964 if let Some(0) = self.max_concurrent_requests {
968 return Err(McpxError::Config(
969 "max_concurrent_requests must be greater than zero when set".into(),
970 ));
971 }
972
973 if let Some(auth_cfg) = &self.auth
977 && let Some(rl) = &auth_cfg.rate_limit
978 && rl.max_tracked_keys == 0
979 {
980 return Err(McpxError::Config(
981 "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
982 ));
983 }
984
985 if self.tls_handshake_timeout == Duration::ZERO {
990 return Err(McpxError::Config(
991 "tls_handshake_timeout must be greater than zero".into(),
992 ));
993 }
994
995 if self.max_concurrent_tls_handshakes == 0 {
1000 return Err(McpxError::Config(
1001 "max_concurrent_tls_handshakes must be greater than zero".into(),
1002 ));
1003 }
1004
1005 Ok(())
1006 }
1007}
1008
1009#[allow(
1015 missing_debug_implementations,
1016 reason = "contains Arc<AuthState> with non-Debug fields"
1017)]
1018pub struct ReloadHandle {
1019 auth: Option<Arc<AuthState>>,
1020 rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
1021 crl_set: Option<Arc<CrlSet>>,
1022}
1023
1024impl ReloadHandle {
1025 pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
1027 if let Some(ref auth) = self.auth {
1028 auth.reload_keys(keys);
1029 }
1030 }
1031
1032 pub fn reload_rbac(&self, policy: RbacPolicy) {
1034 if let Some(ref rbac) = self.rbac {
1035 rbac.store(Arc::new(policy));
1036 tracing::info!("RBAC policy reloaded");
1037 }
1038 }
1039
1040 pub async fn refresh_crls(&self) -> Result<(), McpxError> {
1046 let Some(ref crl_set) = self.crl_set else {
1047 return Err(McpxError::Config(
1048 "CRL refresh requested but mTLS CRL support is not configured".into(),
1049 ));
1050 };
1051
1052 crl_set.force_refresh().await
1053 }
1054}
1055
1056#[allow(
1073 clippy::too_many_lines,
1074 clippy::cognitive_complexity,
1075 reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
1076)]
1077struct AppRunParams {
1081 tls_paths: Option<(PathBuf, PathBuf)>,
1083 tls_handshake_timeout: Duration,
1085 max_concurrent_tls_handshakes: usize,
1087 mtls_config: Option<MtlsConfig>,
1089 shutdown_timeout: Duration,
1091 auth_state: Option<Arc<AuthState>>,
1093 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1095 on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1097 ct: CancellationToken,
1101 scheme: &'static str,
1103 name: String,
1105}
1106
1107#[allow(
1117 clippy::cognitive_complexity,
1118 reason = "router assembly is intrinsically sequential; splitting harms readability"
1119)]
1120#[allow(
1121 deprecated,
1122 reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
1123)]
1124fn build_app_router<H, F>(
1125 mut config: McpServerConfig,
1126 handler_factory: F,
1127) -> anyhow::Result<(axum::Router, AppRunParams)>
1128where
1129 H: ServerHandler + 'static,
1130 F: Fn() -> H + Send + Sync + Clone + 'static,
1131{
1132 let ct = CancellationToken::new();
1133
1134 let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
1135 tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
1136
1137 let mcp_service = StreamableHttpService::new(
1138 move || Ok(handler_factory()),
1139 {
1140 let mut mgr = LocalSessionManager::default();
1141 mgr.session_config.keep_alive = Some(config.session_idle_timeout);
1142 mgr.into()
1143 },
1144 StreamableHttpServerConfig::default()
1145 .with_allowed_hosts(allowed_hosts)
1146 .with_sse_keep_alive(Some(config.sse_keep_alive))
1147 .with_cancellation_token(ct.child_token()),
1148 );
1149
1150 let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
1152
1153 let auth_state: Option<Arc<AuthState>> = match config.auth {
1157 Some(ref auth_config) if auth_config.enabled => {
1158 let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
1159 let pre_auth_limiter = auth_config
1160 .rate_limit
1161 .as_ref()
1162 .map(crate::auth::build_pre_auth_limiter);
1163
1164 #[cfg(feature = "oauth")]
1165 let jwks_cache = auth_config
1166 .oauth
1167 .as_ref()
1168 .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
1169 .transpose()
1170 .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
1171
1172 Some(Arc::new(AuthState {
1173 api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
1174 rate_limiter,
1175 pre_auth_limiter,
1176 #[cfg(feature = "oauth")]
1177 jwks_cache,
1178 seen_identities: crate::auth::SeenIdentitySet::new(),
1179 counters: crate::auth::AuthCounters::default(),
1180 }))
1181 }
1182 _ => None,
1183 };
1184
1185 let rbac_swap = Arc::new(ArcSwap::new(
1188 config
1189 .rbac
1190 .clone()
1191 .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
1192 ));
1193
1194 if config.admin_enabled {
1197 let Some(ref auth_state_ref) = auth_state else {
1198 return Err(anyhow::anyhow!(
1199 "admin_enabled=true requires auth to be configured and enabled"
1200 ));
1201 };
1202 let admin_state = crate::admin::AdminState {
1203 started_at: std::time::Instant::now(),
1204 name: config.name.clone(),
1205 version: config.version.clone(),
1206 auth: Some(Arc::clone(auth_state_ref)),
1207 rbac: Arc::clone(&rbac_swap),
1208 };
1209 let admin_cfg = crate::admin::AdminConfig {
1210 role: config.admin_role.clone(),
1211 };
1212 mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
1213 tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
1214 }
1215
1216 {
1249 let tool_limiter: Option<Arc<ToolRateLimiter>> =
1250 config.tool_rate_limit.map(build_tool_rate_limiter);
1251
1252 if rbac_swap.load().is_enabled() {
1253 tracing::info!("RBAC enforcement enabled on /mcp");
1254 }
1255 if let Some(limit) = config.tool_rate_limit {
1256 tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1257 }
1258
1259 let rbac_for_mw = Arc::clone(&rbac_swap);
1260 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1261 let p = rbac_for_mw.load_full();
1262 let tl = tool_limiter.clone();
1263 rbac_middleware(p, tl, req, next)
1264 }));
1265 }
1266
1267 if let Some(ref auth_config) = config.auth
1269 && auth_config.enabled
1270 {
1271 let Some(ref state) = auth_state else {
1272 return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1273 };
1274
1275 let methods: Vec<&str> = [
1276 auth_config.mtls.is_some().then_some("mTLS"),
1277 (!auth_config.api_keys.is_empty()).then_some("bearer"),
1278 #[cfg(feature = "oauth")]
1279 auth_config.oauth.is_some().then_some("oauth-jwt"),
1280 ]
1281 .into_iter()
1282 .flatten()
1283 .collect();
1284
1285 tracing::info!(
1286 methods = %methods.join(", "),
1287 api_keys = auth_config.api_keys.len(),
1288 "auth enabled on /mcp"
1289 );
1290
1291 let state_for_mw = Arc::clone(state);
1292 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1293 let s = Arc::clone(&state_for_mw);
1294 auth_middleware(s, req, next)
1295 }));
1296 }
1297
1298 mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1301 axum::http::StatusCode::REQUEST_TIMEOUT,
1302 config.request_timeout,
1303 ));
1304
1305 mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1309 config.max_request_body,
1310 ));
1311
1312 let mut effective_origins = config.allowed_origins.clone();
1319 if effective_origins.is_empty()
1320 && let Some(ref url) = config.public_url
1321 {
1322 if let Some(scheme_end) = url.find("://") {
1327 let scheme_with_sep = url.get(..scheme_end + 3).unwrap_or_default();
1328 let after_scheme = url.get(scheme_end + 3..).unwrap_or_default();
1329 let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1330 let host = after_scheme.get(..host_end).unwrap_or_default();
1331 let origin = format!("{scheme_with_sep}{host}");
1332 tracing::info!(
1333 %origin,
1334 "auto-derived allowed origin from public_url"
1335 );
1336 effective_origins.push(origin);
1337 }
1338 }
1339 let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1340 let cors_origins = Arc::clone(&allowed_origins);
1341 let log_request_headers = config.log_request_headers;
1342
1343 let readyz_route = if let Some(check) = config.readiness_check.take() {
1344 axum::routing::get(move || readyz(Arc::clone(&check)))
1345 } else {
1346 axum::routing::get(healthz)
1347 };
1348
1349 #[allow(unused_mut)] let mut router = axum::Router::new()
1351 .route("/healthz", axum::routing::get(healthz))
1352 .route("/readyz", readyz_route)
1353 .route(
1354 "/version",
1355 axum::routing::get({
1356 let payload_bytes: Arc<[u8]> =
1361 serialize_version_payload(&config.name, &config.version);
1362 move || {
1363 let p = Arc::clone(&payload_bytes);
1364 async move {
1365 (
1366 [(axum::http::header::CONTENT_TYPE, "application/json")],
1367 p.to_vec(),
1368 )
1369 }
1370 }
1371 }),
1372 )
1373 .merge(mcp_router);
1374
1375 if let Some(extra) = config.extra_router.take() {
1382 let extra = match config.extra_route_rate_limit {
1383 Some(per_minute) => {
1384 let limiter = build_extra_route_rate_limiter(per_minute);
1385 tracing::info!(per_minute, "extra-route per-IP rate limit enabled");
1386 extra.layer(axum::middleware::from_fn(move |req, next| {
1387 let l = Arc::clone(&limiter);
1388 extra_route_rate_limit_middleware(l, req, next)
1389 }))
1390 }
1391 None => extra,
1392 };
1393 router = router.merge(extra);
1394 }
1395
1396 let server_url = if let Some(ref url) = config.public_url {
1403 url.trim_end_matches('/').to_owned()
1404 } else {
1405 let prm_scheme = if config.tls_cert_path.is_some() {
1406 "https"
1407 } else {
1408 "http"
1409 };
1410 format!("{prm_scheme}://{}", config.bind_addr)
1411 };
1412 let resource_url = format!("{server_url}/mcp");
1413
1414 #[cfg(feature = "oauth")]
1415 let prm_metadata = if let Some(ref auth_config) = config.auth
1416 && let Some(ref oauth_config) = auth_config.oauth
1417 {
1418 crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1419 } else {
1420 serde_json::json!({ "resource": resource_url })
1421 };
1422 #[cfg(not(feature = "oauth"))]
1423 let prm_metadata = serde_json::json!({ "resource": resource_url });
1424
1425 router = router.route(
1426 "/.well-known/oauth-protected-resource",
1427 axum::routing::get(move || {
1428 let m = prm_metadata.clone();
1429 async move { axum::Json(m) }
1430 }),
1431 );
1432
1433 #[cfg(feature = "oauth")]
1438 if let Some(ref auth_config) = config.auth
1439 && let Some(ref oauth_config) = auth_config.oauth
1440 && oauth_config.proxy.is_some()
1441 {
1442 router =
1443 install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1444 }
1445
1446 let is_tls = config.tls_cert_path.is_some();
1449 let security_headers_cfg = Arc::new(config.security_headers.clone());
1450 router = router.layer(axum::middleware::from_fn(move |req, next| {
1451 let cfg = Arc::clone(&security_headers_cfg);
1452 security_headers_middleware(is_tls, cfg, req, next)
1453 }));
1454
1455 if !cors_origins.is_empty() {
1459 let cors = tower_http::cors::CorsLayer::new()
1460 .allow_origin(
1461 cors_origins
1462 .iter()
1463 .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1464 .collect::<Vec<_>>(),
1465 )
1466 .allow_methods([
1467 axum::http::Method::GET,
1468 axum::http::Method::POST,
1469 axum::http::Method::OPTIONS,
1470 ])
1471 .allow_headers([
1472 axum::http::header::CONTENT_TYPE,
1473 axum::http::header::AUTHORIZATION,
1474 ]);
1475 router = router.layer(cors);
1476 }
1477
1478 if config.compression_enabled {
1482 use tower_http::compression::Predicate as _;
1483 let predicate = tower_http::compression::DefaultPredicate::new().and(
1484 tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1485 );
1486 router = router.layer(
1487 tower_http::compression::CompressionLayer::new()
1488 .gzip(true)
1489 .br(true)
1490 .compress_when(predicate),
1491 );
1492 tracing::info!(
1493 min_size = config.compression_min_size,
1494 "response compression enabled (gzip, br)"
1495 );
1496 }
1497
1498 if let Some(max) = config.max_concurrent_requests {
1501 let overload_handler = tower::ServiceBuilder::new()
1502 .layer(axum::error_handling::HandleErrorLayer::new(
1503 |_err: tower::BoxError| async {
1504 (
1505 axum::http::StatusCode::SERVICE_UNAVAILABLE,
1506 axum::Json(serde_json::json!({
1507 "error": "overloaded",
1508 "error_description": "server is at capacity, retry later"
1509 })),
1510 )
1511 },
1512 ))
1513 .layer(tower::load_shed::LoadShedLayer::new())
1514 .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1515 router = router.layer(overload_handler);
1516 tracing::info!(max, "global concurrency limit enabled");
1517 }
1518
1519 router = router.fallback(|| async {
1523 (
1524 axum::http::StatusCode::NOT_FOUND,
1525 axum::Json(serde_json::json!({
1526 "error": "not_found",
1527 "error_description": "The requested endpoint does not exist"
1528 })),
1529 )
1530 });
1531
1532 #[cfg(feature = "metrics")]
1534 if config.metrics_enabled {
1535 let metrics = Arc::new(
1536 crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1537 );
1538 let m = Arc::clone(&metrics);
1539 router = router.layer(axum::middleware::from_fn(
1540 move |req: Request<Body>, next: Next| {
1541 let m = Arc::clone(&m);
1542 metrics_middleware(m, req, next)
1543 },
1544 ));
1545 let metrics_bind = config.metrics_bind.clone();
1546 let metrics_shutdown = ct.clone();
1547 tokio::spawn(async move {
1548 if let Err(e) =
1549 crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1550 {
1551 tracing::error!("metrics listener failed: {e}");
1552 }
1553 });
1554 }
1555
1556 router = router.layer(axum::middleware::from_fn(normalize_peer_addr_middleware));
1564
1565 router = router.layer(axum::middleware::from_fn(move |req, next| {
1576 let origins = Arc::clone(&allowed_origins);
1577 origin_check_middleware(origins, log_request_headers, req, next)
1578 }));
1579
1580 let scheme = if config.tls_cert_path.is_some() {
1581 "https"
1582 } else {
1583 "http"
1584 };
1585
1586 let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1587 (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1588 _ => None,
1589 };
1590 let tls_handshake_timeout = config.tls_handshake_timeout;
1591 let max_concurrent_tls_handshakes = config.max_concurrent_tls_handshakes;
1592 let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1593
1594 Ok((
1595 router,
1596 AppRunParams {
1597 tls_paths,
1598 tls_handshake_timeout,
1599 max_concurrent_tls_handshakes,
1600 mtls_config,
1601 shutdown_timeout: config.shutdown_timeout,
1602 auth_state,
1603 rbac_swap,
1604 on_reload_ready: config.on_reload_ready.take(),
1605 ct,
1606 scheme,
1607 name: config.name.clone(),
1608 },
1609 ))
1610}
1611
1612pub async fn serve<H, F>(
1629 config: Validated<McpServerConfig>,
1630 handler_factory: F,
1631) -> Result<(), McpxError>
1632where
1633 H: ServerHandler + 'static,
1634 F: Fn() -> H + Send + Sync + Clone + 'static,
1635{
1636 let config = config.into_inner();
1637 #[allow(
1638 deprecated,
1639 reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1640 )]
1641 let bind_addr = config.bind_addr.clone();
1642 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1643
1644 let listener = TcpListener::bind(&bind_addr)
1645 .await
1646 .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1647 log_listening(¶ms.name, params.scheme, &bind_addr);
1648
1649 run_server(
1650 router,
1651 listener,
1652 params.tls_paths,
1653 params.tls_handshake_timeout,
1654 params.max_concurrent_tls_handshakes,
1655 params.mtls_config,
1656 params.shutdown_timeout,
1657 params.auth_state,
1658 params.rbac_swap,
1659 params.on_reload_ready,
1660 params.ct,
1661 )
1662 .await
1663 .map_err(anyhow_to_startup)
1664}
1665
1666pub async fn serve_with_listener<H, F>(
1696 listener: TcpListener,
1697 config: Validated<McpServerConfig>,
1698 handler_factory: F,
1699 ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1700 shutdown: Option<CancellationToken>,
1701) -> Result<(), McpxError>
1702where
1703 H: ServerHandler + 'static,
1704 F: Fn() -> H + Send + Sync + Clone + 'static,
1705{
1706 let config = config.into_inner();
1707 let local_addr = listener
1708 .local_addr()
1709 .map_err(|e| io_to_startup("listener.local_addr", e))?;
1710 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1711
1712 log_listening(¶ms.name, params.scheme, &local_addr.to_string());
1713
1714 if let Some(external) = shutdown {
1718 let internal = params.ct.clone();
1719 tokio::spawn(async move {
1720 external.cancelled().await;
1721 internal.cancel();
1722 });
1723 }
1724
1725 if let Some(tx) = ready_tx {
1729 let _ = tx.send(local_addr);
1731 }
1732
1733 run_server(
1734 router,
1735 listener,
1736 params.tls_paths,
1737 params.tls_handshake_timeout,
1738 params.max_concurrent_tls_handshakes,
1739 params.mtls_config,
1740 params.shutdown_timeout,
1741 params.auth_state,
1742 params.rbac_swap,
1743 params.on_reload_ready,
1744 params.ct,
1745 )
1746 .await
1747 .map_err(anyhow_to_startup)
1748}
1749
1750#[allow(
1753 clippy::cognitive_complexity,
1754 reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1755)]
1756fn log_listening(name: &str, scheme: &str, addr: &str) {
1757 tracing::info!("{name} listening on {addr}");
1758 tracing::info!(" MCP endpoint: {scheme}://{addr}/mcp");
1759 tracing::info!(" Health check: {scheme}://{addr}/healthz");
1760 tracing::info!(" Readiness: {scheme}://{addr}/readyz");
1761}
1762
1763#[allow(
1786 clippy::too_many_arguments,
1787 clippy::cognitive_complexity,
1788 reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1789)]
1790async fn run_server(
1791 router: axum::Router,
1792 listener: TcpListener,
1793 tls_paths: Option<(PathBuf, PathBuf)>,
1794 tls_handshake_timeout: Duration,
1795 max_concurrent_tls_handshakes: usize,
1796 mtls_config: Option<MtlsConfig>,
1797 shutdown_timeout: Duration,
1798 auth_state: Option<Arc<AuthState>>,
1799 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1800 mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1801 ct: CancellationToken,
1802) -> anyhow::Result<()> {
1803 let shutdown_trigger = CancellationToken::new();
1807 {
1808 let trigger = shutdown_trigger.clone();
1809 let parent = ct.clone();
1810 tokio::spawn(async move {
1811 tokio::select! {
1812 () = shutdown_signal() => {}
1813 () = parent.cancelled() => {}
1814 }
1815 trigger.cancel();
1816 });
1817 }
1818
1819 let graceful = {
1820 let trigger = shutdown_trigger.clone();
1821 let ct = ct.clone();
1822 async move {
1823 trigger.cancelled().await;
1824 tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1825 ct.cancel();
1826 }
1827 };
1828
1829 let force_exit_timer = {
1830 let trigger = shutdown_trigger.clone();
1831 async move {
1832 trigger.cancelled().await;
1833 tokio::time::sleep(shutdown_timeout).await;
1834 }
1835 };
1836
1837 if let Some((cert_path, key_path)) = tls_paths {
1838 let crl_set = if let Some(mtls) = mtls_config.as_ref()
1839 && mtls.crl_enabled
1840 {
1841 let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1842 let (crl_set, discover_rx) =
1843 mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1844 .await
1845 .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1846 tokio::spawn(mtls_revocation::run_crl_refresher(
1847 Arc::clone(&crl_set),
1848 discover_rx,
1849 ct.clone(),
1850 ));
1851 Some(crl_set)
1852 } else {
1853 None
1854 };
1855
1856 if let Some(cb) = on_reload_ready.take() {
1857 cb(ReloadHandle {
1858 auth: auth_state.clone(),
1859 rbac: Some(Arc::clone(&rbac_swap)),
1860 crl_set: crl_set.clone(),
1861 });
1862 }
1863
1864 let tls_listener = TlsListener::new(
1865 listener,
1866 &cert_path,
1867 &key_path,
1868 mtls_config.as_ref(),
1869 crl_set,
1870 tls_handshake_timeout,
1871 max_concurrent_tls_handshakes,
1872 )?;
1873 let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1874 tokio::select! {
1875 result = axum::serve(tls_listener, make_svc)
1876 .with_graceful_shutdown(graceful) => { result?; }
1877 () = force_exit_timer => {
1878 tracing::warn!("shutdown timeout exceeded, forcing exit");
1879 }
1880 }
1881 } else {
1882 if let Some(cb) = on_reload_ready.take() {
1883 cb(ReloadHandle {
1884 auth: auth_state,
1885 rbac: Some(rbac_swap),
1886 crl_set: None,
1887 });
1888 }
1889
1890 let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1891 tokio::select! {
1892 result = axum::serve(listener, make_svc)
1893 .with_graceful_shutdown(graceful) => { result?; }
1894 () = force_exit_timer => {
1895 tracing::warn!("shutdown timeout exceeded, forcing exit");
1896 }
1897 }
1898 }
1899
1900 Ok(())
1901}
1902
1903#[cfg(feature = "oauth")]
1912fn install_oauth_proxy_routes(
1913 router: axum::Router,
1914 server_url: &str,
1915 oauth_config: &crate::oauth::OAuthConfig,
1916 auth_state: Option<&Arc<AuthState>>,
1917) -> Result<axum::Router, McpxError> {
1918 let Some(ref proxy) = oauth_config.proxy else {
1919 return Ok(router);
1920 };
1921
1922 let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1925
1926 let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1927 let router = router.route(
1928 "/.well-known/oauth-authorization-server",
1929 axum::routing::get(move || {
1930 let m = asm.clone();
1931 async move { axum::Json(m) }
1932 }),
1933 );
1934
1935 let proxy_authorize = proxy.clone();
1936 let router = router.route(
1937 "/authorize",
1938 axum::routing::get(
1939 move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1940 let p = proxy_authorize.clone();
1941 async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1942 },
1943 ),
1944 );
1945
1946 let proxy_token = proxy.clone();
1947 let token_http = http.clone();
1948 let router = router.route(
1949 "/token",
1950 axum::routing::post(move |body: String| {
1951 let p = proxy_token.clone();
1952 let h = token_http.clone();
1953 async move { crate::oauth::handle_token(&h, &p, &body).await }
1954 })
1955 .layer(axum::middleware::from_fn(
1956 oauth_token_cache_headers_middleware,
1957 )),
1958 );
1959
1960 let proxy_register = proxy.clone();
1961 let router = router.route(
1962 "/register",
1963 axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1964 let p = proxy_register;
1965 async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1966 })
1967 .layer(axum::middleware::from_fn(
1968 oauth_token_cache_headers_middleware,
1969 )),
1970 );
1971
1972 let admin_routes_enabled = proxy.expose_admin_endpoints
1973 && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1974 if proxy.expose_admin_endpoints
1975 && !proxy.require_auth_on_admin_endpoints
1976 && proxy.allow_unauthenticated_admin_endpoints
1977 {
1978 tracing::warn!(
1982 "OAuth introspect/revoke endpoints are unauthenticated by explicit \
1983 allow_unauthenticated_admin_endpoints opt-out; ensure an \
1984 authenticated reverse proxy fronts these routes"
1985 );
1986 }
1987
1988 let admin_router = if admin_routes_enabled {
1989 build_oauth_admin_router(proxy, http, auth_state)?
1990 } else {
1991 axum::Router::new()
1992 };
1993
1994 let router = router.merge(admin_router);
1995
1996 tracing::info!(
1997 introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1998 revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1999 "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
2000 );
2001 Ok(router)
2002}
2003
2004#[cfg(feature = "oauth")]
2010fn build_oauth_admin_router(
2011 proxy: &crate::oauth::OAuthProxyConfig,
2012 http: crate::oauth::OauthHttpClient,
2013 auth_state: Option<&Arc<AuthState>>,
2014) -> Result<axum::Router, McpxError> {
2015 let mut admin_router = axum::Router::new();
2016 if proxy.introspection_url.is_some() {
2017 let proxy_introspect = proxy.clone();
2018 let introspect_http = http.clone();
2019 admin_router = admin_router.route(
2020 "/introspect",
2021 axum::routing::post(move |body: String| {
2022 let p = proxy_introspect.clone();
2023 let h = introspect_http.clone();
2024 async move { crate::oauth::handle_introspect(&h, &p, &body).await }
2025 }),
2026 );
2027 }
2028 if proxy.revocation_url.is_some() {
2029 let proxy_revoke = proxy.clone();
2030 let revoke_http = http;
2031 admin_router = admin_router.route(
2032 "/revoke",
2033 axum::routing::post(move |body: String| {
2034 let p = proxy_revoke.clone();
2035 let h = revoke_http.clone();
2036 async move { crate::oauth::handle_revoke(&h, &p, &body).await }
2037 }),
2038 );
2039 }
2040
2041 let admin_router = admin_router.layer(axum::middleware::from_fn(
2042 oauth_token_cache_headers_middleware,
2043 ));
2044
2045 if proxy.require_auth_on_admin_endpoints {
2046 let Some(state) = auth_state else {
2047 return Err(McpxError::Startup(
2048 "oauth proxy admin endpoints require auth state".into(),
2049 ));
2050 };
2051 let state_for_mw = Arc::clone(state);
2052 Ok(
2053 admin_router.layer(axum::middleware::from_fn(move |req, next| {
2054 let s = Arc::clone(&state_for_mw);
2055 auth_middleware(s, req, next)
2056 })),
2057 )
2058 } else {
2059 Ok(admin_router)
2060 }
2061}
2062
2063fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
2068 let mut hosts = vec![
2069 "localhost".to_owned(),
2070 "127.0.0.1".to_owned(),
2071 "::1".to_owned(),
2072 ];
2073
2074 if let Some(url) = public_url
2075 && let Ok(uri) = url.parse::<axum::http::Uri>()
2076 && let Some(authority) = uri.authority()
2077 {
2078 let host = authority.host().to_owned();
2079 if !hosts.iter().any(|h| h == &host) {
2080 hosts.push(host);
2081 }
2082
2083 let authority = authority.as_str().to_owned();
2084 if !hosts.iter().any(|h| h == &authority) {
2085 hosts.push(authority);
2086 }
2087 }
2088
2089 if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
2090 && let Some(authority) = uri.authority()
2091 {
2092 let host = authority.host().to_owned();
2093 if !hosts.iter().any(|h| h == &host) {
2094 hosts.push(host);
2095 }
2096
2097 let authority = authority.as_str().to_owned();
2098 if !hosts.iter().any(|h| h == &authority) {
2099 hosts.push(authority);
2100 }
2101 }
2102
2103 hosts
2104}
2105
2106impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
2119 for TlsConnInfo
2120{
2121 fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
2122 let addr = *target.remote_addr();
2123 let identity = target.io().identity().cloned();
2124 TlsConnInfo::new(addr, identity)
2125 }
2126}
2127
2128const DEFAULT_TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
2135
2136const DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES: usize = 256;
2144
2145const TLS_ACCEPT_CHANNEL_CAPACITY: usize = 32;
2150
2151struct TlsListener {
2167 local_addr: SocketAddr,
2170 rx: mpsc::Receiver<(AuthenticatedTlsStream, SocketAddr)>,
2172 acceptor_task: tokio::task::JoinHandle<()>,
2175}
2176
2177impl TlsListener {
2178 fn new(
2179 inner: TcpListener,
2180 cert_path: &Path,
2181 key_path: &Path,
2182 mtls_config: Option<&MtlsConfig>,
2183 crl_set: Option<Arc<CrlSet>>,
2184 handshake_timeout: Duration,
2185 max_concurrent_handshakes: usize,
2186 ) -> anyhow::Result<Self> {
2187 rustls::crypto::ring::default_provider()
2189 .install_default()
2190 .ok();
2191
2192 let certs = load_certs(cert_path)?;
2193 let key = load_key(key_path)?;
2194
2195 let mtls_default_role;
2196
2197 let tls_config = if let Some(mtls) = mtls_config {
2198 mtls_default_role = mtls.default_role.clone();
2199 let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
2200 {
2201 let Some(crl_set) = crl_set else {
2202 return Err(anyhow::anyhow!(
2203 "mTLS CRL verifier requested but CRL state was not initialized"
2204 ));
2205 };
2206 Arc::new(DynamicClientCertVerifier::new(crl_set))
2207 } else {
2208 let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
2209 if mtls.required {
2210 rustls::server::WebPkiClientVerifier::builder(root_store)
2211 .build()
2212 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2213 } else {
2214 rustls::server::WebPkiClientVerifier::builder(root_store)
2215 .allow_unauthenticated()
2216 .build()
2217 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2218 }
2219 };
2220
2221 tracing::info!(
2222 ca = %mtls.ca_cert_path.display(),
2223 required = mtls.required,
2224 crl_enabled = mtls.crl_enabled,
2225 "mTLS client auth configured"
2226 );
2227
2228 rustls::ServerConfig::builder_with_protocol_versions(&[
2229 &rustls::version::TLS12,
2230 &rustls::version::TLS13,
2231 ])
2232 .with_client_cert_verifier(verifier)
2233 .with_single_cert(certs, key)?
2234 } else {
2235 mtls_default_role = "viewer".to_owned();
2236 rustls::ServerConfig::builder_with_protocol_versions(&[
2237 &rustls::version::TLS12,
2238 &rustls::version::TLS13,
2239 ])
2240 .with_no_client_auth()
2241 .with_single_cert(certs, key)?
2242 };
2243
2244 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
2245 tracing::info!(
2246 "TLS enabled (cert: {}, key: {})",
2247 cert_path.display(),
2248 key_path.display()
2249 );
2250 let local_addr = inner.local_addr()?;
2251 let (tx, rx) = mpsc::channel(TLS_ACCEPT_CHANNEL_CAPACITY);
2252 let acceptor_task = tokio::spawn(run_tls_acceptor(
2253 inner,
2254 acceptor,
2255 mtls_default_role,
2256 tx,
2257 handshake_timeout,
2258 max_concurrent_handshakes,
2259 ));
2260 Ok(Self {
2261 local_addr,
2262 rx,
2263 acceptor_task,
2264 })
2265 }
2266
2267 fn extract_handshake_identity(
2271 tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2272 default_role: &str,
2273 addr: SocketAddr,
2274 ) -> Option<AuthIdentity> {
2275 let (_, server_conn) = tls_stream.get_ref();
2276 let cert_der = server_conn.peer_certificates()?.first()?;
2277 let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
2278 tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
2279 Some(id)
2280 }
2281}
2282
2283async fn run_tls_acceptor(
2291 listener: TcpListener,
2292 acceptor: tokio_rustls::TlsAcceptor,
2293 default_role: String,
2294 tx: mpsc::Sender<(AuthenticatedTlsStream, SocketAddr)>,
2295 handshake_timeout: Duration,
2296 max_concurrent_handshakes: usize,
2297) {
2298 let inflight = Arc::new(Semaphore::new(max_concurrent_handshakes));
2299 loop {
2300 let Ok(permit) = Arc::clone(&inflight).acquire_owned().await else {
2304 return;
2306 };
2307 let (stream, addr) = match listener.accept().await {
2308 Ok(pair) => pair,
2309 Err(e) => {
2310 tracing::debug!("TCP accept error: {e}");
2311 continue;
2312 }
2313 };
2314 if tx.is_closed() {
2315 return;
2317 }
2318 let acceptor = acceptor.clone();
2319 let default_role = default_role.clone();
2320 let tx = tx.clone();
2321 tokio::spawn(async move {
2322 let _permit = permit;
2323 match tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
2324 Ok(Ok(tls_stream)) => {
2325 let identity =
2326 TlsListener::extract_handshake_identity(&tls_stream, &default_role, addr);
2327 let wrapped = AuthenticatedTlsStream {
2328 inner: tls_stream,
2329 identity,
2330 };
2331 let _ = tx.send((wrapped, addr)).await;
2334 }
2335 Ok(Err(e)) => {
2336 tracing::debug!("TLS handshake failed from {addr}: {e}");
2337 }
2338 Err(_elapsed) => {
2339 tracing::debug!(
2340 "TLS handshake timed out from {addr} after {handshake_timeout:?}"
2341 );
2342 }
2343 }
2344 });
2345 }
2346}
2347
2348pub(crate) struct AuthenticatedTlsStream {
2360 inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2361 identity: Option<AuthIdentity>,
2362}
2363
2364impl AuthenticatedTlsStream {
2365 #[must_use]
2367 pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
2368 self.identity.as_ref()
2369 }
2370}
2371
2372impl std::fmt::Debug for AuthenticatedTlsStream {
2373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2374 f.debug_struct("AuthenticatedTlsStream")
2375 .field("identity", &self.identity.as_ref().map(|id| &id.name))
2376 .finish_non_exhaustive()
2377 }
2378}
2379
2380impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2381 fn poll_read(
2382 mut self: Pin<&mut Self>,
2383 cx: &mut std::task::Context<'_>,
2384 buf: &mut tokio::io::ReadBuf<'_>,
2385 ) -> std::task::Poll<std::io::Result<()>> {
2386 Pin::new(&mut self.inner).poll_read(cx, buf)
2387 }
2388}
2389
2390impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2391 fn poll_write(
2392 mut self: Pin<&mut Self>,
2393 cx: &mut std::task::Context<'_>,
2394 buf: &[u8],
2395 ) -> std::task::Poll<std::io::Result<usize>> {
2396 Pin::new(&mut self.inner).poll_write(cx, buf)
2397 }
2398
2399 fn poll_flush(
2400 mut self: Pin<&mut Self>,
2401 cx: &mut std::task::Context<'_>,
2402 ) -> std::task::Poll<std::io::Result<()>> {
2403 Pin::new(&mut self.inner).poll_flush(cx)
2404 }
2405
2406 fn poll_shutdown(
2407 mut self: Pin<&mut Self>,
2408 cx: &mut std::task::Context<'_>,
2409 ) -> std::task::Poll<std::io::Result<()>> {
2410 Pin::new(&mut self.inner).poll_shutdown(cx)
2411 }
2412
2413 fn poll_write_vectored(
2414 mut self: Pin<&mut Self>,
2415 cx: &mut std::task::Context<'_>,
2416 bufs: &[std::io::IoSlice<'_>],
2417 ) -> std::task::Poll<std::io::Result<usize>> {
2418 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2419 }
2420
2421 fn is_write_vectored(&self) -> bool {
2422 self.inner.is_write_vectored()
2423 }
2424}
2425
2426impl axum::serve::Listener for TlsListener {
2427 type Io = AuthenticatedTlsStream;
2428 type Addr = SocketAddr;
2429
2430 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2436 if let Some(pair) = self.rx.recv().await {
2437 return pair;
2438 }
2439 tracing::error!("TLS acceptor task terminated; no further connections will be accepted");
2445 std::future::pending().await
2446 }
2447
2448 fn local_addr(&self) -> std::io::Result<Self::Addr> {
2449 Ok(self.local_addr)
2450 }
2451}
2452
2453impl Drop for TlsListener {
2454 fn drop(&mut self) {
2455 self.acceptor_task.abort();
2458 }
2459}
2460
2461fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2462 use rustls::pki_types::pem::PemObject;
2463 let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2464 .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2465 .collect::<Result<_, _>>()
2466 .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2467 anyhow::ensure!(
2468 !certs.is_empty(),
2469 "no certificates found in {}",
2470 path.display()
2471 );
2472 Ok(certs)
2473}
2474
2475fn load_client_auth_roots(
2476 path: &Path,
2477) -> anyhow::Result<(
2478 Vec<rustls::pki_types::CertificateDer<'static>>,
2479 Arc<RootCertStore>,
2480)> {
2481 let ca_certs = load_certs(path)?;
2482 let mut root_store = RootCertStore::empty();
2483 for cert in &ca_certs {
2484 root_store
2485 .add(cert.clone())
2486 .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2487 }
2488
2489 Ok((ca_certs, Arc::new(root_store)))
2490}
2491
2492fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2493 use rustls::pki_types::pem::PemObject;
2494 rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2495 .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2496}
2497
2498#[allow(
2499 clippy::unused_async,
2500 reason = "axum route handler signature requires `async fn` even when the body is synchronous"
2501)]
2502async fn healthz() -> impl IntoResponse {
2503 axum::Json(serde_json::json!({
2504 "status": "ok",
2505 }))
2506}
2507
2508fn version_payload(name: &str, version: &str) -> serde_json::Value {
2515 serde_json::json!({
2516 "name": name,
2517 "version": version,
2518 "build_git_sha": option_env!("RMCP_SERVER_KIT_BUILD_SHA").unwrap_or("unknown"),
2519 "build_timestamp": option_env!("RMCP_SERVER_KIT_BUILD_TIME").unwrap_or("unknown"),
2520 "rust_version": option_env!("RMCP_SERVER_KIT_RUSTC_VERSION").unwrap_or("unknown"),
2521 "mcpx_version": env!("CARGO_PKG_VERSION"),
2522 })
2523}
2524
2525fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2535 let value = version_payload(name, version);
2536 serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2537}
2538
2539async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2540 let status = check().await;
2541 let ready = status
2542 .get("ready")
2543 .and_then(serde_json::Value::as_bool)
2544 .unwrap_or(false);
2545 let code = if ready {
2546 axum::http::StatusCode::OK
2547 } else {
2548 axum::http::StatusCode::SERVICE_UNAVAILABLE
2549 };
2550 (code, axum::Json(status))
2551}
2552
2553async fn shutdown_signal() {
2557 let ctrl_c = tokio::signal::ctrl_c();
2558
2559 #[cfg(unix)]
2560 {
2561 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2562 Ok(mut term) => {
2563 tokio::select! {
2564 _ = ctrl_c => {}
2565 _ = term.recv() => {}
2566 }
2567 }
2568 Err(e) => {
2569 tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2570 ctrl_c.await.ok();
2571 }
2572 }
2573 }
2574
2575 #[cfg(not(unix))]
2576 {
2577 ctrl_c.await.ok();
2578 }
2579}
2580
2581#[cfg(feature = "metrics")]
2587async fn metrics_middleware(
2588 metrics: Arc<crate::metrics::McpMetrics>,
2589 req: Request<Body>,
2590 next: Next,
2591) -> axum::response::Response {
2592 let method = req.method().to_string();
2593 let path = req.uri().path().to_owned();
2594 let start = std::time::Instant::now();
2595
2596 let response = next.run(req).await;
2597
2598 let status = response.status().as_u16().to_string();
2599 let duration = start.elapsed().as_secs_f64();
2600
2601 metrics
2602 .http_requests_total
2603 .with_label_values(&[&method, &path, &status])
2604 .inc();
2605 metrics
2606 .http_request_duration_seconds
2607 .with_label_values(&[&method, &path])
2608 .observe(duration);
2609
2610 response
2611}
2612
2613async fn security_headers_middleware(
2625 is_tls: bool,
2626 cfg: Arc<SecurityHeadersConfig>,
2627 req: Request<Body>,
2628 next: Next,
2629) -> axum::response::Response {
2630 use axum::http::{HeaderName, header};
2631
2632 let mut resp = next.run(req).await;
2633 let headers = resp.headers_mut();
2634
2635 headers.remove(header::SERVER);
2637 headers.remove(HeaderName::from_static("x-powered-by"));
2638
2639 apply_security_header(
2640 headers,
2641 header::X_CONTENT_TYPE_OPTIONS,
2642 cfg.x_content_type_options.as_deref(),
2643 "nosniff",
2644 );
2645 apply_security_header(
2646 headers,
2647 header::X_FRAME_OPTIONS,
2648 cfg.x_frame_options.as_deref(),
2649 "deny",
2650 );
2651 apply_security_header(
2652 headers,
2653 header::CACHE_CONTROL,
2654 cfg.cache_control.as_deref(),
2655 "no-store, max-age=0",
2656 );
2657 apply_security_header(
2658 headers,
2659 header::REFERRER_POLICY,
2660 cfg.referrer_policy.as_deref(),
2661 "no-referrer",
2662 );
2663 apply_security_header(
2664 headers,
2665 HeaderName::from_static("cross-origin-opener-policy"),
2666 cfg.cross_origin_opener_policy.as_deref(),
2667 "same-origin",
2668 );
2669 apply_security_header(
2670 headers,
2671 HeaderName::from_static("cross-origin-resource-policy"),
2672 cfg.cross_origin_resource_policy.as_deref(),
2673 "same-origin",
2674 );
2675 apply_security_header(
2676 headers,
2677 HeaderName::from_static("cross-origin-embedder-policy"),
2678 cfg.cross_origin_embedder_policy.as_deref(),
2679 "require-corp",
2680 );
2681 apply_security_header(
2682 headers,
2683 HeaderName::from_static("permissions-policy"),
2684 cfg.permissions_policy.as_deref(),
2685 "accelerometer=(), camera=(), geolocation=(), microphone=()",
2686 );
2687 apply_security_header(
2688 headers,
2689 HeaderName::from_static("x-permitted-cross-domain-policies"),
2690 cfg.x_permitted_cross_domain_policies.as_deref(),
2691 "none",
2692 );
2693 apply_security_header(
2694 headers,
2695 HeaderName::from_static("content-security-policy"),
2696 cfg.content_security_policy.as_deref(),
2697 "default-src 'none'; frame-ancestors 'none'",
2698 );
2699 apply_security_header(
2700 headers,
2701 HeaderName::from_static("x-dns-prefetch-control"),
2702 cfg.x_dns_prefetch_control.as_deref(),
2703 "off",
2704 );
2705
2706 if is_tls {
2707 apply_security_header(
2708 headers,
2709 header::STRICT_TRANSPORT_SECURITY,
2710 cfg.strict_transport_security.as_deref(),
2711 "max-age=63072000; includeSubDomains",
2712 );
2713 }
2714
2715 resp
2716}
2717
2718fn apply_security_header(
2729 headers: &mut axum::http::HeaderMap,
2730 name: axum::http::HeaderName,
2731 override_value: Option<&str>,
2732 default: &'static str,
2733) {
2734 use axum::http::HeaderValue;
2735
2736 match override_value {
2737 None => {
2738 headers.insert(name, HeaderValue::from_static(default));
2739 }
2740 Some("") => {
2741 }
2743 Some(v) => match HeaderValue::from_str(v) {
2744 Ok(hv) => {
2745 headers.insert(name, hv);
2746 }
2747 Err(err) => {
2748 tracing::error!(
2749 header = %name,
2750 error = %err,
2751 "invalid security header override reached middleware; using default"
2752 );
2753 headers.insert(name, HeaderValue::from_static(default));
2754 }
2755 },
2756 }
2757}
2758
2759fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2770 use axum::http::HeaderValue;
2771
2772 let fields: &[(&str, Option<&str>)] = &[
2773 (
2774 "x_content_type_options",
2775 cfg.x_content_type_options.as_deref(),
2776 ),
2777 ("x_frame_options", cfg.x_frame_options.as_deref()),
2778 ("cache_control", cfg.cache_control.as_deref()),
2779 ("referrer_policy", cfg.referrer_policy.as_deref()),
2780 (
2781 "cross_origin_opener_policy",
2782 cfg.cross_origin_opener_policy.as_deref(),
2783 ),
2784 (
2785 "cross_origin_resource_policy",
2786 cfg.cross_origin_resource_policy.as_deref(),
2787 ),
2788 (
2789 "cross_origin_embedder_policy",
2790 cfg.cross_origin_embedder_policy.as_deref(),
2791 ),
2792 ("permissions_policy", cfg.permissions_policy.as_deref()),
2793 (
2794 "x_permitted_cross_domain_policies",
2795 cfg.x_permitted_cross_domain_policies.as_deref(),
2796 ),
2797 (
2798 "content_security_policy",
2799 cfg.content_security_policy.as_deref(),
2800 ),
2801 (
2802 "x_dns_prefetch_control",
2803 cfg.x_dns_prefetch_control.as_deref(),
2804 ),
2805 (
2806 "strict_transport_security",
2807 cfg.strict_transport_security.as_deref(),
2808 ),
2809 ];
2810
2811 for (field, value) in fields {
2812 let Some(v) = value else { continue };
2813 if v.is_empty() {
2814 continue;
2815 }
2816 if let Err(err) = HeaderValue::from_str(v) {
2817 return Err(McpxError::Config(format!(
2818 "invalid security_headers.{field}: {err}"
2819 )));
2820 }
2821 }
2822
2823 if let Some(v) = cfg.strict_transport_security.as_deref()
2824 && !v.is_empty()
2825 && v.to_ascii_lowercase().contains("preload")
2826 {
2827 return Err(McpxError::Config(format!(
2828 "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2829 HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2830 )));
2831 }
2832
2833 Ok(())
2834}
2835
2836#[cfg(feature = "oauth")]
2851async fn oauth_token_cache_headers_middleware(
2852 req: Request<Body>,
2853 next: Next,
2854) -> axum::response::Response {
2855 use axum::http::{HeaderValue, header};
2856
2857 let mut resp = next.run(req).await;
2858 let headers = resp.headers_mut();
2859 headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2860 headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2861 resp
2862}
2863
2864async fn normalize_peer_addr_middleware(
2888 mut req: Request<Body>,
2889 next: Next,
2890) -> axum::response::Response {
2891 let direct = req
2892 .extensions()
2893 .get::<ConnectInfo<SocketAddr>>()
2894 .map(|ci| ci.0);
2895 let from_tls = req
2896 .extensions()
2897 .get::<ConnectInfo<TlsConnInfo>>()
2898 .map(|ci| ci.0.addr);
2899 if let Some(addr) = direct.or(from_tls) {
2900 if direct.is_none() {
2901 req.extensions_mut().insert(ConnectInfo(addr));
2902 }
2903 req.extensions_mut().insert(PeerAddr::new(addr));
2904 }
2905 next.run(req).await
2906}
2907
2908pub(crate) type ExtraRouteRateLimiter = BoundedKeyedLimiter<IpAddr>;
2912
2913const EXTRA_ROUTE_MAX_TRACKED_KEYS: usize = 10_000;
2919
2920const EXTRA_ROUTE_IDLE_EVICTION: Duration = Duration::from_mins(15);
2923
2924fn build_extra_route_rate_limiter(per_minute: u32) -> Arc<ExtraRouteRateLimiter> {
2930 Arc::new(BoundedKeyedLimiter::with_per_minute(
2931 per_minute,
2932 EXTRA_ROUTE_MAX_TRACKED_KEYS,
2933 EXTRA_ROUTE_IDLE_EVICTION,
2934 ))
2935}
2936
2937async fn extra_route_rate_limit_middleware(
2953 limiter: Arc<ExtraRouteRateLimiter>,
2954 req: Request<Body>,
2955 next: Next,
2956) -> axum::response::Response {
2957 let peer_ip: Option<IpAddr> = req
2958 .extensions()
2959 .get::<ConnectInfo<SocketAddr>>()
2960 .map(|ci| ci.0.ip())
2961 .or_else(|| {
2962 req.extensions()
2963 .get::<ConnectInfo<TlsConnInfo>>()
2964 .map(|ci| ci.0.addr.ip())
2965 });
2966 if let Some(ip) = peer_ip
2967 && limiter.check_key(&ip).is_err()
2968 {
2969 tracing::warn!(%ip, "extra route request rate limited");
2970 return McpxError::RateLimited(
2971 "too many requests to application routes from this source".into(),
2972 )
2973 .into_response();
2974 }
2975 next.run(req).await
2976}
2977
2978async fn origin_check_middleware(
2982 allowed: Arc<[String]>,
2983 log_request_headers: bool,
2984 req: Request<Body>,
2985 next: Next,
2986) -> axum::response::Response {
2987 let method = req.method().clone();
2988 let path = req.uri().path().to_owned();
2989
2990 log_incoming_request(&method, &path, req.headers(), log_request_headers);
2991
2992 if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2993 let origin_str = origin.to_str().unwrap_or("");
2994 if !allowed.iter().any(|a| a == origin_str) {
2995 tracing::warn!(
2996 origin = origin_str,
2997 %method,
2998 %path,
2999 allowed = ?&*allowed,
3000 "rejected request: Origin not allowed"
3001 );
3002 return (
3003 axum::http::StatusCode::FORBIDDEN,
3004 "Forbidden: Origin not allowed",
3005 )
3006 .into_response();
3007 }
3008 }
3009 next.run(req).await
3010}
3011
3012fn log_incoming_request(
3015 method: &axum::http::Method,
3016 path: &str,
3017 headers: &axum::http::HeaderMap,
3018 log_request_headers: bool,
3019) {
3020 if log_request_headers {
3021 tracing::debug!(
3022 %method,
3023 %path,
3024 headers = %format_request_headers_for_log(headers),
3025 "incoming request"
3026 );
3027 } else {
3028 tracing::debug!(%method, %path, "incoming request");
3029 }
3030}
3031
3032fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
3033 headers
3034 .iter()
3035 .map(|(k, v)| {
3036 let name = k.as_str();
3037 if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
3038 format!("{name}: [REDACTED]")
3039 } else {
3040 format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
3041 }
3042 })
3043 .collect::<Vec<_>>()
3044 .join(", ")
3045}
3046
3047#[allow(
3071 clippy::cognitive_complexity,
3072 reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
3073)]
3074pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
3075where
3076 H: ServerHandler + 'static,
3077{
3078 use rmcp::ServiceExt as _;
3079
3080 tracing::info!("stdio transport: serving on stdin/stdout");
3081 tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
3082
3083 let transport = rmcp::transport::io::stdio();
3084
3085 let service = handler
3086 .serve(transport)
3087 .await
3088 .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
3089
3090 if let Err(e) = service.waiting().await {
3091 tracing::warn!(error = %e, "stdio session ended with error");
3092 }
3093 tracing::info!("stdio session ended");
3094 Ok(())
3095}
3096
3097#[cfg(test)]
3098mod tests {
3099 #![allow(
3100 clippy::unwrap_used,
3101 clippy::expect_used,
3102 clippy::panic,
3103 clippy::indexing_slicing,
3104 clippy::unwrap_in_result,
3105 clippy::print_stdout,
3106 clippy::print_stderr,
3107 deprecated,
3108 reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
3109 )]
3110 use std::{sync::Arc, time::Duration};
3111
3112 use axum::{
3113 body::Body,
3114 http::{Request, StatusCode, header},
3115 response::IntoResponse,
3116 };
3117 use http_body_util::BodyExt;
3118 use tower::ServiceExt as _;
3119
3120 use super::*;
3121
3122 #[test]
3125 fn server_config_new_defaults() {
3126 let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
3127 assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
3128 assert_eq!(cfg.name, "test-server");
3129 assert_eq!(cfg.version, "1.0.0");
3130 assert!(cfg.tls_cert_path.is_none());
3131 assert!(cfg.tls_key_path.is_none());
3132 assert!(cfg.auth.is_none());
3133 assert!(cfg.rbac.is_none());
3134 assert!(cfg.allowed_origins.is_empty());
3135 assert!(cfg.tool_rate_limit.is_none());
3136 assert!(cfg.readiness_check.is_none());
3137 assert_eq!(cfg.max_request_body, 1024 * 1024);
3138 assert_eq!(cfg.request_timeout, Duration::from_mins(2));
3139 assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
3140 assert!(!cfg.log_request_headers);
3141 assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(10));
3142 assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
3143 }
3144
3145 #[test]
3146 fn tls_handshake_builders_set_fields() {
3147 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3148 .with_tls_handshake_timeout(Duration::from_secs(3))
3149 .with_max_concurrent_tls_handshakes(64);
3150 assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(3));
3151 assert_eq!(cfg.max_concurrent_tls_handshakes, 64);
3152 }
3153
3154 #[test]
3155 fn validate_rejects_zero_tls_handshake_timeout() {
3156 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3157 .with_tls_handshake_timeout(Duration::ZERO);
3158 let err = cfg.validate().expect_err("zero handshake timeout");
3159 assert!(err.to_string().contains("tls_handshake_timeout"));
3160 }
3161
3162 #[test]
3163 fn validate_rejects_zero_max_concurrent_tls_handshakes() {
3164 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3165 .with_max_concurrent_tls_handshakes(0);
3166 let err = cfg.validate().expect_err("zero handshake concurrency");
3167 assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
3168 }
3169
3170 #[test]
3171 fn validate_consumes_and_proves() {
3172 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3174 let validated = cfg.validate().expect("valid config");
3175 assert_eq!(validated.as_inner().name, "test-server");
3177 let raw = validated.into_inner();
3179 assert_eq!(raw.name, "test-server");
3180
3181 let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3183 bad.max_request_body = 0;
3184 assert!(bad.validate().is_err(), "zero body cap must fail validate");
3185 }
3186
3187 #[test]
3188 fn validate_rejects_zero_max_concurrent_requests() {
3189 let cfg =
3190 McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
3191 let err = cfg.validate().expect_err("zero concurrency cap must fail");
3192 assert!(
3193 format!("{err}").contains("max_concurrent_requests"),
3194 "error should mention max_concurrent_requests, got: {err}"
3195 );
3196 }
3197
3198 #[test]
3199 fn validate_rejects_zero_max_tracked_keys() {
3200 let rl = crate::auth::RateLimitConfig {
3203 max_attempts_per_minute: 30,
3204 pre_auth_max_per_minute: None,
3205 max_tracked_keys: 0,
3206 idle_eviction: Duration::from_secs(15 * 60),
3207 };
3208 let auth_cfg = AuthConfig {
3209 enabled: true,
3210 api_keys: Vec::new(),
3211 mtls: None,
3212 rate_limit: Some(rl),
3213 #[cfg(feature = "oauth")]
3214 oauth: None,
3215 };
3216 let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
3217 let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
3218 assert!(
3219 format!("{err}").contains("max_tracked_keys"),
3220 "error should mention max_tracked_keys, got: {err}"
3221 );
3222 }
3223
3224 #[test]
3225 fn derive_allowed_hosts_includes_public_host() {
3226 let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
3227 assert!(
3228 hosts.iter().any(|h| h == "mcp.example.com"),
3229 "public_url host must be allowed"
3230 );
3231 }
3232
3233 #[test]
3234 fn derive_allowed_hosts_includes_bind_authority() {
3235 let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
3236 assert!(
3237 hosts.iter().any(|h| h == "127.0.0.1"),
3238 "bind host must be allowed"
3239 );
3240 assert!(
3241 hosts.iter().any(|h| h == "127.0.0.1:8080"),
3242 "bind authority must be allowed"
3243 );
3244 }
3245
3246 #[tokio::test]
3249 async fn healthz_returns_ok_json() {
3250 let resp = healthz().await.into_response();
3251 assert_eq!(resp.status(), StatusCode::OK);
3252 let body = resp.into_body().collect().await.unwrap().to_bytes();
3253 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3254 assert_eq!(json["status"], "ok");
3255 assert!(
3256 json.get("name").is_none(),
3257 "healthz must not expose server name"
3258 );
3259 assert!(
3260 json.get("version").is_none(),
3261 "healthz must not expose version"
3262 );
3263 }
3264
3265 #[tokio::test]
3268 async fn readyz_returns_ok_when_ready() {
3269 let check: ReadinessCheck =
3270 Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
3271 let resp = readyz(check).await.into_response();
3272 assert_eq!(resp.status(), StatusCode::OK);
3273 let body = resp.into_body().collect().await.unwrap().to_bytes();
3274 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3275 assert_eq!(json["ready"], true);
3276 assert!(
3277 json.get("name").is_none(),
3278 "readyz must not expose server name"
3279 );
3280 assert!(
3281 json.get("version").is_none(),
3282 "readyz must not expose version"
3283 );
3284 assert_eq!(json["db"], "connected");
3285 }
3286
3287 #[tokio::test]
3288 async fn readyz_returns_503_when_not_ready() {
3289 let check: ReadinessCheck =
3290 Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
3291 let resp = readyz(check).await.into_response();
3292 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3293 }
3294
3295 #[tokio::test]
3296 async fn readyz_returns_503_when_ready_missing() {
3297 let check: ReadinessCheck =
3298 Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
3299 let resp = readyz(check).await.into_response();
3300 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3302 }
3303
3304 fn peer_probe_router() -> axum::Router {
3309 async fn probe(req: Request<Body>) -> String {
3310 let ci = req
3311 .extensions()
3312 .get::<ConnectInfo<SocketAddr>>()
3313 .map(|c| c.0.to_string())
3314 .unwrap_or_default();
3315 let pa = req
3316 .extensions()
3317 .get::<PeerAddr>()
3318 .map(|p| p.addr.to_string())
3319 .unwrap_or_default();
3320 format!("{ci}|{pa}")
3321 }
3322 axum::Router::new()
3323 .route("/probe", axum::routing::get(probe))
3324 .layer(axum::middleware::from_fn(normalize_peer_addr_middleware))
3325 }
3326
3327 async fn body_string(resp: axum::response::Response) -> String {
3328 let bytes = resp.into_body().collect().await.unwrap().to_bytes();
3329 String::from_utf8(bytes.to_vec()).unwrap()
3330 }
3331
3332 #[tokio::test]
3333 async fn normalize_preserves_existing_connect_info_and_mirrors_peer_addr() {
3334 let plain: SocketAddr = "10.0.0.1:1111".parse().unwrap();
3337 let tls: SocketAddr = "10.0.0.2:2222".parse().unwrap();
3338 let req = Request::builder()
3339 .uri("/probe")
3340 .extension(ConnectInfo(plain))
3341 .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3342 .body(Body::empty())
3343 .unwrap();
3344 let resp = peer_probe_router().oneshot(req).await.unwrap();
3345 assert_eq!(resp.status(), StatusCode::OK);
3346 assert_eq!(body_string(resp).await, format!("{plain}|{plain}"));
3347 }
3348
3349 #[tokio::test]
3350 async fn normalize_inserts_connect_info_and_peer_addr_from_tls() {
3351 let tls: SocketAddr = "192.168.1.7:50443".parse().unwrap();
3352 let req = Request::builder()
3353 .uri("/probe")
3354 .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3355 .body(Body::empty())
3356 .unwrap();
3357 let resp = peer_probe_router().oneshot(req).await.unwrap();
3358 assert_eq!(resp.status(), StatusCode::OK);
3359 assert_eq!(body_string(resp).await, format!("{tls}|{tls}"));
3360 }
3361
3362 #[tokio::test]
3363 async fn normalize_no_op_without_any_connect_info() {
3364 let req = Request::builder()
3365 .uri("/probe")
3366 .body(Body::empty())
3367 .unwrap();
3368 let resp = peer_probe_router().oneshot(req).await.unwrap();
3369 assert_eq!(resp.status(), StatusCode::OK);
3370 assert_eq!(body_string(resp).await, "|");
3371 }
3372
3373 #[tokio::test]
3374 async fn peer_addr_extractor_rejects_when_absent() {
3375 async fn h(peer: PeerAddr) -> String {
3376 peer.addr.to_string()
3377 }
3378 let app = axum::Router::new().route("/p", axum::routing::get(h));
3379 let req = Request::builder().uri("/p").body(Body::empty()).unwrap();
3380 let resp = app.oneshot(req).await.unwrap();
3381 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
3382 }
3383
3384 #[tokio::test]
3385 async fn peer_addr_extractor_returns_value_when_present() {
3386 async fn h(peer: PeerAddr) -> String {
3387 peer.addr.to_string()
3388 }
3389 let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap();
3390 let app = axum::Router::new().route("/p", axum::routing::get(h));
3391 let req = Request::builder()
3392 .uri("/p")
3393 .extension(PeerAddr::new(addr))
3394 .body(Body::empty())
3395 .unwrap();
3396 let resp = app.oneshot(req).await.unwrap();
3397 assert_eq!(resp.status(), StatusCode::OK);
3398 assert_eq!(body_string(resp).await, addr.to_string());
3399 }
3400
3401 #[tokio::test]
3402 async fn peer_addr_via_extension_extractor() {
3403 async fn h(axum::Extension(peer): axum::Extension<PeerAddr>) -> String {
3404 peer.addr.to_string()
3405 }
3406 let addr: SocketAddr = "127.0.0.1:4242".parse().unwrap();
3407 let app = axum::Router::new().route("/p", axum::routing::get(h));
3408 let req = Request::builder()
3409 .uri("/p")
3410 .extension(PeerAddr::new(addr))
3411 .body(Body::empty())
3412 .unwrap();
3413 let resp = app.oneshot(req).await.unwrap();
3414 assert_eq!(resp.status(), StatusCode::OK);
3415 assert_eq!(body_string(resp).await, addr.to_string());
3416 }
3417
3418 fn limited_router(per_minute: u32) -> axum::Router {
3423 let limiter = build_extra_route_rate_limiter(per_minute);
3424 axum::Router::new()
3425 .route("/limited", axum::routing::get(|| async { "ok" }))
3426 .layer(axum::middleware::from_fn(move |req, next| {
3427 let l = Arc::clone(&limiter);
3428 extra_route_rate_limit_middleware(l, req, next)
3429 }))
3430 }
3431
3432 fn limited_req(ip: &str) -> Request<Body> {
3433 let addr: SocketAddr = format!("{ip}:40000").parse().unwrap();
3434 Request::builder()
3435 .uri("/limited")
3436 .extension(ConnectInfo(addr))
3437 .body(Body::empty())
3438 .unwrap()
3439 }
3440
3441 #[tokio::test]
3442 async fn extra_route_limiter_denies_over_quota() {
3443 let app = limited_router(2);
3444 for i in 0..2 {
3445 let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3446 assert_eq!(resp.status(), StatusCode::OK, "request {i} should pass");
3447 }
3448 let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3449 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3450 let body = body_string(resp).await;
3451 assert!(
3452 body.contains("too many requests to application routes"),
3453 "deny body should match the limiter message, got: {body}"
3454 );
3455 }
3456
3457 #[tokio::test]
3458 async fn extra_route_limiter_isolates_keys() {
3459 let app = limited_router(2);
3460 for _ in 0..2 {
3461 let resp = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3462 assert_eq!(resp.status(), StatusCode::OK);
3463 }
3464 let exhausted = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3465 assert_eq!(exhausted.status(), StatusCode::TOO_MANY_REQUESTS);
3466 let other = app.clone().oneshot(limited_req("10.3.3.3")).await.unwrap();
3468 assert_eq!(other.status(), StatusCode::OK);
3469 }
3470
3471 #[tokio::test]
3472 async fn extra_route_limiter_fails_open_without_peer() {
3473 let app = limited_router(1);
3474 for i in 0..3 {
3475 let req = Request::builder()
3476 .uri("/limited")
3477 .body(Body::empty())
3478 .unwrap();
3479 let resp = app.clone().oneshot(req).await.unwrap();
3480 assert_eq!(
3481 resp.status(),
3482 StatusCode::OK,
3483 "request {i} should fail open"
3484 );
3485 }
3486 }
3487
3488 #[tokio::test]
3489 async fn extra_route_limiter_extracts_tls_conn_info() {
3490 let app = limited_router(2);
3491 let mk = || {
3492 let addr: SocketAddr = "192.168.9.9:55555".parse().unwrap();
3493 Request::builder()
3494 .uri("/limited")
3495 .extension(ConnectInfo(TlsConnInfo::new(addr, None)))
3496 .body(Body::empty())
3497 .unwrap()
3498 };
3499 for _ in 0..2 {
3500 assert_eq!(
3501 app.clone().oneshot(mk()).await.unwrap().status(),
3502 StatusCode::OK
3503 );
3504 }
3505 let resp = app.clone().oneshot(mk()).await.unwrap();
3506 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3507 }
3508
3509 #[test]
3510 fn validate_rejects_zero_extra_route_rate_limit() {
3511 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3512 .with_extra_route_rate_limit(0);
3513 let err = cfg.validate().expect_err("zero extra route rate limit");
3514 assert!(err.to_string().contains("extra_route_rate_limit"));
3515 }
3516
3517 fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
3521 let allowed: Arc<[String]> = Arc::from(origins);
3522 axum::Router::new()
3523 .route("/test", axum::routing::get(|| async { "ok" }))
3524 .layer(axum::middleware::from_fn(move |req, next| {
3525 let a = Arc::clone(&allowed);
3526 origin_check_middleware(a, log_request_headers, req, next)
3527 }))
3528 }
3529
3530 #[tokio::test]
3531 async fn origin_allowed_passes() {
3532 let app = origin_router(vec!["http://localhost:3000".into()], false);
3533 let req = Request::builder()
3534 .uri("/test")
3535 .header(header::ORIGIN, "http://localhost:3000")
3536 .body(Body::empty())
3537 .unwrap();
3538 let resp = app.oneshot(req).await.unwrap();
3539 assert_eq!(resp.status(), StatusCode::OK);
3540 }
3541
3542 #[tokio::test]
3543 async fn origin_rejected_returns_403() {
3544 let app = origin_router(vec!["http://localhost:3000".into()], false);
3545 let req = Request::builder()
3546 .uri("/test")
3547 .header(header::ORIGIN, "http://evil.com")
3548 .body(Body::empty())
3549 .unwrap();
3550 let resp = app.oneshot(req).await.unwrap();
3551 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
3552 }
3553
3554 #[tokio::test]
3555 async fn no_origin_header_passes() {
3556 let app = origin_router(vec!["http://localhost:3000".into()], false);
3557 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3558 let resp = app.oneshot(req).await.unwrap();
3559 assert_eq!(resp.status(), StatusCode::OK);
3560 }
3561
3562 #[tokio::test]
3563 async fn empty_allowlist_rejects_any_origin() {
3564 let app = origin_router(vec![], false);
3565 let req = Request::builder()
3566 .uri("/test")
3567 .header(header::ORIGIN, "http://anything.com")
3568 .body(Body::empty())
3569 .unwrap();
3570 let resp = app.oneshot(req).await.unwrap();
3571 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
3572 }
3573
3574 #[tokio::test]
3575 async fn empty_allowlist_passes_without_origin() {
3576 let app = origin_router(vec![], false);
3577 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3578 let resp = app.oneshot(req).await.unwrap();
3579 assert_eq!(resp.status(), StatusCode::OK);
3580 }
3581
3582 #[test]
3583 fn format_request_headers_redacts_sensitive_values() {
3584 let mut headers = axum::http::HeaderMap::new();
3585 headers.insert("authorization", "Bearer secret-token".parse().unwrap());
3586 headers.insert("cookie", "sid=abc".parse().unwrap());
3587 headers.insert("x-request-id", "req-123".parse().unwrap());
3588
3589 let out = format_request_headers_for_log(&headers);
3590 assert!(out.contains("authorization: [REDACTED]"));
3591 assert!(out.contains("cookie: [REDACTED]"));
3592 assert!(out.contains("x-request-id: req-123"));
3593 assert!(!out.contains("secret-token"));
3594 }
3595
3596 fn security_router(is_tls: bool) -> axum::Router {
3599 security_router_with(is_tls, SecurityHeadersConfig::default())
3600 }
3601
3602 fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
3603 let cfg = Arc::new(cfg);
3604 axum::Router::new()
3605 .route("/test", axum::routing::get(|| async { "ok" }))
3606 .layer(axum::middleware::from_fn(move |req, next| {
3607 let c = Arc::clone(&cfg);
3608 security_headers_middleware(is_tls, c, req, next)
3609 }))
3610 }
3611
3612 #[tokio::test]
3613 async fn security_headers_set_on_response() {
3614 let app = security_router(false);
3615 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3616 let resp = app.oneshot(req).await.unwrap();
3617 assert_eq!(resp.status(), StatusCode::OK);
3618
3619 let h = resp.headers();
3620 assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
3621 assert_eq!(h.get("x-frame-options").unwrap(), "deny");
3622 assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
3623 assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
3624 assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
3625 assert_eq!(
3626 h.get("cross-origin-resource-policy").unwrap(),
3627 "same-origin"
3628 );
3629 assert_eq!(
3630 h.get("cross-origin-embedder-policy").unwrap(),
3631 "require-corp"
3632 );
3633 assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
3634 assert!(
3635 h.get("permissions-policy")
3636 .unwrap()
3637 .to_str()
3638 .unwrap()
3639 .contains("camera=()"),
3640 "permissions-policy must restrict browser features"
3641 );
3642 assert_eq!(
3643 h.get("content-security-policy").unwrap(),
3644 "default-src 'none'; frame-ancestors 'none'"
3645 );
3646 assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
3647 assert!(h.get("strict-transport-security").is_none());
3649 }
3650
3651 #[tokio::test]
3652 async fn hsts_set_when_tls_enabled() {
3653 let app = security_router(true);
3654 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3655 let resp = app.oneshot(req).await.unwrap();
3656
3657 let hsts = resp.headers().get("strict-transport-security").unwrap();
3658 assert!(
3659 hsts.to_str().unwrap().contains("max-age=63072000"),
3660 "HSTS must set 2-year max-age"
3661 );
3662 }
3663
3664 fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
3670 let cfg =
3671 McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
3672 cfg.check()
3673 }
3674
3675 #[test]
3676 fn security_headers_config_default_validates() {
3677 check_with_security_headers(SecurityHeadersConfig::default())
3678 .expect("default SecurityHeadersConfig must validate");
3679 }
3680
3681 #[test]
3682 fn security_headers_config_validate_accepts_empty_string() {
3683 let h = SecurityHeadersConfig {
3685 x_content_type_options: Some(String::new()),
3686 x_frame_options: Some(String::new()),
3687 cache_control: Some(String::new()),
3688 referrer_policy: Some(String::new()),
3689 cross_origin_opener_policy: Some(String::new()),
3690 cross_origin_resource_policy: Some(String::new()),
3691 cross_origin_embedder_policy: Some(String::new()),
3692 permissions_policy: Some(String::new()),
3693 x_permitted_cross_domain_policies: Some(String::new()),
3694 content_security_policy: Some(String::new()),
3695 x_dns_prefetch_control: Some(String::new()),
3696 strict_transport_security: Some(String::new()),
3697 };
3698 check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
3699 }
3700
3701 #[test]
3702 fn security_headers_config_validate_rejects_bad_value() {
3703 let h = SecurityHeadersConfig {
3705 referrer_policy: Some("\u{0007}".into()),
3706 ..SecurityHeadersConfig::default()
3707 };
3708 let err = check_with_security_headers(h)
3709 .expect_err("control char in referrer_policy must reject");
3710 let msg = err.to_string();
3711 assert!(
3712 msg.contains("referrer_policy"),
3713 "error must name the offending field, got: {msg}"
3714 );
3715 }
3716
3717 #[test]
3718 fn security_headers_config_validate_rejects_hsts_preload() {
3719 let h = SecurityHeadersConfig {
3720 strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
3721 ..SecurityHeadersConfig::default()
3722 };
3723 let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
3724 let msg = err.to_string();
3725 assert!(
3726 msg.contains("strict_transport_security"),
3727 "error must name the field, got: {msg}"
3728 );
3729 assert!(
3730 msg.to_lowercase().contains("preload"),
3731 "error must mention `preload`, got: {msg}"
3732 );
3733 }
3734
3735 #[test]
3736 fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
3737 let h = SecurityHeadersConfig {
3739 strict_transport_security: Some("max-age=600; PRELOAD".into()),
3740 ..SecurityHeadersConfig::default()
3741 };
3742 check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
3743 }
3744
3745 #[tokio::test]
3746 async fn security_headers_override_honored() {
3747 let h = SecurityHeadersConfig {
3749 x_frame_options: Some("SAMEORIGIN".into()),
3750 ..SecurityHeadersConfig::default()
3751 };
3752 let app = security_router_with(false, h);
3753 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3754 let resp = app.oneshot(req).await.unwrap();
3755 assert_eq!(resp.status(), StatusCode::OK);
3756
3757 let xfo = resp.headers().get("x-frame-options").unwrap();
3758 assert_eq!(xfo, "SAMEORIGIN");
3759 }
3760
3761 #[tokio::test]
3762 async fn security_headers_empty_string_omits() {
3763 let h = SecurityHeadersConfig {
3765 referrer_policy: Some(String::new()),
3766 ..SecurityHeadersConfig::default()
3767 };
3768 let app = security_router_with(false, h);
3769 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3770 let resp = app.oneshot(req).await.unwrap();
3771 assert_eq!(resp.status(), StatusCode::OK);
3772
3773 assert!(
3774 resp.headers().get("referrer-policy").is_none(),
3775 "Some(\"\") must omit the header"
3776 );
3777 assert_eq!(
3779 resp.headers().get("x-content-type-options").unwrap(),
3780 "nosniff"
3781 );
3782 }
3783
3784 #[tokio::test]
3785 async fn security_headers_hsts_only_when_tls() {
3786 let h = SecurityHeadersConfig {
3788 strict_transport_security: Some("max-age=600".into()),
3789 ..SecurityHeadersConfig::default()
3790 };
3791 let app = security_router_with(false, h);
3792 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3793 let resp = app.oneshot(req).await.unwrap();
3794 assert!(
3795 resp.headers().get("strict-transport-security").is_none(),
3796 "HSTS must remain absent on plaintext deployments even with override"
3797 );
3798 }
3799
3800 #[cfg(feature = "oauth")]
3803 #[tokio::test]
3804 async fn oauth_token_cache_headers_set_pragma_and_vary() {
3805 let app = axum::Router::new()
3806 .route("/token", axum::routing::post(|| async { "{}" }))
3807 .layer(axum::middleware::from_fn(
3808 oauth_token_cache_headers_middleware,
3809 ));
3810 let req = Request::builder()
3811 .method("POST")
3812 .uri("/token")
3813 .body(Body::from("{}"))
3814 .unwrap();
3815 let resp = app.oneshot(req).await.unwrap();
3816 assert_eq!(resp.status(), StatusCode::OK);
3817
3818 let h = resp.headers();
3819 assert_eq!(
3820 h.get("pragma").unwrap(),
3821 "no-cache",
3822 "RFC 6749 §5.1: token responses must set Pragma: no-cache"
3823 );
3824 let vary_values: Vec<String> = h
3825 .get_all("vary")
3826 .iter()
3827 .filter_map(|v| v.to_str().ok().map(str::to_owned))
3828 .collect();
3829 assert!(
3830 vary_values
3831 .iter()
3832 .any(|v| v.eq_ignore_ascii_case("Authorization")),
3833 "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
3834 );
3835 }
3836
3837 #[cfg(feature = "oauth")]
3838 #[tokio::test]
3839 async fn oauth_token_cache_headers_preserve_existing_vary() {
3840 let app = axum::Router::new()
3843 .route(
3844 "/token",
3845 axum::routing::post(|| async {
3846 axum::response::Response::builder()
3847 .header("vary", "Accept-Encoding")
3848 .body(axum::body::Body::from("{}"))
3849 .unwrap()
3850 }),
3851 )
3852 .layer(axum::middleware::from_fn(
3853 oauth_token_cache_headers_middleware,
3854 ));
3855 let req = Request::builder()
3856 .method("POST")
3857 .uri("/token")
3858 .body(Body::empty())
3859 .unwrap();
3860 let resp = app.oneshot(req).await.unwrap();
3861
3862 let vary: Vec<String> = resp
3863 .headers()
3864 .get_all("vary")
3865 .iter()
3866 .filter_map(|v| v.to_str().ok().map(str::to_owned))
3867 .collect();
3868 assert!(
3869 vary.iter().any(|v| v.contains("Accept-Encoding")),
3870 "must preserve pre-existing Vary value, got {vary:?}"
3871 );
3872 assert!(
3873 vary.iter().any(|v| v.contains("Authorization")),
3874 "must append Authorization to Vary, got {vary:?}"
3875 );
3876 }
3877
3878 #[test]
3881 fn version_payload_contains_expected_fields() {
3882 let v = version_payload("my-server", "1.2.3");
3883 assert_eq!(v["name"], "my-server");
3884 assert_eq!(v["version"], "1.2.3");
3885 assert!(v["build_git_sha"].is_string());
3886 assert!(v["build_timestamp"].is_string());
3887 assert!(v["rust_version"].is_string());
3888 assert!(v["mcpx_version"].is_string());
3889 }
3890
3891 #[tokio::test]
3894 async fn concurrency_limit_layer_composes_and_serves() {
3895 let app = axum::Router::new()
3899 .route("/ok", axum::routing::get(|| async { "ok" }))
3900 .layer(
3901 tower::ServiceBuilder::new()
3902 .layer(axum::error_handling::HandleErrorLayer::new(
3903 |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
3904 ))
3905 .layer(tower::load_shed::LoadShedLayer::new())
3906 .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
3907 );
3908 let resp = app
3909 .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
3910 .await
3911 .unwrap();
3912 assert_eq!(resp.status(), StatusCode::OK);
3913 }
3914
3915 #[tokio::test]
3918 async fn compression_layer_gzip_encodes_response() {
3919 use tower_http::compression::Predicate as _;
3920
3921 let big_body = "a".repeat(4096);
3922 let app = axum::Router::new()
3923 .route(
3924 "/big",
3925 axum::routing::get(move || {
3926 let body = big_body.clone();
3927 async move { body }
3928 }),
3929 )
3930 .layer(
3931 tower_http::compression::CompressionLayer::new()
3932 .gzip(true)
3933 .br(true)
3934 .compress_when(
3935 tower_http::compression::DefaultPredicate::new()
3936 .and(tower_http::compression::predicate::SizeAbove::new(1024)),
3937 ),
3938 );
3939
3940 let req = Request::builder()
3941 .uri("/big")
3942 .header(header::ACCEPT_ENCODING, "gzip")
3943 .body(Body::empty())
3944 .unwrap();
3945 let resp = app.oneshot(req).await.unwrap();
3946 assert_eq!(resp.status(), StatusCode::OK);
3947 assert_eq!(
3948 resp.headers().get(header::CONTENT_ENCODING).unwrap(),
3949 "gzip"
3950 );
3951 }
3952
3953 #[tokio::test]
3956 async fn tls_handshake_timeout_reaps_idle_connections() {
3957 use tokio::io::AsyncReadExt as _;
3958
3959 let _ = rustls::crypto::ring::default_provider().install_default();
3960
3961 let key = rcgen::KeyPair::generate().expect("generate key");
3963 let cert = rcgen::CertificateParams::new(vec!["localhost".to_owned()])
3964 .expect("cert params")
3965 .self_signed(&key)
3966 .expect("self-signed cert");
3967 let dir = std::env::temp_dir().join(format!(
3968 "rmcp-server-kit-hs-timeout-{}",
3969 std::time::SystemTime::now()
3970 .duration_since(std::time::UNIX_EPOCH)
3971 .expect("clock after epoch")
3972 .as_nanos()
3973 ));
3974 tokio::fs::create_dir_all(&dir).await.expect("temp dir");
3975 let cert_path = dir.join("server.crt");
3976 let key_path = dir.join("server.key");
3977 tokio::fs::write(&cert_path, cert.pem())
3978 .await
3979 .expect("write cert");
3980 tokio::fs::write(&key_path, key.serialize_pem())
3981 .await
3982 .expect("write key");
3983
3984 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
3985 let tls = TlsListener::new(
3986 listener,
3987 &cert_path,
3988 &key_path,
3989 None,
3990 None,
3991 Duration::from_millis(200),
3992 8, )
3994 .expect("tls listener");
3995 let addr = axum::serve::Listener::local_addr(&tls).expect("local addr");
3996
3997 let mut idle = tokio::net::TcpStream::connect(addr).await.expect("connect");
4001 let mut buf = [0_u8; 16];
4002 let read = tokio::time::timeout(Duration::from_secs(2), idle.read(&mut buf))
4003 .await
4004 .expect("server must reap the idle handshake within its timeout");
4005 match read {
4006 Ok(0) | Err(_) => {} Ok(n) => panic!("unexpected {n} bytes from server during reaped handshake"),
4008 }
4009
4010 drop(tls);
4011 }
4012}