1use std::{
2 future::Future,
3 net::{IpAddr, SocketAddr},
4 path::{Path, PathBuf},
5 pin::Pin,
6 sync::Arc,
7 time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{
12 body::Body,
13 extract::{ConnectInfo, Request},
14 middleware::Next,
15 response::IntoResponse,
16};
17use rmcp::{
18 ServerHandler,
19 transport::streamable_http_server::{
20 StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
21 },
22};
23use rustls::RootCertStore;
24use tokio::{
25 net::TcpListener,
26 sync::{Semaphore, mpsc},
27};
28use tokio_util::sync::CancellationToken;
29
30use crate::{
31 auth::{
32 AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
33 build_rate_limiter, extract_mtls_identity,
34 },
35 bounded_limiter::BoundedKeyedLimiter,
36 error::McpxError,
37 mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
38 rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
39};
40
41#[allow(
45 clippy::needless_pass_by_value,
46 reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
47)]
48fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
49 McpxError::Startup(format!("{e:#}"))
50}
51
52#[allow(
58 clippy::needless_pass_by_value,
59 reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
60)]
61fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
62 McpxError::Startup(format!("{op}: {e}"))
63}
64
65pub type ReadinessCheck =
70 Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
71
72#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
117#[non_exhaustive]
118pub struct PeerAddr {
119 pub addr: SocketAddr,
121}
122
123impl PeerAddr {
124 #[must_use]
127 pub(crate) const fn new(addr: SocketAddr) -> Self {
128 Self { addr }
129 }
130}
131
132impl<S: Send + Sync> axum::extract::FromRequestParts<S> for PeerAddr {
141 type Rejection = (axum::http::StatusCode, &'static str);
142
143 async fn from_request_parts(
144 parts: &mut axum::http::request::Parts,
145 _state: &S,
146 ) -> Result<Self, Self::Rejection> {
147 parts.extensions.get::<Self>().copied().ok_or((
148 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
149 "peer address unavailable: not running under rmcp-server-kit serve()",
150 ))
151 }
152}
153
154#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
177#[non_exhaustive]
178pub struct ClientIp {
179 pub ip: IpAddr,
181}
182
183impl ClientIp {
184 #[must_use]
187 pub(crate) const fn new(ip: IpAddr) -> Self {
188 Self { ip }
189 }
190}
191
192#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Deserialize)]
197#[serde(rename_all = "kebab-case")]
198#[non_exhaustive]
199pub enum ForwardedHeaderMode {
200 XForwardedFor,
202 Forwarded,
204}
205
206struct ForwardResolver {
209 trusted: Vec<ipnet::IpNet>,
210 mode: ForwardedHeaderMode,
211}
212
213#[derive(Debug, Clone, Default)]
234#[non_exhaustive]
235pub struct SecurityHeadersConfig {
236 pub x_content_type_options: Option<String>,
238 pub x_frame_options: Option<String>,
240 pub cache_control: Option<String>,
242 pub referrer_policy: Option<String>,
244 pub cross_origin_opener_policy: Option<String>,
246 pub cross_origin_resource_policy: Option<String>,
248 pub cross_origin_embedder_policy: Option<String>,
250 pub permissions_policy: Option<String>,
253 pub x_permitted_cross_domain_policies: Option<String>,
255 pub content_security_policy: Option<String>,
258 pub x_dns_prefetch_control: Option<String>,
260 pub strict_transport_security: Option<String>,
265}
266
267#[allow(
269 missing_debug_implementations,
270 reason = "contains callback/trait objects that don't impl Debug"
271)]
272#[allow(
273 clippy::struct_excessive_bools,
274 reason = "server configuration naturally has many boolean feature flags"
275)]
276#[non_exhaustive]
277pub struct McpServerConfig {
278 #[deprecated(
280 since = "0.13.0",
281 note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
282 )]
283 pub bind_addr: String,
284 #[deprecated(
286 since = "0.13.0",
287 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
288 )]
289 pub name: String,
290 #[deprecated(
292 since = "0.13.0",
293 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
294 )]
295 pub version: String,
296 #[deprecated(
298 since = "0.13.0",
299 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
300 )]
301 pub tls_cert_path: Option<PathBuf>,
302 #[deprecated(
304 since = "0.13.0",
305 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
306 )]
307 pub tls_key_path: Option<PathBuf>,
308 #[deprecated(
311 since = "0.13.0",
312 note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
313 )]
314 pub auth: Option<AuthConfig>,
315 #[deprecated(
318 since = "0.13.0",
319 note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
320 )]
321 pub rbac: Option<Arc<RbacPolicy>>,
322 #[deprecated(
328 since = "0.13.0",
329 note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
330 )]
331 pub allowed_origins: Vec<String>,
332 #[deprecated(
335 since = "0.13.0",
336 note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
337 )]
338 pub tool_rate_limit: Option<u32>,
339 #[deprecated(
345 since = "1.12.0",
346 note = "use McpServerConfig::with_tool_rate_limit_burst(); direct field access will become pub(crate) in a future major release"
347 )]
348 pub tool_rate_limit_burst: Option<u32>,
349 #[deprecated(
362 since = "1.11.0",
363 note = "use McpServerConfig::with_extra_route_rate_limit(); direct field access will become pub(crate) in a future major release"
364 )]
365 pub extra_route_rate_limit: Option<u32>,
366 #[deprecated(
373 since = "1.12.0",
374 note = "use McpServerConfig::with_extra_route_rate_limit_burst(); direct field access will become pub(crate) in a future major release"
375 )]
376 pub extra_route_rate_limit_burst: Option<u32>,
377 #[deprecated(
385 since = "1.13.0",
386 note = "use McpServerConfig::with_trusted_proxies(); direct field access will become pub(crate) in a future major release"
387 )]
388 pub trusted_proxies: Vec<String>,
389 #[deprecated(
394 since = "1.13.0",
395 note = "use McpServerConfig::with_forwarded_header(); direct field access will become pub(crate) in a future major release"
396 )]
397 pub forwarded_header: Option<ForwardedHeaderMode>,
398 #[deprecated(
401 since = "0.13.0",
402 note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
403 )]
404 pub readiness_check: Option<ReadinessCheck>,
405 #[deprecated(
408 since = "0.13.0",
409 note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
410 )]
411 pub max_request_body: usize,
412 #[deprecated(
415 since = "0.13.0",
416 note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
417 )]
418 pub request_timeout: Duration,
419 #[deprecated(
422 since = "0.13.0",
423 note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
424 )]
425 pub shutdown_timeout: Duration,
426 #[deprecated(
429 since = "0.13.0",
430 note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
431 )]
432 pub session_idle_timeout: Duration,
433 #[deprecated(
436 since = "0.13.0",
437 note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
438 )]
439 pub sse_keep_alive: Duration,
440 #[deprecated(
444 since = "0.13.0",
445 note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
446 )]
447 pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
448 #[deprecated(
455 since = "0.13.0",
456 note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
457 )]
458 pub extra_router: Option<axum::Router>,
459 #[deprecated(
464 since = "0.13.0",
465 note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
466 )]
467 pub public_url: Option<String>,
468 #[deprecated(
471 since = "0.13.0",
472 note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
473 )]
474 pub log_request_headers: bool,
475 #[deprecated(
478 since = "0.13.0",
479 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
480 )]
481 pub compression_enabled: bool,
482 #[deprecated(
485 since = "0.13.0",
486 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
487 )]
488 pub compression_min_size: u16,
489 #[deprecated(
493 since = "0.13.0",
494 note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
495 )]
496 pub max_concurrent_requests: Option<usize>,
497 #[deprecated(
500 since = "0.13.0",
501 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
502 )]
503 pub admin_enabled: bool,
504 #[deprecated(
506 since = "0.13.0",
507 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
508 )]
509 pub admin_role: String,
510 #[cfg(feature = "metrics")]
513 #[deprecated(
514 since = "0.13.0",
515 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
516 )]
517 pub metrics_enabled: bool,
518 #[cfg(feature = "metrics")]
520 #[deprecated(
521 since = "0.13.0",
522 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
523 )]
524 pub metrics_bind: String,
525 #[deprecated(
529 since = "1.5.0",
530 note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
531 )]
532 pub security_headers: SecurityHeadersConfig,
533 #[deprecated(
539 since = "1.9.0",
540 note = "use McpServerConfig::with_tls_handshake_timeout(); direct field access will become pub(crate) in a future major release"
541 )]
542 pub tls_handshake_timeout: Duration,
543 #[deprecated(
550 since = "1.9.0",
551 note = "use McpServerConfig::with_max_concurrent_tls_handshakes(); direct field access will become pub(crate) in a future major release"
552 )]
553 pub max_concurrent_tls_handshakes: usize,
554}
555
556#[allow(
614 missing_debug_implementations,
615 reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
616)]
617pub struct Validated<T>(T);
618
619impl<T> std::fmt::Debug for Validated<T> {
620 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
621 f.debug_struct("Validated").finish_non_exhaustive()
622 }
623}
624
625impl<T> Validated<T> {
626 #[must_use]
628 pub fn as_inner(&self) -> &T {
629 &self.0
630 }
631
632 #[must_use]
637 pub fn into_inner(self) -> T {
638 self.0
639 }
640}
641
642#[allow(
643 deprecated,
644 reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
645)]
646impl McpServerConfig {
647 #[must_use]
655 pub fn new(
656 bind_addr: impl Into<String>,
657 name: impl Into<String>,
658 version: impl Into<String>,
659 ) -> Self {
660 Self {
661 bind_addr: bind_addr.into(),
662 name: name.into(),
663 version: version.into(),
664 tls_cert_path: None,
665 tls_key_path: None,
666 auth: None,
667 rbac: None,
668 allowed_origins: Vec::new(),
669 tool_rate_limit: None,
670 readiness_check: None,
671 max_request_body: 1024 * 1024,
672 request_timeout: Duration::from_mins(2),
673 shutdown_timeout: Duration::from_secs(30),
674 session_idle_timeout: Duration::from_mins(20),
675 sse_keep_alive: Duration::from_secs(15),
676 on_reload_ready: None,
677 extra_router: None,
678 public_url: None,
679 log_request_headers: false,
680 compression_enabled: false,
681 compression_min_size: 1024,
682 max_concurrent_requests: None,
683 admin_enabled: false,
684 admin_role: "admin".to_owned(),
685 #[cfg(feature = "metrics")]
686 metrics_enabled: false,
687 #[cfg(feature = "metrics")]
688 metrics_bind: "127.0.0.1:9090".into(),
689 security_headers: SecurityHeadersConfig::default(),
690 tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
691 max_concurrent_tls_handshakes: DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES,
692 extra_route_rate_limit: None,
693 tool_rate_limit_burst: None,
694 extra_route_rate_limit_burst: None,
695 trusted_proxies: Vec::new(),
696 forwarded_header: None,
697 }
698 }
699
700 #[must_use]
710 pub fn with_auth(mut self, auth: AuthConfig) -> Self {
711 self.auth = Some(auth);
712 self
713 }
714
715 #[must_use]
720 pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
721 self.security_headers = headers;
722 self
723 }
724
725 #[must_use]
729 pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
730 self.bind_addr = addr.into();
731 self
732 }
733
734 #[must_use]
737 pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
738 self.rbac = Some(rbac);
739 self
740 }
741
742 #[must_use]
746 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
747 self.tls_cert_path = Some(cert_path.into());
748 self.tls_key_path = Some(key_path.into());
749 self
750 }
751
752 #[must_use]
756 pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
757 self.public_url = Some(url.into());
758 self
759 }
760
761 #[must_use]
765 pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
766 where
767 I: IntoIterator<Item = S>,
768 S: Into<String>,
769 {
770 self.allowed_origins = origins.into_iter().map(Into::into).collect();
771 self
772 }
773
774 #[must_use]
787 pub fn with_extra_router(mut self, router: axum::Router) -> Self {
788 self.extra_router = Some(router);
789 self
790 }
791
792 #[must_use]
795 pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
796 self.readiness_check = Some(check);
797 self
798 }
799
800 #[must_use]
803 pub fn with_max_request_body(mut self, bytes: usize) -> Self {
804 self.max_request_body = bytes;
805 self
806 }
807
808 #[must_use]
810 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
811 self.request_timeout = timeout;
812 self
813 }
814
815 #[must_use]
817 pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
818 self.shutdown_timeout = timeout;
819 self
820 }
821
822 #[must_use]
824 pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
825 self.session_idle_timeout = timeout;
826 self
827 }
828
829 #[must_use]
831 pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
832 self.sse_keep_alive = interval;
833 self
834 }
835
836 #[must_use]
840 pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
841 self.max_concurrent_requests = Some(limit);
842 self
843 }
844
845 #[must_use]
853 pub fn with_tls_handshake_timeout(mut self, timeout: Duration) -> Self {
854 self.tls_handshake_timeout = timeout;
855 self
856 }
857
858 #[must_use]
867 pub fn with_max_concurrent_tls_handshakes(mut self, limit: usize) -> Self {
868 self.max_concurrent_tls_handshakes = limit;
869 self
870 }
871
872 #[must_use]
875 pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
876 self.tool_rate_limit = Some(per_minute);
877 self
878 }
879
880 #[must_use]
891 pub fn with_extra_route_rate_limit(mut self, per_minute: u32) -> Self {
892 self.extra_route_rate_limit = Some(per_minute);
893 self
894 }
895
896 #[must_use]
901 pub fn with_tool_rate_limit_burst(mut self, burst: u32) -> Self {
902 self.tool_rate_limit_burst = Some(burst);
903 self
904 }
905
906 #[must_use]
912 pub fn with_extra_route_rate_limit_burst(mut self, burst: u32) -> Self {
913 self.extra_route_rate_limit_burst = Some(burst);
914 self
915 }
916
917 #[must_use]
929 pub fn with_trusted_proxies<I, S>(mut self, proxies: I) -> Self
930 where
931 I: IntoIterator<Item = S>,
932 S: Into<String>,
933 {
934 self.trusted_proxies = proxies.into_iter().map(Into::into).collect();
935 self
936 }
937
938 #[must_use]
943 pub fn with_forwarded_header(mut self, mode: ForwardedHeaderMode) -> Self {
944 self.forwarded_header = Some(mode);
945 self
946 }
947
948 #[must_use]
952 pub fn with_reload_callback<F>(mut self, callback: F) -> Self
953 where
954 F: FnOnce(ReloadHandle) + Send + 'static,
955 {
956 self.on_reload_ready = Some(Box::new(callback));
957 self
958 }
959
960 #[must_use]
964 pub fn enable_compression(mut self, min_size: u16) -> Self {
965 self.compression_enabled = true;
966 self.compression_min_size = min_size;
967 self
968 }
969
970 #[must_use]
975 pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
976 self.admin_enabled = true;
977 self.admin_role = role.into();
978 self
979 }
980
981 #[must_use]
984 pub fn enable_request_header_logging(mut self) -> Self {
985 self.log_request_headers = true;
986 self
987 }
988
989 #[cfg(feature = "metrics")]
992 #[must_use]
993 pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
994 self.metrics_enabled = true;
995 self.metrics_bind = bind.into();
996 self
997 }
998
999 pub fn validate(self) -> Result<Validated<Self>, McpxError> {
1032 self.check()?;
1033 Ok(Validated(self))
1034 }
1035
1036 fn check_burst_knobs(&self) -> Result<(), McpxError> {
1043 if self.tool_rate_limit_burst == Some(0) {
1044 return Err(McpxError::Config(
1045 "tool_rate_limit_burst must be greater than zero".into(),
1046 ));
1047 }
1048 if self.extra_route_rate_limit_burst == Some(0) {
1049 return Err(McpxError::Config(
1050 "extra_route_rate_limit_burst must be greater than zero".into(),
1051 ));
1052 }
1053 if self.tool_rate_limit_burst.is_some() && self.tool_rate_limit.is_none() {
1054 return Err(McpxError::Config(
1055 "tool_rate_limit_burst requires tool_rate_limit to be set".into(),
1056 ));
1057 }
1058 if self.extra_route_rate_limit_burst.is_some() && self.extra_route_rate_limit.is_none() {
1059 return Err(McpxError::Config(
1060 "extra_route_rate_limit_burst requires extra_route_rate_limit to be set".into(),
1061 ));
1062 }
1063 if let Some(rl) = self.auth.as_ref().and_then(|a| a.rate_limit.as_ref()) {
1064 if rl.burst == Some(0) {
1065 return Err(McpxError::Config(
1066 "auth rate_limit.burst must be greater than zero".into(),
1067 ));
1068 }
1069 if rl.pre_auth_burst == Some(0) {
1070 return Err(McpxError::Config(
1071 "auth rate_limit.pre_auth_burst must be greater than zero".into(),
1072 ));
1073 }
1074 }
1075 Ok(())
1076 }
1077
1078 fn check_trusted_forwarder(&self) -> Result<(), McpxError> {
1083 for entry in &self.trusted_proxies {
1084 if parse_proxy_net(entry).is_none() {
1085 return Err(McpxError::Config(format!(
1086 "trusted_proxies entry {entry:?} is neither a CIDR nor an IP address"
1087 )));
1088 }
1089 }
1090 if self.forwarded_header.is_some() && self.trusted_proxies.is_empty() {
1091 return Err(McpxError::Config(
1092 "forwarded_header requires trusted_proxies to be nonempty".into(),
1093 ));
1094 }
1095 Ok(())
1096 }
1097
1098 fn check(&self) -> Result<(), McpxError> {
1102 if self.admin_enabled {
1106 let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
1107 if !auth_enabled {
1108 return Err(McpxError::Config(
1109 "admin_enabled=true requires auth to be configured and enabled".into(),
1110 ));
1111 }
1112 }
1113
1114 match (&self.tls_cert_path, &self.tls_key_path) {
1116 (Some(_), None) => {
1117 return Err(McpxError::Config(
1118 "tls_cert_path is set but tls_key_path is missing".into(),
1119 ));
1120 }
1121 (None, Some(_)) => {
1122 return Err(McpxError::Config(
1123 "tls_key_path is set but tls_cert_path is missing".into(),
1124 ));
1125 }
1126 _ => {}
1127 }
1128
1129 if self.bind_addr.parse::<SocketAddr>().is_err() {
1131 return Err(McpxError::Config(format!(
1132 "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
1133 self.bind_addr
1134 )));
1135 }
1136
1137 if let Some(ref url) = self.public_url
1139 && !(url.starts_with("http://") || url.starts_with("https://"))
1140 {
1141 return Err(McpxError::Config(format!(
1142 "public_url {url:?} must start with http:// or https://"
1143 )));
1144 }
1145
1146 for origin in &self.allowed_origins {
1148 if !(origin.starts_with("http://") || origin.starts_with("https://")) {
1149 return Err(McpxError::Config(format!(
1150 "allowed_origins entry {origin:?} must start with http:// or https://"
1151 )));
1152 }
1153 }
1154
1155 if self.max_request_body == 0 {
1157 return Err(McpxError::Config(
1158 "max_request_body must be greater than zero".into(),
1159 ));
1160 }
1161
1162 if self.extra_route_rate_limit == Some(0) {
1166 return Err(McpxError::Config(
1167 "extra_route_rate_limit must be greater than zero".into(),
1168 ));
1169 }
1170
1171 self.check_burst_knobs()?;
1173
1174 self.check_trusted_forwarder()?;
1176
1177 #[cfg(feature = "oauth")]
1179 if let Some(auth_cfg) = &self.auth
1180 && let Some(oauth_cfg) = &auth_cfg.oauth
1181 {
1182 oauth_cfg.validate()?;
1183 }
1184
1185 validate_security_headers(&self.security_headers)?;
1188
1189 if let Some(0) = self.max_concurrent_requests {
1193 return Err(McpxError::Config(
1194 "max_concurrent_requests must be greater than zero when set".into(),
1195 ));
1196 }
1197
1198 if let Some(auth_cfg) = &self.auth
1202 && let Some(rl) = &auth_cfg.rate_limit
1203 && rl.max_tracked_keys == 0
1204 {
1205 return Err(McpxError::Config(
1206 "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
1207 ));
1208 }
1209
1210 if self.tls_handshake_timeout == Duration::ZERO {
1215 return Err(McpxError::Config(
1216 "tls_handshake_timeout must be greater than zero".into(),
1217 ));
1218 }
1219
1220 if self.max_concurrent_tls_handshakes == 0 {
1225 return Err(McpxError::Config(
1226 "max_concurrent_tls_handshakes must be greater than zero".into(),
1227 ));
1228 }
1229
1230 Ok(())
1231 }
1232}
1233
1234#[allow(
1240 missing_debug_implementations,
1241 reason = "contains Arc<AuthState> with non-Debug fields"
1242)]
1243pub struct ReloadHandle {
1244 auth: Option<Arc<AuthState>>,
1245 rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
1246 crl_set: Option<Arc<CrlSet>>,
1247}
1248
1249impl ReloadHandle {
1250 pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
1252 if let Some(ref auth) = self.auth {
1253 auth.reload_keys(keys);
1254 }
1255 }
1256
1257 pub fn reload_rbac(&self, policy: RbacPolicy) {
1259 if let Some(ref rbac) = self.rbac {
1260 rbac.store(Arc::new(policy));
1261 tracing::info!("RBAC policy reloaded");
1262 }
1263 }
1264
1265 pub async fn refresh_crls(&self) -> Result<(), McpxError> {
1271 let Some(ref crl_set) = self.crl_set else {
1272 return Err(McpxError::Config(
1273 "CRL refresh requested but mTLS CRL support is not configured".into(),
1274 ));
1275 };
1276
1277 crl_set.force_refresh().await
1278 }
1279}
1280
1281#[allow(
1298 clippy::too_many_lines,
1299 clippy::cognitive_complexity,
1300 reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
1301)]
1302struct AppRunParams {
1306 tls_paths: Option<(PathBuf, PathBuf)>,
1308 tls_handshake_timeout: Duration,
1310 max_concurrent_tls_handshakes: usize,
1312 mtls_config: Option<MtlsConfig>,
1314 shutdown_timeout: Duration,
1316 auth_state: Option<Arc<AuthState>>,
1318 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1320 on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1322 ct: CancellationToken,
1326 scheme: &'static str,
1328 name: String,
1330}
1331
1332#[allow(
1342 clippy::cognitive_complexity,
1343 reason = "router assembly is intrinsically sequential; splitting harms readability"
1344)]
1345#[allow(
1346 deprecated,
1347 reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
1348)]
1349fn build_app_router<H, F>(
1350 mut config: McpServerConfig,
1351 handler_factory: F,
1352) -> anyhow::Result<(axum::Router, AppRunParams)>
1353where
1354 H: ServerHandler + 'static,
1355 F: Fn() -> H + Send + Sync + Clone + 'static,
1356{
1357 let ct = CancellationToken::new();
1358
1359 let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
1360 tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
1361
1362 let mcp_service = StreamableHttpService::new(
1363 move || Ok(handler_factory()),
1364 {
1365 let mut mgr = LocalSessionManager::default();
1366 mgr.session_config.keep_alive = Some(config.session_idle_timeout);
1367 mgr.into()
1368 },
1369 StreamableHttpServerConfig::default()
1370 .with_allowed_hosts(allowed_hosts)
1371 .with_sse_keep_alive(Some(config.sse_keep_alive))
1372 .with_cancellation_token(ct.child_token()),
1373 );
1374
1375 let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
1377
1378 let auth_state: Option<Arc<AuthState>> = match config.auth {
1382 Some(ref auth_config) if auth_config.enabled => {
1383 let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
1384 let pre_auth_limiter = auth_config
1385 .rate_limit
1386 .as_ref()
1387 .map(crate::auth::build_pre_auth_limiter);
1388
1389 #[cfg(feature = "oauth")]
1390 let jwks_cache = auth_config
1391 .oauth
1392 .as_ref()
1393 .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
1394 .transpose()
1395 .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
1396
1397 Some(Arc::new(AuthState {
1398 api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
1399 rate_limiter,
1400 pre_auth_limiter,
1401 #[cfg(feature = "oauth")]
1402 jwks_cache,
1403 seen_identities: crate::auth::SeenIdentitySet::new(),
1404 counters: crate::auth::AuthCounters::default(),
1405 }))
1406 }
1407 _ => None,
1408 };
1409
1410 let rbac_swap = Arc::new(ArcSwap::new(
1413 config
1414 .rbac
1415 .clone()
1416 .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
1417 ));
1418
1419 if config.admin_enabled {
1422 let Some(ref auth_state_ref) = auth_state else {
1423 return Err(anyhow::anyhow!(
1424 "admin_enabled=true requires auth to be configured and enabled"
1425 ));
1426 };
1427 let admin_state = crate::admin::AdminState {
1428 started_at: std::time::Instant::now(),
1429 name: config.name.clone(),
1430 version: config.version.clone(),
1431 auth: Some(Arc::clone(auth_state_ref)),
1432 rbac: Arc::clone(&rbac_swap),
1433 };
1434 let admin_cfg = crate::admin::AdminConfig {
1435 role: config.admin_role.clone(),
1436 };
1437 mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
1438 tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
1439 }
1440
1441 {
1474 let tool_limiter: Option<Arc<ToolRateLimiter>> = config
1475 .tool_rate_limit
1476 .map(|per_minute| build_tool_rate_limiter(per_minute, config.tool_rate_limit_burst));
1477
1478 if rbac_swap.load().is_enabled() {
1479 tracing::info!("RBAC enforcement enabled on /mcp");
1480 }
1481 if let Some(limit) = config.tool_rate_limit {
1482 tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1483 }
1484
1485 let rbac_for_mw = Arc::clone(&rbac_swap);
1486 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1487 let p = rbac_for_mw.load_full();
1488 let tl = tool_limiter.clone();
1489 rbac_middleware(p, tl, req, next)
1490 }));
1491 }
1492
1493 if let Some(ref auth_config) = config.auth
1495 && auth_config.enabled
1496 {
1497 let Some(ref state) = auth_state else {
1498 return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1499 };
1500
1501 let methods: Vec<&str> = [
1502 auth_config.mtls.is_some().then_some("mTLS"),
1503 (!auth_config.api_keys.is_empty()).then_some("bearer"),
1504 #[cfg(feature = "oauth")]
1505 auth_config.oauth.is_some().then_some("oauth-jwt"),
1506 ]
1507 .into_iter()
1508 .flatten()
1509 .collect();
1510
1511 tracing::info!(
1512 methods = %methods.join(", "),
1513 api_keys = auth_config.api_keys.len(),
1514 "auth enabled on /mcp"
1515 );
1516
1517 let state_for_mw = Arc::clone(state);
1518 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1519 let s = Arc::clone(&state_for_mw);
1520 auth_middleware(s, req, next)
1521 }));
1522 }
1523
1524 mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1527 axum::http::StatusCode::REQUEST_TIMEOUT,
1528 config.request_timeout,
1529 ));
1530
1531 mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1535 config.max_request_body,
1536 ));
1537
1538 let mut effective_origins = config.allowed_origins.clone();
1545 if effective_origins.is_empty()
1546 && let Some(ref url) = config.public_url
1547 {
1548 if let Some(scheme_end) = url.find("://") {
1553 let scheme_with_sep = url.get(..scheme_end + 3).unwrap_or_default();
1554 let after_scheme = url.get(scheme_end + 3..).unwrap_or_default();
1555 let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1556 let host = after_scheme.get(..host_end).unwrap_or_default();
1557 let origin = format!("{scheme_with_sep}{host}");
1558 tracing::info!(
1559 %origin,
1560 "auto-derived allowed origin from public_url"
1561 );
1562 effective_origins.push(origin);
1563 }
1564 }
1565 let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1566 let cors_origins = Arc::clone(&allowed_origins);
1567 let log_request_headers = config.log_request_headers;
1568
1569 let readyz_route = if let Some(check) = config.readiness_check.take() {
1570 axum::routing::get(move || readyz(Arc::clone(&check)))
1571 } else {
1572 axum::routing::get(healthz)
1573 };
1574
1575 #[allow(unused_mut)] let mut router = axum::Router::new()
1577 .route("/healthz", axum::routing::get(healthz))
1578 .route("/readyz", readyz_route)
1579 .route(
1580 "/version",
1581 axum::routing::get({
1582 let payload_bytes: Arc<[u8]> =
1587 serialize_version_payload(&config.name, &config.version);
1588 move || {
1589 let p = Arc::clone(&payload_bytes);
1590 async move {
1591 (
1592 [(axum::http::header::CONTENT_TYPE, "application/json")],
1593 p.to_vec(),
1594 )
1595 }
1596 }
1597 }),
1598 )
1599 .merge(mcp_router);
1600
1601 if let Some(extra) = config.extra_router.take() {
1608 let extra = match config.extra_route_rate_limit {
1609 Some(per_minute) => {
1610 let limiter =
1611 build_extra_route_rate_limiter(per_minute, config.extra_route_rate_limit_burst);
1612 tracing::info!(per_minute, "extra-route per-IP rate limit enabled");
1613 extra.layer(axum::middleware::from_fn(move |req, next| {
1614 let l = Arc::clone(&limiter);
1615 extra_route_rate_limit_middleware(l, req, next)
1616 }))
1617 }
1618 None => extra,
1619 };
1620 router = router.merge(extra);
1621 }
1622
1623 let server_url = if let Some(ref url) = config.public_url {
1630 url.trim_end_matches('/').to_owned()
1631 } else {
1632 let prm_scheme = if config.tls_cert_path.is_some() {
1633 "https"
1634 } else {
1635 "http"
1636 };
1637 format!("{prm_scheme}://{}", config.bind_addr)
1638 };
1639 let resource_url = format!("{server_url}/mcp");
1640
1641 #[cfg(feature = "oauth")]
1642 let prm_metadata = if let Some(ref auth_config) = config.auth
1643 && let Some(ref oauth_config) = auth_config.oauth
1644 {
1645 crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1646 } else {
1647 serde_json::json!({ "resource": resource_url })
1648 };
1649 #[cfg(not(feature = "oauth"))]
1650 let prm_metadata = serde_json::json!({ "resource": resource_url });
1651
1652 router = router.route(
1653 "/.well-known/oauth-protected-resource",
1654 axum::routing::get(move || {
1655 let m = prm_metadata.clone();
1656 async move { axum::Json(m) }
1657 }),
1658 );
1659
1660 #[cfg(feature = "oauth")]
1665 if let Some(ref auth_config) = config.auth
1666 && let Some(ref oauth_config) = auth_config.oauth
1667 && oauth_config.proxy.is_some()
1668 {
1669 router =
1670 install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1671 }
1672
1673 let is_tls = config.tls_cert_path.is_some();
1676 let security_headers_cfg = Arc::new(config.security_headers.clone());
1677 router = router.layer(axum::middleware::from_fn(move |req, next| {
1678 let cfg = Arc::clone(&security_headers_cfg);
1679 security_headers_middleware(is_tls, cfg, req, next)
1680 }));
1681
1682 if !cors_origins.is_empty() {
1686 let cors = tower_http::cors::CorsLayer::new()
1687 .allow_origin(
1688 cors_origins
1689 .iter()
1690 .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1691 .collect::<Vec<_>>(),
1692 )
1693 .allow_methods([
1694 axum::http::Method::GET,
1695 axum::http::Method::POST,
1696 axum::http::Method::OPTIONS,
1697 ])
1698 .allow_headers([
1699 axum::http::header::CONTENT_TYPE,
1700 axum::http::header::AUTHORIZATION,
1701 ]);
1702 router = router.layer(cors);
1703 }
1704
1705 if config.compression_enabled {
1709 use tower_http::compression::Predicate as _;
1710 let predicate = tower_http::compression::DefaultPredicate::new().and(
1711 tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1712 );
1713 router = router.layer(
1714 tower_http::compression::CompressionLayer::new()
1715 .gzip(true)
1716 .br(true)
1717 .compress_when(predicate),
1718 );
1719 tracing::info!(
1720 min_size = config.compression_min_size,
1721 "response compression enabled (gzip, br)"
1722 );
1723 }
1724
1725 if let Some(max) = config.max_concurrent_requests {
1728 let overload_handler = tower::ServiceBuilder::new()
1729 .layer(axum::error_handling::HandleErrorLayer::new(
1730 |_err: tower::BoxError| async {
1731 (
1732 axum::http::StatusCode::SERVICE_UNAVAILABLE,
1733 axum::Json(serde_json::json!({
1734 "error": "overloaded",
1735 "error_description": "server is at capacity, retry later"
1736 })),
1737 )
1738 },
1739 ))
1740 .layer(tower::load_shed::LoadShedLayer::new())
1741 .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1742 router = router.layer(overload_handler);
1743 tracing::info!(max, "global concurrency limit enabled");
1744 }
1745
1746 router = router.fallback(|| async {
1750 (
1751 axum::http::StatusCode::NOT_FOUND,
1752 axum::Json(serde_json::json!({
1753 "error": "not_found",
1754 "error_description": "The requested endpoint does not exist"
1755 })),
1756 )
1757 });
1758
1759 #[cfg(feature = "metrics")]
1761 if config.metrics_enabled {
1762 let metrics = Arc::new(
1763 crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1764 );
1765 let m = Arc::clone(&metrics);
1766 router = router.layer(axum::middleware::from_fn(
1767 move |req: Request<Body>, next: Next| {
1768 let m = Arc::clone(&m);
1769 metrics_middleware(m, req, next)
1770 },
1771 ));
1772 let metrics_bind = config.metrics_bind.clone();
1773 let metrics_shutdown = ct.clone();
1774 tokio::spawn(async move {
1775 if let Err(e) =
1776 crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1777 {
1778 tracing::error!("metrics listener failed: {e}");
1779 }
1780 });
1781 }
1782
1783 let forward_resolver: Option<Arc<ForwardResolver>> = if config.trusted_proxies.is_empty() {
1791 None
1792 } else {
1793 Some(Arc::new(ForwardResolver {
1796 trusted: config
1797 .trusted_proxies
1798 .iter()
1799 .filter_map(|entry| parse_proxy_net(entry))
1800 .collect(),
1801 mode: config
1802 .forwarded_header
1803 .unwrap_or(ForwardedHeaderMode::XForwardedFor),
1804 }))
1805 };
1806 if forward_resolver.is_some() {
1807 tracing::info!(
1808 proxies = config.trusted_proxies.len(),
1809 "trusted-forwarder mode enabled: limiters key by resolved client IP"
1810 );
1811 }
1812 router = router.layer(axum::middleware::from_fn(move |req, next| {
1813 let r = forward_resolver.clone();
1814 normalize_peer_addr_middleware(r, req, next)
1815 }));
1816
1817 router = router.layer(axum::middleware::from_fn(move |req, next| {
1828 let origins = Arc::clone(&allowed_origins);
1829 origin_check_middleware(origins, log_request_headers, req, next)
1830 }));
1831
1832 let scheme = if config.tls_cert_path.is_some() {
1833 "https"
1834 } else {
1835 "http"
1836 };
1837
1838 let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1839 (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1840 _ => None,
1841 };
1842 let tls_handshake_timeout = config.tls_handshake_timeout;
1843 let max_concurrent_tls_handshakes = config.max_concurrent_tls_handshakes;
1844 let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1845
1846 Ok((
1847 router,
1848 AppRunParams {
1849 tls_paths,
1850 tls_handshake_timeout,
1851 max_concurrent_tls_handshakes,
1852 mtls_config,
1853 shutdown_timeout: config.shutdown_timeout,
1854 auth_state,
1855 rbac_swap,
1856 on_reload_ready: config.on_reload_ready.take(),
1857 ct,
1858 scheme,
1859 name: config.name.clone(),
1860 },
1861 ))
1862}
1863
1864pub async fn serve<H, F>(
1881 config: Validated<McpServerConfig>,
1882 handler_factory: F,
1883) -> Result<(), McpxError>
1884where
1885 H: ServerHandler + 'static,
1886 F: Fn() -> H + Send + Sync + Clone + 'static,
1887{
1888 let config = config.into_inner();
1889 #[allow(
1890 deprecated,
1891 reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1892 )]
1893 let bind_addr = config.bind_addr.clone();
1894 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1895
1896 let listener = TcpListener::bind(&bind_addr)
1897 .await
1898 .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1899 log_listening(¶ms.name, params.scheme, &bind_addr);
1900
1901 run_server(
1902 router,
1903 listener,
1904 params.tls_paths,
1905 params.tls_handshake_timeout,
1906 params.max_concurrent_tls_handshakes,
1907 params.mtls_config,
1908 params.shutdown_timeout,
1909 params.auth_state,
1910 params.rbac_swap,
1911 params.on_reload_ready,
1912 params.ct,
1913 )
1914 .await
1915 .map_err(anyhow_to_startup)
1916}
1917
1918pub async fn serve_with_listener<H, F>(
1948 listener: TcpListener,
1949 config: Validated<McpServerConfig>,
1950 handler_factory: F,
1951 ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1952 shutdown: Option<CancellationToken>,
1953) -> Result<(), McpxError>
1954where
1955 H: ServerHandler + 'static,
1956 F: Fn() -> H + Send + Sync + Clone + 'static,
1957{
1958 let config = config.into_inner();
1959 let local_addr = listener
1960 .local_addr()
1961 .map_err(|e| io_to_startup("listener.local_addr", e))?;
1962 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1963
1964 log_listening(¶ms.name, params.scheme, &local_addr.to_string());
1965
1966 if let Some(external) = shutdown {
1970 let internal = params.ct.clone();
1971 tokio::spawn(async move {
1972 external.cancelled().await;
1973 internal.cancel();
1974 });
1975 }
1976
1977 if let Some(tx) = ready_tx {
1981 let _ = tx.send(local_addr);
1983 }
1984
1985 run_server(
1986 router,
1987 listener,
1988 params.tls_paths,
1989 params.tls_handshake_timeout,
1990 params.max_concurrent_tls_handshakes,
1991 params.mtls_config,
1992 params.shutdown_timeout,
1993 params.auth_state,
1994 params.rbac_swap,
1995 params.on_reload_ready,
1996 params.ct,
1997 )
1998 .await
1999 .map_err(anyhow_to_startup)
2000}
2001
2002#[allow(
2005 clippy::cognitive_complexity,
2006 reason = "tracing::info! macro expansions inflate the score; logic is trivial"
2007)]
2008fn log_listening(name: &str, scheme: &str, addr: &str) {
2009 tracing::info!("{name} listening on {addr}");
2010 tracing::info!(" MCP endpoint: {scheme}://{addr}/mcp");
2011 tracing::info!(" Health check: {scheme}://{addr}/healthz");
2012 tracing::info!(" Readiness: {scheme}://{addr}/readyz");
2013}
2014
2015#[allow(
2038 clippy::too_many_arguments,
2039 clippy::cognitive_complexity,
2040 reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
2041)]
2042async fn run_server(
2043 router: axum::Router,
2044 listener: TcpListener,
2045 tls_paths: Option<(PathBuf, PathBuf)>,
2046 tls_handshake_timeout: Duration,
2047 max_concurrent_tls_handshakes: usize,
2048 mtls_config: Option<MtlsConfig>,
2049 shutdown_timeout: Duration,
2050 auth_state: Option<Arc<AuthState>>,
2051 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
2052 mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
2053 ct: CancellationToken,
2054) -> anyhow::Result<()> {
2055 let shutdown_trigger = CancellationToken::new();
2059 {
2060 let trigger = shutdown_trigger.clone();
2061 let parent = ct.clone();
2062 tokio::spawn(async move {
2063 tokio::select! {
2064 () = shutdown_signal() => {}
2065 () = parent.cancelled() => {}
2066 }
2067 trigger.cancel();
2068 });
2069 }
2070
2071 let graceful = {
2072 let trigger = shutdown_trigger.clone();
2073 let ct = ct.clone();
2074 async move {
2075 trigger.cancelled().await;
2076 tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
2077 ct.cancel();
2078 }
2079 };
2080
2081 let force_exit_timer = {
2082 let trigger = shutdown_trigger.clone();
2083 async move {
2084 trigger.cancelled().await;
2085 tokio::time::sleep(shutdown_timeout).await;
2086 }
2087 };
2088
2089 if let Some((cert_path, key_path)) = tls_paths {
2090 let crl_set = if let Some(mtls) = mtls_config.as_ref()
2091 && mtls.crl_enabled
2092 {
2093 let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
2094 let (crl_set, discover_rx) =
2095 mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
2096 .await
2097 .map_err(|error| anyhow::anyhow!(error.to_string()))?;
2098 tokio::spawn(mtls_revocation::run_crl_refresher(
2099 Arc::clone(&crl_set),
2100 discover_rx,
2101 ct.clone(),
2102 ));
2103 Some(crl_set)
2104 } else {
2105 None
2106 };
2107
2108 if let Some(cb) = on_reload_ready.take() {
2109 cb(ReloadHandle {
2110 auth: auth_state.clone(),
2111 rbac: Some(Arc::clone(&rbac_swap)),
2112 crl_set: crl_set.clone(),
2113 });
2114 }
2115
2116 let tls_listener = TlsListener::new(
2117 listener,
2118 &cert_path,
2119 &key_path,
2120 mtls_config.as_ref(),
2121 crl_set,
2122 tls_handshake_timeout,
2123 max_concurrent_tls_handshakes,
2124 )?;
2125 let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
2126 tokio::select! {
2127 result = axum::serve(tls_listener, make_svc)
2128 .with_graceful_shutdown(graceful) => { result?; }
2129 () = force_exit_timer => {
2130 tracing::warn!("shutdown timeout exceeded, forcing exit");
2131 }
2132 }
2133 } else {
2134 if let Some(cb) = on_reload_ready.take() {
2135 cb(ReloadHandle {
2136 auth: auth_state,
2137 rbac: Some(rbac_swap),
2138 crl_set: None,
2139 });
2140 }
2141
2142 let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
2143 tokio::select! {
2144 result = axum::serve(listener, make_svc)
2145 .with_graceful_shutdown(graceful) => { result?; }
2146 () = force_exit_timer => {
2147 tracing::warn!("shutdown timeout exceeded, forcing exit");
2148 }
2149 }
2150 }
2151
2152 Ok(())
2153}
2154
2155#[cfg(feature = "oauth")]
2164fn install_oauth_proxy_routes(
2165 router: axum::Router,
2166 server_url: &str,
2167 oauth_config: &crate::oauth::OAuthConfig,
2168 auth_state: Option<&Arc<AuthState>>,
2169) -> Result<axum::Router, McpxError> {
2170 let Some(ref proxy) = oauth_config.proxy else {
2171 return Ok(router);
2172 };
2173
2174 let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
2177
2178 let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
2179 let router = router.route(
2180 "/.well-known/oauth-authorization-server",
2181 axum::routing::get(move || {
2182 let m = asm.clone();
2183 async move { axum::Json(m) }
2184 }),
2185 );
2186
2187 let proxy_authorize = proxy.clone();
2188 let router = router.route(
2189 "/authorize",
2190 axum::routing::get(
2191 move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
2192 let p = proxy_authorize.clone();
2193 async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
2194 },
2195 ),
2196 );
2197
2198 let proxy_token = proxy.clone();
2199 let token_http = http.clone();
2200 let router = router.route(
2201 "/token",
2202 axum::routing::post(move |body: String| {
2203 let p = proxy_token.clone();
2204 let h = token_http.clone();
2205 async move { crate::oauth::handle_token(&h, &p, &body).await }
2206 })
2207 .layer(axum::middleware::from_fn(
2208 oauth_token_cache_headers_middleware,
2209 )),
2210 );
2211
2212 let proxy_register = proxy.clone();
2213 let router = router.route(
2214 "/register",
2215 axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
2216 let p = proxy_register;
2217 async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
2218 })
2219 .layer(axum::middleware::from_fn(
2220 oauth_token_cache_headers_middleware,
2221 )),
2222 );
2223
2224 let admin_routes_enabled = proxy.expose_admin_endpoints
2225 && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
2226 if proxy.expose_admin_endpoints
2227 && !proxy.require_auth_on_admin_endpoints
2228 && proxy.allow_unauthenticated_admin_endpoints
2229 {
2230 tracing::warn!(
2234 "OAuth introspect/revoke endpoints are unauthenticated by explicit \
2235 allow_unauthenticated_admin_endpoints opt-out; ensure an \
2236 authenticated reverse proxy fronts these routes"
2237 );
2238 }
2239
2240 let admin_router = if admin_routes_enabled {
2241 build_oauth_admin_router(proxy, http, auth_state)?
2242 } else {
2243 axum::Router::new()
2244 };
2245
2246 let router = router.merge(admin_router);
2247
2248 tracing::info!(
2249 introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
2250 revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
2251 "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
2252 );
2253 Ok(router)
2254}
2255
2256#[cfg(feature = "oauth")]
2262fn build_oauth_admin_router(
2263 proxy: &crate::oauth::OAuthProxyConfig,
2264 http: crate::oauth::OauthHttpClient,
2265 auth_state: Option<&Arc<AuthState>>,
2266) -> Result<axum::Router, McpxError> {
2267 let mut admin_router = axum::Router::new();
2268 if proxy.introspection_url.is_some() {
2269 let proxy_introspect = proxy.clone();
2270 let introspect_http = http.clone();
2271 admin_router = admin_router.route(
2272 "/introspect",
2273 axum::routing::post(move |body: String| {
2274 let p = proxy_introspect.clone();
2275 let h = introspect_http.clone();
2276 async move { crate::oauth::handle_introspect(&h, &p, &body).await }
2277 }),
2278 );
2279 }
2280 if proxy.revocation_url.is_some() {
2281 let proxy_revoke = proxy.clone();
2282 let revoke_http = http;
2283 admin_router = admin_router.route(
2284 "/revoke",
2285 axum::routing::post(move |body: String| {
2286 let p = proxy_revoke.clone();
2287 let h = revoke_http.clone();
2288 async move { crate::oauth::handle_revoke(&h, &p, &body).await }
2289 }),
2290 );
2291 }
2292
2293 let admin_router = admin_router.layer(axum::middleware::from_fn(
2294 oauth_token_cache_headers_middleware,
2295 ));
2296
2297 if proxy.require_auth_on_admin_endpoints {
2298 let Some(state) = auth_state else {
2299 return Err(McpxError::Startup(
2300 "oauth proxy admin endpoints require auth state".into(),
2301 ));
2302 };
2303 let state_for_mw = Arc::clone(state);
2304 Ok(
2305 admin_router.layer(axum::middleware::from_fn(move |req, next| {
2306 let s = Arc::clone(&state_for_mw);
2307 auth_middleware(s, req, next)
2308 })),
2309 )
2310 } else {
2311 Ok(admin_router)
2312 }
2313}
2314
2315fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
2320 let mut hosts = vec![
2321 "localhost".to_owned(),
2322 "127.0.0.1".to_owned(),
2323 "::1".to_owned(),
2324 ];
2325
2326 if let Some(url) = public_url
2327 && let Ok(uri) = url.parse::<axum::http::Uri>()
2328 && let Some(authority) = uri.authority()
2329 {
2330 let host = authority.host().to_owned();
2331 if !hosts.iter().any(|h| h == &host) {
2332 hosts.push(host);
2333 }
2334
2335 let authority = authority.as_str().to_owned();
2336 if !hosts.iter().any(|h| h == &authority) {
2337 hosts.push(authority);
2338 }
2339 }
2340
2341 if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
2342 && let Some(authority) = uri.authority()
2343 {
2344 let host = authority.host().to_owned();
2345 if !hosts.iter().any(|h| h == &host) {
2346 hosts.push(host);
2347 }
2348
2349 let authority = authority.as_str().to_owned();
2350 if !hosts.iter().any(|h| h == &authority) {
2351 hosts.push(authority);
2352 }
2353 }
2354
2355 hosts
2356}
2357
2358impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
2371 for TlsConnInfo
2372{
2373 fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
2374 let addr = *target.remote_addr();
2375 let identity = target.io().identity().cloned();
2376 TlsConnInfo::new(addr, identity)
2377 }
2378}
2379
2380const DEFAULT_TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
2387
2388const DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES: usize = 256;
2396
2397const TLS_ACCEPT_CHANNEL_CAPACITY: usize = 32;
2402
2403struct TlsListener {
2419 local_addr: SocketAddr,
2422 rx: mpsc::Receiver<(AuthenticatedTlsStream, SocketAddr)>,
2424 acceptor_task: tokio::task::JoinHandle<()>,
2427}
2428
2429impl TlsListener {
2430 fn new(
2431 inner: TcpListener,
2432 cert_path: &Path,
2433 key_path: &Path,
2434 mtls_config: Option<&MtlsConfig>,
2435 crl_set: Option<Arc<CrlSet>>,
2436 handshake_timeout: Duration,
2437 max_concurrent_handshakes: usize,
2438 ) -> anyhow::Result<Self> {
2439 rustls::crypto::ring::default_provider()
2441 .install_default()
2442 .ok();
2443
2444 let certs = load_certs(cert_path)?;
2445 let key = load_key(key_path)?;
2446
2447 let mtls_default_role;
2448
2449 let tls_config = if let Some(mtls) = mtls_config {
2450 mtls_default_role = mtls.default_role.clone();
2451 let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
2452 {
2453 let Some(crl_set) = crl_set else {
2454 return Err(anyhow::anyhow!(
2455 "mTLS CRL verifier requested but CRL state was not initialized"
2456 ));
2457 };
2458 Arc::new(DynamicClientCertVerifier::new(crl_set))
2459 } else {
2460 let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
2461 if mtls.required {
2462 rustls::server::WebPkiClientVerifier::builder(root_store)
2463 .build()
2464 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2465 } else {
2466 rustls::server::WebPkiClientVerifier::builder(root_store)
2467 .allow_unauthenticated()
2468 .build()
2469 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2470 }
2471 };
2472
2473 tracing::info!(
2474 ca = %mtls.ca_cert_path.display(),
2475 required = mtls.required,
2476 crl_enabled = mtls.crl_enabled,
2477 "mTLS client auth configured"
2478 );
2479
2480 rustls::ServerConfig::builder_with_protocol_versions(&[
2481 &rustls::version::TLS12,
2482 &rustls::version::TLS13,
2483 ])
2484 .with_client_cert_verifier(verifier)
2485 .with_single_cert(certs, key)?
2486 } else {
2487 mtls_default_role = "viewer".to_owned();
2488 rustls::ServerConfig::builder_with_protocol_versions(&[
2489 &rustls::version::TLS12,
2490 &rustls::version::TLS13,
2491 ])
2492 .with_no_client_auth()
2493 .with_single_cert(certs, key)?
2494 };
2495
2496 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
2497 tracing::info!(
2498 "TLS enabled (cert: {}, key: {})",
2499 cert_path.display(),
2500 key_path.display()
2501 );
2502 let local_addr = inner.local_addr()?;
2503 let (tx, rx) = mpsc::channel(TLS_ACCEPT_CHANNEL_CAPACITY);
2504 let acceptor_task = tokio::spawn(run_tls_acceptor(
2505 inner,
2506 acceptor,
2507 mtls_default_role,
2508 tx,
2509 handshake_timeout,
2510 max_concurrent_handshakes,
2511 ));
2512 Ok(Self {
2513 local_addr,
2514 rx,
2515 acceptor_task,
2516 })
2517 }
2518
2519 fn extract_handshake_identity(
2523 tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2524 default_role: &str,
2525 addr: SocketAddr,
2526 ) -> Option<AuthIdentity> {
2527 let (_, server_conn) = tls_stream.get_ref();
2528 let cert_der = server_conn.peer_certificates()?.first()?;
2529 let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
2530 tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
2531 Some(id)
2532 }
2533}
2534
2535async fn run_tls_acceptor(
2543 listener: TcpListener,
2544 acceptor: tokio_rustls::TlsAcceptor,
2545 default_role: String,
2546 tx: mpsc::Sender<(AuthenticatedTlsStream, SocketAddr)>,
2547 handshake_timeout: Duration,
2548 max_concurrent_handshakes: usize,
2549) {
2550 let inflight = Arc::new(Semaphore::new(max_concurrent_handshakes));
2551 loop {
2552 let Ok(permit) = Arc::clone(&inflight).acquire_owned().await else {
2556 return;
2558 };
2559 let (stream, addr) = match listener.accept().await {
2560 Ok(pair) => pair,
2561 Err(e) => {
2562 tracing::debug!("TCP accept error: {e}");
2563 continue;
2564 }
2565 };
2566 if tx.is_closed() {
2567 return;
2569 }
2570 let acceptor = acceptor.clone();
2571 let default_role = default_role.clone();
2572 let tx = tx.clone();
2573 tokio::spawn(async move {
2574 let _permit = permit;
2575 match tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
2576 Ok(Ok(tls_stream)) => {
2577 let identity =
2578 TlsListener::extract_handshake_identity(&tls_stream, &default_role, addr);
2579 let wrapped = AuthenticatedTlsStream {
2580 inner: tls_stream,
2581 identity,
2582 };
2583 let _ = tx.send((wrapped, addr)).await;
2586 }
2587 Ok(Err(e)) => {
2588 tracing::debug!("TLS handshake failed from {addr}: {e}");
2589 }
2590 Err(_elapsed) => {
2591 tracing::debug!(
2592 "TLS handshake timed out from {addr} after {handshake_timeout:?}"
2593 );
2594 }
2595 }
2596 });
2597 }
2598}
2599
2600pub(crate) struct AuthenticatedTlsStream {
2612 inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2613 identity: Option<AuthIdentity>,
2614}
2615
2616impl AuthenticatedTlsStream {
2617 #[must_use]
2619 pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
2620 self.identity.as_ref()
2621 }
2622}
2623
2624impl std::fmt::Debug for AuthenticatedTlsStream {
2625 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2626 f.debug_struct("AuthenticatedTlsStream")
2627 .field("identity", &self.identity.as_ref().map(|id| &id.name))
2628 .finish_non_exhaustive()
2629 }
2630}
2631
2632impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2633 fn poll_read(
2634 mut self: Pin<&mut Self>,
2635 cx: &mut std::task::Context<'_>,
2636 buf: &mut tokio::io::ReadBuf<'_>,
2637 ) -> std::task::Poll<std::io::Result<()>> {
2638 Pin::new(&mut self.inner).poll_read(cx, buf)
2639 }
2640}
2641
2642impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2643 fn poll_write(
2644 mut self: Pin<&mut Self>,
2645 cx: &mut std::task::Context<'_>,
2646 buf: &[u8],
2647 ) -> std::task::Poll<std::io::Result<usize>> {
2648 Pin::new(&mut self.inner).poll_write(cx, buf)
2649 }
2650
2651 fn poll_flush(
2652 mut self: Pin<&mut Self>,
2653 cx: &mut std::task::Context<'_>,
2654 ) -> std::task::Poll<std::io::Result<()>> {
2655 Pin::new(&mut self.inner).poll_flush(cx)
2656 }
2657
2658 fn poll_shutdown(
2659 mut self: Pin<&mut Self>,
2660 cx: &mut std::task::Context<'_>,
2661 ) -> std::task::Poll<std::io::Result<()>> {
2662 Pin::new(&mut self.inner).poll_shutdown(cx)
2663 }
2664
2665 fn poll_write_vectored(
2666 mut self: Pin<&mut Self>,
2667 cx: &mut std::task::Context<'_>,
2668 bufs: &[std::io::IoSlice<'_>],
2669 ) -> std::task::Poll<std::io::Result<usize>> {
2670 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2671 }
2672
2673 fn is_write_vectored(&self) -> bool {
2674 self.inner.is_write_vectored()
2675 }
2676}
2677
2678impl axum::serve::Listener for TlsListener {
2679 type Io = AuthenticatedTlsStream;
2680 type Addr = SocketAddr;
2681
2682 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2688 if let Some(pair) = self.rx.recv().await {
2689 return pair;
2690 }
2691 tracing::error!("TLS acceptor task terminated; no further connections will be accepted");
2697 std::future::pending().await
2698 }
2699
2700 fn local_addr(&self) -> std::io::Result<Self::Addr> {
2701 Ok(self.local_addr)
2702 }
2703}
2704
2705impl Drop for TlsListener {
2706 fn drop(&mut self) {
2707 self.acceptor_task.abort();
2710 }
2711}
2712
2713fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2714 use rustls::pki_types::pem::PemObject;
2715 let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2716 .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2717 .collect::<Result<_, _>>()
2718 .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2719 anyhow::ensure!(
2720 !certs.is_empty(),
2721 "no certificates found in {}",
2722 path.display()
2723 );
2724 Ok(certs)
2725}
2726
2727fn load_client_auth_roots(
2728 path: &Path,
2729) -> anyhow::Result<(
2730 Vec<rustls::pki_types::CertificateDer<'static>>,
2731 Arc<RootCertStore>,
2732)> {
2733 let ca_certs = load_certs(path)?;
2734 let mut root_store = RootCertStore::empty();
2735 for cert in &ca_certs {
2736 root_store
2737 .add(cert.clone())
2738 .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2739 }
2740
2741 Ok((ca_certs, Arc::new(root_store)))
2742}
2743
2744fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2745 use rustls::pki_types::pem::PemObject;
2746 rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2747 .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2748}
2749
2750#[allow(
2751 clippy::unused_async,
2752 reason = "axum route handler signature requires `async fn` even when the body is synchronous"
2753)]
2754async fn healthz() -> impl IntoResponse {
2755 axum::Json(serde_json::json!({
2756 "status": "ok",
2757 }))
2758}
2759
2760fn version_payload(name: &str, version: &str) -> serde_json::Value {
2767 serde_json::json!({
2768 "name": name,
2769 "version": version,
2770 "build_git_sha": option_env!("RMCP_SERVER_KIT_BUILD_SHA").unwrap_or("unknown"),
2771 "build_timestamp": option_env!("RMCP_SERVER_KIT_BUILD_TIME").unwrap_or("unknown"),
2772 "rust_version": option_env!("RMCP_SERVER_KIT_RUSTC_VERSION").unwrap_or("unknown"),
2773 "mcpx_version": env!("CARGO_PKG_VERSION"),
2774 })
2775}
2776
2777fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2787 let value = version_payload(name, version);
2788 serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2789}
2790
2791async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2792 let status = check().await;
2793 let ready = status
2794 .get("ready")
2795 .and_then(serde_json::Value::as_bool)
2796 .unwrap_or(false);
2797 let code = if ready {
2798 axum::http::StatusCode::OK
2799 } else {
2800 axum::http::StatusCode::SERVICE_UNAVAILABLE
2801 };
2802 (code, axum::Json(status))
2803}
2804
2805async fn shutdown_signal() {
2809 let ctrl_c = tokio::signal::ctrl_c();
2810
2811 #[cfg(unix)]
2812 {
2813 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2814 Ok(mut term) => {
2815 tokio::select! {
2816 _ = ctrl_c => {}
2817 _ = term.recv() => {}
2818 }
2819 }
2820 Err(e) => {
2821 tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2822 ctrl_c.await.ok();
2823 }
2824 }
2825 }
2826
2827 #[cfg(not(unix))]
2828 {
2829 ctrl_c.await.ok();
2830 }
2831}
2832
2833#[cfg(feature = "metrics")]
2839async fn metrics_middleware(
2840 metrics: Arc<crate::metrics::McpMetrics>,
2841 req: Request<Body>,
2842 next: Next,
2843) -> axum::response::Response {
2844 let method = req.method().to_string();
2845 let path = req.uri().path().to_owned();
2846 let start = std::time::Instant::now();
2847
2848 let response = next.run(req).await;
2849
2850 let status = response.status().as_u16().to_string();
2851 let duration = start.elapsed().as_secs_f64();
2852
2853 metrics
2854 .http_requests_total
2855 .with_label_values(&[&method, &path, &status])
2856 .inc();
2857 metrics
2858 .http_request_duration_seconds
2859 .with_label_values(&[&method, &path])
2860 .observe(duration);
2861
2862 response
2863}
2864
2865async fn security_headers_middleware(
2877 is_tls: bool,
2878 cfg: Arc<SecurityHeadersConfig>,
2879 req: Request<Body>,
2880 next: Next,
2881) -> axum::response::Response {
2882 use axum::http::{HeaderName, header};
2883
2884 let mut resp = next.run(req).await;
2885 let headers = resp.headers_mut();
2886
2887 headers.remove(header::SERVER);
2889 headers.remove(HeaderName::from_static("x-powered-by"));
2890
2891 apply_security_header(
2892 headers,
2893 header::X_CONTENT_TYPE_OPTIONS,
2894 cfg.x_content_type_options.as_deref(),
2895 "nosniff",
2896 );
2897 apply_security_header(
2898 headers,
2899 header::X_FRAME_OPTIONS,
2900 cfg.x_frame_options.as_deref(),
2901 "deny",
2902 );
2903 apply_security_header(
2904 headers,
2905 header::CACHE_CONTROL,
2906 cfg.cache_control.as_deref(),
2907 "no-store, max-age=0",
2908 );
2909 apply_security_header(
2910 headers,
2911 header::REFERRER_POLICY,
2912 cfg.referrer_policy.as_deref(),
2913 "no-referrer",
2914 );
2915 apply_security_header(
2916 headers,
2917 HeaderName::from_static("cross-origin-opener-policy"),
2918 cfg.cross_origin_opener_policy.as_deref(),
2919 "same-origin",
2920 );
2921 apply_security_header(
2922 headers,
2923 HeaderName::from_static("cross-origin-resource-policy"),
2924 cfg.cross_origin_resource_policy.as_deref(),
2925 "same-origin",
2926 );
2927 apply_security_header(
2928 headers,
2929 HeaderName::from_static("cross-origin-embedder-policy"),
2930 cfg.cross_origin_embedder_policy.as_deref(),
2931 "require-corp",
2932 );
2933 apply_security_header(
2934 headers,
2935 HeaderName::from_static("permissions-policy"),
2936 cfg.permissions_policy.as_deref(),
2937 "accelerometer=(), camera=(), geolocation=(), microphone=()",
2938 );
2939 apply_security_header(
2940 headers,
2941 HeaderName::from_static("x-permitted-cross-domain-policies"),
2942 cfg.x_permitted_cross_domain_policies.as_deref(),
2943 "none",
2944 );
2945 apply_security_header(
2946 headers,
2947 HeaderName::from_static("content-security-policy"),
2948 cfg.content_security_policy.as_deref(),
2949 "default-src 'none'; frame-ancestors 'none'",
2950 );
2951 apply_security_header(
2952 headers,
2953 HeaderName::from_static("x-dns-prefetch-control"),
2954 cfg.x_dns_prefetch_control.as_deref(),
2955 "off",
2956 );
2957
2958 if is_tls {
2959 apply_security_header(
2960 headers,
2961 header::STRICT_TRANSPORT_SECURITY,
2962 cfg.strict_transport_security.as_deref(),
2963 "max-age=63072000; includeSubDomains",
2964 );
2965 }
2966
2967 resp
2968}
2969
2970fn apply_security_header(
2981 headers: &mut axum::http::HeaderMap,
2982 name: axum::http::HeaderName,
2983 override_value: Option<&str>,
2984 default: &'static str,
2985) {
2986 use axum::http::HeaderValue;
2987
2988 match override_value {
2989 None => {
2990 headers.insert(name, HeaderValue::from_static(default));
2991 }
2992 Some("") => {
2993 }
2995 Some(v) => match HeaderValue::from_str(v) {
2996 Ok(hv) => {
2997 headers.insert(name, hv);
2998 }
2999 Err(err) => {
3000 tracing::error!(
3001 header = %name,
3002 error = %err,
3003 "invalid security header override reached middleware; using default"
3004 );
3005 headers.insert(name, HeaderValue::from_static(default));
3006 }
3007 },
3008 }
3009}
3010
3011fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
3022 use axum::http::HeaderValue;
3023
3024 let fields: &[(&str, Option<&str>)] = &[
3025 (
3026 "x_content_type_options",
3027 cfg.x_content_type_options.as_deref(),
3028 ),
3029 ("x_frame_options", cfg.x_frame_options.as_deref()),
3030 ("cache_control", cfg.cache_control.as_deref()),
3031 ("referrer_policy", cfg.referrer_policy.as_deref()),
3032 (
3033 "cross_origin_opener_policy",
3034 cfg.cross_origin_opener_policy.as_deref(),
3035 ),
3036 (
3037 "cross_origin_resource_policy",
3038 cfg.cross_origin_resource_policy.as_deref(),
3039 ),
3040 (
3041 "cross_origin_embedder_policy",
3042 cfg.cross_origin_embedder_policy.as_deref(),
3043 ),
3044 ("permissions_policy", cfg.permissions_policy.as_deref()),
3045 (
3046 "x_permitted_cross_domain_policies",
3047 cfg.x_permitted_cross_domain_policies.as_deref(),
3048 ),
3049 (
3050 "content_security_policy",
3051 cfg.content_security_policy.as_deref(),
3052 ),
3053 (
3054 "x_dns_prefetch_control",
3055 cfg.x_dns_prefetch_control.as_deref(),
3056 ),
3057 (
3058 "strict_transport_security",
3059 cfg.strict_transport_security.as_deref(),
3060 ),
3061 ];
3062
3063 for (field, value) in fields {
3064 let Some(v) = value else { continue };
3065 if v.is_empty() {
3066 continue;
3067 }
3068 if let Err(err) = HeaderValue::from_str(v) {
3069 return Err(McpxError::Config(format!(
3070 "invalid security_headers.{field}: {err}"
3071 )));
3072 }
3073 }
3074
3075 if let Some(v) = cfg.strict_transport_security.as_deref()
3076 && !v.is_empty()
3077 && v.to_ascii_lowercase().contains("preload")
3078 {
3079 return Err(McpxError::Config(format!(
3080 "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
3081 HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
3082 )));
3083 }
3084
3085 Ok(())
3086}
3087
3088#[cfg(feature = "oauth")]
3103async fn oauth_token_cache_headers_middleware(
3104 req: Request<Body>,
3105 next: Next,
3106) -> axum::response::Response {
3107 use axum::http::{HeaderValue, header};
3108
3109 let mut resp = next.run(req).await;
3110 let headers = resp.headers_mut();
3111 headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
3112 headers.append(header::VARY, HeaderValue::from_static("Authorization"));
3113 resp
3114}
3115
3116async fn normalize_peer_addr_middleware(
3145 resolver: Option<Arc<ForwardResolver>>,
3146 mut req: Request<Body>,
3147 next: Next,
3148) -> axum::response::Response {
3149 let direct = req
3150 .extensions()
3151 .get::<ConnectInfo<SocketAddr>>()
3152 .map(|ci| ci.0);
3153 let from_tls = req
3154 .extensions()
3155 .get::<ConnectInfo<TlsConnInfo>>()
3156 .map(|ci| ci.0.addr);
3157 if let Some(addr) = direct.or(from_tls) {
3158 if direct.is_none() {
3159 req.extensions_mut().insert(ConnectInfo(addr));
3160 }
3161 req.extensions_mut().insert(PeerAddr::new(addr));
3162 let client_ip = match &resolver {
3163 Some(r) => {
3164 crate::forwarded::resolve_client_ip(addr.ip(), req.headers(), &r.trusted, r.mode)
3165 .unwrap_or_else(|reason| {
3166 tracing::debug!(
3167 reason = ?reason,
3168 "forwarded-header resolution fell back to direct peer"
3169 );
3170 addr.ip()
3171 })
3172 }
3173 None => addr.ip(),
3174 };
3175 req.extensions_mut().insert(ClientIp::new(client_ip));
3176 }
3177 next.run(req).await
3178}
3179
3180fn parse_proxy_net(entry: &str) -> Option<ipnet::IpNet> {
3183 if let Ok(net) = entry.parse::<ipnet::IpNet>() {
3184 return Some(net);
3185 }
3186 entry.parse::<IpAddr>().ok().map(ipnet::IpNet::from)
3187}
3188
3189pub(crate) fn limiter_client_ip(extensions: &axum::http::Extensions) -> Option<IpAddr> {
3193 if let Some(client) = extensions.get::<ClientIp>() {
3194 return Some(client.ip);
3195 }
3196 extensions
3197 .get::<ConnectInfo<SocketAddr>>()
3198 .map(|ci| ci.0.ip())
3199 .or_else(|| {
3200 extensions
3201 .get::<ConnectInfo<TlsConnInfo>>()
3202 .map(|ci| ci.0.addr.ip())
3203 })
3204}
3205
3206pub(crate) type ExtraRouteRateLimiter = BoundedKeyedLimiter<IpAddr>;
3210
3211const EXTRA_ROUTE_MAX_TRACKED_KEYS: usize = 10_000;
3217
3218const EXTRA_ROUTE_IDLE_EVICTION: Duration = Duration::from_mins(15);
3221
3222fn build_extra_route_rate_limiter(
3229 per_minute: u32,
3230 burst: Option<u32>,
3231) -> Arc<ExtraRouteRateLimiter> {
3232 let rate = std::num::NonZeroU32::new(per_minute.max(1)).unwrap_or(std::num::NonZeroU32::MIN);
3233 let mut quota = governor::Quota::per_minute(rate);
3234 if let Some(b) = burst.and_then(std::num::NonZeroU32::new) {
3235 quota = quota.allow_burst(b);
3236 }
3237 Arc::new(BoundedKeyedLimiter::new(
3238 quota,
3239 EXTRA_ROUTE_MAX_TRACKED_KEYS,
3240 EXTRA_ROUTE_IDLE_EVICTION,
3241 ))
3242}
3243
3244async fn extra_route_rate_limit_middleware(
3261 limiter: Arc<ExtraRouteRateLimiter>,
3262 req: Request<Body>,
3263 next: Next,
3264) -> axum::response::Response {
3265 let peer_ip: Option<IpAddr> = limiter_client_ip(req.extensions());
3266 if let Some(ip) = peer_ip
3267 && let Err(wait) = limiter.check_key_wait(&ip)
3268 {
3269 tracing::warn!(%ip, "extra route request rate limited");
3270 return McpxError::RateLimitedFor {
3271 message: "too many requests to application routes from this source".into(),
3272 retry_after: wait,
3273 }
3274 .into_response();
3275 }
3276 next.run(req).await
3277}
3278
3279async fn origin_check_middleware(
3283 allowed: Arc<[String]>,
3284 log_request_headers: bool,
3285 req: Request<Body>,
3286 next: Next,
3287) -> axum::response::Response {
3288 let method = req.method().clone();
3289 let path = req.uri().path().to_owned();
3290
3291 log_incoming_request(&method, &path, req.headers(), log_request_headers);
3292
3293 if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
3294 let origin_str = origin.to_str().unwrap_or("");
3295 if !allowed.iter().any(|a| a == origin_str) {
3296 tracing::warn!(
3297 origin = origin_str,
3298 %method,
3299 %path,
3300 allowed = ?&*allowed,
3301 "rejected request: Origin not allowed"
3302 );
3303 return (
3304 axum::http::StatusCode::FORBIDDEN,
3305 "Forbidden: Origin not allowed",
3306 )
3307 .into_response();
3308 }
3309 }
3310 next.run(req).await
3311}
3312
3313fn log_incoming_request(
3316 method: &axum::http::Method,
3317 path: &str,
3318 headers: &axum::http::HeaderMap,
3319 log_request_headers: bool,
3320) {
3321 if log_request_headers {
3322 tracing::debug!(
3323 %method,
3324 %path,
3325 headers = %format_request_headers_for_log(headers),
3326 "incoming request"
3327 );
3328 } else {
3329 tracing::debug!(%method, %path, "incoming request");
3330 }
3331}
3332
3333fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
3334 headers
3335 .iter()
3336 .map(|(k, v)| {
3337 let name = k.as_str();
3338 if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
3339 format!("{name}: [REDACTED]")
3340 } else {
3341 format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
3342 }
3343 })
3344 .collect::<Vec<_>>()
3345 .join(", ")
3346}
3347
3348#[allow(
3372 clippy::cognitive_complexity,
3373 reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
3374)]
3375pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
3376where
3377 H: ServerHandler + 'static,
3378{
3379 use rmcp::ServiceExt as _;
3380
3381 tracing::info!("stdio transport: serving on stdin/stdout");
3382 tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
3383
3384 let transport = rmcp::transport::io::stdio();
3385
3386 let service = handler
3387 .serve(transport)
3388 .await
3389 .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
3390
3391 if let Err(e) = service.waiting().await {
3392 tracing::warn!(error = %e, "stdio session ended with error");
3393 }
3394 tracing::info!("stdio session ended");
3395 Ok(())
3396}
3397
3398#[cfg(test)]
3399mod tests {
3400 #![allow(
3401 clippy::unwrap_used,
3402 clippy::expect_used,
3403 clippy::panic,
3404 clippy::indexing_slicing,
3405 clippy::unwrap_in_result,
3406 clippy::print_stdout,
3407 clippy::print_stderr,
3408 deprecated,
3409 reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
3410 )]
3411 use std::{sync::Arc, time::Duration};
3412
3413 use axum::{
3414 body::Body,
3415 http::{Request, StatusCode, header},
3416 response::IntoResponse,
3417 };
3418 use http_body_util::BodyExt;
3419 use tower::ServiceExt as _;
3420
3421 use super::*;
3422
3423 #[test]
3426 fn server_config_new_defaults() {
3427 let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
3428 assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
3429 assert_eq!(cfg.name, "test-server");
3430 assert_eq!(cfg.version, "1.0.0");
3431 assert!(cfg.tls_cert_path.is_none());
3432 assert!(cfg.tls_key_path.is_none());
3433 assert!(cfg.auth.is_none());
3434 assert!(cfg.rbac.is_none());
3435 assert!(cfg.allowed_origins.is_empty());
3436 assert!(cfg.tool_rate_limit.is_none());
3437 assert!(cfg.readiness_check.is_none());
3438 assert_eq!(cfg.max_request_body, 1024 * 1024);
3439 assert_eq!(cfg.request_timeout, Duration::from_mins(2));
3440 assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
3441 assert!(!cfg.log_request_headers);
3442 assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(10));
3443 assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
3444 }
3445
3446 #[test]
3447 fn tls_handshake_builders_set_fields() {
3448 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3449 .with_tls_handshake_timeout(Duration::from_secs(3))
3450 .with_max_concurrent_tls_handshakes(64);
3451 assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(3));
3452 assert_eq!(cfg.max_concurrent_tls_handshakes, 64);
3453 }
3454
3455 #[test]
3456 fn validate_rejects_zero_tls_handshake_timeout() {
3457 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3458 .with_tls_handshake_timeout(Duration::ZERO);
3459 let err = cfg.validate().expect_err("zero handshake timeout");
3460 assert!(err.to_string().contains("tls_handshake_timeout"));
3461 }
3462
3463 #[test]
3464 fn validate_rejects_zero_max_concurrent_tls_handshakes() {
3465 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3466 .with_max_concurrent_tls_handshakes(0);
3467 let err = cfg.validate().expect_err("zero handshake concurrency");
3468 assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
3469 }
3470
3471 #[test]
3472 fn validate_consumes_and_proves() {
3473 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3475 let validated = cfg.validate().expect("valid config");
3476 assert_eq!(validated.as_inner().name, "test-server");
3478 let raw = validated.into_inner();
3480 assert_eq!(raw.name, "test-server");
3481
3482 let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3484 bad.max_request_body = 0;
3485 assert!(bad.validate().is_err(), "zero body cap must fail validate");
3486 }
3487
3488 #[test]
3489 fn validate_rejects_zero_max_concurrent_requests() {
3490 let cfg =
3491 McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
3492 let err = cfg.validate().expect_err("zero concurrency cap must fail");
3493 assert!(
3494 format!("{err}").contains("max_concurrent_requests"),
3495 "error should mention max_concurrent_requests, got: {err}"
3496 );
3497 }
3498
3499 #[test]
3500 fn validate_rejects_zero_max_tracked_keys() {
3501 let rl = crate::auth::RateLimitConfig {
3504 max_attempts_per_minute: 30,
3505 pre_auth_max_per_minute: None,
3506 max_tracked_keys: 0,
3507 idle_eviction: Duration::from_secs(15 * 60),
3508 burst: None,
3509 pre_auth_burst: None,
3510 };
3511 let auth_cfg = AuthConfig {
3512 enabled: true,
3513 api_keys: Vec::new(),
3514 mtls: None,
3515 rate_limit: Some(rl),
3516 #[cfg(feature = "oauth")]
3517 oauth: None,
3518 };
3519 let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
3520 let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
3521 assert!(
3522 format!("{err}").contains("max_tracked_keys"),
3523 "error should mention max_tracked_keys, got: {err}"
3524 );
3525 }
3526
3527 #[test]
3528 fn derive_allowed_hosts_includes_public_host() {
3529 let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
3530 assert!(
3531 hosts.iter().any(|h| h == "mcp.example.com"),
3532 "public_url host must be allowed"
3533 );
3534 }
3535
3536 #[test]
3537 fn derive_allowed_hosts_includes_bind_authority() {
3538 let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
3539 assert!(
3540 hosts.iter().any(|h| h == "127.0.0.1"),
3541 "bind host must be allowed"
3542 );
3543 assert!(
3544 hosts.iter().any(|h| h == "127.0.0.1:8080"),
3545 "bind authority must be allowed"
3546 );
3547 }
3548
3549 #[tokio::test]
3552 async fn healthz_returns_ok_json() {
3553 let resp = healthz().await.into_response();
3554 assert_eq!(resp.status(), StatusCode::OK);
3555 let body = resp.into_body().collect().await.unwrap().to_bytes();
3556 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3557 assert_eq!(json["status"], "ok");
3558 assert!(
3559 json.get("name").is_none(),
3560 "healthz must not expose server name"
3561 );
3562 assert!(
3563 json.get("version").is_none(),
3564 "healthz must not expose version"
3565 );
3566 }
3567
3568 #[tokio::test]
3571 async fn readyz_returns_ok_when_ready() {
3572 let check: ReadinessCheck =
3573 Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
3574 let resp = readyz(check).await.into_response();
3575 assert_eq!(resp.status(), StatusCode::OK);
3576 let body = resp.into_body().collect().await.unwrap().to_bytes();
3577 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3578 assert_eq!(json["ready"], true);
3579 assert!(
3580 json.get("name").is_none(),
3581 "readyz must not expose server name"
3582 );
3583 assert!(
3584 json.get("version").is_none(),
3585 "readyz must not expose version"
3586 );
3587 assert_eq!(json["db"], "connected");
3588 }
3589
3590 #[tokio::test]
3591 async fn readyz_returns_503_when_not_ready() {
3592 let check: ReadinessCheck =
3593 Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
3594 let resp = readyz(check).await.into_response();
3595 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3596 }
3597
3598 #[tokio::test]
3599 async fn readyz_returns_503_when_ready_missing() {
3600 let check: ReadinessCheck =
3601 Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
3602 let resp = readyz(check).await.into_response();
3603 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3605 }
3606
3607 fn peer_probe_router() -> axum::Router {
3612 async fn probe(req: Request<Body>) -> String {
3613 let ci = req
3614 .extensions()
3615 .get::<ConnectInfo<SocketAddr>>()
3616 .map(|c| c.0.to_string())
3617 .unwrap_or_default();
3618 let pa = req
3619 .extensions()
3620 .get::<PeerAddr>()
3621 .map(|p| p.addr.to_string())
3622 .unwrap_or_default();
3623 format!("{ci}|{pa}")
3624 }
3625 axum::Router::new()
3626 .route("/probe", axum::routing::get(probe))
3627 .layer(axum::middleware::from_fn(|req, next| {
3628 normalize_peer_addr_middleware(None, req, next)
3629 }))
3630 }
3631
3632 async fn body_string(resp: axum::response::Response) -> String {
3633 let bytes = resp.into_body().collect().await.unwrap().to_bytes();
3634 String::from_utf8(bytes.to_vec()).unwrap()
3635 }
3636
3637 #[tokio::test]
3638 async fn normalize_preserves_existing_connect_info_and_mirrors_peer_addr() {
3639 let plain: SocketAddr = "10.0.0.1:1111".parse().unwrap();
3642 let tls: SocketAddr = "10.0.0.2:2222".parse().unwrap();
3643 let req = Request::builder()
3644 .uri("/probe")
3645 .extension(ConnectInfo(plain))
3646 .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3647 .body(Body::empty())
3648 .unwrap();
3649 let resp = peer_probe_router().oneshot(req).await.unwrap();
3650 assert_eq!(resp.status(), StatusCode::OK);
3651 assert_eq!(body_string(resp).await, format!("{plain}|{plain}"));
3652 }
3653
3654 #[tokio::test]
3655 async fn normalize_inserts_connect_info_and_peer_addr_from_tls() {
3656 let tls: SocketAddr = "192.168.1.7:50443".parse().unwrap();
3657 let req = Request::builder()
3658 .uri("/probe")
3659 .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3660 .body(Body::empty())
3661 .unwrap();
3662 let resp = peer_probe_router().oneshot(req).await.unwrap();
3663 assert_eq!(resp.status(), StatusCode::OK);
3664 assert_eq!(body_string(resp).await, format!("{tls}|{tls}"));
3665 }
3666
3667 #[tokio::test]
3668 async fn normalize_no_op_without_any_connect_info() {
3669 let req = Request::builder()
3670 .uri("/probe")
3671 .body(Body::empty())
3672 .unwrap();
3673 let resp = peer_probe_router().oneshot(req).await.unwrap();
3674 assert_eq!(resp.status(), StatusCode::OK);
3675 assert_eq!(body_string(resp).await, "|");
3676 }
3677
3678 #[tokio::test]
3679 async fn peer_addr_extractor_rejects_when_absent() {
3680 async fn h(peer: PeerAddr) -> String {
3681 peer.addr.to_string()
3682 }
3683 let app = axum::Router::new().route("/p", axum::routing::get(h));
3684 let req = Request::builder().uri("/p").body(Body::empty()).unwrap();
3685 let resp = app.oneshot(req).await.unwrap();
3686 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
3687 }
3688
3689 #[tokio::test]
3690 async fn peer_addr_extractor_returns_value_when_present() {
3691 async fn h(peer: PeerAddr) -> String {
3692 peer.addr.to_string()
3693 }
3694 let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap();
3695 let app = axum::Router::new().route("/p", axum::routing::get(h));
3696 let req = Request::builder()
3697 .uri("/p")
3698 .extension(PeerAddr::new(addr))
3699 .body(Body::empty())
3700 .unwrap();
3701 let resp = app.oneshot(req).await.unwrap();
3702 assert_eq!(resp.status(), StatusCode::OK);
3703 assert_eq!(body_string(resp).await, addr.to_string());
3704 }
3705
3706 #[tokio::test]
3707 async fn peer_addr_via_extension_extractor() {
3708 async fn h(axum::Extension(peer): axum::Extension<PeerAddr>) -> String {
3709 peer.addr.to_string()
3710 }
3711 let addr: SocketAddr = "127.0.0.1:4242".parse().unwrap();
3712 let app = axum::Router::new().route("/p", axum::routing::get(h));
3713 let req = Request::builder()
3714 .uri("/p")
3715 .extension(PeerAddr::new(addr))
3716 .body(Body::empty())
3717 .unwrap();
3718 let resp = app.oneshot(req).await.unwrap();
3719 assert_eq!(resp.status(), StatusCode::OK);
3720 assert_eq!(body_string(resp).await, addr.to_string());
3721 }
3722
3723 fn limited_router(per_minute: u32) -> axum::Router {
3728 limited_router_with_burst(per_minute, None)
3729 }
3730
3731 fn limited_router_with_burst(per_minute: u32, burst: Option<u32>) -> axum::Router {
3733 let limiter = build_extra_route_rate_limiter(per_minute, burst);
3734 axum::Router::new()
3735 .route("/limited", axum::routing::get(|| async { "ok" }))
3736 .layer(axum::middleware::from_fn(move |req, next| {
3737 let l = Arc::clone(&limiter);
3738 extra_route_rate_limit_middleware(l, req, next)
3739 }))
3740 }
3741
3742 fn limited_req(ip: &str) -> Request<Body> {
3743 let addr: SocketAddr = format!("{ip}:40000").parse().unwrap();
3744 Request::builder()
3745 .uri("/limited")
3746 .extension(ConnectInfo(addr))
3747 .body(Body::empty())
3748 .unwrap()
3749 }
3750
3751 #[tokio::test]
3752 async fn extra_route_limiter_denies_over_quota() {
3753 let app = limited_router(2);
3754 for i in 0..2 {
3755 let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3756 assert_eq!(resp.status(), StatusCode::OK, "request {i} should pass");
3757 }
3758 let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3759 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3760 let body = body_string(resp).await;
3761 assert!(
3762 body.contains("too many requests to application routes"),
3763 "deny body should match the limiter message, got: {body}"
3764 );
3765 }
3766
3767 #[tokio::test]
3768 async fn extra_route_limiter_isolates_keys() {
3769 let app = limited_router(2);
3770 for _ in 0..2 {
3771 let resp = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3772 assert_eq!(resp.status(), StatusCode::OK);
3773 }
3774 let exhausted = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3775 assert_eq!(exhausted.status(), StatusCode::TOO_MANY_REQUESTS);
3776 let other = app.clone().oneshot(limited_req("10.3.3.3")).await.unwrap();
3778 assert_eq!(other.status(), StatusCode::OK);
3779 }
3780
3781 #[tokio::test]
3782 async fn extra_route_limiter_fails_open_without_peer() {
3783 let app = limited_router(1);
3784 for i in 0..3 {
3785 let req = Request::builder()
3786 .uri("/limited")
3787 .body(Body::empty())
3788 .unwrap();
3789 let resp = app.clone().oneshot(req).await.unwrap();
3790 assert_eq!(
3791 resp.status(),
3792 StatusCode::OK,
3793 "request {i} should fail open"
3794 );
3795 }
3796 }
3797
3798 #[tokio::test]
3799 async fn extra_route_limiter_extracts_tls_conn_info() {
3800 let app = limited_router(2);
3801 let mk = || {
3802 let addr: SocketAddr = "192.168.9.9:55555".parse().unwrap();
3803 Request::builder()
3804 .uri("/limited")
3805 .extension(ConnectInfo(TlsConnInfo::new(addr, None)))
3806 .body(Body::empty())
3807 .unwrap()
3808 };
3809 for _ in 0..2 {
3810 assert_eq!(
3811 app.clone().oneshot(mk()).await.unwrap().status(),
3812 StatusCode::OK
3813 );
3814 }
3815 let resp = app.clone().oneshot(mk()).await.unwrap();
3816 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3817 }
3818
3819 #[test]
3820 fn validate_rejects_zero_extra_route_rate_limit() {
3821 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3822 .with_extra_route_rate_limit(0);
3823 let err = cfg.validate().expect_err("zero extra route rate limit");
3824 assert!(err.to_string().contains("extra_route_rate_limit"));
3825 }
3826
3827 #[tokio::test]
3828 async fn extra_route_limiter_burst_allows_initial_spike() {
3829 let app = limited_router_with_burst(1, Some(3));
3830 for i in 0..3 {
3831 let resp = app.clone().oneshot(limited_req("10.4.4.4")).await.unwrap();
3832 assert_eq!(resp.status(), StatusCode::OK, "burst request {i}");
3833 }
3834 let resp = app.clone().oneshot(limited_req("10.4.4.4")).await.unwrap();
3835 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3836 }
3837
3838 #[tokio::test]
3839 async fn extra_route_limiter_deny_sets_retry_after() {
3840 let app = limited_router(1);
3841 let ok = app.clone().oneshot(limited_req("10.5.5.5")).await.unwrap();
3842 assert_eq!(ok.status(), StatusCode::OK);
3843 let denied = app.clone().oneshot(limited_req("10.5.5.5")).await.unwrap();
3844 assert_eq!(denied.status(), StatusCode::TOO_MANY_REQUESTS);
3845 let retry_after = denied
3846 .headers()
3847 .get(header::RETRY_AFTER)
3848 .expect("Retry-After present")
3849 .to_str()
3850 .unwrap()
3851 .parse::<u64>()
3852 .unwrap();
3853 assert!(retry_after >= 1, "delta-seconds must be >= 1");
3854 }
3855
3856 #[test]
3857 fn validate_rejects_zero_burst_knobs() {
3858 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3859 .with_tool_rate_limit(10)
3860 .with_tool_rate_limit_burst(0)
3861 .validate()
3862 .expect_err("zero tool burst");
3863 assert!(err.to_string().contains("tool_rate_limit_burst"));
3864
3865 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3866 .with_extra_route_rate_limit(10)
3867 .with_extra_route_rate_limit_burst(0)
3868 .validate()
3869 .expect_err("zero extra route burst");
3870 assert!(err.to_string().contains("extra_route_rate_limit_burst"));
3871 }
3872
3873 #[test]
3874 fn validate_rejects_orphan_burst_knobs() {
3875 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3876 .with_tool_rate_limit_burst(5)
3877 .validate()
3878 .expect_err("orphan tool burst");
3879 assert!(err.to_string().contains("requires tool_rate_limit"));
3880
3881 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3882 .with_extra_route_rate_limit_burst(5)
3883 .validate()
3884 .expect_err("orphan extra route burst");
3885 assert!(err.to_string().contains("requires extra_route_rate_limit"));
3886 }
3887
3888 #[test]
3889 fn validate_rejects_zero_auth_bursts() {
3890 let auth = AuthConfig::with_keys(vec![])
3891 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_burst(0));
3892 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3893 .with_auth(auth)
3894 .validate()
3895 .expect_err("zero auth burst");
3896 assert!(err.to_string().contains("rate_limit.burst"));
3897
3898 let auth = AuthConfig::with_keys(vec![])
3899 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(0));
3900 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3901 .with_auth(auth)
3902 .validate()
3903 .expect_err("zero pre-auth burst");
3904 assert!(err.to_string().contains("pre_auth_burst"));
3905 }
3906
3907 #[test]
3910 fn validate_accepts_pre_auth_burst_without_explicit_pre_auth_rate() {
3911 let auth = AuthConfig::with_keys(vec![])
3912 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(50));
3913 let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0").with_auth(auth);
3914 assert!(cfg.validate().is_ok(), "pre_auth_burst has no orphan rule");
3915 }
3916
3917 fn forward_resolver(trusted: &[&str], mode: ForwardedHeaderMode) -> Arc<ForwardResolver> {
3920 Arc::new(ForwardResolver {
3921 trusted: trusted.iter().map(|s| s.parse().unwrap()).collect(),
3922 mode,
3923 })
3924 }
3925
3926 fn forwarded_probe_router(resolver: Option<Arc<ForwardResolver>>) -> axum::Router {
3928 async fn probe(req: Request<Body>) -> String {
3929 let pa = req
3930 .extensions()
3931 .get::<PeerAddr>()
3932 .map(|p| p.addr.ip().to_string())
3933 .unwrap_or_default();
3934 let ci = req
3935 .extensions()
3936 .get::<ClientIp>()
3937 .map(|c| c.ip.to_string())
3938 .unwrap_or_default();
3939 format!("{pa}|{ci}")
3940 }
3941 axum::Router::new()
3942 .route("/probe", axum::routing::get(probe))
3943 .layer(axum::middleware::from_fn(move |req, next| {
3944 let r = resolver.clone();
3945 normalize_peer_addr_middleware(r, req, next)
3946 }))
3947 }
3948
3949 fn probe_req(peer: &str, header: Option<(&str, &str)>) -> Request<Body> {
3950 let addr: SocketAddr = peer.parse().unwrap();
3951 let mut builder = Request::builder()
3952 .uri("/probe")
3953 .extension(ConnectInfo(addr));
3954 if let Some((name, value)) = header {
3955 builder = builder.header(name, value);
3956 }
3957 builder.body(Body::empty()).unwrap()
3958 }
3959
3960 #[tokio::test]
3961 async fn client_ip_equals_direct_without_resolver() {
3962 let app = forwarded_probe_router(None);
3963 let resp = app
3964 .oneshot(probe_req(
3965 "10.1.2.3:4444",
3966 Some(("x-forwarded-for", "203.0.113.7")),
3967 ))
3968 .await
3969 .unwrap();
3970 assert_eq!(
3971 body_string(resp).await,
3972 "10.1.2.3|10.1.2.3",
3973 "feature off: header ignored, ClientIp == direct"
3974 );
3975 }
3976
3977 #[tokio::test]
3978 async fn client_ip_resolved_for_trusted_peer() {
3979 let app = forwarded_probe_router(Some(forward_resolver(
3980 &["10.0.0.0/8"],
3981 ForwardedHeaderMode::XForwardedFor,
3982 )));
3983 let resp = app
3984 .oneshot(probe_req(
3985 "10.0.0.1:9999",
3986 Some(("x-forwarded-for", "203.0.113.7")),
3987 ))
3988 .await
3989 .unwrap();
3990 assert_eq!(
3991 body_string(resp).await,
3992 "10.0.0.1|203.0.113.7",
3993 "PeerAddr stays direct while ClientIp resolves"
3994 );
3995 }
3996
3997 #[tokio::test]
3998 async fn client_ip_falls_back_to_direct_on_malformed_header() {
3999 let app = forwarded_probe_router(Some(forward_resolver(
4000 &["10.0.0.0/8"],
4001 ForwardedHeaderMode::XForwardedFor,
4002 )));
4003 let resp = app
4004 .oneshot(probe_req(
4005 "10.0.0.1:9999",
4006 Some(("x-forwarded-for", "not-an-ip")),
4007 ))
4008 .await
4009 .unwrap();
4010 assert_eq!(
4011 body_string(resp).await,
4012 "10.0.0.1|10.0.0.1",
4013 "malformed chain falls back to the direct peer"
4014 );
4015 }
4016
4017 #[test]
4018 fn forwarded_header_mode_deserializes_kebab_case() {
4019 #[derive(serde::Deserialize)]
4020 struct Wrapper {
4021 mode: ForwardedHeaderMode,
4022 }
4023 let w: Wrapper = toml::from_str(r#"mode = "x-forwarded-for""#).unwrap();
4024 assert_eq!(w.mode, ForwardedHeaderMode::XForwardedFor);
4025 let w: Wrapper = toml::from_str(r#"mode = "forwarded""#).unwrap();
4026 assert_eq!(w.mode, ForwardedHeaderMode::Forwarded);
4027 assert!(
4028 toml::from_str::<Wrapper>(r#"mode = "XForwardedFor""#).is_err(),
4029 "PascalCase wire value must be rejected"
4030 );
4031 }
4032
4033 #[test]
4034 fn validate_rejects_bad_trusted_proxy_entry() {
4035 let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4036 .with_trusted_proxies(["not-a-cidr"]);
4037 let err = cfg.validate().expect_err("bad CIDR");
4038 assert!(err.to_string().contains("trusted_proxies"));
4039 }
4040
4041 #[test]
4042 fn validate_accepts_cidr_and_bare_ip_proxy_entries() {
4043 let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0").with_trusted_proxies([
4044 "10.0.0.0/8",
4045 "192.0.2.1",
4046 "2001:db8::1",
4047 ]);
4048 assert!(cfg.validate().is_ok(), "CIDRs and bare IPs are accepted");
4049 }
4050
4051 #[test]
4052 fn validate_rejects_forwarded_header_without_proxies() {
4053 let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4054 .with_forwarded_header(ForwardedHeaderMode::Forwarded);
4055 let err = cfg.validate().expect_err("mode without proxies");
4056 assert!(err.to_string().contains("requires trusted_proxies"));
4057 }
4058
4059 fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
4063 let allowed: Arc<[String]> = Arc::from(origins);
4064 axum::Router::new()
4065 .route("/test", axum::routing::get(|| async { "ok" }))
4066 .layer(axum::middleware::from_fn(move |req, next| {
4067 let a = Arc::clone(&allowed);
4068 origin_check_middleware(a, log_request_headers, req, next)
4069 }))
4070 }
4071
4072 #[tokio::test]
4073 async fn origin_allowed_passes() {
4074 let app = origin_router(vec!["http://localhost:3000".into()], false);
4075 let req = Request::builder()
4076 .uri("/test")
4077 .header(header::ORIGIN, "http://localhost:3000")
4078 .body(Body::empty())
4079 .unwrap();
4080 let resp = app.oneshot(req).await.unwrap();
4081 assert_eq!(resp.status(), StatusCode::OK);
4082 }
4083
4084 #[tokio::test]
4085 async fn origin_rejected_returns_403() {
4086 let app = origin_router(vec!["http://localhost:3000".into()], false);
4087 let req = Request::builder()
4088 .uri("/test")
4089 .header(header::ORIGIN, "http://evil.com")
4090 .body(Body::empty())
4091 .unwrap();
4092 let resp = app.oneshot(req).await.unwrap();
4093 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
4094 }
4095
4096 #[tokio::test]
4097 async fn no_origin_header_passes() {
4098 let app = origin_router(vec!["http://localhost:3000".into()], false);
4099 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4100 let resp = app.oneshot(req).await.unwrap();
4101 assert_eq!(resp.status(), StatusCode::OK);
4102 }
4103
4104 #[tokio::test]
4105 async fn empty_allowlist_rejects_any_origin() {
4106 let app = origin_router(vec![], false);
4107 let req = Request::builder()
4108 .uri("/test")
4109 .header(header::ORIGIN, "http://anything.com")
4110 .body(Body::empty())
4111 .unwrap();
4112 let resp = app.oneshot(req).await.unwrap();
4113 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
4114 }
4115
4116 #[tokio::test]
4117 async fn empty_allowlist_passes_without_origin() {
4118 let app = origin_router(vec![], false);
4119 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4120 let resp = app.oneshot(req).await.unwrap();
4121 assert_eq!(resp.status(), StatusCode::OK);
4122 }
4123
4124 #[test]
4125 fn format_request_headers_redacts_sensitive_values() {
4126 let mut headers = axum::http::HeaderMap::new();
4127 headers.insert("authorization", "Bearer secret-token".parse().unwrap());
4128 headers.insert("cookie", "sid=abc".parse().unwrap());
4129 headers.insert("x-request-id", "req-123".parse().unwrap());
4130
4131 let out = format_request_headers_for_log(&headers);
4132 assert!(out.contains("authorization: [REDACTED]"));
4133 assert!(out.contains("cookie: [REDACTED]"));
4134 assert!(out.contains("x-request-id: req-123"));
4135 assert!(!out.contains("secret-token"));
4136 }
4137
4138 fn security_router(is_tls: bool) -> axum::Router {
4141 security_router_with(is_tls, SecurityHeadersConfig::default())
4142 }
4143
4144 fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
4145 let cfg = Arc::new(cfg);
4146 axum::Router::new()
4147 .route("/test", axum::routing::get(|| async { "ok" }))
4148 .layer(axum::middleware::from_fn(move |req, next| {
4149 let c = Arc::clone(&cfg);
4150 security_headers_middleware(is_tls, c, req, next)
4151 }))
4152 }
4153
4154 #[tokio::test]
4155 async fn security_headers_set_on_response() {
4156 let app = security_router(false);
4157 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4158 let resp = app.oneshot(req).await.unwrap();
4159 assert_eq!(resp.status(), StatusCode::OK);
4160
4161 let h = resp.headers();
4162 assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
4163 assert_eq!(h.get("x-frame-options").unwrap(), "deny");
4164 assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
4165 assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
4166 assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
4167 assert_eq!(
4168 h.get("cross-origin-resource-policy").unwrap(),
4169 "same-origin"
4170 );
4171 assert_eq!(
4172 h.get("cross-origin-embedder-policy").unwrap(),
4173 "require-corp"
4174 );
4175 assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
4176 assert!(
4177 h.get("permissions-policy")
4178 .unwrap()
4179 .to_str()
4180 .unwrap()
4181 .contains("camera=()"),
4182 "permissions-policy must restrict browser features"
4183 );
4184 assert_eq!(
4185 h.get("content-security-policy").unwrap(),
4186 "default-src 'none'; frame-ancestors 'none'"
4187 );
4188 assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
4189 assert!(h.get("strict-transport-security").is_none());
4191 }
4192
4193 #[tokio::test]
4194 async fn hsts_set_when_tls_enabled() {
4195 let app = security_router(true);
4196 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4197 let resp = app.oneshot(req).await.unwrap();
4198
4199 let hsts = resp.headers().get("strict-transport-security").unwrap();
4200 assert!(
4201 hsts.to_str().unwrap().contains("max-age=63072000"),
4202 "HSTS must set 2-year max-age"
4203 );
4204 }
4205
4206 fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
4212 let cfg =
4213 McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
4214 cfg.check()
4215 }
4216
4217 #[test]
4218 fn security_headers_config_default_validates() {
4219 check_with_security_headers(SecurityHeadersConfig::default())
4220 .expect("default SecurityHeadersConfig must validate");
4221 }
4222
4223 #[test]
4224 fn security_headers_config_validate_accepts_empty_string() {
4225 let h = SecurityHeadersConfig {
4227 x_content_type_options: Some(String::new()),
4228 x_frame_options: Some(String::new()),
4229 cache_control: Some(String::new()),
4230 referrer_policy: Some(String::new()),
4231 cross_origin_opener_policy: Some(String::new()),
4232 cross_origin_resource_policy: Some(String::new()),
4233 cross_origin_embedder_policy: Some(String::new()),
4234 permissions_policy: Some(String::new()),
4235 x_permitted_cross_domain_policies: Some(String::new()),
4236 content_security_policy: Some(String::new()),
4237 x_dns_prefetch_control: Some(String::new()),
4238 strict_transport_security: Some(String::new()),
4239 };
4240 check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
4241 }
4242
4243 #[test]
4244 fn security_headers_config_validate_rejects_bad_value() {
4245 let h = SecurityHeadersConfig {
4247 referrer_policy: Some("\u{0007}".into()),
4248 ..SecurityHeadersConfig::default()
4249 };
4250 let err = check_with_security_headers(h)
4251 .expect_err("control char in referrer_policy must reject");
4252 let msg = err.to_string();
4253 assert!(
4254 msg.contains("referrer_policy"),
4255 "error must name the offending field, got: {msg}"
4256 );
4257 }
4258
4259 #[test]
4260 fn security_headers_config_validate_rejects_hsts_preload() {
4261 let h = SecurityHeadersConfig {
4262 strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
4263 ..SecurityHeadersConfig::default()
4264 };
4265 let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
4266 let msg = err.to_string();
4267 assert!(
4268 msg.contains("strict_transport_security"),
4269 "error must name the field, got: {msg}"
4270 );
4271 assert!(
4272 msg.to_lowercase().contains("preload"),
4273 "error must mention `preload`, got: {msg}"
4274 );
4275 }
4276
4277 #[test]
4278 fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
4279 let h = SecurityHeadersConfig {
4281 strict_transport_security: Some("max-age=600; PRELOAD".into()),
4282 ..SecurityHeadersConfig::default()
4283 };
4284 check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
4285 }
4286
4287 #[tokio::test]
4288 async fn security_headers_override_honored() {
4289 let h = SecurityHeadersConfig {
4291 x_frame_options: Some("SAMEORIGIN".into()),
4292 ..SecurityHeadersConfig::default()
4293 };
4294 let app = security_router_with(false, h);
4295 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4296 let resp = app.oneshot(req).await.unwrap();
4297 assert_eq!(resp.status(), StatusCode::OK);
4298
4299 let xfo = resp.headers().get("x-frame-options").unwrap();
4300 assert_eq!(xfo, "SAMEORIGIN");
4301 }
4302
4303 #[tokio::test]
4304 async fn security_headers_empty_string_omits() {
4305 let h = SecurityHeadersConfig {
4307 referrer_policy: Some(String::new()),
4308 ..SecurityHeadersConfig::default()
4309 };
4310 let app = security_router_with(false, h);
4311 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4312 let resp = app.oneshot(req).await.unwrap();
4313 assert_eq!(resp.status(), StatusCode::OK);
4314
4315 assert!(
4316 resp.headers().get("referrer-policy").is_none(),
4317 "Some(\"\") must omit the header"
4318 );
4319 assert_eq!(
4321 resp.headers().get("x-content-type-options").unwrap(),
4322 "nosniff"
4323 );
4324 }
4325
4326 #[tokio::test]
4327 async fn security_headers_hsts_only_when_tls() {
4328 let h = SecurityHeadersConfig {
4330 strict_transport_security: Some("max-age=600".into()),
4331 ..SecurityHeadersConfig::default()
4332 };
4333 let app = security_router_with(false, h);
4334 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4335 let resp = app.oneshot(req).await.unwrap();
4336 assert!(
4337 resp.headers().get("strict-transport-security").is_none(),
4338 "HSTS must remain absent on plaintext deployments even with override"
4339 );
4340 }
4341
4342 #[cfg(feature = "oauth")]
4345 #[tokio::test]
4346 async fn oauth_token_cache_headers_set_pragma_and_vary() {
4347 let app = axum::Router::new()
4348 .route("/token", axum::routing::post(|| async { "{}" }))
4349 .layer(axum::middleware::from_fn(
4350 oauth_token_cache_headers_middleware,
4351 ));
4352 let req = Request::builder()
4353 .method("POST")
4354 .uri("/token")
4355 .body(Body::from("{}"))
4356 .unwrap();
4357 let resp = app.oneshot(req).await.unwrap();
4358 assert_eq!(resp.status(), StatusCode::OK);
4359
4360 let h = resp.headers();
4361 assert_eq!(
4362 h.get("pragma").unwrap(),
4363 "no-cache",
4364 "RFC 6749 §5.1: token responses must set Pragma: no-cache"
4365 );
4366 let vary_values: Vec<String> = h
4367 .get_all("vary")
4368 .iter()
4369 .filter_map(|v| v.to_str().ok().map(str::to_owned))
4370 .collect();
4371 assert!(
4372 vary_values
4373 .iter()
4374 .any(|v| v.eq_ignore_ascii_case("Authorization")),
4375 "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
4376 );
4377 }
4378
4379 #[cfg(feature = "oauth")]
4380 #[tokio::test]
4381 async fn oauth_token_cache_headers_preserve_existing_vary() {
4382 let app = axum::Router::new()
4385 .route(
4386 "/token",
4387 axum::routing::post(|| async {
4388 axum::response::Response::builder()
4389 .header("vary", "Accept-Encoding")
4390 .body(axum::body::Body::from("{}"))
4391 .unwrap()
4392 }),
4393 )
4394 .layer(axum::middleware::from_fn(
4395 oauth_token_cache_headers_middleware,
4396 ));
4397 let req = Request::builder()
4398 .method("POST")
4399 .uri("/token")
4400 .body(Body::empty())
4401 .unwrap();
4402 let resp = app.oneshot(req).await.unwrap();
4403
4404 let vary: Vec<String> = resp
4405 .headers()
4406 .get_all("vary")
4407 .iter()
4408 .filter_map(|v| v.to_str().ok().map(str::to_owned))
4409 .collect();
4410 assert!(
4411 vary.iter().any(|v| v.contains("Accept-Encoding")),
4412 "must preserve pre-existing Vary value, got {vary:?}"
4413 );
4414 assert!(
4415 vary.iter().any(|v| v.contains("Authorization")),
4416 "must append Authorization to Vary, got {vary:?}"
4417 );
4418 }
4419
4420 #[test]
4423 fn version_payload_contains_expected_fields() {
4424 let v = version_payload("my-server", "1.2.3");
4425 assert_eq!(v["name"], "my-server");
4426 assert_eq!(v["version"], "1.2.3");
4427 assert!(v["build_git_sha"].is_string());
4428 assert!(v["build_timestamp"].is_string());
4429 assert!(v["rust_version"].is_string());
4430 assert!(v["mcpx_version"].is_string());
4431 }
4432
4433 #[tokio::test]
4436 async fn concurrency_limit_layer_composes_and_serves() {
4437 let app = axum::Router::new()
4441 .route("/ok", axum::routing::get(|| async { "ok" }))
4442 .layer(
4443 tower::ServiceBuilder::new()
4444 .layer(axum::error_handling::HandleErrorLayer::new(
4445 |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
4446 ))
4447 .layer(tower::load_shed::LoadShedLayer::new())
4448 .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
4449 );
4450 let resp = app
4451 .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
4452 .await
4453 .unwrap();
4454 assert_eq!(resp.status(), StatusCode::OK);
4455 }
4456
4457 #[tokio::test]
4460 async fn compression_layer_gzip_encodes_response() {
4461 use tower_http::compression::Predicate as _;
4462
4463 let big_body = "a".repeat(4096);
4464 let app = axum::Router::new()
4465 .route(
4466 "/big",
4467 axum::routing::get(move || {
4468 let body = big_body.clone();
4469 async move { body }
4470 }),
4471 )
4472 .layer(
4473 tower_http::compression::CompressionLayer::new()
4474 .gzip(true)
4475 .br(true)
4476 .compress_when(
4477 tower_http::compression::DefaultPredicate::new()
4478 .and(tower_http::compression::predicate::SizeAbove::new(1024)),
4479 ),
4480 );
4481
4482 let req = Request::builder()
4483 .uri("/big")
4484 .header(header::ACCEPT_ENCODING, "gzip")
4485 .body(Body::empty())
4486 .unwrap();
4487 let resp = app.oneshot(req).await.unwrap();
4488 assert_eq!(resp.status(), StatusCode::OK);
4489 assert_eq!(
4490 resp.headers().get(header::CONTENT_ENCODING).unwrap(),
4491 "gzip"
4492 );
4493 }
4494
4495 #[tokio::test]
4498 async fn tls_handshake_timeout_reaps_idle_connections() {
4499 use tokio::io::AsyncReadExt as _;
4500
4501 let _ = rustls::crypto::ring::default_provider().install_default();
4502
4503 let key = rcgen::KeyPair::generate().expect("generate key");
4505 let cert = rcgen::CertificateParams::new(vec!["localhost".to_owned()])
4506 .expect("cert params")
4507 .self_signed(&key)
4508 .expect("self-signed cert");
4509 let dir = std::env::temp_dir().join(format!(
4510 "rmcp-server-kit-hs-timeout-{}",
4511 std::time::SystemTime::now()
4512 .duration_since(std::time::UNIX_EPOCH)
4513 .expect("clock after epoch")
4514 .as_nanos()
4515 ));
4516 tokio::fs::create_dir_all(&dir).await.expect("temp dir");
4517 let cert_path = dir.join("server.crt");
4518 let key_path = dir.join("server.key");
4519 tokio::fs::write(&cert_path, cert.pem())
4520 .await
4521 .expect("write cert");
4522 tokio::fs::write(&key_path, key.serialize_pem())
4523 .await
4524 .expect("write key");
4525
4526 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
4527 let tls = TlsListener::new(
4528 listener,
4529 &cert_path,
4530 &key_path,
4531 None,
4532 None,
4533 Duration::from_millis(200),
4534 8, )
4536 .expect("tls listener");
4537 let addr = axum::serve::Listener::local_addr(&tls).expect("local addr");
4538
4539 let mut idle = tokio::net::TcpStream::connect(addr).await.expect("connect");
4543 let mut buf = [0_u8; 16];
4544 let read = tokio::time::timeout(Duration::from_secs(2), idle.read(&mut buf))
4545 .await
4546 .expect("server must reap the idle handshake within its timeout");
4547 match read {
4548 Ok(0) | Err(_) => {} Ok(n) => panic!("unexpected {n} bytes from server during reaped handshake"),
4550 }
4551
4552 drop(tls);
4553 }
4554}