1use std::{
2 future::Future,
3 net::{IpAddr, SocketAddr},
4 path::{Path, PathBuf},
5 pin::Pin,
6 sync::Arc,
7 time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{
12 body::Body,
13 extract::{ConnectInfo, Request},
14 middleware::Next,
15 response::IntoResponse,
16};
17use rmcp::{
18 ServerHandler,
19 transport::streamable_http_server::{
20 StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
21 },
22};
23use rustls::RootCertStore;
24use tokio::{
25 net::TcpListener,
26 sync::{Semaphore, mpsc},
27};
28use tokio_util::sync::CancellationToken;
29
30use crate::{
31 auth::{
32 AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
33 build_rate_limiter, extract_mtls_identity,
34 },
35 bounded_limiter::BoundedKeyedLimiter,
36 error::McpxError,
37 mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
38 rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
39};
40
41#[allow(
45 clippy::needless_pass_by_value,
46 reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
47)]
48fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
49 McpxError::Startup(format!("{e:#}"))
50}
51
52#[allow(
58 clippy::needless_pass_by_value,
59 reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
60)]
61fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
62 McpxError::Startup(format!("{op}: {e}"))
63}
64
65pub type ReadinessCheck =
70 Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
71
72#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
117#[non_exhaustive]
118pub struct PeerAddr {
119 pub addr: SocketAddr,
121}
122
123impl PeerAddr {
124 #[must_use]
127 pub(crate) const fn new(addr: SocketAddr) -> Self {
128 Self { addr }
129 }
130}
131
132impl<S: Send + Sync> axum::extract::FromRequestParts<S> for PeerAddr {
141 type Rejection = (axum::http::StatusCode, &'static str);
142
143 async fn from_request_parts(
144 parts: &mut axum::http::request::Parts,
145 _state: &S,
146 ) -> Result<Self, Self::Rejection> {
147 parts.extensions.get::<Self>().copied().ok_or((
148 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
149 "peer address unavailable: not running under rmcp-server-kit serve()",
150 ))
151 }
152}
153
154#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
177#[non_exhaustive]
178pub struct ClientIp {
179 pub ip: IpAddr,
181}
182
183impl ClientIp {
184 #[must_use]
187 pub(crate) const fn new(ip: IpAddr) -> Self {
188 Self { ip }
189 }
190}
191
192#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Deserialize)]
197#[serde(rename_all = "kebab-case")]
198#[non_exhaustive]
199pub enum ForwardedHeaderMode {
200 XForwardedFor,
202 Forwarded,
204}
205
206struct ForwardResolver {
209 trusted: Vec<ipnet::IpNet>,
210 mode: ForwardedHeaderMode,
211}
212
213#[derive(Debug, Clone, Default)]
234#[non_exhaustive]
235pub struct SecurityHeadersConfig {
236 pub x_content_type_options: Option<String>,
238 pub x_frame_options: Option<String>,
240 pub cache_control: Option<String>,
242 pub referrer_policy: Option<String>,
244 pub cross_origin_opener_policy: Option<String>,
246 pub cross_origin_resource_policy: Option<String>,
248 pub cross_origin_embedder_policy: Option<String>,
250 pub permissions_policy: Option<String>,
253 pub x_permitted_cross_domain_policies: Option<String>,
255 pub content_security_policy: Option<String>,
258 pub x_dns_prefetch_control: Option<String>,
260 pub strict_transport_security: Option<String>,
265}
266
267#[allow(
269 missing_debug_implementations,
270 reason = "contains callback/trait objects that don't impl Debug"
271)]
272#[allow(
273 clippy::struct_excessive_bools,
274 reason = "server configuration naturally has many boolean feature flags"
275)]
276#[non_exhaustive]
277pub struct McpServerConfig {
278 #[deprecated(
280 since = "0.13.0",
281 note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
282 )]
283 pub bind_addr: String,
284 #[deprecated(
286 since = "0.13.0",
287 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
288 )]
289 pub name: String,
290 #[deprecated(
292 since = "0.13.0",
293 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
294 )]
295 pub version: String,
296 #[deprecated(
298 since = "0.13.0",
299 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
300 )]
301 pub tls_cert_path: Option<PathBuf>,
302 #[deprecated(
304 since = "0.13.0",
305 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
306 )]
307 pub tls_key_path: Option<PathBuf>,
308 #[deprecated(
311 since = "0.13.0",
312 note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
313 )]
314 pub auth: Option<AuthConfig>,
315 #[deprecated(
318 since = "0.13.0",
319 note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
320 )]
321 pub rbac: Option<Arc<RbacPolicy>>,
322 #[deprecated(
328 since = "0.13.0",
329 note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
330 )]
331 pub allowed_origins: Vec<String>,
332 #[deprecated(
335 since = "0.13.0",
336 note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
337 )]
338 pub tool_rate_limit: Option<u32>,
339 #[deprecated(
345 since = "1.12.0",
346 note = "use McpServerConfig::with_tool_rate_limit_burst(); direct field access will become pub(crate) in a future major release"
347 )]
348 pub tool_rate_limit_burst: Option<u32>,
349 #[deprecated(
362 since = "1.11.0",
363 note = "use McpServerConfig::with_extra_route_rate_limit(); direct field access will become pub(crate) in a future major release"
364 )]
365 pub extra_route_rate_limit: Option<u32>,
366 #[deprecated(
373 since = "1.12.0",
374 note = "use McpServerConfig::with_extra_route_rate_limit_burst(); direct field access will become pub(crate) in a future major release"
375 )]
376 pub extra_route_rate_limit_burst: Option<u32>,
377 #[deprecated(
390 since = "1.14.0",
391 note = "use McpServerConfig::with_extra_route_rate_limit_exempt_paths(); direct field access will become pub(crate) in a future major release"
392 )]
393 pub extra_route_rate_limit_exempt_paths: Vec<String>,
394 #[deprecated(
402 since = "1.13.0",
403 note = "use McpServerConfig::with_trusted_proxies(); direct field access will become pub(crate) in a future major release"
404 )]
405 pub trusted_proxies: Vec<String>,
406 #[deprecated(
411 since = "1.13.0",
412 note = "use McpServerConfig::with_forwarded_header(); direct field access will become pub(crate) in a future major release"
413 )]
414 pub forwarded_header: Option<ForwardedHeaderMode>,
415 #[deprecated(
418 since = "0.13.0",
419 note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
420 )]
421 pub readiness_check: Option<ReadinessCheck>,
422 #[deprecated(
425 since = "0.13.0",
426 note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
427 )]
428 pub max_request_body: usize,
429 #[deprecated(
432 since = "0.13.0",
433 note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
434 )]
435 pub request_timeout: Duration,
436 #[deprecated(
439 since = "0.13.0",
440 note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
441 )]
442 pub shutdown_timeout: Duration,
443 #[deprecated(
446 since = "0.13.0",
447 note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
448 )]
449 pub session_idle_timeout: Duration,
450 #[deprecated(
453 since = "0.13.0",
454 note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
455 )]
456 pub sse_keep_alive: Duration,
457 #[deprecated(
461 since = "0.13.0",
462 note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
463 )]
464 pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
465 #[deprecated(
472 since = "0.13.0",
473 note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
474 )]
475 pub extra_router: Option<axum::Router>,
476 #[deprecated(
481 since = "0.13.0",
482 note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
483 )]
484 pub public_url: Option<String>,
485 #[deprecated(
488 since = "0.13.0",
489 note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
490 )]
491 pub log_request_headers: bool,
492 #[deprecated(
495 since = "0.13.0",
496 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
497 )]
498 pub compression_enabled: bool,
499 #[deprecated(
502 since = "0.13.0",
503 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
504 )]
505 pub compression_min_size: u16,
506 #[deprecated(
510 since = "0.13.0",
511 note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
512 )]
513 pub max_concurrent_requests: Option<usize>,
514 #[deprecated(
517 since = "0.13.0",
518 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
519 )]
520 pub admin_enabled: bool,
521 #[deprecated(
523 since = "0.13.0",
524 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
525 )]
526 pub admin_role: String,
527 #[cfg(feature = "metrics")]
530 #[deprecated(
531 since = "0.13.0",
532 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
533 )]
534 pub metrics_enabled: bool,
535 #[cfg(feature = "metrics")]
537 #[deprecated(
538 since = "0.13.0",
539 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
540 )]
541 pub metrics_bind: String,
542 #[deprecated(
546 since = "1.5.0",
547 note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
548 )]
549 pub security_headers: SecurityHeadersConfig,
550 #[deprecated(
556 since = "1.9.0",
557 note = "use McpServerConfig::with_tls_handshake_timeout(); direct field access will become pub(crate) in a future major release"
558 )]
559 pub tls_handshake_timeout: Duration,
560 #[deprecated(
567 since = "1.9.0",
568 note = "use McpServerConfig::with_max_concurrent_tls_handshakes(); direct field access will become pub(crate) in a future major release"
569 )]
570 pub max_concurrent_tls_handshakes: usize,
571}
572
573#[allow(
631 missing_debug_implementations,
632 reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
633)]
634pub struct Validated<T>(T);
635
636impl<T> std::fmt::Debug for Validated<T> {
637 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
638 f.debug_struct("Validated").finish_non_exhaustive()
639 }
640}
641
642impl<T> Validated<T> {
643 #[must_use]
645 pub fn as_inner(&self) -> &T {
646 &self.0
647 }
648
649 #[must_use]
654 pub fn into_inner(self) -> T {
655 self.0
656 }
657}
658
659#[allow(
660 deprecated,
661 reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
662)]
663impl McpServerConfig {
664 #[must_use]
672 pub fn new(
673 bind_addr: impl Into<String>,
674 name: impl Into<String>,
675 version: impl Into<String>,
676 ) -> Self {
677 Self {
678 bind_addr: bind_addr.into(),
679 name: name.into(),
680 version: version.into(),
681 tls_cert_path: None,
682 tls_key_path: None,
683 auth: None,
684 rbac: None,
685 allowed_origins: Vec::new(),
686 tool_rate_limit: None,
687 readiness_check: None,
688 max_request_body: 1024 * 1024,
689 request_timeout: Duration::from_mins(2),
690 shutdown_timeout: Duration::from_secs(30),
691 session_idle_timeout: Duration::from_mins(20),
692 sse_keep_alive: Duration::from_secs(15),
693 on_reload_ready: None,
694 extra_router: None,
695 public_url: None,
696 log_request_headers: false,
697 compression_enabled: false,
698 compression_min_size: 1024,
699 max_concurrent_requests: None,
700 admin_enabled: false,
701 admin_role: "admin".to_owned(),
702 #[cfg(feature = "metrics")]
703 metrics_enabled: false,
704 #[cfg(feature = "metrics")]
705 metrics_bind: "127.0.0.1:9090".into(),
706 security_headers: SecurityHeadersConfig::default(),
707 tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
708 max_concurrent_tls_handshakes: DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES,
709 extra_route_rate_limit: None,
710 tool_rate_limit_burst: None,
711 extra_route_rate_limit_burst: None,
712 extra_route_rate_limit_exempt_paths: Vec::new(),
713 trusted_proxies: Vec::new(),
714 forwarded_header: None,
715 }
716 }
717
718 #[must_use]
728 pub fn with_auth(mut self, auth: AuthConfig) -> Self {
729 self.auth = Some(auth);
730 self
731 }
732
733 #[must_use]
738 pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
739 self.security_headers = headers;
740 self
741 }
742
743 #[must_use]
747 pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
748 self.bind_addr = addr.into();
749 self
750 }
751
752 #[must_use]
755 pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
756 self.rbac = Some(rbac);
757 self
758 }
759
760 #[must_use]
764 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
765 self.tls_cert_path = Some(cert_path.into());
766 self.tls_key_path = Some(key_path.into());
767 self
768 }
769
770 #[must_use]
774 pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
775 self.public_url = Some(url.into());
776 self
777 }
778
779 #[must_use]
783 pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
784 where
785 I: IntoIterator<Item = S>,
786 S: Into<String>,
787 {
788 self.allowed_origins = origins.into_iter().map(Into::into).collect();
789 self
790 }
791
792 #[must_use]
805 pub fn with_extra_router(mut self, router: axum::Router) -> Self {
806 self.extra_router = Some(router);
807 self
808 }
809
810 #[must_use]
813 pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
814 self.readiness_check = Some(check);
815 self
816 }
817
818 #[must_use]
821 pub fn with_max_request_body(mut self, bytes: usize) -> Self {
822 self.max_request_body = bytes;
823 self
824 }
825
826 #[must_use]
828 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
829 self.request_timeout = timeout;
830 self
831 }
832
833 #[must_use]
835 pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
836 self.shutdown_timeout = timeout;
837 self
838 }
839
840 #[must_use]
842 pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
843 self.session_idle_timeout = timeout;
844 self
845 }
846
847 #[must_use]
849 pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
850 self.sse_keep_alive = interval;
851 self
852 }
853
854 #[must_use]
858 pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
859 self.max_concurrent_requests = Some(limit);
860 self
861 }
862
863 #[must_use]
871 pub fn with_tls_handshake_timeout(mut self, timeout: Duration) -> Self {
872 self.tls_handshake_timeout = timeout;
873 self
874 }
875
876 #[must_use]
885 pub fn with_max_concurrent_tls_handshakes(mut self, limit: usize) -> Self {
886 self.max_concurrent_tls_handshakes = limit;
887 self
888 }
889
890 #[must_use]
893 pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
894 self.tool_rate_limit = Some(per_minute);
895 self
896 }
897
898 #[must_use]
909 pub fn with_extra_route_rate_limit(mut self, per_minute: u32) -> Self {
910 self.extra_route_rate_limit = Some(per_minute);
911 self
912 }
913
914 #[must_use]
919 pub fn with_tool_rate_limit_burst(mut self, burst: u32) -> Self {
920 self.tool_rate_limit_burst = Some(burst);
921 self
922 }
923
924 #[must_use]
930 pub fn with_extra_route_rate_limit_burst(mut self, burst: u32) -> Self {
931 self.extra_route_rate_limit_burst = Some(burst);
932 self
933 }
934
935 #[must_use]
955 pub fn with_extra_route_rate_limit_exempt_paths<I, S>(mut self, paths: I) -> Self
956 where
957 I: IntoIterator<Item = S>,
958 S: Into<String>,
959 {
960 self.extra_route_rate_limit_exempt_paths = paths.into_iter().map(Into::into).collect();
961 self
962 }
963
964 #[must_use]
976 pub fn with_trusted_proxies<I, S>(mut self, proxies: I) -> Self
977 where
978 I: IntoIterator<Item = S>,
979 S: Into<String>,
980 {
981 self.trusted_proxies = proxies.into_iter().map(Into::into).collect();
982 self
983 }
984
985 #[must_use]
990 pub fn with_forwarded_header(mut self, mode: ForwardedHeaderMode) -> Self {
991 self.forwarded_header = Some(mode);
992 self
993 }
994
995 #[must_use]
999 pub fn with_reload_callback<F>(mut self, callback: F) -> Self
1000 where
1001 F: FnOnce(ReloadHandle) + Send + 'static,
1002 {
1003 self.on_reload_ready = Some(Box::new(callback));
1004 self
1005 }
1006
1007 #[must_use]
1011 pub fn enable_compression(mut self, min_size: u16) -> Self {
1012 self.compression_enabled = true;
1013 self.compression_min_size = min_size;
1014 self
1015 }
1016
1017 #[must_use]
1022 pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
1023 self.admin_enabled = true;
1024 self.admin_role = role.into();
1025 self
1026 }
1027
1028 #[must_use]
1031 pub fn enable_request_header_logging(mut self) -> Self {
1032 self.log_request_headers = true;
1033 self
1034 }
1035
1036 #[cfg(feature = "metrics")]
1039 #[must_use]
1040 pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
1041 self.metrics_enabled = true;
1042 self.metrics_bind = bind.into();
1043 self
1044 }
1045
1046 pub fn validate(self) -> Result<Validated<Self>, McpxError> {
1079 self.check()?;
1080 Ok(Validated(self))
1081 }
1082
1083 fn check_burst_knobs(&self) -> Result<(), McpxError> {
1090 if self.tool_rate_limit_burst == Some(0) {
1091 return Err(McpxError::Config(
1092 "tool_rate_limit_burst must be greater than zero".into(),
1093 ));
1094 }
1095 if self.extra_route_rate_limit_burst == Some(0) {
1096 return Err(McpxError::Config(
1097 "extra_route_rate_limit_burst must be greater than zero".into(),
1098 ));
1099 }
1100 if self.tool_rate_limit_burst.is_some() && self.tool_rate_limit.is_none() {
1101 return Err(McpxError::Config(
1102 "tool_rate_limit_burst requires tool_rate_limit to be set".into(),
1103 ));
1104 }
1105 if self.extra_route_rate_limit_burst.is_some() && self.extra_route_rate_limit.is_none() {
1106 return Err(McpxError::Config(
1107 "extra_route_rate_limit_burst requires extra_route_rate_limit to be set".into(),
1108 ));
1109 }
1110 if !self.extra_route_rate_limit_exempt_paths.is_empty()
1111 && self.extra_route_rate_limit.is_none()
1112 {
1113 return Err(McpxError::Config(
1114 "extra_route_rate_limit_exempt_paths requires extra_route_rate_limit to be set"
1115 .into(),
1116 ));
1117 }
1118 for path in &self.extra_route_rate_limit_exempt_paths {
1119 if path.is_empty() || !path.starts_with('/') {
1120 return Err(McpxError::Config(format!(
1121 "extra_route_rate_limit_exempt_paths entries must be non-empty and start with '/': {path:?}"
1122 )));
1123 }
1124 }
1125 if let Some(rl) = self.auth.as_ref().and_then(|a| a.rate_limit.as_ref()) {
1126 if rl.burst == Some(0) {
1127 return Err(McpxError::Config(
1128 "auth rate_limit.burst must be greater than zero".into(),
1129 ));
1130 }
1131 if rl.pre_auth_burst == Some(0) {
1132 return Err(McpxError::Config(
1133 "auth rate_limit.pre_auth_burst must be greater than zero".into(),
1134 ));
1135 }
1136 }
1137 Ok(())
1138 }
1139
1140 fn check_trusted_forwarder(&self) -> Result<(), McpxError> {
1145 for entry in &self.trusted_proxies {
1146 if parse_proxy_net(entry).is_none() {
1147 return Err(McpxError::Config(format!(
1148 "trusted_proxies entry {entry:?} is neither a CIDR nor an IP address"
1149 )));
1150 }
1151 }
1152 if self.forwarded_header.is_some() && self.trusted_proxies.is_empty() {
1153 return Err(McpxError::Config(
1154 "forwarded_header requires trusted_proxies to be nonempty".into(),
1155 ));
1156 }
1157 Ok(())
1158 }
1159
1160 fn check(&self) -> Result<(), McpxError> {
1164 if self.admin_enabled {
1168 let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
1169 if !auth_enabled {
1170 return Err(McpxError::Config(
1171 "admin_enabled=true requires auth to be configured and enabled".into(),
1172 ));
1173 }
1174 }
1175
1176 match (&self.tls_cert_path, &self.tls_key_path) {
1178 (Some(_), None) => {
1179 return Err(McpxError::Config(
1180 "tls_cert_path is set but tls_key_path is missing".into(),
1181 ));
1182 }
1183 (None, Some(_)) => {
1184 return Err(McpxError::Config(
1185 "tls_key_path is set but tls_cert_path is missing".into(),
1186 ));
1187 }
1188 _ => {}
1189 }
1190
1191 if self.bind_addr.parse::<SocketAddr>().is_err() {
1193 return Err(McpxError::Config(format!(
1194 "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
1195 self.bind_addr
1196 )));
1197 }
1198
1199 if let Some(ref url) = self.public_url
1201 && !(url.starts_with("http://") || url.starts_with("https://"))
1202 {
1203 return Err(McpxError::Config(format!(
1204 "public_url {url:?} must start with http:// or https://"
1205 )));
1206 }
1207
1208 for origin in &self.allowed_origins {
1210 if !(origin.starts_with("http://") || origin.starts_with("https://")) {
1211 return Err(McpxError::Config(format!(
1212 "allowed_origins entry {origin:?} must start with http:// or https://"
1213 )));
1214 }
1215 }
1216
1217 if self.max_request_body == 0 {
1219 return Err(McpxError::Config(
1220 "max_request_body must be greater than zero".into(),
1221 ));
1222 }
1223
1224 if self.extra_route_rate_limit == Some(0) {
1228 return Err(McpxError::Config(
1229 "extra_route_rate_limit must be greater than zero".into(),
1230 ));
1231 }
1232
1233 self.check_burst_knobs()?;
1235
1236 self.check_trusted_forwarder()?;
1238
1239 #[cfg(feature = "oauth")]
1241 if let Some(auth_cfg) = &self.auth
1242 && let Some(oauth_cfg) = &auth_cfg.oauth
1243 {
1244 oauth_cfg.validate()?;
1245 }
1246
1247 validate_security_headers(&self.security_headers)?;
1250
1251 if self.max_concurrent_requests == Some(0) {
1255 return Err(McpxError::Config(
1256 "max_concurrent_requests must be greater than zero when set".into(),
1257 ));
1258 }
1259
1260 if let Some(auth_cfg) = &self.auth
1264 && let Some(rl) = &auth_cfg.rate_limit
1265 && rl.max_tracked_keys == 0
1266 {
1267 return Err(McpxError::Config(
1268 "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
1269 ));
1270 }
1271
1272 if self.tls_handshake_timeout == Duration::ZERO {
1277 return Err(McpxError::Config(
1278 "tls_handshake_timeout must be greater than zero".into(),
1279 ));
1280 }
1281
1282 if self.max_concurrent_tls_handshakes == 0 {
1287 return Err(McpxError::Config(
1288 "max_concurrent_tls_handshakes must be greater than zero".into(),
1289 ));
1290 }
1291
1292 Ok(())
1293 }
1294}
1295
1296#[allow(
1302 missing_debug_implementations,
1303 reason = "contains Arc<AuthState> with non-Debug fields"
1304)]
1305pub struct ReloadHandle {
1306 auth: Option<Arc<AuthState>>,
1307 rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
1308 crl_set: Option<Arc<CrlSet>>,
1309}
1310
1311impl ReloadHandle {
1312 pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
1314 if let Some(ref auth) = self.auth {
1315 auth.reload_keys(keys);
1316 }
1317 }
1318
1319 pub fn reload_rbac(&self, policy: RbacPolicy) {
1321 if let Some(ref rbac) = self.rbac {
1322 rbac.store(Arc::new(policy));
1323 tracing::info!("RBAC policy reloaded");
1324 }
1325 }
1326
1327 pub async fn refresh_crls(&self) -> Result<(), McpxError> {
1333 let Some(ref crl_set) = self.crl_set else {
1334 return Err(McpxError::Config(
1335 "CRL refresh requested but mTLS CRL support is not configured".into(),
1336 ));
1337 };
1338
1339 crl_set.force_refresh().await
1340 }
1341}
1342
1343#[allow(
1360 clippy::too_many_lines,
1361 clippy::cognitive_complexity,
1362 reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
1363)]
1364struct AppRunParams {
1368 tls_paths: Option<(PathBuf, PathBuf)>,
1370 tls_handshake_timeout: Duration,
1372 max_concurrent_tls_handshakes: usize,
1374 mtls_config: Option<MtlsConfig>,
1376 shutdown_timeout: Duration,
1378 auth_state: Option<Arc<AuthState>>,
1380 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1382 on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1384 ct: CancellationToken,
1388 scheme: &'static str,
1390 name: String,
1392}
1393
1394#[allow(
1404 clippy::cognitive_complexity,
1405 reason = "router assembly is intrinsically sequential; splitting harms readability"
1406)]
1407#[allow(
1408 deprecated,
1409 reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
1410)]
1411fn build_app_router<H, F>(
1412 mut config: McpServerConfig,
1413 handler_factory: F,
1414) -> anyhow::Result<(axum::Router, AppRunParams)>
1415where
1416 H: ServerHandler + 'static,
1417 F: Fn() -> H + Send + Sync + Clone + 'static,
1418{
1419 let ct = CancellationToken::new();
1420
1421 let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
1422 tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
1423
1424 let mcp_service = StreamableHttpService::new(
1425 move || Ok(handler_factory()),
1426 {
1427 let mut mgr = LocalSessionManager::default();
1428 mgr.session_config.keep_alive = Some(config.session_idle_timeout);
1429 mgr.into()
1430 },
1431 StreamableHttpServerConfig::default()
1432 .with_allowed_hosts(allowed_hosts)
1433 .with_sse_keep_alive(Some(config.sse_keep_alive))
1434 .with_cancellation_token(ct.child_token()),
1435 );
1436
1437 let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
1439
1440 let auth_state: Option<Arc<AuthState>> = match config.auth {
1444 Some(ref auth_config) if auth_config.enabled => {
1445 let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
1446 let pre_auth_limiter = auth_config
1447 .rate_limit
1448 .as_ref()
1449 .map(crate::auth::build_pre_auth_limiter);
1450
1451 #[cfg(feature = "oauth")]
1452 let jwks_cache = auth_config
1453 .oauth
1454 .as_ref()
1455 .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
1456 .transpose()
1457 .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
1458
1459 Some(Arc::new(AuthState {
1460 api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
1461 rate_limiter,
1462 pre_auth_limiter,
1463 #[cfg(feature = "oauth")]
1464 jwks_cache,
1465 seen_identities: crate::auth::SeenIdentitySet::new(),
1466 counters: crate::auth::AuthCounters::default(),
1467 }))
1468 }
1469 _ => None,
1470 };
1471
1472 let rbac_swap = Arc::new(ArcSwap::new(
1475 config
1476 .rbac
1477 .clone()
1478 .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
1479 ));
1480
1481 if config.admin_enabled {
1484 let Some(ref auth_state_ref) = auth_state else {
1485 return Err(anyhow::anyhow!(
1486 "admin_enabled=true requires auth to be configured and enabled"
1487 ));
1488 };
1489 let admin_state = crate::admin::AdminState {
1490 started_at: std::time::Instant::now(),
1491 name: config.name.clone(),
1492 version: config.version.clone(),
1493 auth: Some(Arc::clone(auth_state_ref)),
1494 rbac: Arc::clone(&rbac_swap),
1495 };
1496 let admin_cfg = crate::admin::AdminConfig {
1497 role: config.admin_role.clone(),
1498 };
1499 mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
1500 tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
1501 }
1502
1503 {
1536 let tool_limiter: Option<Arc<ToolRateLimiter>> = config
1537 .tool_rate_limit
1538 .map(|per_minute| build_tool_rate_limiter(per_minute, config.tool_rate_limit_burst));
1539
1540 if rbac_swap.load().is_enabled() {
1541 tracing::info!("RBAC enforcement enabled on /mcp");
1542 }
1543 if let Some(limit) = config.tool_rate_limit {
1544 tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1545 }
1546
1547 let rbac_for_mw = Arc::clone(&rbac_swap);
1548 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1549 let p = rbac_for_mw.load_full();
1550 let tl = tool_limiter.clone();
1551 rbac_middleware(p, tl, req, next)
1552 }));
1553 }
1554
1555 if let Some(ref auth_config) = config.auth
1557 && auth_config.enabled
1558 {
1559 let Some(ref state) = auth_state else {
1560 return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1561 };
1562
1563 let methods: Vec<&str> = [
1564 auth_config.mtls.is_some().then_some("mTLS"),
1565 (!auth_config.api_keys.is_empty()).then_some("bearer"),
1566 #[cfg(feature = "oauth")]
1567 auth_config.oauth.is_some().then_some("oauth-jwt"),
1568 ]
1569 .into_iter()
1570 .flatten()
1571 .collect();
1572
1573 tracing::info!(
1574 methods = %methods.join(", "),
1575 api_keys = auth_config.api_keys.len(),
1576 "auth enabled on /mcp"
1577 );
1578
1579 let state_for_mw = Arc::clone(state);
1580 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1581 let s = Arc::clone(&state_for_mw);
1582 auth_middleware(s, req, next)
1583 }));
1584 }
1585
1586 mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1589 axum::http::StatusCode::REQUEST_TIMEOUT,
1590 config.request_timeout,
1591 ));
1592
1593 mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1597 config.max_request_body,
1598 ));
1599
1600 let mut effective_origins = config.allowed_origins.clone();
1607 if effective_origins.is_empty()
1608 && let Some(ref url) = config.public_url
1609 {
1610 if let Some(scheme_end) = url.find("://") {
1615 let scheme_with_sep = url.get(..scheme_end + 3).unwrap_or_default();
1616 let after_scheme = url.get(scheme_end + 3..).unwrap_or_default();
1617 let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1618 let host = after_scheme.get(..host_end).unwrap_or_default();
1619 let origin = format!("{scheme_with_sep}{host}");
1620 tracing::info!(
1621 %origin,
1622 "auto-derived allowed origin from public_url"
1623 );
1624 effective_origins.push(origin);
1625 }
1626 }
1627 let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1628 let cors_origins = Arc::clone(&allowed_origins);
1629 let log_request_headers = config.log_request_headers;
1630
1631 let readyz_route = if let Some(check) = config.readiness_check.take() {
1632 axum::routing::get(move || readyz(Arc::clone(&check)))
1633 } else {
1634 axum::routing::get(healthz)
1635 };
1636
1637 #[allow(unused_mut)] let mut router = axum::Router::new()
1639 .route("/healthz", axum::routing::get(healthz))
1640 .route("/readyz", readyz_route)
1641 .route(
1642 "/version",
1643 axum::routing::get({
1644 let payload_bytes: Arc<[u8]> =
1649 serialize_version_payload(&config.name, &config.version);
1650 move || {
1651 let p = Arc::clone(&payload_bytes);
1652 async move {
1653 (
1654 [(axum::http::header::CONTENT_TYPE, "application/json")],
1655 p.to_vec(),
1656 )
1657 }
1658 }
1659 }),
1660 )
1661 .merge(mcp_router);
1662
1663 if let Some(extra) = config.extra_router.take() {
1670 let extra = match config.extra_route_rate_limit {
1671 Some(per_minute) => {
1672 let limiter =
1673 build_extra_route_rate_limiter(per_minute, config.extra_route_rate_limit_burst);
1674 let exempt: Arc<std::collections::HashSet<String>> = Arc::new(
1675 config
1676 .extra_route_rate_limit_exempt_paths
1677 .iter()
1678 .cloned()
1679 .collect(),
1680 );
1681 tracing::info!(
1682 per_minute,
1683 exempt_paths = exempt.len(),
1684 "extra-route per-IP rate limit enabled"
1685 );
1686 extra.layer(axum::middleware::from_fn(move |req, next| {
1687 let l = Arc::clone(&limiter);
1688 let e = Arc::clone(&exempt);
1689 extra_route_rate_limit_middleware(l, e, req, next)
1690 }))
1691 }
1692 None => extra,
1693 };
1694 router = router.merge(extra);
1695 }
1696
1697 let server_url = if let Some(ref url) = config.public_url {
1704 url.trim_end_matches('/').to_owned()
1705 } else {
1706 let prm_scheme = if config.tls_cert_path.is_some() {
1707 "https"
1708 } else {
1709 "http"
1710 };
1711 format!("{prm_scheme}://{}", config.bind_addr)
1712 };
1713 let resource_url = format!("{server_url}/mcp");
1714
1715 #[cfg(feature = "oauth")]
1716 let prm_metadata = if let Some(ref auth_config) = config.auth
1717 && let Some(ref oauth_config) = auth_config.oauth
1718 {
1719 crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1720 } else {
1721 serde_json::json!({ "resource": resource_url })
1722 };
1723 #[cfg(not(feature = "oauth"))]
1724 let prm_metadata = serde_json::json!({ "resource": resource_url });
1725
1726 router = router.route(
1727 "/.well-known/oauth-protected-resource",
1728 axum::routing::get(move || {
1729 let m = prm_metadata.clone();
1730 async move { axum::Json(m) }
1731 }),
1732 );
1733
1734 #[cfg(feature = "oauth")]
1739 if let Some(ref auth_config) = config.auth
1740 && let Some(ref oauth_config) = auth_config.oauth
1741 && oauth_config.proxy.is_some()
1742 {
1743 router =
1744 install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1745 }
1746
1747 let is_tls = config.tls_cert_path.is_some();
1750 let security_headers_cfg = Arc::new(config.security_headers.clone());
1751 router = router.layer(axum::middleware::from_fn(move |req, next| {
1752 let cfg = Arc::clone(&security_headers_cfg);
1753 security_headers_middleware(is_tls, cfg, req, next)
1754 }));
1755
1756 if !cors_origins.is_empty() {
1760 let cors = tower_http::cors::CorsLayer::new()
1761 .allow_origin(
1762 cors_origins
1763 .iter()
1764 .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1765 .collect::<Vec<_>>(),
1766 )
1767 .allow_methods([
1768 axum::http::Method::GET,
1769 axum::http::Method::POST,
1770 axum::http::Method::OPTIONS,
1771 ])
1772 .allow_headers([
1773 axum::http::header::CONTENT_TYPE,
1774 axum::http::header::AUTHORIZATION,
1775 ]);
1776 router = router.layer(cors);
1777 }
1778
1779 if config.compression_enabled {
1783 use tower_http::compression::Predicate as _;
1784 let predicate = tower_http::compression::DefaultPredicate::new().and(
1785 tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1786 );
1787 router = router.layer(
1788 tower_http::compression::CompressionLayer::new()
1789 .gzip(true)
1790 .br(true)
1791 .compress_when(predicate),
1792 );
1793 tracing::info!(
1794 min_size = config.compression_min_size,
1795 "response compression enabled (gzip, br)"
1796 );
1797 }
1798
1799 if let Some(max) = config.max_concurrent_requests {
1802 let overload_handler = tower::ServiceBuilder::new()
1803 .layer(axum::error_handling::HandleErrorLayer::new(
1804 |_err: tower::BoxError| async {
1805 (
1806 axum::http::StatusCode::SERVICE_UNAVAILABLE,
1807 axum::Json(serde_json::json!({
1808 "error": "overloaded",
1809 "error_description": "server is at capacity, retry later"
1810 })),
1811 )
1812 },
1813 ))
1814 .layer(tower::load_shed::LoadShedLayer::new())
1815 .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1816 router = router.layer(overload_handler);
1817 tracing::info!(max, "global concurrency limit enabled");
1818 }
1819
1820 router = router.fallback(|| async {
1824 (
1825 axum::http::StatusCode::NOT_FOUND,
1826 axum::Json(serde_json::json!({
1827 "error": "not_found",
1828 "error_description": "The requested endpoint does not exist"
1829 })),
1830 )
1831 });
1832
1833 #[cfg(feature = "metrics")]
1835 if config.metrics_enabled {
1836 let metrics = Arc::new(
1837 crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1838 );
1839 let m = Arc::clone(&metrics);
1840 router = router.layer(axum::middleware::from_fn(
1841 move |req: Request<Body>, next: Next| {
1842 let m = Arc::clone(&m);
1843 metrics_middleware(m, req, next)
1844 },
1845 ));
1846 let metrics_bind = config.metrics_bind.clone();
1847 let metrics_shutdown = ct.clone();
1848 tokio::spawn(async move {
1849 if let Err(e) =
1850 crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1851 {
1852 tracing::error!("metrics listener failed: {e}");
1853 }
1854 });
1855 }
1856
1857 let forward_resolver: Option<Arc<ForwardResolver>> = if config.trusted_proxies.is_empty() {
1865 None
1866 } else {
1867 Some(Arc::new(ForwardResolver {
1870 trusted: config
1871 .trusted_proxies
1872 .iter()
1873 .filter_map(|entry| parse_proxy_net(entry))
1874 .collect(),
1875 mode: config
1876 .forwarded_header
1877 .unwrap_or(ForwardedHeaderMode::XForwardedFor),
1878 }))
1879 };
1880 if forward_resolver.is_some() {
1881 tracing::info!(
1882 proxies = config.trusted_proxies.len(),
1883 "trusted-forwarder mode enabled: limiters key by resolved client IP"
1884 );
1885 }
1886 router = router.layer(axum::middleware::from_fn(move |req, next| {
1887 let r = forward_resolver.clone();
1888 normalize_peer_addr_middleware(r, req, next)
1889 }));
1890
1891 router = router.layer(axum::middleware::from_fn(move |req, next| {
1902 let origins = Arc::clone(&allowed_origins);
1903 origin_check_middleware(origins, log_request_headers, req, next)
1904 }));
1905
1906 let scheme = if config.tls_cert_path.is_some() {
1907 "https"
1908 } else {
1909 "http"
1910 };
1911
1912 let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1913 (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1914 _ => None,
1915 };
1916 let tls_handshake_timeout = config.tls_handshake_timeout;
1917 let max_concurrent_tls_handshakes = config.max_concurrent_tls_handshakes;
1918 let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1919
1920 Ok((
1921 router,
1922 AppRunParams {
1923 tls_paths,
1924 tls_handshake_timeout,
1925 max_concurrent_tls_handshakes,
1926 mtls_config,
1927 shutdown_timeout: config.shutdown_timeout,
1928 auth_state,
1929 rbac_swap,
1930 on_reload_ready: config.on_reload_ready.take(),
1931 ct,
1932 scheme,
1933 name: config.name.clone(),
1934 },
1935 ))
1936}
1937
1938pub async fn serve<H, F>(
1955 config: Validated<McpServerConfig>,
1956 handler_factory: F,
1957) -> Result<(), McpxError>
1958where
1959 H: ServerHandler + 'static,
1960 F: Fn() -> H + Send + Sync + Clone + 'static,
1961{
1962 let config = config.into_inner();
1963 #[allow(
1964 deprecated,
1965 reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1966 )]
1967 let bind_addr = config.bind_addr.clone();
1968 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1969
1970 let listener = TcpListener::bind(&bind_addr)
1971 .await
1972 .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1973 log_listening(¶ms.name, params.scheme, &bind_addr);
1974
1975 run_server(
1976 router,
1977 listener,
1978 params.tls_paths,
1979 params.tls_handshake_timeout,
1980 params.max_concurrent_tls_handshakes,
1981 params.mtls_config,
1982 params.shutdown_timeout,
1983 params.auth_state,
1984 params.rbac_swap,
1985 params.on_reload_ready,
1986 params.ct,
1987 )
1988 .await
1989 .map_err(anyhow_to_startup)
1990}
1991
1992pub async fn serve_with_listener<H, F>(
2022 listener: TcpListener,
2023 config: Validated<McpServerConfig>,
2024 handler_factory: F,
2025 ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
2026 shutdown: Option<CancellationToken>,
2027) -> Result<(), McpxError>
2028where
2029 H: ServerHandler + 'static,
2030 F: Fn() -> H + Send + Sync + Clone + 'static,
2031{
2032 let config = config.into_inner();
2033 let local_addr = listener
2034 .local_addr()
2035 .map_err(|e| io_to_startup("listener.local_addr", e))?;
2036 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
2037
2038 log_listening(¶ms.name, params.scheme, &local_addr.to_string());
2039
2040 if let Some(external) = shutdown {
2044 let internal = params.ct.clone();
2045 tokio::spawn(async move {
2046 external.cancelled().await;
2047 internal.cancel();
2048 });
2049 }
2050
2051 if let Some(tx) = ready_tx {
2055 let _ = tx.send(local_addr);
2057 }
2058
2059 run_server(
2060 router,
2061 listener,
2062 params.tls_paths,
2063 params.tls_handshake_timeout,
2064 params.max_concurrent_tls_handshakes,
2065 params.mtls_config,
2066 params.shutdown_timeout,
2067 params.auth_state,
2068 params.rbac_swap,
2069 params.on_reload_ready,
2070 params.ct,
2071 )
2072 .await
2073 .map_err(anyhow_to_startup)
2074}
2075
2076#[allow(
2079 clippy::cognitive_complexity,
2080 reason = "tracing::info! macro expansions inflate the score; logic is trivial"
2081)]
2082fn log_listening(name: &str, scheme: &str, addr: &str) {
2083 tracing::info!("{name} listening on {addr}");
2084 tracing::info!(" MCP endpoint: {scheme}://{addr}/mcp");
2085 tracing::info!(" Health check: {scheme}://{addr}/healthz");
2086 tracing::info!(" Readiness: {scheme}://{addr}/readyz");
2087}
2088
2089#[allow(
2112 clippy::too_many_arguments,
2113 clippy::cognitive_complexity,
2114 reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
2115)]
2116async fn run_server(
2117 router: axum::Router,
2118 listener: TcpListener,
2119 tls_paths: Option<(PathBuf, PathBuf)>,
2120 tls_handshake_timeout: Duration,
2121 max_concurrent_tls_handshakes: usize,
2122 mtls_config: Option<MtlsConfig>,
2123 shutdown_timeout: Duration,
2124 auth_state: Option<Arc<AuthState>>,
2125 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
2126 mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
2127 ct: CancellationToken,
2128) -> anyhow::Result<()> {
2129 let shutdown_trigger = CancellationToken::new();
2133 {
2134 let trigger = shutdown_trigger.clone();
2135 let parent = ct.clone();
2136 tokio::spawn(async move {
2137 tokio::select! {
2140 () = shutdown_signal() => {}
2141 () = parent.cancelled() => {}
2142 }
2143 trigger.cancel();
2144 });
2145 }
2146
2147 let graceful = {
2148 let trigger = shutdown_trigger.clone();
2149 let ct = ct.clone();
2150 async move {
2151 trigger.cancelled().await;
2152 tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
2153 ct.cancel();
2154 }
2155 };
2156
2157 let force_exit_timer = {
2158 let trigger = shutdown_trigger.clone();
2159 async move {
2160 trigger.cancelled().await;
2161 tokio::time::sleep(shutdown_timeout).await;
2162 }
2163 };
2164
2165 if let Some((cert_path, key_path)) = tls_paths {
2166 let crl_set = if let Some(mtls) = mtls_config.as_ref()
2167 && mtls.crl_enabled
2168 {
2169 let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
2170 let (crl_set, discover_rx) =
2171 mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
2172 .await
2173 .map_err(|error| anyhow::anyhow!(error.to_string()))?;
2174 tokio::spawn(mtls_revocation::run_crl_refresher(
2175 Arc::clone(&crl_set),
2176 discover_rx,
2177 ct.clone(),
2178 ));
2179 Some(crl_set)
2180 } else {
2181 None
2182 };
2183
2184 if let Some(cb) = on_reload_ready.take() {
2185 cb(ReloadHandle {
2186 auth: auth_state.clone(),
2187 rbac: Some(Arc::clone(&rbac_swap)),
2188 crl_set: crl_set.clone(),
2189 });
2190 }
2191
2192 let tls_listener = TlsListener::new(
2193 listener,
2194 &cert_path,
2195 &key_path,
2196 mtls_config.as_ref(),
2197 crl_set,
2198 tls_handshake_timeout,
2199 max_concurrent_tls_handshakes,
2200 )?;
2201 let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
2202 tokio::select! {
2205 result = axum::serve(tls_listener, make_svc)
2206 .with_graceful_shutdown(graceful) => { result?; }
2207 () = force_exit_timer => {
2208 tracing::warn!("shutdown timeout exceeded, forcing exit");
2209 }
2210 }
2211 } else {
2212 if let Some(cb) = on_reload_ready.take() {
2213 cb(ReloadHandle {
2214 auth: auth_state,
2215 rbac: Some(rbac_swap),
2216 crl_set: None,
2217 });
2218 }
2219
2220 let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
2221 tokio::select! {
2224 result = axum::serve(listener, make_svc)
2225 .with_graceful_shutdown(graceful) => { result?; }
2226 () = force_exit_timer => {
2227 tracing::warn!("shutdown timeout exceeded, forcing exit");
2228 }
2229 }
2230 }
2231
2232 Ok(())
2233}
2234
2235#[cfg(feature = "oauth")]
2244fn install_oauth_proxy_routes(
2245 router: axum::Router,
2246 server_url: &str,
2247 oauth_config: &crate::oauth::OAuthConfig,
2248 auth_state: Option<&Arc<AuthState>>,
2249) -> Result<axum::Router, McpxError> {
2250 let Some(ref proxy) = oauth_config.proxy else {
2251 return Ok(router);
2252 };
2253
2254 let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
2257
2258 let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
2259 let router = router.route(
2260 "/.well-known/oauth-authorization-server",
2261 axum::routing::get(move || {
2262 let m = asm.clone();
2263 async move { axum::Json(m) }
2264 }),
2265 );
2266
2267 let proxy_authorize = proxy.clone();
2268 let router = router.route(
2269 "/authorize",
2270 axum::routing::get(
2271 move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
2272 let p = proxy_authorize.clone();
2273 async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
2274 },
2275 ),
2276 );
2277
2278 let proxy_token = proxy.clone();
2279 let token_http = http.clone();
2280 let router = router.route(
2281 "/token",
2282 axum::routing::post(move |body: String| {
2283 let p = proxy_token.clone();
2284 let h = token_http.clone();
2285 async move { crate::oauth::handle_token(&h, &p, &body).await }
2286 })
2287 .layer(axum::middleware::from_fn(
2288 oauth_token_cache_headers_middleware,
2289 )),
2290 );
2291
2292 let proxy_register = proxy.clone();
2293 let router = router.route(
2294 "/register",
2295 axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
2296 let p = proxy_register;
2297 async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
2298 })
2299 .layer(axum::middleware::from_fn(
2300 oauth_token_cache_headers_middleware,
2301 )),
2302 );
2303
2304 let admin_routes_enabled = proxy.expose_admin_endpoints
2305 && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
2306 if proxy.expose_admin_endpoints
2307 && !proxy.require_auth_on_admin_endpoints
2308 && proxy.allow_unauthenticated_admin_endpoints
2309 {
2310 tracing::warn!(
2314 "OAuth introspect/revoke endpoints are unauthenticated by explicit \
2315 allow_unauthenticated_admin_endpoints opt-out; ensure an \
2316 authenticated reverse proxy fronts these routes"
2317 );
2318 }
2319
2320 let admin_router = if admin_routes_enabled {
2321 build_oauth_admin_router(proxy, http, auth_state)?
2322 } else {
2323 axum::Router::new()
2324 };
2325
2326 let router = router.merge(admin_router);
2327
2328 tracing::info!(
2329 introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
2330 revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
2331 "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
2332 );
2333 Ok(router)
2334}
2335
2336#[cfg(feature = "oauth")]
2342fn build_oauth_admin_router(
2343 proxy: &crate::oauth::OAuthProxyConfig,
2344 http: crate::oauth::OauthHttpClient,
2345 auth_state: Option<&Arc<AuthState>>,
2346) -> Result<axum::Router, McpxError> {
2347 let mut admin_router = axum::Router::new();
2348 if proxy.introspection_url.is_some() {
2349 let proxy_introspect = proxy.clone();
2350 let introspect_http = http.clone();
2351 admin_router = admin_router.route(
2352 "/introspect",
2353 axum::routing::post(move |body: String| {
2354 let p = proxy_introspect.clone();
2355 let h = introspect_http.clone();
2356 async move { crate::oauth::handle_introspect(&h, &p, &body).await }
2357 }),
2358 );
2359 }
2360 if proxy.revocation_url.is_some() {
2361 let proxy_revoke = proxy.clone();
2362 let revoke_http = http;
2363 admin_router = admin_router.route(
2364 "/revoke",
2365 axum::routing::post(move |body: String| {
2366 let p = proxy_revoke.clone();
2367 let h = revoke_http.clone();
2368 async move { crate::oauth::handle_revoke(&h, &p, &body).await }
2369 }),
2370 );
2371 }
2372
2373 let admin_router = admin_router.layer(axum::middleware::from_fn(
2374 oauth_token_cache_headers_middleware,
2375 ));
2376
2377 if proxy.require_auth_on_admin_endpoints {
2378 let Some(state) = auth_state else {
2379 return Err(McpxError::Startup(
2380 "oauth proxy admin endpoints require auth state".into(),
2381 ));
2382 };
2383 let state_for_mw = Arc::clone(state);
2384 Ok(
2385 admin_router.layer(axum::middleware::from_fn(move |req, next| {
2386 let s = Arc::clone(&state_for_mw);
2387 auth_middleware(s, req, next)
2388 })),
2389 )
2390 } else {
2391 Ok(admin_router)
2392 }
2393}
2394
2395fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
2400 let mut hosts = vec![
2401 "localhost".to_owned(),
2402 "127.0.0.1".to_owned(),
2403 "::1".to_owned(),
2404 ];
2405
2406 if let Some(url) = public_url
2407 && let Ok(uri) = url.parse::<axum::http::Uri>()
2408 && let Some(authority) = uri.authority()
2409 {
2410 let host = authority.host().to_owned();
2411 if !hosts.iter().any(|h| h == &host) {
2412 hosts.push(host);
2413 }
2414
2415 let authority = authority.as_str().to_owned();
2416 if !hosts.iter().any(|h| h == &authority) {
2417 hosts.push(authority);
2418 }
2419 }
2420
2421 if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
2422 && let Some(authority) = uri.authority()
2423 {
2424 let host = authority.host().to_owned();
2425 if !hosts.iter().any(|h| h == &host) {
2426 hosts.push(host);
2427 }
2428
2429 let authority = authority.as_str().to_owned();
2430 if !hosts.iter().any(|h| h == &authority) {
2431 hosts.push(authority);
2432 }
2433 }
2434
2435 hosts
2436}
2437
2438impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
2451 for TlsConnInfo
2452{
2453 fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
2454 let addr = *target.remote_addr();
2455 let identity = target.io().identity().cloned();
2456 Self::new(addr, identity)
2457 }
2458}
2459
2460const DEFAULT_TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
2467
2468const DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES: usize = 256;
2476
2477const TLS_ACCEPT_CHANNEL_CAPACITY: usize = 32;
2482
2483struct TlsListener {
2499 local_addr: SocketAddr,
2502 rx: mpsc::Receiver<(AuthenticatedTlsStream, SocketAddr)>,
2504 acceptor_task: tokio::task::JoinHandle<()>,
2507}
2508
2509impl TlsListener {
2510 fn new(
2511 inner: TcpListener,
2512 cert_path: &Path,
2513 key_path: &Path,
2514 mtls_config: Option<&MtlsConfig>,
2515 crl_set: Option<Arc<CrlSet>>,
2516 handshake_timeout: Duration,
2517 max_concurrent_handshakes: usize,
2518 ) -> anyhow::Result<Self> {
2519 rustls::crypto::ring::default_provider()
2521 .install_default()
2522 .ok();
2523
2524 let certs = load_certs(cert_path)?;
2525 let key = load_key(key_path)?;
2526
2527 let mtls_default_role;
2528
2529 let tls_config = if let Some(mtls) = mtls_config {
2530 mtls_default_role = mtls.default_role.clone();
2531 let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
2532 {
2533 let Some(crl_set) = crl_set else {
2534 return Err(anyhow::anyhow!(
2535 "mTLS CRL verifier requested but CRL state was not initialized"
2536 ));
2537 };
2538 Arc::new(DynamicClientCertVerifier::new(crl_set))
2539 } else {
2540 let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
2541 if mtls.required {
2542 rustls::server::WebPkiClientVerifier::builder(root_store)
2543 .build()
2544 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2545 } else {
2546 rustls::server::WebPkiClientVerifier::builder(root_store)
2547 .allow_unauthenticated()
2548 .build()
2549 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2550 }
2551 };
2552
2553 tracing::info!(
2554 ca = %mtls.ca_cert_path.display(),
2555 required = mtls.required,
2556 crl_enabled = mtls.crl_enabled,
2557 "mTLS client auth configured"
2558 );
2559
2560 rustls::ServerConfig::builder_with_protocol_versions(&[
2561 &rustls::version::TLS12,
2562 &rustls::version::TLS13,
2563 ])
2564 .with_client_cert_verifier(verifier)
2565 .with_single_cert(certs, key)?
2566 } else {
2567 mtls_default_role = "viewer".to_owned();
2568 rustls::ServerConfig::builder_with_protocol_versions(&[
2569 &rustls::version::TLS12,
2570 &rustls::version::TLS13,
2571 ])
2572 .with_no_client_auth()
2573 .with_single_cert(certs, key)?
2574 };
2575
2576 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
2577 tracing::info!(
2578 "TLS enabled (cert: {}, key: {})",
2579 cert_path.display(),
2580 key_path.display()
2581 );
2582 let local_addr = inner.local_addr()?;
2583 let (tx, rx) = mpsc::channel(TLS_ACCEPT_CHANNEL_CAPACITY);
2584 let acceptor_task = tokio::spawn(run_tls_acceptor(
2585 inner,
2586 acceptor,
2587 mtls_default_role,
2588 tx,
2589 handshake_timeout,
2590 max_concurrent_handshakes,
2591 ));
2592 Ok(Self {
2593 local_addr,
2594 rx,
2595 acceptor_task,
2596 })
2597 }
2598
2599 fn extract_handshake_identity(
2603 tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2604 default_role: &str,
2605 addr: SocketAddr,
2606 ) -> Option<AuthIdentity> {
2607 let (_, server_conn) = tls_stream.get_ref();
2608 let cert_der = server_conn.peer_certificates()?.first()?;
2609 let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
2610 tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
2611 Some(id)
2612 }
2613}
2614
2615async fn run_tls_acceptor(
2623 listener: TcpListener,
2624 acceptor: tokio_rustls::TlsAcceptor,
2625 default_role: String,
2626 tx: mpsc::Sender<(AuthenticatedTlsStream, SocketAddr)>,
2627 handshake_timeout: Duration,
2628 max_concurrent_handshakes: usize,
2629) {
2630 let inflight = Arc::new(Semaphore::new(max_concurrent_handshakes));
2631 loop {
2632 let Ok(permit) = Arc::clone(&inflight).acquire_owned().await else {
2636 return;
2638 };
2639 let (stream, addr) = match listener.accept().await {
2640 Ok(pair) => pair,
2641 Err(e) => {
2642 tracing::debug!("TCP accept error: {e}");
2643 continue;
2644 }
2645 };
2646 if tx.is_closed() {
2647 return;
2649 }
2650 let acceptor = acceptor.clone();
2651 let default_role = default_role.clone();
2652 let tx = tx.clone();
2653 tokio::spawn(async move {
2654 let _permit = permit;
2655 match tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
2656 Ok(Ok(tls_stream)) => {
2657 let identity =
2658 TlsListener::extract_handshake_identity(&tls_stream, &default_role, addr);
2659 let wrapped = AuthenticatedTlsStream {
2660 inner: tls_stream,
2661 identity,
2662 };
2663 let _ = tx.send((wrapped, addr)).await;
2666 }
2667 Ok(Err(e)) => {
2668 tracing::debug!("TLS handshake failed from {addr}: {e}");
2669 }
2670 Err(_elapsed) => {
2671 tracing::debug!(
2672 "TLS handshake timed out from {addr} after {handshake_timeout:?}"
2673 );
2674 }
2675 }
2676 });
2677 }
2678}
2679
2680pub(crate) struct AuthenticatedTlsStream {
2692 inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2693 identity: Option<AuthIdentity>,
2694}
2695
2696impl AuthenticatedTlsStream {
2697 #[must_use]
2699 pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
2700 self.identity.as_ref()
2701 }
2702}
2703
2704impl std::fmt::Debug for AuthenticatedTlsStream {
2705 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2706 f.debug_struct("AuthenticatedTlsStream")
2707 .field("identity", &self.identity.as_ref().map(|id| &id.name))
2708 .finish_non_exhaustive()
2709 }
2710}
2711
2712impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2713 fn poll_read(
2714 mut self: Pin<&mut Self>,
2715 cx: &mut std::task::Context<'_>,
2716 buf: &mut tokio::io::ReadBuf<'_>,
2717 ) -> std::task::Poll<std::io::Result<()>> {
2718 Pin::new(&mut self.inner).poll_read(cx, buf)
2719 }
2720}
2721
2722impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2723 fn poll_write(
2724 mut self: Pin<&mut Self>,
2725 cx: &mut std::task::Context<'_>,
2726 buf: &[u8],
2727 ) -> std::task::Poll<std::io::Result<usize>> {
2728 Pin::new(&mut self.inner).poll_write(cx, buf)
2729 }
2730
2731 fn poll_flush(
2732 mut self: Pin<&mut Self>,
2733 cx: &mut std::task::Context<'_>,
2734 ) -> std::task::Poll<std::io::Result<()>> {
2735 Pin::new(&mut self.inner).poll_flush(cx)
2736 }
2737
2738 fn poll_shutdown(
2739 mut self: Pin<&mut Self>,
2740 cx: &mut std::task::Context<'_>,
2741 ) -> std::task::Poll<std::io::Result<()>> {
2742 Pin::new(&mut self.inner).poll_shutdown(cx)
2743 }
2744
2745 fn poll_write_vectored(
2746 mut self: Pin<&mut Self>,
2747 cx: &mut std::task::Context<'_>,
2748 bufs: &[std::io::IoSlice<'_>],
2749 ) -> std::task::Poll<std::io::Result<usize>> {
2750 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2751 }
2752
2753 fn is_write_vectored(&self) -> bool {
2754 self.inner.is_write_vectored()
2755 }
2756}
2757
2758impl axum::serve::Listener for TlsListener {
2759 type Io = AuthenticatedTlsStream;
2760 type Addr = SocketAddr;
2761
2762 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2768 if let Some(pair) = self.rx.recv().await {
2769 return pair;
2770 }
2771 tracing::error!("TLS acceptor task terminated; no further connections will be accepted");
2777 std::future::pending().await
2778 }
2779
2780 fn local_addr(&self) -> std::io::Result<Self::Addr> {
2781 Ok(self.local_addr)
2782 }
2783}
2784
2785impl Drop for TlsListener {
2786 fn drop(&mut self) {
2787 self.acceptor_task.abort();
2790 }
2791}
2792
2793fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2794 use rustls::pki_types::pem::PemObject;
2795 let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2796 .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2797 .collect::<Result<_, _>>()
2798 .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2799 anyhow::ensure!(
2800 !certs.is_empty(),
2801 "no certificates found in {}",
2802 path.display()
2803 );
2804 Ok(certs)
2805}
2806
2807fn load_client_auth_roots(
2808 path: &Path,
2809) -> anyhow::Result<(
2810 Vec<rustls::pki_types::CertificateDer<'static>>,
2811 Arc<RootCertStore>,
2812)> {
2813 let ca_certs = load_certs(path)?;
2814 let mut root_store = RootCertStore::empty();
2815 for cert in &ca_certs {
2816 root_store
2817 .add(cert.clone())
2818 .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2819 }
2820
2821 Ok((ca_certs, Arc::new(root_store)))
2822}
2823
2824fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2825 use rustls::pki_types::pem::PemObject;
2826 rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2827 .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2828}
2829
2830#[allow(
2831 clippy::unused_async,
2832 reason = "axum route handler signature requires `async fn` even when the body is synchronous"
2833)]
2834async fn healthz() -> impl IntoResponse {
2835 axum::Json(serde_json::json!({
2836 "status": "ok",
2837 }))
2838}
2839
2840fn version_payload(name: &str, version: &str) -> serde_json::Value {
2847 serde_json::json!({
2848 "name": name,
2849 "version": version,
2850 "build_git_sha": option_env!("RMCP_SERVER_KIT_BUILD_SHA").unwrap_or("unknown"),
2851 "build_timestamp": option_env!("RMCP_SERVER_KIT_BUILD_TIME").unwrap_or("unknown"),
2852 "rust_version": option_env!("RMCP_SERVER_KIT_RUSTC_VERSION").unwrap_or("unknown"),
2853 "mcpx_version": env!("CARGO_PKG_VERSION"),
2854 })
2855}
2856
2857fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2867 let value = version_payload(name, version);
2868 serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2869}
2870
2871async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2872 let status = check().await;
2873 let ready = status
2874 .get("ready")
2875 .and_then(serde_json::Value::as_bool)
2876 .unwrap_or(false);
2877 let code = if ready {
2878 axum::http::StatusCode::OK
2879 } else {
2880 axum::http::StatusCode::SERVICE_UNAVAILABLE
2881 };
2882 (code, axum::Json(status))
2883}
2884
2885async fn shutdown_signal() {
2889 let ctrl_c = tokio::signal::ctrl_c();
2890
2891 #[cfg(unix)]
2892 {
2893 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2894 Ok(mut term) => {
2895 tokio::select! {
2898 _ = ctrl_c => {}
2899 _ = term.recv() => {}
2900 }
2901 }
2902 Err(e) => {
2903 tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2904 ctrl_c.await.ok();
2905 }
2906 }
2907 }
2908
2909 #[cfg(not(unix))]
2910 {
2911 ctrl_c.await.ok();
2912 }
2913}
2914
2915#[cfg(feature = "metrics")]
2926async fn metrics_middleware(
2927 metrics: Arc<crate::metrics::McpMetrics>,
2928 mut req: Request<Body>,
2929 next: Next,
2930) -> axum::response::Response {
2931 let method = req.method().to_string();
2932 let path = req.uri().path().to_owned();
2933 let start = std::time::Instant::now();
2934
2935 req.extensions_mut().insert(Arc::clone(&metrics));
2936 let response = next.run(req).await;
2937
2938 let status = response.status().as_u16().to_string();
2939 let duration = start.elapsed().as_secs_f64();
2940
2941 metrics
2942 .http_requests_total
2943 .with_label_values(&[&method, &path, &status])
2944 .inc();
2945 metrics
2946 .http_request_duration_seconds
2947 .with_label_values(&[&method, &path])
2948 .observe(duration);
2949
2950 response
2951}
2952
2953async fn security_headers_middleware(
2965 is_tls: bool,
2966 cfg: Arc<SecurityHeadersConfig>,
2967 req: Request<Body>,
2968 next: Next,
2969) -> axum::response::Response {
2970 use axum::http::{HeaderName, header};
2971
2972 let mut resp = next.run(req).await;
2973 let headers = resp.headers_mut();
2974
2975 headers.remove(header::SERVER);
2977 headers.remove(HeaderName::from_static("x-powered-by"));
2978
2979 apply_security_header(
2980 headers,
2981 header::X_CONTENT_TYPE_OPTIONS,
2982 cfg.x_content_type_options.as_deref(),
2983 "nosniff",
2984 );
2985 apply_security_header(
2986 headers,
2987 header::X_FRAME_OPTIONS,
2988 cfg.x_frame_options.as_deref(),
2989 "deny",
2990 );
2991 apply_security_header(
2992 headers,
2993 header::CACHE_CONTROL,
2994 cfg.cache_control.as_deref(),
2995 "no-store, max-age=0",
2996 );
2997 apply_security_header(
2998 headers,
2999 header::REFERRER_POLICY,
3000 cfg.referrer_policy.as_deref(),
3001 "no-referrer",
3002 );
3003 apply_security_header(
3004 headers,
3005 HeaderName::from_static("cross-origin-opener-policy"),
3006 cfg.cross_origin_opener_policy.as_deref(),
3007 "same-origin",
3008 );
3009 apply_security_header(
3010 headers,
3011 HeaderName::from_static("cross-origin-resource-policy"),
3012 cfg.cross_origin_resource_policy.as_deref(),
3013 "same-origin",
3014 );
3015 apply_security_header(
3016 headers,
3017 HeaderName::from_static("cross-origin-embedder-policy"),
3018 cfg.cross_origin_embedder_policy.as_deref(),
3019 "require-corp",
3020 );
3021 apply_security_header(
3022 headers,
3023 HeaderName::from_static("permissions-policy"),
3024 cfg.permissions_policy.as_deref(),
3025 "accelerometer=(), camera=(), geolocation=(), microphone=()",
3026 );
3027 apply_security_header(
3028 headers,
3029 HeaderName::from_static("x-permitted-cross-domain-policies"),
3030 cfg.x_permitted_cross_domain_policies.as_deref(),
3031 "none",
3032 );
3033 apply_security_header(
3034 headers,
3035 HeaderName::from_static("content-security-policy"),
3036 cfg.content_security_policy.as_deref(),
3037 "default-src 'none'; frame-ancestors 'none'",
3038 );
3039 apply_security_header(
3040 headers,
3041 HeaderName::from_static("x-dns-prefetch-control"),
3042 cfg.x_dns_prefetch_control.as_deref(),
3043 "off",
3044 );
3045
3046 if is_tls {
3047 apply_security_header(
3048 headers,
3049 header::STRICT_TRANSPORT_SECURITY,
3050 cfg.strict_transport_security.as_deref(),
3051 "max-age=63072000; includeSubDomains",
3052 );
3053 }
3054
3055 resp
3056}
3057
3058fn apply_security_header(
3069 headers: &mut axum::http::HeaderMap,
3070 name: axum::http::HeaderName,
3071 override_value: Option<&str>,
3072 default: &'static str,
3073) {
3074 use axum::http::HeaderValue;
3075
3076 match override_value {
3077 None => {
3078 headers.insert(name, HeaderValue::from_static(default));
3079 }
3080 Some("") => {
3081 }
3083 Some(v) => match HeaderValue::from_str(v) {
3084 Ok(hv) => {
3085 headers.insert(name, hv);
3086 }
3087 Err(err) => {
3088 tracing::error!(
3089 header = %name,
3090 error = %err,
3091 "invalid security header override reached middleware; using default"
3092 );
3093 headers.insert(name, HeaderValue::from_static(default));
3094 }
3095 },
3096 }
3097}
3098
3099fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
3110 use axum::http::HeaderValue;
3111
3112 let fields: &[(&str, Option<&str>)] = &[
3113 (
3114 "x_content_type_options",
3115 cfg.x_content_type_options.as_deref(),
3116 ),
3117 ("x_frame_options", cfg.x_frame_options.as_deref()),
3118 ("cache_control", cfg.cache_control.as_deref()),
3119 ("referrer_policy", cfg.referrer_policy.as_deref()),
3120 (
3121 "cross_origin_opener_policy",
3122 cfg.cross_origin_opener_policy.as_deref(),
3123 ),
3124 (
3125 "cross_origin_resource_policy",
3126 cfg.cross_origin_resource_policy.as_deref(),
3127 ),
3128 (
3129 "cross_origin_embedder_policy",
3130 cfg.cross_origin_embedder_policy.as_deref(),
3131 ),
3132 ("permissions_policy", cfg.permissions_policy.as_deref()),
3133 (
3134 "x_permitted_cross_domain_policies",
3135 cfg.x_permitted_cross_domain_policies.as_deref(),
3136 ),
3137 (
3138 "content_security_policy",
3139 cfg.content_security_policy.as_deref(),
3140 ),
3141 (
3142 "x_dns_prefetch_control",
3143 cfg.x_dns_prefetch_control.as_deref(),
3144 ),
3145 (
3146 "strict_transport_security",
3147 cfg.strict_transport_security.as_deref(),
3148 ),
3149 ];
3150
3151 for (field, value) in fields {
3152 let Some(v) = value else { continue };
3153 if v.is_empty() {
3154 continue;
3155 }
3156 if let Err(err) = HeaderValue::from_str(v) {
3157 return Err(McpxError::Config(format!(
3158 "invalid security_headers.{field}: {err}"
3159 )));
3160 }
3161 }
3162
3163 if let Some(v) = cfg.strict_transport_security.as_deref()
3164 && !v.is_empty()
3165 && v.to_ascii_lowercase().contains("preload")
3166 {
3167 return Err(McpxError::Config(format!(
3168 "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
3169 HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
3170 )));
3171 }
3172
3173 Ok(())
3174}
3175
3176#[cfg(feature = "oauth")]
3191async fn oauth_token_cache_headers_middleware(
3192 req: Request<Body>,
3193 next: Next,
3194) -> axum::response::Response {
3195 use axum::http::{HeaderValue, header};
3196
3197 let mut resp = next.run(req).await;
3198 let headers = resp.headers_mut();
3199 headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
3200 headers.append(header::VARY, HeaderValue::from_static("Authorization"));
3201 resp
3202}
3203
3204async fn normalize_peer_addr_middleware(
3233 resolver: Option<Arc<ForwardResolver>>,
3234 mut req: Request<Body>,
3235 next: Next,
3236) -> axum::response::Response {
3237 let direct = req
3238 .extensions()
3239 .get::<ConnectInfo<SocketAddr>>()
3240 .map(|ci| ci.0);
3241 let from_tls = req
3242 .extensions()
3243 .get::<ConnectInfo<TlsConnInfo>>()
3244 .map(|ci| ci.0.addr);
3245 if let Some(addr) = direct.or(from_tls) {
3246 if direct.is_none() {
3247 req.extensions_mut().insert(ConnectInfo(addr));
3248 }
3249 req.extensions_mut().insert(PeerAddr::new(addr));
3250 let client_ip = match &resolver {
3251 Some(r) => {
3252 crate::forwarded::resolve_client_ip(addr.ip(), req.headers(), &r.trusted, r.mode)
3253 .unwrap_or_else(|reason| {
3254 tracing::debug!(
3255 reason = ?reason,
3256 "forwarded-header resolution fell back to direct peer"
3257 );
3258 addr.ip()
3259 })
3260 }
3261 None => addr.ip(),
3262 };
3263 req.extensions_mut().insert(ClientIp::new(client_ip));
3264 }
3265 next.run(req).await
3266}
3267
3268fn parse_proxy_net(entry: &str) -> Option<ipnet::IpNet> {
3271 if let Ok(net) = entry.parse::<ipnet::IpNet>() {
3272 return Some(net);
3273 }
3274 entry.parse::<IpAddr>().ok().map(ipnet::IpNet::from)
3275}
3276
3277pub(crate) fn limiter_client_ip(extensions: &axum::http::Extensions) -> Option<IpAddr> {
3281 if let Some(client) = extensions.get::<ClientIp>() {
3282 return Some(client.ip);
3283 }
3284 extensions
3285 .get::<ConnectInfo<SocketAddr>>()
3286 .map(|ci| ci.0.ip())
3287 .or_else(|| {
3288 extensions
3289 .get::<ConnectInfo<TlsConnInfo>>()
3290 .map(|ci| ci.0.addr.ip())
3291 })
3292}
3293
3294pub(crate) type ExtraRouteRateLimiter = BoundedKeyedLimiter<IpAddr>;
3298
3299const EXTRA_ROUTE_MAX_TRACKED_KEYS: usize = 10_000;
3305
3306const EXTRA_ROUTE_IDLE_EVICTION: Duration = Duration::from_mins(15);
3309
3310fn build_extra_route_rate_limiter(
3317 per_minute: u32,
3318 burst: Option<u32>,
3319) -> Arc<ExtraRouteRateLimiter> {
3320 let rate = std::num::NonZeroU32::new(per_minute.max(1)).unwrap_or(std::num::NonZeroU32::MIN);
3321 let mut quota = governor::Quota::per_minute(rate);
3322 if let Some(b) = burst.and_then(std::num::NonZeroU32::new) {
3323 quota = quota.allow_burst(b);
3324 }
3325 Arc::new(BoundedKeyedLimiter::new(
3326 quota,
3327 EXTRA_ROUTE_MAX_TRACKED_KEYS,
3328 EXTRA_ROUTE_IDLE_EVICTION,
3329 ))
3330}
3331
3332async fn extra_route_rate_limit_middleware(
3354 limiter: Arc<ExtraRouteRateLimiter>,
3355 exempt: Arc<std::collections::HashSet<String>>,
3356 req: Request<Body>,
3357 next: Next,
3358) -> axum::response::Response {
3359 if exempt.contains(req.uri().path()) {
3360 return next.run(req).await;
3361 }
3362 let peer_ip: Option<IpAddr> = limiter_client_ip(req.extensions());
3363 if let Some(ip) = peer_ip
3364 && let Err(wait) = limiter.check_key_wait(&ip)
3365 {
3366 #[cfg(feature = "metrics")]
3367 crate::metrics::record_rate_limit_deny(req.extensions(), "extra_route");
3368 tracing::warn!(%ip, "extra route request rate limited");
3369 return McpxError::RateLimitedFor {
3370 message: "too many requests to application routes from this source".into(),
3371 retry_after: wait,
3372 }
3373 .into_response();
3374 }
3375 next.run(req).await
3376}
3377
3378async fn origin_check_middleware(
3382 allowed: Arc<[String]>,
3383 log_request_headers: bool,
3384 req: Request<Body>,
3385 next: Next,
3386) -> axum::response::Response {
3387 let method = req.method().clone();
3388 let path = req.uri().path().to_owned();
3389
3390 log_incoming_request(&method, &path, req.headers(), log_request_headers);
3391
3392 if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
3393 let origin_str = origin.to_str().unwrap_or("");
3394 if !allowed.iter().any(|a| a == origin_str) {
3395 tracing::warn!(
3396 origin = origin_str,
3397 %method,
3398 %path,
3399 allowed = ?&*allowed,
3400 "rejected request: Origin not allowed"
3401 );
3402 return (
3403 axum::http::StatusCode::FORBIDDEN,
3404 "Forbidden: Origin not allowed",
3405 )
3406 .into_response();
3407 }
3408 }
3409 next.run(req).await
3410}
3411
3412fn log_incoming_request(
3415 method: &axum::http::Method,
3416 path: &str,
3417 headers: &axum::http::HeaderMap,
3418 log_request_headers: bool,
3419) {
3420 if log_request_headers {
3421 tracing::debug!(
3422 %method,
3423 %path,
3424 headers = %format_request_headers_for_log(headers),
3425 "incoming request"
3426 );
3427 } else {
3428 tracing::debug!(%method, %path, "incoming request");
3429 }
3430}
3431
3432fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
3433 headers
3434 .iter()
3435 .map(|(k, v)| {
3436 let name = k.as_str();
3437 if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
3438 format!("{name}: [REDACTED]")
3439 } else {
3440 format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
3441 }
3442 })
3443 .collect::<Vec<_>>()
3444 .join(", ")
3445}
3446
3447#[allow(
3471 clippy::cognitive_complexity,
3472 reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
3473)]
3474pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
3475where
3476 H: ServerHandler + 'static,
3477{
3478 use rmcp::ServiceExt as _;
3479
3480 tracing::info!("stdio transport: serving on stdin/stdout");
3481 tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
3482
3483 let transport = rmcp::transport::io::stdio();
3484
3485 let service = handler
3486 .serve(transport)
3487 .await
3488 .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
3489
3490 if let Err(e) = service.waiting().await {
3491 tracing::warn!(error = %e, "stdio session ended with error");
3492 }
3493 tracing::info!("stdio session ended");
3494 Ok(())
3495}
3496
3497#[cfg(test)]
3498mod tests {
3499 #![allow(
3500 clippy::unwrap_used,
3501 clippy::expect_used,
3502 clippy::panic,
3503 clippy::indexing_slicing,
3504 clippy::unwrap_in_result,
3505 clippy::print_stdout,
3506 clippy::print_stderr,
3507 deprecated,
3508 reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
3509 )]
3510 use std::{sync::Arc, time::Duration};
3511
3512 use axum::{
3513 body::Body,
3514 http::{Request, StatusCode, header},
3515 response::IntoResponse,
3516 };
3517 use http_body_util::BodyExt;
3518 use tower::ServiceExt as _;
3519
3520 use super::*;
3521
3522 #[test]
3525 fn server_config_new_defaults() {
3526 let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
3527 assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
3528 assert_eq!(cfg.name, "test-server");
3529 assert_eq!(cfg.version, "1.0.0");
3530 assert!(cfg.tls_cert_path.is_none());
3531 assert!(cfg.tls_key_path.is_none());
3532 assert!(cfg.auth.is_none());
3533 assert!(cfg.rbac.is_none());
3534 assert!(cfg.allowed_origins.is_empty());
3535 assert!(cfg.tool_rate_limit.is_none());
3536 assert!(cfg.readiness_check.is_none());
3537 assert_eq!(cfg.max_request_body, 1024 * 1024);
3538 assert_eq!(cfg.request_timeout, Duration::from_mins(2));
3539 assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
3540 assert!(!cfg.log_request_headers);
3541 assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(10));
3542 assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
3543 }
3544
3545 #[test]
3546 fn tls_handshake_builders_set_fields() {
3547 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3548 .with_tls_handshake_timeout(Duration::from_secs(3))
3549 .with_max_concurrent_tls_handshakes(64);
3550 assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(3));
3551 assert_eq!(cfg.max_concurrent_tls_handshakes, 64);
3552 }
3553
3554 #[test]
3555 fn validate_rejects_zero_tls_handshake_timeout() {
3556 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3557 .with_tls_handshake_timeout(Duration::ZERO);
3558 let err = cfg.validate().expect_err("zero handshake timeout");
3559 assert!(err.to_string().contains("tls_handshake_timeout"));
3560 }
3561
3562 #[test]
3563 fn validate_rejects_zero_max_concurrent_tls_handshakes() {
3564 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3565 .with_max_concurrent_tls_handshakes(0);
3566 let err = cfg.validate().expect_err("zero handshake concurrency");
3567 assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
3568 }
3569
3570 #[test]
3571 fn validate_consumes_and_proves() {
3572 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3574 let validated = cfg.validate().expect("valid config");
3575 assert_eq!(validated.as_inner().name, "test-server");
3577 let raw = validated.into_inner();
3579 assert_eq!(raw.name, "test-server");
3580
3581 let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3583 bad.max_request_body = 0;
3584 assert!(bad.validate().is_err(), "zero body cap must fail validate");
3585 }
3586
3587 #[test]
3588 fn validate_rejects_zero_max_concurrent_requests() {
3589 let cfg =
3590 McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
3591 let err = cfg.validate().expect_err("zero concurrency cap must fail");
3592 assert!(
3593 format!("{err}").contains("max_concurrent_requests"),
3594 "error should mention max_concurrent_requests, got: {err}"
3595 );
3596 }
3597
3598 #[test]
3599 fn validate_rejects_zero_max_tracked_keys() {
3600 let rl = crate::auth::RateLimitConfig {
3603 max_attempts_per_minute: 30,
3604 pre_auth_max_per_minute: None,
3605 max_tracked_keys: 0,
3606 idle_eviction: Duration::from_secs(15 * 60),
3607 burst: None,
3608 pre_auth_burst: None,
3609 };
3610 let auth_cfg = AuthConfig {
3611 enabled: true,
3612 api_keys: Vec::new(),
3613 mtls: None,
3614 rate_limit: Some(rl),
3615 #[cfg(feature = "oauth")]
3616 oauth: None,
3617 };
3618 let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
3619 let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
3620 assert!(
3621 format!("{err}").contains("max_tracked_keys"),
3622 "error should mention max_tracked_keys, got: {err}"
3623 );
3624 }
3625
3626 #[test]
3627 fn derive_allowed_hosts_includes_public_host() {
3628 let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
3629 assert!(
3630 hosts.iter().any(|h| h == "mcp.example.com"),
3631 "public_url host must be allowed"
3632 );
3633 }
3634
3635 #[test]
3636 fn derive_allowed_hosts_includes_bind_authority() {
3637 let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
3638 assert!(
3639 hosts.iter().any(|h| h == "127.0.0.1"),
3640 "bind host must be allowed"
3641 );
3642 assert!(
3643 hosts.iter().any(|h| h == "127.0.0.1:8080"),
3644 "bind authority must be allowed"
3645 );
3646 }
3647
3648 #[tokio::test]
3651 async fn healthz_returns_ok_json() {
3652 let resp = healthz().await.into_response();
3653 assert_eq!(resp.status(), StatusCode::OK);
3654 let body = resp.into_body().collect().await.unwrap().to_bytes();
3655 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3656 assert_eq!(json["status"], "ok");
3657 assert!(
3658 json.get("name").is_none(),
3659 "healthz must not expose server name"
3660 );
3661 assert!(
3662 json.get("version").is_none(),
3663 "healthz must not expose version"
3664 );
3665 }
3666
3667 #[tokio::test]
3670 async fn readyz_returns_ok_when_ready() {
3671 let check: ReadinessCheck =
3672 Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
3673 let resp = readyz(check).await.into_response();
3674 assert_eq!(resp.status(), StatusCode::OK);
3675 let body = resp.into_body().collect().await.unwrap().to_bytes();
3676 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3677 assert_eq!(json["ready"], true);
3678 assert!(
3679 json.get("name").is_none(),
3680 "readyz must not expose server name"
3681 );
3682 assert!(
3683 json.get("version").is_none(),
3684 "readyz must not expose version"
3685 );
3686 assert_eq!(json["db"], "connected");
3687 }
3688
3689 #[tokio::test]
3690 async fn readyz_returns_503_when_not_ready() {
3691 let check: ReadinessCheck =
3692 Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
3693 let resp = readyz(check).await.into_response();
3694 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3695 }
3696
3697 #[tokio::test]
3698 async fn readyz_returns_503_when_ready_missing() {
3699 let check: ReadinessCheck =
3700 Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
3701 let resp = readyz(check).await.into_response();
3702 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3704 }
3705
3706 fn peer_probe_router() -> axum::Router {
3711 async fn probe(req: Request<Body>) -> String {
3712 let ci = req
3713 .extensions()
3714 .get::<ConnectInfo<SocketAddr>>()
3715 .map(|c| c.0.to_string())
3716 .unwrap_or_default();
3717 let pa = req
3718 .extensions()
3719 .get::<PeerAddr>()
3720 .map(|p| p.addr.to_string())
3721 .unwrap_or_default();
3722 format!("{ci}|{pa}")
3723 }
3724 axum::Router::new()
3725 .route("/probe", axum::routing::get(probe))
3726 .layer(axum::middleware::from_fn(|req, next| {
3727 normalize_peer_addr_middleware(None, req, next)
3728 }))
3729 }
3730
3731 async fn body_string(resp: axum::response::Response) -> String {
3732 let bytes = resp.into_body().collect().await.unwrap().to_bytes();
3733 String::from_utf8(bytes.to_vec()).unwrap()
3734 }
3735
3736 #[tokio::test]
3737 async fn normalize_preserves_existing_connect_info_and_mirrors_peer_addr() {
3738 let plain: SocketAddr = "10.0.0.1:1111".parse().unwrap();
3741 let tls: SocketAddr = "10.0.0.2:2222".parse().unwrap();
3742 let req = Request::builder()
3743 .uri("/probe")
3744 .extension(ConnectInfo(plain))
3745 .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3746 .body(Body::empty())
3747 .unwrap();
3748 let resp = peer_probe_router().oneshot(req).await.unwrap();
3749 assert_eq!(resp.status(), StatusCode::OK);
3750 assert_eq!(body_string(resp).await, format!("{plain}|{plain}"));
3751 }
3752
3753 #[tokio::test]
3754 async fn normalize_inserts_connect_info_and_peer_addr_from_tls() {
3755 let tls: SocketAddr = "192.168.1.7:50443".parse().unwrap();
3756 let req = Request::builder()
3757 .uri("/probe")
3758 .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3759 .body(Body::empty())
3760 .unwrap();
3761 let resp = peer_probe_router().oneshot(req).await.unwrap();
3762 assert_eq!(resp.status(), StatusCode::OK);
3763 assert_eq!(body_string(resp).await, format!("{tls}|{tls}"));
3764 }
3765
3766 #[tokio::test]
3767 async fn normalize_no_op_without_any_connect_info() {
3768 let req = Request::builder()
3769 .uri("/probe")
3770 .body(Body::empty())
3771 .unwrap();
3772 let resp = peer_probe_router().oneshot(req).await.unwrap();
3773 assert_eq!(resp.status(), StatusCode::OK);
3774 assert_eq!(body_string(resp).await, "|");
3775 }
3776
3777 #[tokio::test]
3778 async fn peer_addr_extractor_rejects_when_absent() {
3779 async fn h(peer: PeerAddr) -> String {
3780 peer.addr.to_string()
3781 }
3782 let app = axum::Router::new().route("/p", axum::routing::get(h));
3783 let req = Request::builder().uri("/p").body(Body::empty()).unwrap();
3784 let resp = app.oneshot(req).await.unwrap();
3785 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
3786 }
3787
3788 #[tokio::test]
3789 async fn peer_addr_extractor_returns_value_when_present() {
3790 async fn h(peer: PeerAddr) -> String {
3791 peer.addr.to_string()
3792 }
3793 let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap();
3794 let app = axum::Router::new().route("/p", axum::routing::get(h));
3795 let req = Request::builder()
3796 .uri("/p")
3797 .extension(PeerAddr::new(addr))
3798 .body(Body::empty())
3799 .unwrap();
3800 let resp = app.oneshot(req).await.unwrap();
3801 assert_eq!(resp.status(), StatusCode::OK);
3802 assert_eq!(body_string(resp).await, addr.to_string());
3803 }
3804
3805 #[tokio::test]
3806 async fn peer_addr_via_extension_extractor() {
3807 async fn h(axum::Extension(peer): axum::Extension<PeerAddr>) -> String {
3808 peer.addr.to_string()
3809 }
3810 let addr: SocketAddr = "127.0.0.1:4242".parse().unwrap();
3811 let app = axum::Router::new().route("/p", axum::routing::get(h));
3812 let req = Request::builder()
3813 .uri("/p")
3814 .extension(PeerAddr::new(addr))
3815 .body(Body::empty())
3816 .unwrap();
3817 let resp = app.oneshot(req).await.unwrap();
3818 assert_eq!(resp.status(), StatusCode::OK);
3819 assert_eq!(body_string(resp).await, addr.to_string());
3820 }
3821
3822 fn limited_router(per_minute: u32) -> axum::Router {
3827 limited_router_with_burst(per_minute, None)
3828 }
3829
3830 fn limited_router_with_burst(per_minute: u32, burst: Option<u32>) -> axum::Router {
3832 limited_router_full(per_minute, burst, &[])
3833 }
3834
3835 fn limited_router_full(
3839 per_minute: u32,
3840 burst: Option<u32>,
3841 exempt_paths: &[&str],
3842 ) -> axum::Router {
3843 let limiter = build_extra_route_rate_limiter(per_minute, burst);
3844 let exempt: Arc<std::collections::HashSet<String>> =
3845 Arc::new(exempt_paths.iter().map(|s| (*s).to_owned()).collect());
3846 axum::Router::new()
3847 .route("/limited", axum::routing::get(|| async { "ok" }))
3848 .route("/exempt", axum::routing::get(|| async { "ok" }))
3849 .layer(axum::middleware::from_fn(move |req, next| {
3850 let l = Arc::clone(&limiter);
3851 let e = Arc::clone(&exempt);
3852 extra_route_rate_limit_middleware(l, e, req, next)
3853 }))
3854 }
3855
3856 fn limited_req(ip: &str) -> Request<Body> {
3857 limited_req_to(ip, "/limited")
3858 }
3859
3860 fn limited_req_to(ip: &str, path: &str) -> Request<Body> {
3861 let addr: SocketAddr = format!("{ip}:40000").parse().unwrap();
3862 Request::builder()
3863 .uri(path)
3864 .extension(ConnectInfo(addr))
3865 .body(Body::empty())
3866 .unwrap()
3867 }
3868
3869 #[tokio::test]
3870 async fn extra_route_limiter_denies_over_quota() {
3871 let app = limited_router(2);
3872 for i in 0..2 {
3873 let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3874 assert_eq!(resp.status(), StatusCode::OK, "request {i} should pass");
3875 }
3876 let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3877 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3878 let body = body_string(resp).await;
3879 assert!(
3880 body.contains("too many requests to application routes"),
3881 "deny body should match the limiter message, got: {body}"
3882 );
3883 }
3884
3885 #[tokio::test]
3886 async fn extra_route_limiter_isolates_keys() {
3887 let app = limited_router(2);
3888 for _ in 0..2 {
3889 let resp = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3890 assert_eq!(resp.status(), StatusCode::OK);
3891 }
3892 let exhausted = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3893 assert_eq!(exhausted.status(), StatusCode::TOO_MANY_REQUESTS);
3894 let other = app.clone().oneshot(limited_req("10.3.3.3")).await.unwrap();
3896 assert_eq!(other.status(), StatusCode::OK);
3897 }
3898
3899 #[tokio::test]
3900 async fn extra_route_limiter_fails_open_without_peer() {
3901 let app = limited_router(1);
3902 for i in 0..3 {
3903 let req = Request::builder()
3904 .uri("/limited")
3905 .body(Body::empty())
3906 .unwrap();
3907 let resp = app.clone().oneshot(req).await.unwrap();
3908 assert_eq!(
3909 resp.status(),
3910 StatusCode::OK,
3911 "request {i} should fail open"
3912 );
3913 }
3914 }
3915
3916 #[tokio::test]
3917 async fn extra_route_limiter_extracts_tls_conn_info() {
3918 let app = limited_router(2);
3919 let mk = || {
3920 let addr: SocketAddr = "192.168.9.9:55555".parse().unwrap();
3921 Request::builder()
3922 .uri("/limited")
3923 .extension(ConnectInfo(TlsConnInfo::new(addr, None)))
3924 .body(Body::empty())
3925 .unwrap()
3926 };
3927 for _ in 0..2 {
3928 assert_eq!(
3929 app.clone().oneshot(mk()).await.unwrap().status(),
3930 StatusCode::OK
3931 );
3932 }
3933 let resp = app.clone().oneshot(mk()).await.unwrap();
3934 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3935 }
3936
3937 #[tokio::test]
3938 async fn extra_route_limiter_exempt_path_bypasses_quota() {
3939 let app = limited_router_full(1, None, &["/exempt"]);
3942 for i in 0..5 {
3943 let resp = app
3944 .clone()
3945 .oneshot(limited_req_to("10.6.6.6", "/exempt"))
3946 .await
3947 .unwrap();
3948 assert_eq!(resp.status(), StatusCode::OK, "exempt request {i}");
3949 }
3950 let resp = app.clone().oneshot(limited_req("10.6.6.6")).await.unwrap();
3952 assert_eq!(resp.status(), StatusCode::OK);
3953 let resp = app.clone().oneshot(limited_req("10.6.6.6")).await.unwrap();
3955 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3956 }
3957
3958 #[tokio::test]
3959 async fn extra_route_limiter_exemption_is_raw_exact_match() {
3960 let app = limited_router_full(1, None, &["/exempt"]);
3963 let ok = app
3964 .clone()
3965 .oneshot(limited_req_to("10.7.7.7", "/exempt/"))
3966 .await
3967 .unwrap();
3968 assert_eq!(
3969 ok.status(),
3970 StatusCode::NOT_FOUND,
3971 "variant path routes 404"
3972 );
3973 let denied = app
3975 .clone()
3976 .oneshot(limited_req_to("10.7.7.7", "/limited"))
3977 .await
3978 .unwrap();
3979 assert_eq!(denied.status(), StatusCode::TOO_MANY_REQUESTS);
3980 }
3981
3982 #[cfg(feature = "metrics")]
3983 #[tokio::test]
3984 async fn extra_route_limiter_deny_increments_counter_exempt_does_not() {
3985 let metrics = Arc::new(crate::metrics::McpMetrics::new().unwrap());
3986 let app = limited_router_full(1, None, &["/exempt"]);
3987 let mk = |path: &str| {
3988 let addr: SocketAddr = "10.8.8.8:40000".parse().unwrap();
3989 Request::builder()
3990 .uri(path)
3991 .extension(ConnectInfo(addr))
3992 .extension(Arc::clone(&metrics))
3993 .body(Body::empty())
3994 .unwrap()
3995 };
3996 let counter = || {
3997 metrics
3998 .rate_limited_total
3999 .with_label_values(&["extra_route"])
4000 .get()
4001 };
4002 for _ in 0..3 {
4004 assert_eq!(
4005 app.clone().oneshot(mk("/exempt")).await.unwrap().status(),
4006 StatusCode::OK
4007 );
4008 }
4009 assert_eq!(counter(), 0, "exempt requests must not count as denies");
4010 assert_eq!(
4012 app.clone().oneshot(mk("/limited")).await.unwrap().status(),
4013 StatusCode::OK
4014 );
4015 assert_eq!(counter(), 0);
4016 assert_eq!(
4017 app.clone().oneshot(mk("/limited")).await.unwrap().status(),
4018 StatusCode::TOO_MANY_REQUESTS
4019 );
4020 assert_eq!(counter(), 1, "deny must increment the extra_route label");
4021 }
4022
4023 #[test]
4024 fn validate_rejects_exempt_paths_without_base_knob() {
4025 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
4026 .with_extra_route_rate_limit_exempt_paths(["/ok"]);
4027 let err = cfg.validate().expect_err("exempt paths without rate limit");
4028 assert!(err.to_string().contains("requires extra_route_rate_limit"));
4029 }
4030
4031 #[test]
4032 fn validate_rejects_malformed_exempt_paths() {
4033 for bad in ["", "no-slash"] {
4034 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
4035 .with_extra_route_rate_limit(10)
4036 .with_extra_route_rate_limit_exempt_paths([bad]);
4037 let err = cfg.validate().expect_err("malformed exempt path");
4038 assert!(
4039 err.to_string()
4040 .contains("must be non-empty and start with '/'"),
4041 "entry {bad:?}: {err}"
4042 );
4043 }
4044 }
4045
4046 #[test]
4047 fn validate_accepts_wellformed_exempt_paths() {
4048 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
4049 .with_extra_route_rate_limit(10)
4050 .with_extra_route_rate_limit_exempt_paths(["/.well-known/oauth-authorization-server"]);
4051 assert!(cfg.validate().is_ok());
4052 }
4053
4054 #[test]
4055 fn validate_rejects_zero_extra_route_rate_limit() {
4056 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
4057 .with_extra_route_rate_limit(0);
4058 let err = cfg.validate().expect_err("zero extra route rate limit");
4059 assert!(err.to_string().contains("extra_route_rate_limit"));
4060 }
4061
4062 #[tokio::test]
4063 async fn extra_route_limiter_burst_allows_initial_spike() {
4064 let app = limited_router_with_burst(1, Some(3));
4065 for i in 0..3 {
4066 let resp = app.clone().oneshot(limited_req("10.4.4.4")).await.unwrap();
4067 assert_eq!(resp.status(), StatusCode::OK, "burst request {i}");
4068 }
4069 let resp = app.clone().oneshot(limited_req("10.4.4.4")).await.unwrap();
4070 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
4071 }
4072
4073 #[tokio::test]
4074 async fn extra_route_limiter_deny_sets_retry_after() {
4075 let app = limited_router(1);
4076 let ok = app.clone().oneshot(limited_req("10.5.5.5")).await.unwrap();
4077 assert_eq!(ok.status(), StatusCode::OK);
4078 let denied = app.clone().oneshot(limited_req("10.5.5.5")).await.unwrap();
4079 assert_eq!(denied.status(), StatusCode::TOO_MANY_REQUESTS);
4080 let retry_after = denied
4081 .headers()
4082 .get(header::RETRY_AFTER)
4083 .expect("Retry-After present")
4084 .to_str()
4085 .unwrap()
4086 .parse::<u64>()
4087 .unwrap();
4088 assert!(retry_after >= 1, "delta-seconds must be >= 1");
4089 }
4090
4091 #[test]
4092 fn validate_rejects_zero_burst_knobs() {
4093 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4094 .with_tool_rate_limit(10)
4095 .with_tool_rate_limit_burst(0)
4096 .validate()
4097 .expect_err("zero tool burst");
4098 assert!(err.to_string().contains("tool_rate_limit_burst"));
4099
4100 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4101 .with_extra_route_rate_limit(10)
4102 .with_extra_route_rate_limit_burst(0)
4103 .validate()
4104 .expect_err("zero extra route burst");
4105 assert!(err.to_string().contains("extra_route_rate_limit_burst"));
4106 }
4107
4108 #[test]
4109 fn validate_rejects_orphan_burst_knobs() {
4110 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4111 .with_tool_rate_limit_burst(5)
4112 .validate()
4113 .expect_err("orphan tool burst");
4114 assert!(err.to_string().contains("requires tool_rate_limit"));
4115
4116 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4117 .with_extra_route_rate_limit_burst(5)
4118 .validate()
4119 .expect_err("orphan extra route burst");
4120 assert!(err.to_string().contains("requires extra_route_rate_limit"));
4121 }
4122
4123 #[test]
4124 fn validate_rejects_zero_auth_bursts() {
4125 let auth = AuthConfig::with_keys(vec![])
4126 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_burst(0));
4127 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4128 .with_auth(auth)
4129 .validate()
4130 .expect_err("zero auth burst");
4131 assert!(err.to_string().contains("rate_limit.burst"));
4132
4133 let auth = AuthConfig::with_keys(vec![])
4134 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(0));
4135 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4136 .with_auth(auth)
4137 .validate()
4138 .expect_err("zero pre-auth burst");
4139 assert!(err.to_string().contains("pre_auth_burst"));
4140 }
4141
4142 #[test]
4145 fn validate_accepts_pre_auth_burst_without_explicit_pre_auth_rate() {
4146 let auth = AuthConfig::with_keys(vec![])
4147 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(50));
4148 let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0").with_auth(auth);
4149 assert!(cfg.validate().is_ok(), "pre_auth_burst has no orphan rule");
4150 }
4151
4152 fn forward_resolver(trusted: &[&str], mode: ForwardedHeaderMode) -> Arc<ForwardResolver> {
4155 Arc::new(ForwardResolver {
4156 trusted: trusted.iter().map(|s| s.parse().unwrap()).collect(),
4157 mode,
4158 })
4159 }
4160
4161 fn forwarded_probe_router(resolver: Option<Arc<ForwardResolver>>) -> axum::Router {
4163 async fn probe(req: Request<Body>) -> String {
4164 let pa = req
4165 .extensions()
4166 .get::<PeerAddr>()
4167 .map(|p| p.addr.ip().to_string())
4168 .unwrap_or_default();
4169 let ci = req
4170 .extensions()
4171 .get::<ClientIp>()
4172 .map(|c| c.ip.to_string())
4173 .unwrap_or_default();
4174 format!("{pa}|{ci}")
4175 }
4176 axum::Router::new()
4177 .route("/probe", axum::routing::get(probe))
4178 .layer(axum::middleware::from_fn(move |req, next| {
4179 let r = resolver.clone();
4180 normalize_peer_addr_middleware(r, req, next)
4181 }))
4182 }
4183
4184 fn probe_req(peer: &str, header: Option<(&str, &str)>) -> Request<Body> {
4185 let addr: SocketAddr = peer.parse().unwrap();
4186 let mut builder = Request::builder()
4187 .uri("/probe")
4188 .extension(ConnectInfo(addr));
4189 if let Some((name, value)) = header {
4190 builder = builder.header(name, value);
4191 }
4192 builder.body(Body::empty()).unwrap()
4193 }
4194
4195 #[tokio::test]
4196 async fn client_ip_equals_direct_without_resolver() {
4197 let app = forwarded_probe_router(None);
4198 let resp = app
4199 .oneshot(probe_req(
4200 "10.1.2.3:4444",
4201 Some(("x-forwarded-for", "203.0.113.7")),
4202 ))
4203 .await
4204 .unwrap();
4205 assert_eq!(
4206 body_string(resp).await,
4207 "10.1.2.3|10.1.2.3",
4208 "feature off: header ignored, ClientIp == direct"
4209 );
4210 }
4211
4212 #[tokio::test]
4213 async fn client_ip_resolved_for_trusted_peer() {
4214 let app = forwarded_probe_router(Some(forward_resolver(
4215 &["10.0.0.0/8"],
4216 ForwardedHeaderMode::XForwardedFor,
4217 )));
4218 let resp = app
4219 .oneshot(probe_req(
4220 "10.0.0.1:9999",
4221 Some(("x-forwarded-for", "203.0.113.7")),
4222 ))
4223 .await
4224 .unwrap();
4225 assert_eq!(
4226 body_string(resp).await,
4227 "10.0.0.1|203.0.113.7",
4228 "PeerAddr stays direct while ClientIp resolves"
4229 );
4230 }
4231
4232 #[tokio::test]
4233 async fn client_ip_falls_back_to_direct_on_malformed_header() {
4234 let app = forwarded_probe_router(Some(forward_resolver(
4235 &["10.0.0.0/8"],
4236 ForwardedHeaderMode::XForwardedFor,
4237 )));
4238 let resp = app
4239 .oneshot(probe_req(
4240 "10.0.0.1:9999",
4241 Some(("x-forwarded-for", "not-an-ip")),
4242 ))
4243 .await
4244 .unwrap();
4245 assert_eq!(
4246 body_string(resp).await,
4247 "10.0.0.1|10.0.0.1",
4248 "malformed chain falls back to the direct peer"
4249 );
4250 }
4251
4252 #[test]
4253 fn forwarded_header_mode_deserializes_kebab_case() {
4254 #[derive(serde::Deserialize)]
4255 struct Wrapper {
4256 mode: ForwardedHeaderMode,
4257 }
4258 let w: Wrapper = toml::from_str(r#"mode = "x-forwarded-for""#).unwrap();
4259 assert_eq!(w.mode, ForwardedHeaderMode::XForwardedFor);
4260 let w: Wrapper = toml::from_str(r#"mode = "forwarded""#).unwrap();
4261 assert_eq!(w.mode, ForwardedHeaderMode::Forwarded);
4262 assert!(
4263 toml::from_str::<Wrapper>(r#"mode = "XForwardedFor""#).is_err(),
4264 "PascalCase wire value must be rejected"
4265 );
4266 }
4267
4268 #[test]
4269 fn validate_rejects_bad_trusted_proxy_entry() {
4270 let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4271 .with_trusted_proxies(["not-a-cidr"]);
4272 let err = cfg.validate().expect_err("bad CIDR");
4273 assert!(err.to_string().contains("trusted_proxies"));
4274 }
4275
4276 #[test]
4277 fn validate_accepts_cidr_and_bare_ip_proxy_entries() {
4278 let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0").with_trusted_proxies([
4279 "10.0.0.0/8",
4280 "192.0.2.1",
4281 "2001:db8::1",
4282 ]);
4283 assert!(cfg.validate().is_ok(), "CIDRs and bare IPs are accepted");
4284 }
4285
4286 #[test]
4287 fn validate_rejects_forwarded_header_without_proxies() {
4288 let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4289 .with_forwarded_header(ForwardedHeaderMode::Forwarded);
4290 let err = cfg.validate().expect_err("mode without proxies");
4291 assert!(err.to_string().contains("requires trusted_proxies"));
4292 }
4293
4294 fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
4298 let allowed: Arc<[String]> = Arc::from(origins);
4299 axum::Router::new()
4300 .route("/test", axum::routing::get(|| async { "ok" }))
4301 .layer(axum::middleware::from_fn(move |req, next| {
4302 let a = Arc::clone(&allowed);
4303 origin_check_middleware(a, log_request_headers, req, next)
4304 }))
4305 }
4306
4307 #[tokio::test]
4308 async fn origin_allowed_passes() {
4309 let app = origin_router(vec!["http://localhost:3000".into()], false);
4310 let req = Request::builder()
4311 .uri("/test")
4312 .header(header::ORIGIN, "http://localhost:3000")
4313 .body(Body::empty())
4314 .unwrap();
4315 let resp = app.oneshot(req).await.unwrap();
4316 assert_eq!(resp.status(), StatusCode::OK);
4317 }
4318
4319 #[tokio::test]
4320 async fn origin_rejected_returns_403() {
4321 let app = origin_router(vec!["http://localhost:3000".into()], false);
4322 let req = Request::builder()
4323 .uri("/test")
4324 .header(header::ORIGIN, "http://evil.com")
4325 .body(Body::empty())
4326 .unwrap();
4327 let resp = app.oneshot(req).await.unwrap();
4328 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
4329 }
4330
4331 #[tokio::test]
4332 async fn no_origin_header_passes() {
4333 let app = origin_router(vec!["http://localhost:3000".into()], false);
4334 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4335 let resp = app.oneshot(req).await.unwrap();
4336 assert_eq!(resp.status(), StatusCode::OK);
4337 }
4338
4339 #[tokio::test]
4340 async fn empty_allowlist_rejects_any_origin() {
4341 let app = origin_router(vec![], false);
4342 let req = Request::builder()
4343 .uri("/test")
4344 .header(header::ORIGIN, "http://anything.com")
4345 .body(Body::empty())
4346 .unwrap();
4347 let resp = app.oneshot(req).await.unwrap();
4348 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
4349 }
4350
4351 #[tokio::test]
4352 async fn empty_allowlist_passes_without_origin() {
4353 let app = origin_router(vec![], false);
4354 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4355 let resp = app.oneshot(req).await.unwrap();
4356 assert_eq!(resp.status(), StatusCode::OK);
4357 }
4358
4359 #[test]
4360 fn format_request_headers_redacts_sensitive_values() {
4361 let mut headers = axum::http::HeaderMap::new();
4362 headers.insert("authorization", "Bearer secret-token".parse().unwrap());
4363 headers.insert("cookie", "sid=abc".parse().unwrap());
4364 headers.insert("x-request-id", "req-123".parse().unwrap());
4365
4366 let out = format_request_headers_for_log(&headers);
4367 assert!(out.contains("authorization: [REDACTED]"));
4368 assert!(out.contains("cookie: [REDACTED]"));
4369 assert!(out.contains("x-request-id: req-123"));
4370 assert!(!out.contains("secret-token"));
4371 }
4372
4373 fn security_router(is_tls: bool) -> axum::Router {
4376 security_router_with(is_tls, SecurityHeadersConfig::default())
4377 }
4378
4379 fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
4380 let cfg = Arc::new(cfg);
4381 axum::Router::new()
4382 .route("/test", axum::routing::get(|| async { "ok" }))
4383 .layer(axum::middleware::from_fn(move |req, next| {
4384 let c = Arc::clone(&cfg);
4385 security_headers_middleware(is_tls, c, req, next)
4386 }))
4387 }
4388
4389 #[tokio::test]
4390 async fn security_headers_set_on_response() {
4391 let app = security_router(false);
4392 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4393 let resp = app.oneshot(req).await.unwrap();
4394 assert_eq!(resp.status(), StatusCode::OK);
4395
4396 let h = resp.headers();
4397 assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
4398 assert_eq!(h.get("x-frame-options").unwrap(), "deny");
4399 assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
4400 assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
4401 assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
4402 assert_eq!(
4403 h.get("cross-origin-resource-policy").unwrap(),
4404 "same-origin"
4405 );
4406 assert_eq!(
4407 h.get("cross-origin-embedder-policy").unwrap(),
4408 "require-corp"
4409 );
4410 assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
4411 assert!(
4412 h.get("permissions-policy")
4413 .unwrap()
4414 .to_str()
4415 .unwrap()
4416 .contains("camera=()"),
4417 "permissions-policy must restrict browser features"
4418 );
4419 assert_eq!(
4420 h.get("content-security-policy").unwrap(),
4421 "default-src 'none'; frame-ancestors 'none'"
4422 );
4423 assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
4424 assert!(h.get("strict-transport-security").is_none());
4426 }
4427
4428 #[tokio::test]
4429 async fn hsts_set_when_tls_enabled() {
4430 let app = security_router(true);
4431 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4432 let resp = app.oneshot(req).await.unwrap();
4433
4434 let hsts = resp.headers().get("strict-transport-security").unwrap();
4435 assert!(
4436 hsts.to_str().unwrap().contains("max-age=63072000"),
4437 "HSTS must set 2-year max-age"
4438 );
4439 }
4440
4441 fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
4447 let cfg =
4448 McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
4449 cfg.check()
4450 }
4451
4452 #[test]
4453 fn security_headers_config_default_validates() {
4454 check_with_security_headers(SecurityHeadersConfig::default())
4455 .expect("default SecurityHeadersConfig must validate");
4456 }
4457
4458 #[test]
4459 fn security_headers_config_validate_accepts_empty_string() {
4460 let h = SecurityHeadersConfig {
4462 x_content_type_options: Some(String::new()),
4463 x_frame_options: Some(String::new()),
4464 cache_control: Some(String::new()),
4465 referrer_policy: Some(String::new()),
4466 cross_origin_opener_policy: Some(String::new()),
4467 cross_origin_resource_policy: Some(String::new()),
4468 cross_origin_embedder_policy: Some(String::new()),
4469 permissions_policy: Some(String::new()),
4470 x_permitted_cross_domain_policies: Some(String::new()),
4471 content_security_policy: Some(String::new()),
4472 x_dns_prefetch_control: Some(String::new()),
4473 strict_transport_security: Some(String::new()),
4474 };
4475 check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
4476 }
4477
4478 #[test]
4479 fn security_headers_config_validate_rejects_bad_value() {
4480 let h = SecurityHeadersConfig {
4482 referrer_policy: Some("\u{0007}".into()),
4483 ..SecurityHeadersConfig::default()
4484 };
4485 let err = check_with_security_headers(h)
4486 .expect_err("control char in referrer_policy must reject");
4487 let msg = err.to_string();
4488 assert!(
4489 msg.contains("referrer_policy"),
4490 "error must name the offending field, got: {msg}"
4491 );
4492 }
4493
4494 #[test]
4495 fn security_headers_config_validate_rejects_hsts_preload() {
4496 let h = SecurityHeadersConfig {
4497 strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
4498 ..SecurityHeadersConfig::default()
4499 };
4500 let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
4501 let msg = err.to_string();
4502 assert!(
4503 msg.contains("strict_transport_security"),
4504 "error must name the field, got: {msg}"
4505 );
4506 assert!(
4507 msg.to_lowercase().contains("preload"),
4508 "error must mention `preload`, got: {msg}"
4509 );
4510 }
4511
4512 #[test]
4513 fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
4514 let h = SecurityHeadersConfig {
4516 strict_transport_security: Some("max-age=600; PRELOAD".into()),
4517 ..SecurityHeadersConfig::default()
4518 };
4519 check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
4520 }
4521
4522 #[tokio::test]
4523 async fn security_headers_override_honored() {
4524 let h = SecurityHeadersConfig {
4526 x_frame_options: Some("SAMEORIGIN".into()),
4527 ..SecurityHeadersConfig::default()
4528 };
4529 let app = security_router_with(false, h);
4530 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4531 let resp = app.oneshot(req).await.unwrap();
4532 assert_eq!(resp.status(), StatusCode::OK);
4533
4534 let xfo = resp.headers().get("x-frame-options").unwrap();
4535 assert_eq!(xfo, "SAMEORIGIN");
4536 }
4537
4538 #[tokio::test]
4539 async fn security_headers_empty_string_omits() {
4540 let h = SecurityHeadersConfig {
4542 referrer_policy: Some(String::new()),
4543 ..SecurityHeadersConfig::default()
4544 };
4545 let app = security_router_with(false, h);
4546 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4547 let resp = app.oneshot(req).await.unwrap();
4548 assert_eq!(resp.status(), StatusCode::OK);
4549
4550 assert!(
4551 resp.headers().get("referrer-policy").is_none(),
4552 "Some(\"\") must omit the header"
4553 );
4554 assert_eq!(
4556 resp.headers().get("x-content-type-options").unwrap(),
4557 "nosniff"
4558 );
4559 }
4560
4561 #[tokio::test]
4562 async fn security_headers_hsts_only_when_tls() {
4563 let h = SecurityHeadersConfig {
4565 strict_transport_security: Some("max-age=600".into()),
4566 ..SecurityHeadersConfig::default()
4567 };
4568 let app = security_router_with(false, h);
4569 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4570 let resp = app.oneshot(req).await.unwrap();
4571 assert!(
4572 resp.headers().get("strict-transport-security").is_none(),
4573 "HSTS must remain absent on plaintext deployments even with override"
4574 );
4575 }
4576
4577 #[cfg(feature = "oauth")]
4580 #[tokio::test]
4581 async fn oauth_token_cache_headers_set_pragma_and_vary() {
4582 let app = axum::Router::new()
4583 .route("/token", axum::routing::post(|| async { "{}" }))
4584 .layer(axum::middleware::from_fn(
4585 oauth_token_cache_headers_middleware,
4586 ));
4587 let req = Request::builder()
4588 .method("POST")
4589 .uri("/token")
4590 .body(Body::from("{}"))
4591 .unwrap();
4592 let resp = app.oneshot(req).await.unwrap();
4593 assert_eq!(resp.status(), StatusCode::OK);
4594
4595 let h = resp.headers();
4596 assert_eq!(
4597 h.get("pragma").unwrap(),
4598 "no-cache",
4599 "RFC 6749 §5.1: token responses must set Pragma: no-cache"
4600 );
4601 let vary_values: Vec<String> = h
4602 .get_all("vary")
4603 .iter()
4604 .filter_map(|v| v.to_str().ok().map(str::to_owned))
4605 .collect();
4606 assert!(
4607 vary_values
4608 .iter()
4609 .any(|v| v.eq_ignore_ascii_case("Authorization")),
4610 "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
4611 );
4612 }
4613
4614 #[cfg(feature = "oauth")]
4615 #[tokio::test]
4616 async fn oauth_token_cache_headers_preserve_existing_vary() {
4617 let app = axum::Router::new()
4620 .route(
4621 "/token",
4622 axum::routing::post(|| async {
4623 axum::response::Response::builder()
4624 .header("vary", "Accept-Encoding")
4625 .body(axum::body::Body::from("{}"))
4626 .unwrap()
4627 }),
4628 )
4629 .layer(axum::middleware::from_fn(
4630 oauth_token_cache_headers_middleware,
4631 ));
4632 let req = Request::builder()
4633 .method("POST")
4634 .uri("/token")
4635 .body(Body::empty())
4636 .unwrap();
4637 let resp = app.oneshot(req).await.unwrap();
4638
4639 let vary: Vec<String> = resp
4640 .headers()
4641 .get_all("vary")
4642 .iter()
4643 .filter_map(|v| v.to_str().ok().map(str::to_owned))
4644 .collect();
4645 assert!(
4646 vary.iter().any(|v| v.contains("Accept-Encoding")),
4647 "must preserve pre-existing Vary value, got {vary:?}"
4648 );
4649 assert!(
4650 vary.iter().any(|v| v.contains("Authorization")),
4651 "must append Authorization to Vary, got {vary:?}"
4652 );
4653 }
4654
4655 #[test]
4658 fn version_payload_contains_expected_fields() {
4659 let v = version_payload("my-server", "1.2.3");
4660 assert_eq!(v["name"], "my-server");
4661 assert_eq!(v["version"], "1.2.3");
4662 assert!(v["build_git_sha"].is_string());
4663 assert!(v["build_timestamp"].is_string());
4664 assert!(v["rust_version"].is_string());
4665 assert!(v["mcpx_version"].is_string());
4666 }
4667
4668 #[tokio::test]
4671 async fn concurrency_limit_layer_composes_and_serves() {
4672 let app = axum::Router::new()
4676 .route("/ok", axum::routing::get(|| async { "ok" }))
4677 .layer(
4678 tower::ServiceBuilder::new()
4679 .layer(axum::error_handling::HandleErrorLayer::new(
4680 |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
4681 ))
4682 .layer(tower::load_shed::LoadShedLayer::new())
4683 .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
4684 );
4685 let resp = app
4686 .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
4687 .await
4688 .unwrap();
4689 assert_eq!(resp.status(), StatusCode::OK);
4690 }
4691
4692 #[tokio::test]
4695 async fn compression_layer_gzip_encodes_response() {
4696 use tower_http::compression::Predicate as _;
4697
4698 let big_body = "a".repeat(4096);
4699 let app = axum::Router::new()
4700 .route(
4701 "/big",
4702 axum::routing::get(move || {
4703 let body = big_body.clone();
4704 async move { body }
4705 }),
4706 )
4707 .layer(
4708 tower_http::compression::CompressionLayer::new()
4709 .gzip(true)
4710 .br(true)
4711 .compress_when(
4712 tower_http::compression::DefaultPredicate::new()
4713 .and(tower_http::compression::predicate::SizeAbove::new(1024)),
4714 ),
4715 );
4716
4717 let req = Request::builder()
4718 .uri("/big")
4719 .header(header::ACCEPT_ENCODING, "gzip")
4720 .body(Body::empty())
4721 .unwrap();
4722 let resp = app.oneshot(req).await.unwrap();
4723 assert_eq!(resp.status(), StatusCode::OK);
4724 assert_eq!(
4725 resp.headers().get(header::CONTENT_ENCODING).unwrap(),
4726 "gzip"
4727 );
4728 }
4729
4730 #[tokio::test]
4733 async fn tls_handshake_timeout_reaps_idle_connections() {
4734 use tokio::io::AsyncReadExt as _;
4735
4736 let _ = rustls::crypto::ring::default_provider().install_default();
4737
4738 let key = rcgen::KeyPair::generate().expect("generate key");
4740 let cert = rcgen::CertificateParams::new(vec!["localhost".to_owned()])
4741 .expect("cert params")
4742 .self_signed(&key)
4743 .expect("self-signed cert");
4744 let dir = std::env::temp_dir().join(format!(
4745 "rmcp-server-kit-hs-timeout-{}",
4746 std::time::SystemTime::now()
4747 .duration_since(std::time::UNIX_EPOCH)
4748 .expect("clock after epoch")
4749 .as_nanos()
4750 ));
4751 tokio::fs::create_dir_all(&dir).await.expect("temp dir");
4752 let cert_path = dir.join("server.crt");
4753 let key_path = dir.join("server.key");
4754 tokio::fs::write(&cert_path, cert.pem())
4755 .await
4756 .expect("write cert");
4757 tokio::fs::write(&key_path, key.serialize_pem())
4758 .await
4759 .expect("write key");
4760
4761 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
4762 let tls = TlsListener::new(
4763 listener,
4764 &cert_path,
4765 &key_path,
4766 None,
4767 None,
4768 Duration::from_millis(200),
4769 8, )
4771 .expect("tls listener");
4772 let addr = axum::serve::Listener::local_addr(&tls).expect("local addr");
4773
4774 let mut idle = tokio::net::TcpStream::connect(addr).await.expect("connect");
4778 let mut buf = [0_u8; 16];
4779 let read = tokio::time::timeout(Duration::from_secs(2), idle.read(&mut buf))
4780 .await
4781 .expect("server must reap the idle handshake within its timeout");
4782 match read {
4783 Ok(0) | Err(_) => {} Ok(n) => panic!("unexpected {n} bytes from server during reaped handshake"),
4785 }
4786
4787 drop(tls);
4788 }
4789}