1use std::{
2 future::Future,
3 net::{IpAddr, SocketAddr},
4 path::{Path, PathBuf},
5 pin::Pin,
6 sync::Arc,
7 time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{
12 body::Body,
13 extract::{ConnectInfo, Request},
14 middleware::Next,
15 response::IntoResponse,
16};
17use rmcp::{
18 ServerHandler,
19 transport::streamable_http_server::{
20 StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
21 },
22};
23use rustls::RootCertStore;
24use tokio::{
25 net::TcpListener,
26 sync::{Semaphore, mpsc},
27};
28use tokio_util::sync::CancellationToken;
29
30use crate::{
31 auth::{
32 AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
33 build_rate_limiter, extract_mtls_identity,
34 },
35 bounded_limiter::BoundedKeyedLimiter,
36 error::McpxError,
37 mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
38 rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
39};
40
41#[allow(
45 clippy::needless_pass_by_value,
46 reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
47)]
48fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
49 McpxError::Startup(format!("{e:#}"))
50}
51
52#[allow(
58 clippy::needless_pass_by_value,
59 reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
60)]
61fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
62 McpxError::Startup(format!("{op}: {e}"))
63}
64
65pub type ReadinessCheck =
70 Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
71
72#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
117#[non_exhaustive]
118pub struct PeerAddr {
119 pub addr: SocketAddr,
121}
122
123impl PeerAddr {
124 #[must_use]
127 pub(crate) const fn new(addr: SocketAddr) -> Self {
128 Self { addr }
129 }
130}
131
132impl<S: Send + Sync> axum::extract::FromRequestParts<S> for PeerAddr {
141 type Rejection = (axum::http::StatusCode, &'static str);
142
143 async fn from_request_parts(
144 parts: &mut axum::http::request::Parts,
145 _state: &S,
146 ) -> Result<Self, Self::Rejection> {
147 parts.extensions.get::<Self>().copied().ok_or((
148 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
149 "peer address unavailable: not running under rmcp-server-kit serve()",
150 ))
151 }
152}
153
154#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
177#[non_exhaustive]
178pub struct ClientIp {
179 pub ip: IpAddr,
181}
182
183impl ClientIp {
184 #[must_use]
187 pub(crate) const fn new(ip: IpAddr) -> Self {
188 Self { ip }
189 }
190}
191
192#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Deserialize)]
197#[serde(rename_all = "kebab-case")]
198#[non_exhaustive]
199pub enum ForwardedHeaderMode {
200 XForwardedFor,
202 Forwarded,
204}
205
206struct ForwardResolver {
209 trusted: Vec<ipnet::IpNet>,
210 mode: ForwardedHeaderMode,
211}
212
213#[derive(Debug, Clone, Default)]
234#[non_exhaustive]
235pub struct SecurityHeadersConfig {
236 pub x_content_type_options: Option<String>,
238 pub x_frame_options: Option<String>,
240 pub cache_control: Option<String>,
242 pub referrer_policy: Option<String>,
244 pub cross_origin_opener_policy: Option<String>,
246 pub cross_origin_resource_policy: Option<String>,
248 pub cross_origin_embedder_policy: Option<String>,
250 pub permissions_policy: Option<String>,
253 pub x_permitted_cross_domain_policies: Option<String>,
255 pub content_security_policy: Option<String>,
258 pub x_dns_prefetch_control: Option<String>,
260 pub strict_transport_security: Option<String>,
265}
266
267#[allow(
269 missing_debug_implementations,
270 reason = "contains callback/trait objects that don't impl Debug"
271)]
272#[allow(
273 clippy::struct_excessive_bools,
274 reason = "server configuration naturally has many boolean feature flags"
275)]
276#[non_exhaustive]
277pub struct McpServerConfig {
278 #[deprecated(
280 since = "0.13.0",
281 note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
282 )]
283 pub bind_addr: String,
284 #[deprecated(
286 since = "0.13.0",
287 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
288 )]
289 pub name: String,
290 #[deprecated(
292 since = "0.13.0",
293 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
294 )]
295 pub version: String,
296 #[deprecated(
298 since = "0.13.0",
299 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
300 )]
301 pub tls_cert_path: Option<PathBuf>,
302 #[deprecated(
304 since = "0.13.0",
305 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
306 )]
307 pub tls_key_path: Option<PathBuf>,
308 #[deprecated(
311 since = "0.13.0",
312 note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
313 )]
314 pub auth: Option<AuthConfig>,
315 #[deprecated(
318 since = "0.13.0",
319 note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
320 )]
321 pub rbac: Option<Arc<RbacPolicy>>,
322 #[deprecated(
328 since = "0.13.0",
329 note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
330 )]
331 pub allowed_origins: Vec<String>,
332 #[deprecated(
335 since = "0.13.0",
336 note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
337 )]
338 pub tool_rate_limit: Option<u32>,
339 #[deprecated(
345 since = "1.12.0",
346 note = "use McpServerConfig::with_tool_rate_limit_burst(); direct field access will become pub(crate) in a future major release"
347 )]
348 pub tool_rate_limit_burst: Option<u32>,
349 #[deprecated(
362 since = "1.11.0",
363 note = "use McpServerConfig::with_extra_route_rate_limit(); direct field access will become pub(crate) in a future major release"
364 )]
365 pub extra_route_rate_limit: Option<u32>,
366 #[deprecated(
373 since = "1.12.0",
374 note = "use McpServerConfig::with_extra_route_rate_limit_burst(); direct field access will become pub(crate) in a future major release"
375 )]
376 pub extra_route_rate_limit_burst: Option<u32>,
377 #[deprecated(
390 since = "1.14.0",
391 note = "use McpServerConfig::with_extra_route_rate_limit_exempt_paths(); direct field access will become pub(crate) in a future major release"
392 )]
393 pub extra_route_rate_limit_exempt_paths: Vec<String>,
394 #[deprecated(
402 since = "1.13.0",
403 note = "use McpServerConfig::with_trusted_proxies(); direct field access will become pub(crate) in a future major release"
404 )]
405 pub trusted_proxies: Vec<String>,
406 #[deprecated(
411 since = "1.13.0",
412 note = "use McpServerConfig::with_forwarded_header(); direct field access will become pub(crate) in a future major release"
413 )]
414 pub forwarded_header: Option<ForwardedHeaderMode>,
415 #[deprecated(
418 since = "0.13.0",
419 note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
420 )]
421 pub readiness_check: Option<ReadinessCheck>,
422 #[deprecated(
425 since = "0.13.0",
426 note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
427 )]
428 pub max_request_body: usize,
429 #[deprecated(
432 since = "0.13.0",
433 note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
434 )]
435 pub request_timeout: Duration,
436 #[deprecated(
439 since = "0.13.0",
440 note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
441 )]
442 pub shutdown_timeout: Duration,
443 #[deprecated(
446 since = "0.13.0",
447 note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
448 )]
449 pub session_idle_timeout: Duration,
450 #[deprecated(
453 since = "0.13.0",
454 note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
455 )]
456 pub sse_keep_alive: Duration,
457 #[deprecated(
461 since = "0.13.0",
462 note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
463 )]
464 pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
465 #[deprecated(
472 since = "0.13.0",
473 note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
474 )]
475 pub extra_router: Option<axum::Router>,
476 #[deprecated(
481 since = "0.13.0",
482 note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
483 )]
484 pub public_url: Option<String>,
485 #[deprecated(
488 since = "0.13.0",
489 note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
490 )]
491 pub log_request_headers: bool,
492 #[deprecated(
495 since = "0.13.0",
496 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
497 )]
498 pub compression_enabled: bool,
499 #[deprecated(
502 since = "0.13.0",
503 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
504 )]
505 pub compression_min_size: u16,
506 #[deprecated(
510 since = "0.13.0",
511 note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
512 )]
513 pub max_concurrent_requests: Option<usize>,
514 #[deprecated(
517 since = "0.13.0",
518 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
519 )]
520 pub admin_enabled: bool,
521 #[deprecated(
523 since = "0.13.0",
524 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
525 )]
526 pub admin_role: String,
527 #[cfg(feature = "metrics")]
530 #[deprecated(
531 since = "0.13.0",
532 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
533 )]
534 pub metrics_enabled: bool,
535 #[cfg(feature = "metrics")]
537 #[deprecated(
538 since = "0.13.0",
539 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
540 )]
541 pub metrics_bind: String,
542 #[deprecated(
546 since = "1.5.0",
547 note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
548 )]
549 pub security_headers: SecurityHeadersConfig,
550 #[deprecated(
556 since = "1.9.0",
557 note = "use McpServerConfig::with_tls_handshake_timeout(); direct field access will become pub(crate) in a future major release"
558 )]
559 pub tls_handshake_timeout: Duration,
560 #[deprecated(
567 since = "1.9.0",
568 note = "use McpServerConfig::with_max_concurrent_tls_handshakes(); direct field access will become pub(crate) in a future major release"
569 )]
570 pub max_concurrent_tls_handshakes: usize,
571}
572
573#[allow(
631 missing_debug_implementations,
632 reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
633)]
634pub struct Validated<T>(T);
635
636impl<T> std::fmt::Debug for Validated<T> {
637 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
638 f.debug_struct("Validated").finish_non_exhaustive()
639 }
640}
641
642impl<T> Validated<T> {
643 #[must_use]
645 pub fn as_inner(&self) -> &T {
646 &self.0
647 }
648
649 #[must_use]
654 pub fn into_inner(self) -> T {
655 self.0
656 }
657}
658
659#[allow(
660 deprecated,
661 reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
662)]
663impl McpServerConfig {
664 #[must_use]
672 pub fn new(
673 bind_addr: impl Into<String>,
674 name: impl Into<String>,
675 version: impl Into<String>,
676 ) -> Self {
677 Self {
678 bind_addr: bind_addr.into(),
679 name: name.into(),
680 version: version.into(),
681 tls_cert_path: None,
682 tls_key_path: None,
683 auth: None,
684 rbac: None,
685 allowed_origins: Vec::new(),
686 tool_rate_limit: None,
687 readiness_check: None,
688 max_request_body: 1024 * 1024,
689 request_timeout: Duration::from_mins(2),
690 shutdown_timeout: Duration::from_secs(30),
691 session_idle_timeout: Duration::from_mins(20),
692 sse_keep_alive: Duration::from_secs(15),
693 on_reload_ready: None,
694 extra_router: None,
695 public_url: None,
696 log_request_headers: false,
697 compression_enabled: false,
698 compression_min_size: 1024,
699 max_concurrent_requests: None,
700 admin_enabled: false,
701 admin_role: "admin".to_owned(),
702 #[cfg(feature = "metrics")]
703 metrics_enabled: false,
704 #[cfg(feature = "metrics")]
705 metrics_bind: "127.0.0.1:9090".into(),
706 security_headers: SecurityHeadersConfig::default(),
707 tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
708 max_concurrent_tls_handshakes: DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES,
709 extra_route_rate_limit: None,
710 tool_rate_limit_burst: None,
711 extra_route_rate_limit_burst: None,
712 extra_route_rate_limit_exempt_paths: Vec::new(),
713 trusted_proxies: Vec::new(),
714 forwarded_header: None,
715 }
716 }
717
718 #[must_use]
728 pub fn with_auth(mut self, auth: AuthConfig) -> Self {
729 self.auth = Some(auth);
730 self
731 }
732
733 #[must_use]
738 pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
739 self.security_headers = headers;
740 self
741 }
742
743 #[must_use]
747 pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
748 self.bind_addr = addr.into();
749 self
750 }
751
752 #[must_use]
755 pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
756 self.rbac = Some(rbac);
757 self
758 }
759
760 #[must_use]
764 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
765 self.tls_cert_path = Some(cert_path.into());
766 self.tls_key_path = Some(key_path.into());
767 self
768 }
769
770 #[must_use]
774 pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
775 self.public_url = Some(url.into());
776 self
777 }
778
779 #[must_use]
783 pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
784 where
785 I: IntoIterator<Item = S>,
786 S: Into<String>,
787 {
788 self.allowed_origins = origins.into_iter().map(Into::into).collect();
789 self
790 }
791
792 #[must_use]
805 pub fn with_extra_router(mut self, router: axum::Router) -> Self {
806 self.extra_router = Some(router);
807 self
808 }
809
810 #[must_use]
813 pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
814 self.readiness_check = Some(check);
815 self
816 }
817
818 #[must_use]
821 pub fn with_max_request_body(mut self, bytes: usize) -> Self {
822 self.max_request_body = bytes;
823 self
824 }
825
826 #[must_use]
828 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
829 self.request_timeout = timeout;
830 self
831 }
832
833 #[must_use]
835 pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
836 self.shutdown_timeout = timeout;
837 self
838 }
839
840 #[must_use]
842 pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
843 self.session_idle_timeout = timeout;
844 self
845 }
846
847 #[must_use]
849 pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
850 self.sse_keep_alive = interval;
851 self
852 }
853
854 #[must_use]
858 pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
859 self.max_concurrent_requests = Some(limit);
860 self
861 }
862
863 #[must_use]
871 pub fn with_tls_handshake_timeout(mut self, timeout: Duration) -> Self {
872 self.tls_handshake_timeout = timeout;
873 self
874 }
875
876 #[must_use]
885 pub fn with_max_concurrent_tls_handshakes(mut self, limit: usize) -> Self {
886 self.max_concurrent_tls_handshakes = limit;
887 self
888 }
889
890 #[must_use]
893 pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
894 self.tool_rate_limit = Some(per_minute);
895 self
896 }
897
898 #[must_use]
909 pub fn with_extra_route_rate_limit(mut self, per_minute: u32) -> Self {
910 self.extra_route_rate_limit = Some(per_minute);
911 self
912 }
913
914 #[must_use]
919 pub fn with_tool_rate_limit_burst(mut self, burst: u32) -> Self {
920 self.tool_rate_limit_burst = Some(burst);
921 self
922 }
923
924 #[must_use]
930 pub fn with_extra_route_rate_limit_burst(mut self, burst: u32) -> Self {
931 self.extra_route_rate_limit_burst = Some(burst);
932 self
933 }
934
935 #[must_use]
955 pub fn with_extra_route_rate_limit_exempt_paths<I, S>(mut self, paths: I) -> Self
956 where
957 I: IntoIterator<Item = S>,
958 S: Into<String>,
959 {
960 self.extra_route_rate_limit_exempt_paths = paths.into_iter().map(Into::into).collect();
961 self
962 }
963
964 #[must_use]
976 pub fn with_trusted_proxies<I, S>(mut self, proxies: I) -> Self
977 where
978 I: IntoIterator<Item = S>,
979 S: Into<String>,
980 {
981 self.trusted_proxies = proxies.into_iter().map(Into::into).collect();
982 self
983 }
984
985 #[must_use]
990 pub fn with_forwarded_header(mut self, mode: ForwardedHeaderMode) -> Self {
991 self.forwarded_header = Some(mode);
992 self
993 }
994
995 #[must_use]
999 pub fn with_reload_callback<F>(mut self, callback: F) -> Self
1000 where
1001 F: FnOnce(ReloadHandle) + Send + 'static,
1002 {
1003 self.on_reload_ready = Some(Box::new(callback));
1004 self
1005 }
1006
1007 #[must_use]
1011 pub fn enable_compression(mut self, min_size: u16) -> Self {
1012 self.compression_enabled = true;
1013 self.compression_min_size = min_size;
1014 self
1015 }
1016
1017 #[must_use]
1022 pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
1023 self.admin_enabled = true;
1024 self.admin_role = role.into();
1025 self
1026 }
1027
1028 #[must_use]
1031 pub fn enable_request_header_logging(mut self) -> Self {
1032 self.log_request_headers = true;
1033 self
1034 }
1035
1036 #[cfg(feature = "metrics")]
1039 #[must_use]
1040 pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
1041 self.metrics_enabled = true;
1042 self.metrics_bind = bind.into();
1043 self
1044 }
1045
1046 pub fn validate(self) -> Result<Validated<Self>, McpxError> {
1079 self.check()?;
1080 Ok(Validated(self))
1081 }
1082
1083 fn check_burst_knobs(&self) -> Result<(), McpxError> {
1090 if self.tool_rate_limit_burst == Some(0) {
1091 return Err(McpxError::Config(
1092 "tool_rate_limit_burst must be greater than zero".into(),
1093 ));
1094 }
1095 if self.extra_route_rate_limit_burst == Some(0) {
1096 return Err(McpxError::Config(
1097 "extra_route_rate_limit_burst must be greater than zero".into(),
1098 ));
1099 }
1100 if self.tool_rate_limit_burst.is_some() && self.tool_rate_limit.is_none() {
1101 return Err(McpxError::Config(
1102 "tool_rate_limit_burst requires tool_rate_limit to be set".into(),
1103 ));
1104 }
1105 if self.extra_route_rate_limit_burst.is_some() && self.extra_route_rate_limit.is_none() {
1106 return Err(McpxError::Config(
1107 "extra_route_rate_limit_burst requires extra_route_rate_limit to be set".into(),
1108 ));
1109 }
1110 if !self.extra_route_rate_limit_exempt_paths.is_empty()
1111 && self.extra_route_rate_limit.is_none()
1112 {
1113 return Err(McpxError::Config(
1114 "extra_route_rate_limit_exempt_paths requires extra_route_rate_limit to be set"
1115 .into(),
1116 ));
1117 }
1118 for path in &self.extra_route_rate_limit_exempt_paths {
1119 if path.is_empty() || !path.starts_with('/') {
1120 return Err(McpxError::Config(format!(
1121 "extra_route_rate_limit_exempt_paths entries must be non-empty and start with '/': {path:?}"
1122 )));
1123 }
1124 }
1125 if let Some(rl) = self.auth.as_ref().and_then(|a| a.rate_limit.as_ref()) {
1126 if rl.burst == Some(0) {
1127 return Err(McpxError::Config(
1128 "auth rate_limit.burst must be greater than zero".into(),
1129 ));
1130 }
1131 if rl.pre_auth_burst == Some(0) {
1132 return Err(McpxError::Config(
1133 "auth rate_limit.pre_auth_burst must be greater than zero".into(),
1134 ));
1135 }
1136 }
1137 Ok(())
1138 }
1139
1140 fn check_trusted_forwarder(&self) -> Result<(), McpxError> {
1145 for entry in &self.trusted_proxies {
1146 if parse_proxy_net(entry).is_none() {
1147 return Err(McpxError::Config(format!(
1148 "trusted_proxies entry {entry:?} is neither a CIDR nor an IP address"
1149 )));
1150 }
1151 }
1152 if self.forwarded_header.is_some() && self.trusted_proxies.is_empty() {
1153 return Err(McpxError::Config(
1154 "forwarded_header requires trusted_proxies to be nonempty".into(),
1155 ));
1156 }
1157 Ok(())
1158 }
1159
1160 fn check(&self) -> Result<(), McpxError> {
1164 if self.admin_enabled {
1168 let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
1169 if !auth_enabled {
1170 return Err(McpxError::Config(
1171 "admin_enabled=true requires auth to be configured and enabled".into(),
1172 ));
1173 }
1174 }
1175
1176 match (&self.tls_cert_path, &self.tls_key_path) {
1178 (Some(_), None) => {
1179 return Err(McpxError::Config(
1180 "tls_cert_path is set but tls_key_path is missing".into(),
1181 ));
1182 }
1183 (None, Some(_)) => {
1184 return Err(McpxError::Config(
1185 "tls_key_path is set but tls_cert_path is missing".into(),
1186 ));
1187 }
1188 _ => {}
1189 }
1190
1191 if self.bind_addr.parse::<SocketAddr>().is_err() {
1193 return Err(McpxError::Config(format!(
1194 "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
1195 self.bind_addr
1196 )));
1197 }
1198
1199 if let Some(ref url) = self.public_url
1201 && !(url.starts_with("http://") || url.starts_with("https://"))
1202 {
1203 return Err(McpxError::Config(format!(
1204 "public_url {url:?} must start with http:// or https://"
1205 )));
1206 }
1207
1208 for origin in &self.allowed_origins {
1210 if !(origin.starts_with("http://") || origin.starts_with("https://")) {
1211 return Err(McpxError::Config(format!(
1212 "allowed_origins entry {origin:?} must start with http:// or https://"
1213 )));
1214 }
1215 }
1216
1217 if self.max_request_body == 0 {
1219 return Err(McpxError::Config(
1220 "max_request_body must be greater than zero".into(),
1221 ));
1222 }
1223
1224 if self.extra_route_rate_limit == Some(0) {
1228 return Err(McpxError::Config(
1229 "extra_route_rate_limit must be greater than zero".into(),
1230 ));
1231 }
1232
1233 self.check_burst_knobs()?;
1235
1236 self.check_trusted_forwarder()?;
1238
1239 #[cfg(feature = "oauth")]
1241 if let Some(auth_cfg) = &self.auth
1242 && let Some(oauth_cfg) = &auth_cfg.oauth
1243 {
1244 oauth_cfg.validate()?;
1245 }
1246
1247 validate_security_headers(&self.security_headers)?;
1250
1251 if self.max_concurrent_requests == Some(0) {
1255 return Err(McpxError::Config(
1256 "max_concurrent_requests must be greater than zero when set".into(),
1257 ));
1258 }
1259
1260 if let Some(auth_cfg) = &self.auth
1264 && let Some(rl) = &auth_cfg.rate_limit
1265 && rl.max_tracked_keys == 0
1266 {
1267 return Err(McpxError::Config(
1268 "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
1269 ));
1270 }
1271
1272 if self.tls_handshake_timeout == Duration::ZERO {
1277 return Err(McpxError::Config(
1278 "tls_handshake_timeout must be greater than zero".into(),
1279 ));
1280 }
1281
1282 if self.max_concurrent_tls_handshakes == 0 {
1287 return Err(McpxError::Config(
1288 "max_concurrent_tls_handshakes must be greater than zero".into(),
1289 ));
1290 }
1291
1292 Ok(())
1293 }
1294}
1295
1296#[allow(
1302 missing_debug_implementations,
1303 reason = "contains Arc<AuthState> with non-Debug fields"
1304)]
1305pub struct ReloadHandle {
1306 auth: Option<Arc<AuthState>>,
1307 rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
1308 crl_set: Option<Arc<CrlSet>>,
1309}
1310
1311impl ReloadHandle {
1312 pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
1314 if let Some(ref auth) = self.auth {
1315 auth.reload_keys(keys);
1316 }
1317 }
1318
1319 pub fn reload_rbac(&self, policy: RbacPolicy) {
1321 if let Some(ref rbac) = self.rbac {
1322 rbac.store(Arc::new(policy));
1323 tracing::info!("RBAC policy reloaded");
1324 }
1325 }
1326
1327 pub async fn refresh_crls(&self) -> Result<(), McpxError> {
1333 let Some(ref crl_set) = self.crl_set else {
1334 return Err(McpxError::Config(
1335 "CRL refresh requested but mTLS CRL support is not configured".into(),
1336 ));
1337 };
1338
1339 crl_set.force_refresh().await
1340 }
1341}
1342
1343#[allow(
1360 clippy::too_many_lines,
1361 clippy::cognitive_complexity,
1362 reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
1363)]
1364struct AppRunParams {
1368 tls_paths: Option<(PathBuf, PathBuf)>,
1370 tls_handshake_timeout: Duration,
1372 max_concurrent_tls_handshakes: usize,
1374 mtls_config: Option<MtlsConfig>,
1376 shutdown_timeout: Duration,
1378 auth_state: Option<Arc<AuthState>>,
1380 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1382 on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1384 ct: CancellationToken,
1388 scheme: &'static str,
1390 name: String,
1392}
1393
1394#[allow(
1404 clippy::cognitive_complexity,
1405 reason = "router assembly is intrinsically sequential; splitting harms readability"
1406)]
1407#[allow(
1408 deprecated,
1409 reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
1410)]
1411fn build_app_router<H, F>(
1412 mut config: McpServerConfig,
1413 handler_factory: F,
1414) -> anyhow::Result<(axum::Router, AppRunParams)>
1415where
1416 H: ServerHandler + 'static,
1417 F: Fn() -> H + Send + Sync + Clone + 'static,
1418{
1419 let ct = CancellationToken::new();
1420
1421 let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
1422 tracing::info!(allowed_hosts = %allowed_hosts.join(", "), "configured Streamable HTTP allowed hosts");
1423
1424 let mcp_service = StreamableHttpService::new(
1425 move || Ok(handler_factory()),
1426 {
1427 let mut mgr = LocalSessionManager::default();
1428 mgr.session_config.keep_alive = Some(config.session_idle_timeout);
1429 mgr.into()
1430 },
1431 StreamableHttpServerConfig::default()
1432 .with_allowed_hosts(allowed_hosts)
1433 .with_sse_keep_alive(Some(config.sse_keep_alive))
1434 .with_cancellation_token(ct.child_token()),
1435 );
1436
1437 let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
1439
1440 let auth_state: Option<Arc<AuthState>> = match config.auth {
1444 Some(ref auth_config) if auth_config.enabled => {
1445 let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
1446 let pre_auth_limiter = auth_config
1447 .rate_limit
1448 .as_ref()
1449 .map(crate::auth::build_pre_auth_limiter);
1450
1451 #[cfg(feature = "oauth")]
1452 let jwks_cache = auth_config
1453 .oauth
1454 .as_ref()
1455 .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
1456 .transpose()
1457 .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
1458
1459 Some(Arc::new(AuthState {
1460 api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
1461 rate_limiter,
1462 pre_auth_limiter,
1463 #[cfg(feature = "oauth")]
1464 jwks_cache,
1465 seen_identities: crate::auth::SeenIdentitySet::new(),
1466 counters: crate::auth::AuthCounters::default(),
1467 }))
1468 }
1469 _ => None,
1470 };
1471
1472 let rbac_swap = Arc::new(ArcSwap::new(
1475 config
1476 .rbac
1477 .clone()
1478 .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
1479 ));
1480
1481 if config.admin_enabled {
1484 let Some(ref auth_state_ref) = auth_state else {
1485 return Err(anyhow::anyhow!(
1486 "admin_enabled=true requires auth to be configured and enabled"
1487 ));
1488 };
1489 let admin_state = crate::admin::AdminState {
1490 started_at: std::time::Instant::now(),
1491 name: config.name.clone(),
1492 version: config.version.clone(),
1493 auth: Some(Arc::clone(auth_state_ref)),
1494 rbac: Arc::clone(&rbac_swap),
1495 };
1496 let admin_cfg = crate::admin::AdminConfig {
1497 role: config.admin_role.clone(),
1498 };
1499 mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
1500 tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
1501 }
1502
1503 {
1536 let tool_limiter: Option<Arc<ToolRateLimiter>> = config
1537 .tool_rate_limit
1538 .map(|per_minute| build_tool_rate_limiter(per_minute, config.tool_rate_limit_burst));
1539
1540 if rbac_swap.load().is_enabled() {
1541 tracing::info!("RBAC enforcement enabled on /mcp");
1542 }
1543 if let Some(limit) = config.tool_rate_limit {
1544 tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1545 }
1546
1547 let rbac_for_mw = Arc::clone(&rbac_swap);
1548 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1549 let p = rbac_for_mw.load_full();
1550 let tl = tool_limiter.clone();
1551 rbac_middleware(p, tl, req, next)
1552 }));
1553 }
1554
1555 if let Some(ref auth_config) = config.auth
1557 && auth_config.enabled
1558 {
1559 let Some(ref state) = auth_state else {
1560 return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1561 };
1562
1563 let methods: Vec<&str> = [
1564 auth_config.mtls.is_some().then_some("mTLS"),
1565 (!auth_config.api_keys.is_empty()).then_some("bearer"),
1566 #[cfg(feature = "oauth")]
1567 auth_config.oauth.is_some().then_some("oauth-jwt"),
1568 ]
1569 .into_iter()
1570 .flatten()
1571 .collect();
1572
1573 tracing::info!(
1574 methods = %methods.join(", "),
1575 api_keys = auth_config.api_keys.len(),
1576 "auth enabled on /mcp"
1577 );
1578
1579 let state_for_mw = Arc::clone(state);
1580 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1581 let s = Arc::clone(&state_for_mw);
1582 auth_middleware(s, req, next)
1583 }));
1584 }
1585
1586 mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1589 axum::http::StatusCode::REQUEST_TIMEOUT,
1590 config.request_timeout,
1591 ));
1592
1593 mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1597 config.max_request_body,
1598 ));
1599
1600 let mut effective_origins = config.allowed_origins.clone();
1607 if effective_origins.is_empty()
1608 && let Some(ref url) = config.public_url
1609 {
1610 if let Some(scheme_end) = url.find("://") {
1615 let scheme_with_sep = url.get(..scheme_end + 3).unwrap_or_default();
1616 let after_scheme = url.get(scheme_end + 3..).unwrap_or_default();
1617 let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1618 let host = after_scheme.get(..host_end).unwrap_or_default();
1619 let origin = format!("{scheme_with_sep}{host}");
1620 tracing::info!(
1621 %origin,
1622 "auto-derived allowed origin from public_url"
1623 );
1624 effective_origins.push(origin);
1625 }
1626 }
1627 let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1628 let cors_origins = Arc::clone(&allowed_origins);
1629 let log_request_headers = config.log_request_headers;
1630
1631 let readyz_route = if let Some(check) = config.readiness_check.take() {
1632 axum::routing::get(move || readyz(Arc::clone(&check)))
1633 } else {
1634 axum::routing::get(healthz)
1635 };
1636
1637 #[allow(unused_mut)] let mut router = axum::Router::new()
1639 .route("/healthz", axum::routing::get(healthz))
1640 .route("/readyz", readyz_route)
1641 .route(
1642 "/version",
1643 axum::routing::get({
1644 let payload_bytes: Arc<[u8]> =
1649 serialize_version_payload(&config.name, &config.version);
1650 move || {
1651 let p = Arc::clone(&payload_bytes);
1652 async move {
1653 (
1654 [(axum::http::header::CONTENT_TYPE, "application/json")],
1655 p.to_vec(),
1656 )
1657 }
1658 }
1659 }),
1660 )
1661 .merge(mcp_router);
1662
1663 if let Some(extra) = config.extra_router.take() {
1670 let extra = match config.extra_route_rate_limit {
1671 Some(per_minute) => {
1672 let limiter =
1673 build_extra_route_rate_limiter(per_minute, config.extra_route_rate_limit_burst);
1674 let exempt: Arc<std::collections::HashSet<String>> = Arc::new(
1675 config
1676 .extra_route_rate_limit_exempt_paths
1677 .iter()
1678 .cloned()
1679 .collect(),
1680 );
1681 tracing::info!(
1682 per_minute,
1683 exempt_paths = exempt.len(),
1684 "extra-route per-IP rate limit enabled"
1685 );
1686 extra.layer(axum::middleware::from_fn(move |req, next| {
1687 let l = Arc::clone(&limiter);
1688 let e = Arc::clone(&exempt);
1689 extra_route_rate_limit_middleware(l, e, req, next)
1690 }))
1691 }
1692 None => extra,
1693 };
1694 router = router.merge(extra);
1695 }
1696
1697 let server_url = if let Some(ref url) = config.public_url {
1704 url.trim_end_matches('/').to_owned()
1705 } else {
1706 let prm_scheme = if config.tls_cert_path.is_some() {
1707 "https"
1708 } else {
1709 "http"
1710 };
1711 format!("{prm_scheme}://{}", config.bind_addr)
1712 };
1713 let resource_url = format!("{server_url}/mcp");
1714
1715 #[cfg(feature = "oauth")]
1716 let prm_metadata = if let Some(ref auth_config) = config.auth
1717 && let Some(ref oauth_config) = auth_config.oauth
1718 {
1719 crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1720 } else {
1721 serde_json::json!({ "resource": resource_url })
1722 };
1723 #[cfg(not(feature = "oauth"))]
1724 let prm_metadata = serde_json::json!({ "resource": resource_url });
1725
1726 router = router.route(
1727 "/.well-known/oauth-protected-resource",
1728 axum::routing::get(move || {
1729 let m = prm_metadata.clone();
1730 async move { axum::Json(m) }
1731 }),
1732 );
1733
1734 #[cfg(feature = "oauth")]
1739 if let Some(ref auth_config) = config.auth
1740 && let Some(ref oauth_config) = auth_config.oauth
1741 && oauth_config.proxy.is_some()
1742 {
1743 router = install_oauth_proxy_routes(
1744 router,
1745 &server_url,
1746 oauth_config,
1747 auth_state.as_ref(),
1748 config.max_request_body,
1749 )?;
1750 }
1751
1752 let is_tls = config.tls_cert_path.is_some();
1755 let security_headers_cfg = Arc::new(config.security_headers.clone());
1756 router = router.layer(axum::middleware::from_fn(move |req, next| {
1757 let cfg = Arc::clone(&security_headers_cfg);
1758 security_headers_middleware(is_tls, cfg, req, next)
1759 }));
1760
1761 if !cors_origins.is_empty() {
1765 let cors = tower_http::cors::CorsLayer::new()
1766 .allow_origin(
1767 cors_origins
1768 .iter()
1769 .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1770 .collect::<Vec<_>>(),
1771 )
1772 .allow_methods([
1773 axum::http::Method::GET,
1774 axum::http::Method::POST,
1775 axum::http::Method::OPTIONS,
1776 ])
1777 .allow_headers([
1778 axum::http::header::CONTENT_TYPE,
1779 axum::http::header::AUTHORIZATION,
1780 ]);
1781 router = router.layer(cors);
1782 }
1783
1784 if config.compression_enabled {
1788 use tower_http::compression::Predicate as _;
1789 let predicate = tower_http::compression::DefaultPredicate::new().and(
1790 tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1791 );
1792 router = router.layer(
1793 tower_http::compression::CompressionLayer::new()
1794 .gzip(true)
1795 .br(true)
1796 .compress_when(predicate),
1797 );
1798 tracing::info!(
1799 min_size = config.compression_min_size,
1800 "response compression enabled (gzip, br)"
1801 );
1802 }
1803
1804 if let Some(max) = config.max_concurrent_requests {
1807 let overload_handler = tower::ServiceBuilder::new()
1808 .layer(axum::error_handling::HandleErrorLayer::new(
1809 |_err: tower::BoxError| async {
1810 (
1811 axum::http::StatusCode::SERVICE_UNAVAILABLE,
1812 axum::Json(serde_json::json!({
1813 "error": "overloaded",
1814 "error_description": "server is at capacity, retry later"
1815 })),
1816 )
1817 },
1818 ))
1819 .layer(tower::load_shed::LoadShedLayer::new())
1820 .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1821 router = router.layer(overload_handler);
1822 tracing::info!(max, "global concurrency limit enabled");
1823 }
1824
1825 router = router.fallback(|| async {
1829 (
1830 axum::http::StatusCode::NOT_FOUND,
1831 axum::Json(serde_json::json!({
1832 "error": "not_found",
1833 "error_description": "The requested endpoint does not exist"
1834 })),
1835 )
1836 });
1837
1838 #[cfg(feature = "metrics")]
1840 if config.metrics_enabled {
1841 let metrics = Arc::new(
1842 crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1843 );
1844 let m = Arc::clone(&metrics);
1845 router = router.layer(axum::middleware::from_fn(
1846 move |req: Request<Body>, next: Next| {
1847 let m = Arc::clone(&m);
1848 metrics_middleware(m, req, next)
1849 },
1850 ));
1851 let metrics_bind = config.metrics_bind.clone();
1852 let metrics_shutdown = ct.clone();
1853 tokio::spawn(async move {
1854 if let Err(e) =
1855 crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1856 {
1857 tracing::error!("metrics listener failed: {e}");
1858 }
1859 });
1860 }
1861
1862 let forward_resolver: Option<Arc<ForwardResolver>> = if config.trusted_proxies.is_empty() {
1870 None
1871 } else {
1872 Some(Arc::new(ForwardResolver {
1875 trusted: config
1876 .trusted_proxies
1877 .iter()
1878 .filter_map(|entry| parse_proxy_net(entry))
1879 .collect(),
1880 mode: config
1881 .forwarded_header
1882 .unwrap_or(ForwardedHeaderMode::XForwardedFor),
1883 }))
1884 };
1885 if forward_resolver.is_some() {
1886 tracing::info!(
1887 proxies = config.trusted_proxies.len(),
1888 "trusted-forwarder mode enabled: limiters key by resolved client IP"
1889 );
1890 }
1891 router = router.layer(axum::middleware::from_fn(move |req, next| {
1892 let r = forward_resolver.clone();
1893 normalize_peer_addr_middleware(r, req, next)
1894 }));
1895
1896 router = router.layer(axum::middleware::from_fn(move |req, next| {
1907 let origins = Arc::clone(&allowed_origins);
1908 origin_check_middleware(origins, log_request_headers, req, next)
1909 }));
1910
1911 let scheme = if config.tls_cert_path.is_some() {
1912 "https"
1913 } else {
1914 "http"
1915 };
1916
1917 let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1918 (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1919 _ => None,
1920 };
1921 let tls_handshake_timeout = config.tls_handshake_timeout;
1922 let max_concurrent_tls_handshakes = config.max_concurrent_tls_handshakes;
1923 let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1924
1925 Ok((
1926 router,
1927 AppRunParams {
1928 tls_paths,
1929 tls_handshake_timeout,
1930 max_concurrent_tls_handshakes,
1931 mtls_config,
1932 shutdown_timeout: config.shutdown_timeout,
1933 auth_state,
1934 rbac_swap,
1935 on_reload_ready: config.on_reload_ready.take(),
1936 ct,
1937 scheme,
1938 name: config.name.clone(),
1939 },
1940 ))
1941}
1942
1943pub async fn serve<H, F>(
1960 config: Validated<McpServerConfig>,
1961 handler_factory: F,
1962) -> Result<(), McpxError>
1963where
1964 H: ServerHandler + 'static,
1965 F: Fn() -> H + Send + Sync + Clone + 'static,
1966{
1967 let config = config.into_inner();
1968 #[allow(
1969 deprecated,
1970 reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1971 )]
1972 let bind_addr = config.bind_addr.clone();
1973 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1974
1975 let listener = TcpListener::bind(&bind_addr)
1976 .await
1977 .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1978 log_listening(¶ms.name, params.scheme, &bind_addr);
1979
1980 run_server(
1981 router,
1982 listener,
1983 params.tls_paths,
1984 params.tls_handshake_timeout,
1985 params.max_concurrent_tls_handshakes,
1986 params.mtls_config,
1987 params.shutdown_timeout,
1988 params.auth_state,
1989 params.rbac_swap,
1990 params.on_reload_ready,
1991 params.ct,
1992 )
1993 .await
1994 .map_err(anyhow_to_startup)
1995}
1996
1997pub async fn serve_with_listener<H, F>(
2027 listener: TcpListener,
2028 config: Validated<McpServerConfig>,
2029 handler_factory: F,
2030 ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
2031 shutdown: Option<CancellationToken>,
2032) -> Result<(), McpxError>
2033where
2034 H: ServerHandler + 'static,
2035 F: Fn() -> H + Send + Sync + Clone + 'static,
2036{
2037 let config = config.into_inner();
2038 let local_addr = listener
2039 .local_addr()
2040 .map_err(|e| io_to_startup("listener.local_addr", e))?;
2041 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
2042
2043 log_listening(¶ms.name, params.scheme, &local_addr.to_string());
2044
2045 if let Some(external) = shutdown {
2049 let internal = params.ct.clone();
2050 tokio::spawn(async move {
2051 external.cancelled().await;
2052 internal.cancel();
2053 });
2054 }
2055
2056 if let Some(tx) = ready_tx {
2060 let _ = tx.send(local_addr);
2062 }
2063
2064 run_server(
2065 router,
2066 listener,
2067 params.tls_paths,
2068 params.tls_handshake_timeout,
2069 params.max_concurrent_tls_handshakes,
2070 params.mtls_config,
2071 params.shutdown_timeout,
2072 params.auth_state,
2073 params.rbac_swap,
2074 params.on_reload_ready,
2075 params.ct,
2076 )
2077 .await
2078 .map_err(anyhow_to_startup)
2079}
2080
2081#[allow(
2084 clippy::cognitive_complexity,
2085 reason = "tracing::info! macro expansions inflate the score; logic is trivial"
2086)]
2087fn log_listening(name: &str, scheme: &str, addr: &str) {
2088 tracing::info!("{name} listening on {addr}");
2089 tracing::info!(" MCP endpoint: {scheme}://{addr}/mcp");
2090 tracing::info!(" Health check: {scheme}://{addr}/healthz");
2091 tracing::info!(" Readiness: {scheme}://{addr}/readyz");
2092}
2093
2094#[allow(
2117 clippy::too_many_arguments,
2118 clippy::cognitive_complexity,
2119 reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
2120)]
2121async fn run_server(
2122 router: axum::Router,
2123 listener: TcpListener,
2124 tls_paths: Option<(PathBuf, PathBuf)>,
2125 tls_handshake_timeout: Duration,
2126 max_concurrent_tls_handshakes: usize,
2127 mtls_config: Option<MtlsConfig>,
2128 shutdown_timeout: Duration,
2129 auth_state: Option<Arc<AuthState>>,
2130 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
2131 mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
2132 ct: CancellationToken,
2133) -> anyhow::Result<()> {
2134 let shutdown_trigger = CancellationToken::new();
2138 {
2139 let trigger = shutdown_trigger.clone();
2140 let parent = ct.clone();
2141 tokio::spawn(async move {
2142 tokio::select! {
2145 () = shutdown_signal() => {}
2146 () = parent.cancelled() => {}
2147 }
2148 trigger.cancel();
2149 });
2150 }
2151
2152 let graceful = {
2153 let trigger = shutdown_trigger.clone();
2154 let ct = ct.clone();
2155 async move {
2156 trigger.cancelled().await;
2157 tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
2158 ct.cancel();
2159 }
2160 };
2161
2162 let force_exit_timer = {
2163 let trigger = shutdown_trigger.clone();
2164 async move {
2165 trigger.cancelled().await;
2166 tokio::time::sleep(shutdown_timeout).await;
2167 }
2168 };
2169
2170 if let Some((cert_path, key_path)) = tls_paths {
2171 let crl_set = if let Some(mtls) = mtls_config.as_ref()
2172 && mtls.crl_enabled
2173 {
2174 let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
2175 let (crl_set, discover_rx) =
2176 mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
2177 .await
2178 .map_err(|error| anyhow::anyhow!(error.to_string()))?;
2179 tokio::spawn(mtls_revocation::run_crl_refresher(
2180 Arc::clone(&crl_set),
2181 discover_rx,
2182 ct.clone(),
2183 ));
2184 Some(crl_set)
2185 } else {
2186 None
2187 };
2188
2189 if let Some(cb) = on_reload_ready.take() {
2190 cb(ReloadHandle {
2191 auth: auth_state.clone(),
2192 rbac: Some(Arc::clone(&rbac_swap)),
2193 crl_set: crl_set.clone(),
2194 });
2195 }
2196
2197 let tls_listener = TlsListener::new(
2198 listener,
2199 &cert_path,
2200 &key_path,
2201 mtls_config.as_ref(),
2202 crl_set,
2203 tls_handshake_timeout,
2204 max_concurrent_tls_handshakes,
2205 )?;
2206 let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
2207 tokio::select! {
2210 result = axum::serve(tls_listener, make_svc)
2211 .with_graceful_shutdown(graceful) => { result?; }
2212 () = force_exit_timer => {
2213 tracing::warn!("shutdown timeout exceeded, forcing exit");
2214 }
2215 }
2216 } else {
2217 if let Some(cb) = on_reload_ready.take() {
2218 cb(ReloadHandle {
2219 auth: auth_state,
2220 rbac: Some(rbac_swap),
2221 crl_set: None,
2222 });
2223 }
2224
2225 let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
2226 tokio::select! {
2229 result = axum::serve(listener, make_svc)
2230 .with_graceful_shutdown(graceful) => { result?; }
2231 () = force_exit_timer => {
2232 tracing::warn!("shutdown timeout exceeded, forcing exit");
2233 }
2234 }
2235 }
2236
2237 Ok(())
2238}
2239
2240#[cfg(feature = "oauth")]
2249fn install_oauth_proxy_routes(
2250 router: axum::Router,
2251 server_url: &str,
2252 oauth_config: &crate::oauth::OAuthConfig,
2253 auth_state: Option<&Arc<AuthState>>,
2254 max_request_body: usize,
2255) -> Result<axum::Router, McpxError> {
2256 let Some(ref proxy) = oauth_config.proxy else {
2257 return Ok(router);
2258 };
2259
2260 let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
2263
2264 let proxy_router = axum::Router::new();
2270
2271 let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
2272 let proxy_router = proxy_router.route(
2273 "/.well-known/oauth-authorization-server",
2274 axum::routing::get(move || {
2275 let m = asm.clone();
2276 async move { axum::Json(m) }
2277 }),
2278 );
2279
2280 let proxy_authorize = proxy.clone();
2281 let proxy_router = proxy_router.route(
2282 "/authorize",
2283 axum::routing::get(
2284 move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
2285 let p = proxy_authorize.clone();
2286 async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
2287 },
2288 ),
2289 );
2290
2291 let proxy_token = proxy.clone();
2292 let token_http = http.clone();
2293 let proxy_router = proxy_router.route(
2294 "/token",
2295 axum::routing::post(move |body: String| {
2296 let p = proxy_token.clone();
2297 let h = token_http.clone();
2298 async move { crate::oauth::handle_token(&h, &p, &body).await }
2299 })
2300 .layer(axum::middleware::from_fn(
2301 oauth_token_cache_headers_middleware,
2302 )),
2303 );
2304
2305 let proxy_register = proxy.clone();
2306 let proxy_router = proxy_router.route(
2307 "/register",
2308 axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
2309 let p = proxy_register;
2310 async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
2311 })
2312 .layer(axum::middleware::from_fn(
2313 oauth_token_cache_headers_middleware,
2314 )),
2315 );
2316
2317 let admin_routes_enabled = proxy.expose_admin_endpoints
2318 && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
2319 if proxy.expose_admin_endpoints
2320 && !proxy.require_auth_on_admin_endpoints
2321 && proxy.allow_unauthenticated_admin_endpoints
2322 {
2323 tracing::warn!(
2327 "OAuth introspect/revoke endpoints are unauthenticated by explicit \
2328 allow_unauthenticated_admin_endpoints opt-out; ensure an \
2329 authenticated reverse proxy fronts these routes"
2330 );
2331 }
2332
2333 let admin_router = if admin_routes_enabled {
2334 build_oauth_admin_router(proxy, http, auth_state)?
2335 } else {
2336 axum::Router::new()
2337 };
2338
2339 let proxy_router =
2343 proxy_router
2344 .merge(admin_router)
2345 .layer(tower_http::limit::RequestBodyLimitLayer::new(
2346 max_request_body,
2347 ));
2348
2349 let router = router.merge(proxy_router);
2350
2351 tracing::info!(
2352 introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
2353 revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
2354 max_request_body,
2355 "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
2356 );
2357 Ok(router)
2358}
2359
2360#[cfg(feature = "oauth")]
2366fn build_oauth_admin_router(
2367 proxy: &crate::oauth::OAuthProxyConfig,
2368 http: crate::oauth::OauthHttpClient,
2369 auth_state: Option<&Arc<AuthState>>,
2370) -> Result<axum::Router, McpxError> {
2371 let mut admin_router = axum::Router::new();
2372 if proxy.introspection_url.is_some() {
2373 let proxy_introspect = proxy.clone();
2374 let introspect_http = http.clone();
2375 admin_router = admin_router.route(
2376 "/introspect",
2377 axum::routing::post(move |body: String| {
2378 let p = proxy_introspect.clone();
2379 let h = introspect_http.clone();
2380 async move { crate::oauth::handle_introspect(&h, &p, &body).await }
2381 }),
2382 );
2383 }
2384 if proxy.revocation_url.is_some() {
2385 let proxy_revoke = proxy.clone();
2386 let revoke_http = http;
2387 admin_router = admin_router.route(
2388 "/revoke",
2389 axum::routing::post(move |body: String| {
2390 let p = proxy_revoke.clone();
2391 let h = revoke_http.clone();
2392 async move { crate::oauth::handle_revoke(&h, &p, &body).await }
2393 }),
2394 );
2395 }
2396
2397 let admin_router = admin_router.layer(axum::middleware::from_fn(
2398 oauth_token_cache_headers_middleware,
2399 ));
2400
2401 if proxy.require_auth_on_admin_endpoints {
2402 let Some(state) = auth_state else {
2403 return Err(McpxError::Startup(
2404 "oauth proxy admin endpoints require auth state".into(),
2405 ));
2406 };
2407 let state_for_mw = Arc::clone(state);
2408 Ok(
2409 admin_router.layer(axum::middleware::from_fn(move |req, next| {
2410 let s = Arc::clone(&state_for_mw);
2411 auth_middleware(s, req, next)
2412 })),
2413 )
2414 } else {
2415 Ok(admin_router)
2416 }
2417}
2418
2419fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
2424 let mut hosts = vec![
2425 "localhost".to_owned(),
2426 "127.0.0.1".to_owned(),
2427 "::1".to_owned(),
2428 ];
2429
2430 if let Some(url) = public_url
2431 && let Ok(uri) = url.parse::<axum::http::Uri>()
2432 && let Some(authority) = uri.authority()
2433 {
2434 let host = authority.host().to_owned();
2435 if !hosts.iter().any(|h| h == &host) {
2436 hosts.push(host);
2437 }
2438
2439 let authority = authority.as_str().to_owned();
2440 if !hosts.iter().any(|h| h == &authority) {
2441 hosts.push(authority);
2442 }
2443 }
2444
2445 if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
2446 && let Some(authority) = uri.authority()
2447 {
2448 let host = authority.host().to_owned();
2449 if !hosts.iter().any(|h| h == &host) {
2450 hosts.push(host);
2451 }
2452
2453 let authority = authority.as_str().to_owned();
2454 if !hosts.iter().any(|h| h == &authority) {
2455 hosts.push(authority);
2456 }
2457 }
2458
2459 hosts
2460}
2461
2462impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
2475 for TlsConnInfo
2476{
2477 fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
2478 let addr = *target.remote_addr();
2479 let identity = target.io().identity().cloned();
2480 Self::new(addr, identity)
2481 }
2482}
2483
2484const DEFAULT_TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
2491
2492const DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES: usize = 256;
2500
2501const TLS_ACCEPT_CHANNEL_CAPACITY: usize = 32;
2506
2507struct TlsListener {
2523 local_addr: SocketAddr,
2526 rx: mpsc::Receiver<(AuthenticatedTlsStream, SocketAddr)>,
2528 acceptor_task: tokio::task::JoinHandle<()>,
2531}
2532
2533impl TlsListener {
2534 fn new(
2535 inner: TcpListener,
2536 cert_path: &Path,
2537 key_path: &Path,
2538 mtls_config: Option<&MtlsConfig>,
2539 crl_set: Option<Arc<CrlSet>>,
2540 handshake_timeout: Duration,
2541 max_concurrent_handshakes: usize,
2542 ) -> anyhow::Result<Self> {
2543 rustls::crypto::ring::default_provider()
2545 .install_default()
2546 .ok();
2547
2548 let certs = load_certs(cert_path)?;
2549 let key = load_key(key_path)?;
2550
2551 let mtls_default_role;
2552
2553 let tls_config = if let Some(mtls) = mtls_config {
2554 mtls_default_role = mtls.default_role.clone();
2555 let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
2556 {
2557 let Some(crl_set) = crl_set else {
2558 return Err(anyhow::anyhow!(
2559 "mTLS CRL verifier requested but CRL state was not initialized"
2560 ));
2561 };
2562 Arc::new(DynamicClientCertVerifier::new(crl_set))
2563 } else {
2564 let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
2565 if mtls.required {
2566 rustls::server::WebPkiClientVerifier::builder(root_store)
2567 .build()
2568 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2569 } else {
2570 rustls::server::WebPkiClientVerifier::builder(root_store)
2571 .allow_unauthenticated()
2572 .build()
2573 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2574 }
2575 };
2576
2577 tracing::info!(
2578 ca = %mtls.ca_cert_path.display(),
2579 required = mtls.required,
2580 crl_enabled = mtls.crl_enabled,
2581 "mTLS client auth configured"
2582 );
2583
2584 rustls::ServerConfig::builder_with_protocol_versions(&[
2585 &rustls::version::TLS12,
2586 &rustls::version::TLS13,
2587 ])
2588 .with_client_cert_verifier(verifier)
2589 .with_single_cert(certs, key)?
2590 } else {
2591 mtls_default_role = "viewer".to_owned();
2592 rustls::ServerConfig::builder_with_protocol_versions(&[
2593 &rustls::version::TLS12,
2594 &rustls::version::TLS13,
2595 ])
2596 .with_no_client_auth()
2597 .with_single_cert(certs, key)?
2598 };
2599
2600 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
2601 tracing::info!(
2602 "TLS enabled (cert: {}, key: {})",
2603 cert_path.display(),
2604 key_path.display()
2605 );
2606 let local_addr = inner.local_addr()?;
2607 let (tx, rx) = mpsc::channel(TLS_ACCEPT_CHANNEL_CAPACITY);
2608 let acceptor_task = tokio::spawn(run_tls_acceptor(
2609 inner,
2610 acceptor,
2611 mtls_default_role,
2612 tx,
2613 handshake_timeout,
2614 max_concurrent_handshakes,
2615 ));
2616 Ok(Self {
2617 local_addr,
2618 rx,
2619 acceptor_task,
2620 })
2621 }
2622
2623 fn extract_handshake_identity(
2627 tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2628 default_role: &str,
2629 addr: SocketAddr,
2630 ) -> Option<AuthIdentity> {
2631 let (_, server_conn) = tls_stream.get_ref();
2632 let cert_der = server_conn.peer_certificates()?.first()?;
2633 let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
2634 tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
2635 Some(id)
2636 }
2637}
2638
2639async fn run_tls_acceptor(
2647 listener: TcpListener,
2648 acceptor: tokio_rustls::TlsAcceptor,
2649 default_role: String,
2650 tx: mpsc::Sender<(AuthenticatedTlsStream, SocketAddr)>,
2651 handshake_timeout: Duration,
2652 max_concurrent_handshakes: usize,
2653) {
2654 let inflight = Arc::new(Semaphore::new(max_concurrent_handshakes));
2655 loop {
2656 let Ok(permit) = Arc::clone(&inflight).acquire_owned().await else {
2660 return;
2662 };
2663 let (stream, addr) = match listener.accept().await {
2664 Ok(pair) => pair,
2665 Err(e) => {
2666 tracing::debug!("TCP accept error: {e}");
2667 continue;
2668 }
2669 };
2670 if tx.is_closed() {
2671 return;
2673 }
2674 let acceptor = acceptor.clone();
2675 let default_role = default_role.clone();
2676 let tx = tx.clone();
2677 tokio::spawn(async move {
2678 let _permit = permit;
2679 match tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
2680 Ok(Ok(tls_stream)) => {
2681 let identity =
2682 TlsListener::extract_handshake_identity(&tls_stream, &default_role, addr);
2683 let wrapped = AuthenticatedTlsStream {
2684 inner: tls_stream,
2685 identity,
2686 };
2687 let _ = tx.send((wrapped, addr)).await;
2690 }
2691 Ok(Err(e)) => {
2692 tracing::debug!("TLS handshake failed from {addr}: {e}");
2693 }
2694 Err(_elapsed) => {
2695 tracing::debug!(
2696 "TLS handshake timed out from {addr} after {handshake_timeout:?}"
2697 );
2698 }
2699 }
2700 });
2701 }
2702}
2703
2704pub(crate) struct AuthenticatedTlsStream {
2716 inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2717 identity: Option<AuthIdentity>,
2718}
2719
2720impl AuthenticatedTlsStream {
2721 #[must_use]
2723 pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
2724 self.identity.as_ref()
2725 }
2726}
2727
2728impl std::fmt::Debug for AuthenticatedTlsStream {
2729 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2730 f.debug_struct("AuthenticatedTlsStream")
2731 .field("identity", &self.identity.as_ref().map(|id| &id.name))
2732 .finish_non_exhaustive()
2733 }
2734}
2735
2736impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2737 fn poll_read(
2738 mut self: Pin<&mut Self>,
2739 cx: &mut std::task::Context<'_>,
2740 buf: &mut tokio::io::ReadBuf<'_>,
2741 ) -> std::task::Poll<std::io::Result<()>> {
2742 Pin::new(&mut self.inner).poll_read(cx, buf)
2743 }
2744}
2745
2746impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2747 fn poll_write(
2748 mut self: Pin<&mut Self>,
2749 cx: &mut std::task::Context<'_>,
2750 buf: &[u8],
2751 ) -> std::task::Poll<std::io::Result<usize>> {
2752 Pin::new(&mut self.inner).poll_write(cx, buf)
2753 }
2754
2755 fn poll_flush(
2756 mut self: Pin<&mut Self>,
2757 cx: &mut std::task::Context<'_>,
2758 ) -> std::task::Poll<std::io::Result<()>> {
2759 Pin::new(&mut self.inner).poll_flush(cx)
2760 }
2761
2762 fn poll_shutdown(
2763 mut self: Pin<&mut Self>,
2764 cx: &mut std::task::Context<'_>,
2765 ) -> std::task::Poll<std::io::Result<()>> {
2766 Pin::new(&mut self.inner).poll_shutdown(cx)
2767 }
2768
2769 fn poll_write_vectored(
2770 mut self: Pin<&mut Self>,
2771 cx: &mut std::task::Context<'_>,
2772 bufs: &[std::io::IoSlice<'_>],
2773 ) -> std::task::Poll<std::io::Result<usize>> {
2774 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2775 }
2776
2777 fn is_write_vectored(&self) -> bool {
2778 self.inner.is_write_vectored()
2779 }
2780}
2781
2782impl axum::serve::Listener for TlsListener {
2783 type Io = AuthenticatedTlsStream;
2784 type Addr = SocketAddr;
2785
2786 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2792 if let Some(pair) = self.rx.recv().await {
2793 return pair;
2794 }
2795 tracing::error!("TLS acceptor task terminated; no further connections will be accepted");
2801 std::future::pending().await
2802 }
2803
2804 fn local_addr(&self) -> std::io::Result<Self::Addr> {
2805 Ok(self.local_addr)
2806 }
2807}
2808
2809impl Drop for TlsListener {
2810 fn drop(&mut self) {
2811 self.acceptor_task.abort();
2814 }
2815}
2816
2817fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2818 use rustls::pki_types::pem::PemObject;
2819 let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2820 .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2821 .collect::<Result<_, _>>()
2822 .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2823 anyhow::ensure!(
2824 !certs.is_empty(),
2825 "no certificates found in {}",
2826 path.display()
2827 );
2828 Ok(certs)
2829}
2830
2831fn load_client_auth_roots(
2832 path: &Path,
2833) -> anyhow::Result<(
2834 Vec<rustls::pki_types::CertificateDer<'static>>,
2835 Arc<RootCertStore>,
2836)> {
2837 let ca_certs = load_certs(path)?;
2838 let mut root_store = RootCertStore::empty();
2839 for cert in &ca_certs {
2840 root_store
2841 .add(cert.clone())
2842 .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2843 }
2844
2845 Ok((ca_certs, Arc::new(root_store)))
2846}
2847
2848fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2849 use rustls::pki_types::pem::PemObject;
2850 rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2851 .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2852}
2853
2854#[allow(
2855 clippy::unused_async,
2856 reason = "axum route handler signature requires `async fn` even when the body is synchronous"
2857)]
2858async fn healthz() -> impl IntoResponse {
2859 axum::Json(serde_json::json!({
2860 "status": "ok",
2861 }))
2862}
2863
2864fn version_payload(name: &str, version: &str) -> serde_json::Value {
2871 serde_json::json!({
2872 "name": name,
2873 "version": version,
2874 "build_git_sha": option_env!("RMCP_SERVER_KIT_BUILD_SHA").unwrap_or("unknown"),
2875 "build_timestamp": option_env!("RMCP_SERVER_KIT_BUILD_TIME").unwrap_or("unknown"),
2876 "rust_version": option_env!("RMCP_SERVER_KIT_RUSTC_VERSION").unwrap_or("unknown"),
2877 "mcpx_version": env!("CARGO_PKG_VERSION"),
2878 })
2879}
2880
2881fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2891 let value = version_payload(name, version);
2892 serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2893}
2894
2895async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2896 let status = check().await;
2897 let ready = status
2898 .get("ready")
2899 .and_then(serde_json::Value::as_bool)
2900 .unwrap_or(false);
2901 let code = if ready {
2902 axum::http::StatusCode::OK
2903 } else {
2904 axum::http::StatusCode::SERVICE_UNAVAILABLE
2905 };
2906 (code, axum::Json(status))
2907}
2908
2909async fn shutdown_signal() {
2913 let ctrl_c = tokio::signal::ctrl_c();
2914
2915 #[cfg(unix)]
2916 {
2917 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2918 Ok(mut term) => {
2919 tokio::select! {
2922 _ = ctrl_c => {}
2923 _ = term.recv() => {}
2924 }
2925 }
2926 Err(e) => {
2927 tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2928 ctrl_c.await.ok();
2929 }
2930 }
2931 }
2932
2933 #[cfg(not(unix))]
2934 {
2935 ctrl_c.await.ok();
2936 }
2937}
2938
2939#[cfg(feature = "metrics")]
2950async fn metrics_middleware(
2951 metrics: Arc<crate::metrics::McpMetrics>,
2952 mut req: Request<Body>,
2953 next: Next,
2954) -> axum::response::Response {
2955 let method = req.method().to_string();
2956 let path = req.uri().path().to_owned();
2957 let start = std::time::Instant::now();
2958
2959 req.extensions_mut().insert(Arc::clone(&metrics));
2960 let response = next.run(req).await;
2961
2962 let status = response.status().as_u16().to_string();
2963 let duration = start.elapsed().as_secs_f64();
2964
2965 metrics
2966 .http_requests_total
2967 .with_label_values(&[&method, &path, &status])
2968 .inc();
2969 metrics
2970 .http_request_duration_seconds
2971 .with_label_values(&[&method, &path])
2972 .observe(duration);
2973
2974 response
2975}
2976
2977async fn security_headers_middleware(
2989 is_tls: bool,
2990 cfg: Arc<SecurityHeadersConfig>,
2991 req: Request<Body>,
2992 next: Next,
2993) -> axum::response::Response {
2994 use axum::http::{HeaderName, header};
2995
2996 let mut resp = next.run(req).await;
2997 let headers = resp.headers_mut();
2998
2999 headers.remove(header::SERVER);
3001 headers.remove(HeaderName::from_static("x-powered-by"));
3002
3003 apply_security_header(
3004 headers,
3005 header::X_CONTENT_TYPE_OPTIONS,
3006 cfg.x_content_type_options.as_deref(),
3007 "nosniff",
3008 );
3009 apply_security_header(
3010 headers,
3011 header::X_FRAME_OPTIONS,
3012 cfg.x_frame_options.as_deref(),
3013 "deny",
3014 );
3015 apply_security_header(
3016 headers,
3017 header::CACHE_CONTROL,
3018 cfg.cache_control.as_deref(),
3019 "no-store, max-age=0",
3020 );
3021 apply_security_header(
3022 headers,
3023 header::REFERRER_POLICY,
3024 cfg.referrer_policy.as_deref(),
3025 "no-referrer",
3026 );
3027 apply_security_header(
3028 headers,
3029 HeaderName::from_static("cross-origin-opener-policy"),
3030 cfg.cross_origin_opener_policy.as_deref(),
3031 "same-origin",
3032 );
3033 apply_security_header(
3034 headers,
3035 HeaderName::from_static("cross-origin-resource-policy"),
3036 cfg.cross_origin_resource_policy.as_deref(),
3037 "same-origin",
3038 );
3039 apply_security_header(
3040 headers,
3041 HeaderName::from_static("cross-origin-embedder-policy"),
3042 cfg.cross_origin_embedder_policy.as_deref(),
3043 "require-corp",
3044 );
3045 apply_security_header(
3046 headers,
3047 HeaderName::from_static("permissions-policy"),
3048 cfg.permissions_policy.as_deref(),
3049 "accelerometer=(), camera=(), geolocation=(), microphone=()",
3050 );
3051 apply_security_header(
3052 headers,
3053 HeaderName::from_static("x-permitted-cross-domain-policies"),
3054 cfg.x_permitted_cross_domain_policies.as_deref(),
3055 "none",
3056 );
3057 apply_security_header(
3058 headers,
3059 HeaderName::from_static("content-security-policy"),
3060 cfg.content_security_policy.as_deref(),
3061 "default-src 'none'; frame-ancestors 'none'",
3062 );
3063 apply_security_header(
3064 headers,
3065 HeaderName::from_static("x-dns-prefetch-control"),
3066 cfg.x_dns_prefetch_control.as_deref(),
3067 "off",
3068 );
3069
3070 if is_tls {
3071 apply_security_header(
3072 headers,
3073 header::STRICT_TRANSPORT_SECURITY,
3074 cfg.strict_transport_security.as_deref(),
3075 "max-age=63072000; includeSubDomains",
3076 );
3077 }
3078
3079 resp
3080}
3081
3082fn apply_security_header(
3093 headers: &mut axum::http::HeaderMap,
3094 name: axum::http::HeaderName,
3095 override_value: Option<&str>,
3096 default: &'static str,
3097) {
3098 use axum::http::HeaderValue;
3099
3100 match override_value {
3101 None => {
3102 headers.insert(name, HeaderValue::from_static(default));
3103 }
3104 Some("") => {
3105 }
3107 Some(v) => match HeaderValue::from_str(v) {
3108 Ok(hv) => {
3109 headers.insert(name, hv);
3110 }
3111 Err(err) => {
3112 tracing::error!(
3113 header = %name,
3114 error = %err,
3115 "invalid security header override reached middleware; using default"
3116 );
3117 headers.insert(name, HeaderValue::from_static(default));
3118 }
3119 },
3120 }
3121}
3122
3123fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
3134 use axum::http::HeaderValue;
3135
3136 let fields: &[(&str, Option<&str>)] = &[
3137 (
3138 "x_content_type_options",
3139 cfg.x_content_type_options.as_deref(),
3140 ),
3141 ("x_frame_options", cfg.x_frame_options.as_deref()),
3142 ("cache_control", cfg.cache_control.as_deref()),
3143 ("referrer_policy", cfg.referrer_policy.as_deref()),
3144 (
3145 "cross_origin_opener_policy",
3146 cfg.cross_origin_opener_policy.as_deref(),
3147 ),
3148 (
3149 "cross_origin_resource_policy",
3150 cfg.cross_origin_resource_policy.as_deref(),
3151 ),
3152 (
3153 "cross_origin_embedder_policy",
3154 cfg.cross_origin_embedder_policy.as_deref(),
3155 ),
3156 ("permissions_policy", cfg.permissions_policy.as_deref()),
3157 (
3158 "x_permitted_cross_domain_policies",
3159 cfg.x_permitted_cross_domain_policies.as_deref(),
3160 ),
3161 (
3162 "content_security_policy",
3163 cfg.content_security_policy.as_deref(),
3164 ),
3165 (
3166 "x_dns_prefetch_control",
3167 cfg.x_dns_prefetch_control.as_deref(),
3168 ),
3169 (
3170 "strict_transport_security",
3171 cfg.strict_transport_security.as_deref(),
3172 ),
3173 ];
3174
3175 for (field, value) in fields {
3176 let Some(v) = value else { continue };
3177 if v.is_empty() {
3178 continue;
3179 }
3180 if let Err(err) = HeaderValue::from_str(v) {
3181 return Err(McpxError::Config(format!(
3182 "invalid security_headers.{field}: {err}"
3183 )));
3184 }
3185 }
3186
3187 if let Some(v) = cfg.strict_transport_security.as_deref()
3188 && !v.is_empty()
3189 && v.to_ascii_lowercase().contains("preload")
3190 {
3191 return Err(McpxError::Config(format!(
3192 "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
3193 HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
3194 )));
3195 }
3196
3197 Ok(())
3198}
3199
3200#[cfg(feature = "oauth")]
3215async fn oauth_token_cache_headers_middleware(
3216 req: Request<Body>,
3217 next: Next,
3218) -> axum::response::Response {
3219 use axum::http::{HeaderValue, header};
3220
3221 let mut resp = next.run(req).await;
3222 let headers = resp.headers_mut();
3223 headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
3224 headers.append(header::VARY, HeaderValue::from_static("Authorization"));
3225 resp
3226}
3227
3228async fn normalize_peer_addr_middleware(
3257 resolver: Option<Arc<ForwardResolver>>,
3258 mut req: Request<Body>,
3259 next: Next,
3260) -> axum::response::Response {
3261 let direct = req
3262 .extensions()
3263 .get::<ConnectInfo<SocketAddr>>()
3264 .map(|ci| ci.0);
3265 let from_tls = req
3266 .extensions()
3267 .get::<ConnectInfo<TlsConnInfo>>()
3268 .map(|ci| ci.0.addr);
3269 if let Some(addr) = direct.or(from_tls) {
3270 if direct.is_none() {
3271 req.extensions_mut().insert(ConnectInfo(addr));
3272 }
3273 req.extensions_mut().insert(PeerAddr::new(addr));
3274 let client_ip = match &resolver {
3275 Some(r) => {
3276 crate::forwarded::resolve_client_ip(addr.ip(), req.headers(), &r.trusted, r.mode)
3277 .unwrap_or_else(|reason| {
3278 tracing::debug!(
3279 reason = ?reason,
3280 "forwarded-header resolution fell back to direct peer"
3281 );
3282 addr.ip()
3283 })
3284 }
3285 None => addr.ip(),
3286 };
3287 req.extensions_mut().insert(ClientIp::new(client_ip));
3288 }
3289 next.run(req).await
3290}
3291
3292fn parse_proxy_net(entry: &str) -> Option<ipnet::IpNet> {
3295 if let Ok(net) = entry.parse::<ipnet::IpNet>() {
3296 return Some(net);
3297 }
3298 entry.parse::<IpAddr>().ok().map(ipnet::IpNet::from)
3299}
3300
3301pub(crate) fn limiter_client_ip(extensions: &axum::http::Extensions) -> Option<IpAddr> {
3305 if let Some(client) = extensions.get::<ClientIp>() {
3306 return Some(client.ip);
3307 }
3308 extensions
3309 .get::<ConnectInfo<SocketAddr>>()
3310 .map(|ci| ci.0.ip())
3311 .or_else(|| {
3312 extensions
3313 .get::<ConnectInfo<TlsConnInfo>>()
3314 .map(|ci| ci.0.addr.ip())
3315 })
3316}
3317
3318pub(crate) type ExtraRouteRateLimiter = BoundedKeyedLimiter<IpAddr>;
3322
3323const EXTRA_ROUTE_MAX_TRACKED_KEYS: usize = 10_000;
3329
3330const EXTRA_ROUTE_IDLE_EVICTION: Duration = Duration::from_mins(15);
3333
3334fn build_extra_route_rate_limiter(
3341 per_minute: u32,
3342 burst: Option<u32>,
3343) -> Arc<ExtraRouteRateLimiter> {
3344 let rate = std::num::NonZeroU32::new(per_minute.max(1)).unwrap_or(std::num::NonZeroU32::MIN);
3345 let mut quota = governor::Quota::per_minute(rate);
3346 if let Some(b) = burst.and_then(std::num::NonZeroU32::new) {
3347 quota = quota.allow_burst(b);
3348 }
3349 Arc::new(BoundedKeyedLimiter::new(
3350 quota,
3351 EXTRA_ROUTE_MAX_TRACKED_KEYS,
3352 EXTRA_ROUTE_IDLE_EVICTION,
3353 ))
3354}
3355
3356async fn extra_route_rate_limit_middleware(
3378 limiter: Arc<ExtraRouteRateLimiter>,
3379 exempt: Arc<std::collections::HashSet<String>>,
3380 req: Request<Body>,
3381 next: Next,
3382) -> axum::response::Response {
3383 if exempt.contains(req.uri().path()) {
3384 return next.run(req).await;
3385 }
3386 let peer_ip: Option<IpAddr> = limiter_client_ip(req.extensions());
3387 if let Some(ip) = peer_ip
3388 && let Err(wait) = limiter.check_key_wait(&ip)
3389 {
3390 #[cfg(feature = "metrics")]
3391 crate::metrics::record_rate_limit_deny(req.extensions(), "extra_route");
3392 tracing::warn!(%ip, "extra route request rate limited");
3393 return McpxError::RateLimitedFor {
3394 message: "too many requests to application routes from this source".into(),
3395 retry_after: wait,
3396 }
3397 .into_response();
3398 }
3399 next.run(req).await
3400}
3401
3402async fn origin_check_middleware(
3406 allowed: Arc<[String]>,
3407 log_request_headers: bool,
3408 req: Request<Body>,
3409 next: Next,
3410) -> axum::response::Response {
3411 let method = req.method().clone();
3412 let path = req.uri().path().to_owned();
3413
3414 log_incoming_request(&method, &path, req.headers(), log_request_headers);
3415
3416 if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
3417 let origin_str = origin.to_str().unwrap_or("");
3418 if !allowed.iter().any(|a| a == origin_str) {
3419 tracing::warn!(
3420 origin = origin_str,
3421 %method,
3422 %path,
3423 allowed = ?&*allowed,
3424 "rejected request: Origin not allowed"
3425 );
3426 return (
3427 axum::http::StatusCode::FORBIDDEN,
3428 "Forbidden: Origin not allowed",
3429 )
3430 .into_response();
3431 }
3432 }
3433 next.run(req).await
3434}
3435
3436fn log_incoming_request(
3439 method: &axum::http::Method,
3440 path: &str,
3441 headers: &axum::http::HeaderMap,
3442 log_request_headers: bool,
3443) {
3444 if log_request_headers {
3445 tracing::debug!(
3446 %method,
3447 %path,
3448 headers = %format_request_headers_for_log(headers),
3449 "incoming request"
3450 );
3451 } else {
3452 tracing::debug!(%method, %path, "incoming request");
3453 }
3454}
3455
3456fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
3457 headers
3458 .iter()
3459 .map(|(k, v)| {
3460 let name = k.as_str();
3461 if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
3462 format!("{name}: [REDACTED]")
3463 } else {
3464 format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
3465 }
3466 })
3467 .collect::<Vec<_>>()
3468 .join(", ")
3469}
3470
3471#[allow(
3495 clippy::cognitive_complexity,
3496 reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
3497)]
3498pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
3499where
3500 H: ServerHandler + 'static,
3501{
3502 use rmcp::ServiceExt as _;
3503
3504 tracing::info!("stdio transport: serving on stdin/stdout");
3505 tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
3506
3507 let transport = rmcp::transport::io::stdio();
3508
3509 let service = handler
3510 .serve(transport)
3511 .await
3512 .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
3513
3514 if let Err(e) = service.waiting().await {
3515 tracing::warn!(error = %e, "stdio session ended with error");
3516 }
3517 tracing::info!("stdio session ended");
3518 Ok(())
3519}
3520
3521#[cfg(test)]
3522mod tests {
3523 #![allow(
3524 clippy::unwrap_used,
3525 clippy::expect_used,
3526 clippy::panic,
3527 clippy::indexing_slicing,
3528 clippy::unwrap_in_result,
3529 clippy::print_stdout,
3530 clippy::print_stderr,
3531 deprecated,
3532 reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
3533 )]
3534 use std::{sync::Arc, time::Duration};
3535
3536 use axum::{
3537 body::Body,
3538 http::{Request, StatusCode, header},
3539 response::IntoResponse,
3540 };
3541 use http_body_util::BodyExt;
3542 use tower::ServiceExt as _;
3543
3544 use super::*;
3545
3546 #[test]
3549 fn server_config_new_defaults() {
3550 let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
3551 assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
3552 assert_eq!(cfg.name, "test-server");
3553 assert_eq!(cfg.version, "1.0.0");
3554 assert!(cfg.tls_cert_path.is_none());
3555 assert!(cfg.tls_key_path.is_none());
3556 assert!(cfg.auth.is_none());
3557 assert!(cfg.rbac.is_none());
3558 assert!(cfg.allowed_origins.is_empty());
3559 assert!(cfg.tool_rate_limit.is_none());
3560 assert!(cfg.readiness_check.is_none());
3561 assert_eq!(cfg.max_request_body, 1024 * 1024);
3562 assert_eq!(cfg.request_timeout, Duration::from_mins(2));
3563 assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
3564 assert!(!cfg.log_request_headers);
3565 assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(10));
3566 assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
3567 }
3568
3569 #[test]
3570 fn tls_handshake_builders_set_fields() {
3571 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3572 .with_tls_handshake_timeout(Duration::from_secs(3))
3573 .with_max_concurrent_tls_handshakes(64);
3574 assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(3));
3575 assert_eq!(cfg.max_concurrent_tls_handshakes, 64);
3576 }
3577
3578 #[test]
3579 fn validate_rejects_zero_tls_handshake_timeout() {
3580 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3581 .with_tls_handshake_timeout(Duration::ZERO);
3582 let err = cfg.validate().expect_err("zero handshake timeout");
3583 assert!(err.to_string().contains("tls_handshake_timeout"));
3584 }
3585
3586 #[test]
3587 fn validate_rejects_zero_max_concurrent_tls_handshakes() {
3588 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3589 .with_max_concurrent_tls_handshakes(0);
3590 let err = cfg.validate().expect_err("zero handshake concurrency");
3591 assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
3592 }
3593
3594 #[test]
3595 fn validate_consumes_and_proves() {
3596 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3598 let validated = cfg.validate().expect("valid config");
3599 assert_eq!(validated.as_inner().name, "test-server");
3601 let raw = validated.into_inner();
3603 assert_eq!(raw.name, "test-server");
3604
3605 let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3607 bad.max_request_body = 0;
3608 assert!(bad.validate().is_err(), "zero body cap must fail validate");
3609 }
3610
3611 #[test]
3612 fn validate_rejects_zero_max_concurrent_requests() {
3613 let cfg =
3614 McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
3615 let err = cfg.validate().expect_err("zero concurrency cap must fail");
3616 assert!(
3617 format!("{err}").contains("max_concurrent_requests"),
3618 "error should mention max_concurrent_requests, got: {err}"
3619 );
3620 }
3621
3622 #[test]
3623 fn validate_rejects_zero_max_tracked_keys() {
3624 let rl = crate::auth::RateLimitConfig {
3627 max_attempts_per_minute: 30,
3628 pre_auth_max_per_minute: None,
3629 max_tracked_keys: 0,
3630 idle_eviction: Duration::from_secs(15 * 60),
3631 burst: None,
3632 pre_auth_burst: None,
3633 };
3634 let auth_cfg = AuthConfig {
3635 enabled: true,
3636 api_keys: Vec::new(),
3637 mtls: None,
3638 rate_limit: Some(rl),
3639 #[cfg(feature = "oauth")]
3640 oauth: None,
3641 };
3642 let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
3643 let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
3644 assert!(
3645 format!("{err}").contains("max_tracked_keys"),
3646 "error should mention max_tracked_keys, got: {err}"
3647 );
3648 }
3649
3650 #[test]
3651 fn derive_allowed_hosts_includes_public_host() {
3652 let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
3653 assert!(
3654 hosts.iter().any(|h| h == "mcp.example.com"),
3655 "public_url host must be allowed"
3656 );
3657 }
3658
3659 #[test]
3660 fn derive_allowed_hosts_includes_bind_authority() {
3661 let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
3662 assert!(
3663 hosts.iter().any(|h| h == "127.0.0.1"),
3664 "bind host must be allowed"
3665 );
3666 assert!(
3667 hosts.iter().any(|h| h == "127.0.0.1:8080"),
3668 "bind authority must be allowed"
3669 );
3670 }
3671
3672 #[tokio::test]
3675 async fn healthz_returns_ok_json() {
3676 let resp = healthz().await.into_response();
3677 assert_eq!(resp.status(), StatusCode::OK);
3678 let body = resp.into_body().collect().await.unwrap().to_bytes();
3679 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3680 assert_eq!(json["status"], "ok");
3681 assert!(
3682 json.get("name").is_none(),
3683 "healthz must not expose server name"
3684 );
3685 assert!(
3686 json.get("version").is_none(),
3687 "healthz must not expose version"
3688 );
3689 }
3690
3691 #[tokio::test]
3694 async fn readyz_returns_ok_when_ready() {
3695 let check: ReadinessCheck =
3696 Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
3697 let resp = readyz(check).await.into_response();
3698 assert_eq!(resp.status(), StatusCode::OK);
3699 let body = resp.into_body().collect().await.unwrap().to_bytes();
3700 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3701 assert_eq!(json["ready"], true);
3702 assert!(
3703 json.get("name").is_none(),
3704 "readyz must not expose server name"
3705 );
3706 assert!(
3707 json.get("version").is_none(),
3708 "readyz must not expose version"
3709 );
3710 assert_eq!(json["db"], "connected");
3711 }
3712
3713 #[tokio::test]
3714 async fn readyz_returns_503_when_not_ready() {
3715 let check: ReadinessCheck =
3716 Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
3717 let resp = readyz(check).await.into_response();
3718 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3719 }
3720
3721 #[tokio::test]
3722 async fn readyz_returns_503_when_ready_missing() {
3723 let check: ReadinessCheck =
3724 Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
3725 let resp = readyz(check).await.into_response();
3726 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3728 }
3729
3730 fn peer_probe_router() -> axum::Router {
3735 async fn probe(req: Request<Body>) -> String {
3736 let ci = req
3737 .extensions()
3738 .get::<ConnectInfo<SocketAddr>>()
3739 .map(|c| c.0.to_string())
3740 .unwrap_or_default();
3741 let pa = req
3742 .extensions()
3743 .get::<PeerAddr>()
3744 .map(|p| p.addr.to_string())
3745 .unwrap_or_default();
3746 format!("{ci}|{pa}")
3747 }
3748 axum::Router::new()
3749 .route("/probe", axum::routing::get(probe))
3750 .layer(axum::middleware::from_fn(|req, next| {
3751 normalize_peer_addr_middleware(None, req, next)
3752 }))
3753 }
3754
3755 async fn body_string(resp: axum::response::Response) -> String {
3756 let bytes = resp.into_body().collect().await.unwrap().to_bytes();
3757 String::from_utf8(bytes.to_vec()).unwrap()
3758 }
3759
3760 #[tokio::test]
3761 async fn normalize_preserves_existing_connect_info_and_mirrors_peer_addr() {
3762 let plain: SocketAddr = "10.0.0.1:1111".parse().unwrap();
3765 let tls: SocketAddr = "10.0.0.2:2222".parse().unwrap();
3766 let req = Request::builder()
3767 .uri("/probe")
3768 .extension(ConnectInfo(plain))
3769 .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3770 .body(Body::empty())
3771 .unwrap();
3772 let resp = peer_probe_router().oneshot(req).await.unwrap();
3773 assert_eq!(resp.status(), StatusCode::OK);
3774 assert_eq!(body_string(resp).await, format!("{plain}|{plain}"));
3775 }
3776
3777 #[tokio::test]
3778 async fn normalize_inserts_connect_info_and_peer_addr_from_tls() {
3779 let tls: SocketAddr = "192.168.1.7:50443".parse().unwrap();
3780 let req = Request::builder()
3781 .uri("/probe")
3782 .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3783 .body(Body::empty())
3784 .unwrap();
3785 let resp = peer_probe_router().oneshot(req).await.unwrap();
3786 assert_eq!(resp.status(), StatusCode::OK);
3787 assert_eq!(body_string(resp).await, format!("{tls}|{tls}"));
3788 }
3789
3790 #[tokio::test]
3791 async fn normalize_no_op_without_any_connect_info() {
3792 let req = Request::builder()
3793 .uri("/probe")
3794 .body(Body::empty())
3795 .unwrap();
3796 let resp = peer_probe_router().oneshot(req).await.unwrap();
3797 assert_eq!(resp.status(), StatusCode::OK);
3798 assert_eq!(body_string(resp).await, "|");
3799 }
3800
3801 #[tokio::test]
3802 async fn peer_addr_extractor_rejects_when_absent() {
3803 async fn h(peer: PeerAddr) -> String {
3804 peer.addr.to_string()
3805 }
3806 let app = axum::Router::new().route("/p", axum::routing::get(h));
3807 let req = Request::builder().uri("/p").body(Body::empty()).unwrap();
3808 let resp = app.oneshot(req).await.unwrap();
3809 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
3810 }
3811
3812 #[tokio::test]
3813 async fn peer_addr_extractor_returns_value_when_present() {
3814 async fn h(peer: PeerAddr) -> String {
3815 peer.addr.to_string()
3816 }
3817 let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap();
3818 let app = axum::Router::new().route("/p", axum::routing::get(h));
3819 let req = Request::builder()
3820 .uri("/p")
3821 .extension(PeerAddr::new(addr))
3822 .body(Body::empty())
3823 .unwrap();
3824 let resp = app.oneshot(req).await.unwrap();
3825 assert_eq!(resp.status(), StatusCode::OK);
3826 assert_eq!(body_string(resp).await, addr.to_string());
3827 }
3828
3829 #[tokio::test]
3830 async fn peer_addr_via_extension_extractor() {
3831 async fn h(axum::Extension(peer): axum::Extension<PeerAddr>) -> String {
3832 peer.addr.to_string()
3833 }
3834 let addr: SocketAddr = "127.0.0.1:4242".parse().unwrap();
3835 let app = axum::Router::new().route("/p", axum::routing::get(h));
3836 let req = Request::builder()
3837 .uri("/p")
3838 .extension(PeerAddr::new(addr))
3839 .body(Body::empty())
3840 .unwrap();
3841 let resp = app.oneshot(req).await.unwrap();
3842 assert_eq!(resp.status(), StatusCode::OK);
3843 assert_eq!(body_string(resp).await, addr.to_string());
3844 }
3845
3846 fn limited_router(per_minute: u32) -> axum::Router {
3851 limited_router_with_burst(per_minute, None)
3852 }
3853
3854 fn limited_router_with_burst(per_minute: u32, burst: Option<u32>) -> axum::Router {
3856 limited_router_full(per_minute, burst, &[])
3857 }
3858
3859 fn limited_router_full(
3863 per_minute: u32,
3864 burst: Option<u32>,
3865 exempt_paths: &[&str],
3866 ) -> axum::Router {
3867 let limiter = build_extra_route_rate_limiter(per_minute, burst);
3868 let exempt: Arc<std::collections::HashSet<String>> =
3869 Arc::new(exempt_paths.iter().map(|s| (*s).to_owned()).collect());
3870 axum::Router::new()
3871 .route("/limited", axum::routing::get(|| async { "ok" }))
3872 .route("/exempt", axum::routing::get(|| async { "ok" }))
3873 .layer(axum::middleware::from_fn(move |req, next| {
3874 let l = Arc::clone(&limiter);
3875 let e = Arc::clone(&exempt);
3876 extra_route_rate_limit_middleware(l, e, req, next)
3877 }))
3878 }
3879
3880 fn limited_req(ip: &str) -> Request<Body> {
3881 limited_req_to(ip, "/limited")
3882 }
3883
3884 fn limited_req_to(ip: &str, path: &str) -> Request<Body> {
3885 let addr: SocketAddr = format!("{ip}:40000").parse().unwrap();
3886 Request::builder()
3887 .uri(path)
3888 .extension(ConnectInfo(addr))
3889 .body(Body::empty())
3890 .unwrap()
3891 }
3892
3893 #[tokio::test]
3894 async fn extra_route_limiter_denies_over_quota() {
3895 let app = limited_router(2);
3896 for i in 0..2 {
3897 let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3898 assert_eq!(resp.status(), StatusCode::OK, "request {i} should pass");
3899 }
3900 let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3901 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3902 let body = body_string(resp).await;
3903 assert!(
3904 body.contains("too many requests to application routes"),
3905 "deny body should match the limiter message, got: {body}"
3906 );
3907 }
3908
3909 #[tokio::test]
3910 async fn extra_route_limiter_isolates_keys() {
3911 let app = limited_router(2);
3912 for _ in 0..2 {
3913 let resp = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3914 assert_eq!(resp.status(), StatusCode::OK);
3915 }
3916 let exhausted = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3917 assert_eq!(exhausted.status(), StatusCode::TOO_MANY_REQUESTS);
3918 let other = app.clone().oneshot(limited_req("10.3.3.3")).await.unwrap();
3920 assert_eq!(other.status(), StatusCode::OK);
3921 }
3922
3923 #[tokio::test]
3924 async fn extra_route_limiter_fails_open_without_peer() {
3925 let app = limited_router(1);
3926 for i in 0..3 {
3927 let req = Request::builder()
3928 .uri("/limited")
3929 .body(Body::empty())
3930 .unwrap();
3931 let resp = app.clone().oneshot(req).await.unwrap();
3932 assert_eq!(
3933 resp.status(),
3934 StatusCode::OK,
3935 "request {i} should fail open"
3936 );
3937 }
3938 }
3939
3940 #[tokio::test]
3941 async fn extra_route_limiter_extracts_tls_conn_info() {
3942 let app = limited_router(2);
3943 let mk = || {
3944 let addr: SocketAddr = "192.168.9.9:55555".parse().unwrap();
3945 Request::builder()
3946 .uri("/limited")
3947 .extension(ConnectInfo(TlsConnInfo::new(addr, None)))
3948 .body(Body::empty())
3949 .unwrap()
3950 };
3951 for _ in 0..2 {
3952 assert_eq!(
3953 app.clone().oneshot(mk()).await.unwrap().status(),
3954 StatusCode::OK
3955 );
3956 }
3957 let resp = app.clone().oneshot(mk()).await.unwrap();
3958 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3959 }
3960
3961 #[tokio::test]
3962 async fn extra_route_limiter_exempt_path_bypasses_quota() {
3963 let app = limited_router_full(1, None, &["/exempt"]);
3966 for i in 0..5 {
3967 let resp = app
3968 .clone()
3969 .oneshot(limited_req_to("10.6.6.6", "/exempt"))
3970 .await
3971 .unwrap();
3972 assert_eq!(resp.status(), StatusCode::OK, "exempt request {i}");
3973 }
3974 let resp = app.clone().oneshot(limited_req("10.6.6.6")).await.unwrap();
3976 assert_eq!(resp.status(), StatusCode::OK);
3977 let resp = app.clone().oneshot(limited_req("10.6.6.6")).await.unwrap();
3979 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3980 }
3981
3982 #[tokio::test]
3983 async fn extra_route_limiter_exemption_is_raw_exact_match() {
3984 let app = limited_router_full(1, None, &["/exempt"]);
3987 let ok = app
3988 .clone()
3989 .oneshot(limited_req_to("10.7.7.7", "/exempt/"))
3990 .await
3991 .unwrap();
3992 assert_eq!(
3993 ok.status(),
3994 StatusCode::NOT_FOUND,
3995 "variant path routes 404"
3996 );
3997 let denied = app
3999 .clone()
4000 .oneshot(limited_req_to("10.7.7.7", "/limited"))
4001 .await
4002 .unwrap();
4003 assert_eq!(denied.status(), StatusCode::TOO_MANY_REQUESTS);
4004 }
4005
4006 #[cfg(feature = "metrics")]
4007 #[tokio::test]
4008 async fn extra_route_limiter_deny_increments_counter_exempt_does_not() {
4009 let metrics = Arc::new(crate::metrics::McpMetrics::new().unwrap());
4010 let app = limited_router_full(1, None, &["/exempt"]);
4011 let mk = |path: &str| {
4012 let addr: SocketAddr = "10.8.8.8:40000".parse().unwrap();
4013 Request::builder()
4014 .uri(path)
4015 .extension(ConnectInfo(addr))
4016 .extension(Arc::clone(&metrics))
4017 .body(Body::empty())
4018 .unwrap()
4019 };
4020 let counter = || {
4021 metrics
4022 .rate_limited_total
4023 .with_label_values(&["extra_route"])
4024 .get()
4025 };
4026 for _ in 0..3 {
4028 assert_eq!(
4029 app.clone().oneshot(mk("/exempt")).await.unwrap().status(),
4030 StatusCode::OK
4031 );
4032 }
4033 assert_eq!(counter(), 0, "exempt requests must not count as denies");
4034 assert_eq!(
4036 app.clone().oneshot(mk("/limited")).await.unwrap().status(),
4037 StatusCode::OK
4038 );
4039 assert_eq!(counter(), 0);
4040 assert_eq!(
4041 app.clone().oneshot(mk("/limited")).await.unwrap().status(),
4042 StatusCode::TOO_MANY_REQUESTS
4043 );
4044 assert_eq!(counter(), 1, "deny must increment the extra_route label");
4045 }
4046
4047 #[test]
4048 fn validate_rejects_exempt_paths_without_base_knob() {
4049 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
4050 .with_extra_route_rate_limit_exempt_paths(["/ok"]);
4051 let err = cfg.validate().expect_err("exempt paths without rate limit");
4052 assert!(err.to_string().contains("requires extra_route_rate_limit"));
4053 }
4054
4055 #[test]
4056 fn validate_rejects_malformed_exempt_paths() {
4057 for bad in ["", "no-slash"] {
4058 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
4059 .with_extra_route_rate_limit(10)
4060 .with_extra_route_rate_limit_exempt_paths([bad]);
4061 let err = cfg.validate().expect_err("malformed exempt path");
4062 assert!(
4063 err.to_string()
4064 .contains("must be non-empty and start with '/'"),
4065 "entry {bad:?}: {err}"
4066 );
4067 }
4068 }
4069
4070 #[test]
4071 fn validate_accepts_wellformed_exempt_paths() {
4072 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
4073 .with_extra_route_rate_limit(10)
4074 .with_extra_route_rate_limit_exempt_paths(["/.well-known/oauth-authorization-server"]);
4075 assert!(cfg.validate().is_ok());
4076 }
4077
4078 #[test]
4079 fn validate_rejects_zero_extra_route_rate_limit() {
4080 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
4081 .with_extra_route_rate_limit(0);
4082 let err = cfg.validate().expect_err("zero extra route rate limit");
4083 assert!(err.to_string().contains("extra_route_rate_limit"));
4084 }
4085
4086 #[tokio::test]
4087 async fn extra_route_limiter_burst_allows_initial_spike() {
4088 let app = limited_router_with_burst(1, Some(3));
4089 for i in 0..3 {
4090 let resp = app.clone().oneshot(limited_req("10.4.4.4")).await.unwrap();
4091 assert_eq!(resp.status(), StatusCode::OK, "burst request {i}");
4092 }
4093 let resp = app.clone().oneshot(limited_req("10.4.4.4")).await.unwrap();
4094 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
4095 }
4096
4097 #[tokio::test]
4098 async fn extra_route_limiter_deny_sets_retry_after() {
4099 let app = limited_router(1);
4100 let ok = app.clone().oneshot(limited_req("10.5.5.5")).await.unwrap();
4101 assert_eq!(ok.status(), StatusCode::OK);
4102 let denied = app.clone().oneshot(limited_req("10.5.5.5")).await.unwrap();
4103 assert_eq!(denied.status(), StatusCode::TOO_MANY_REQUESTS);
4104 let retry_after = denied
4105 .headers()
4106 .get(header::RETRY_AFTER)
4107 .expect("Retry-After present")
4108 .to_str()
4109 .unwrap()
4110 .parse::<u64>()
4111 .unwrap();
4112 assert!(retry_after >= 1, "delta-seconds must be >= 1");
4113 }
4114
4115 #[test]
4116 fn validate_rejects_zero_burst_knobs() {
4117 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4118 .with_tool_rate_limit(10)
4119 .with_tool_rate_limit_burst(0)
4120 .validate()
4121 .expect_err("zero tool burst");
4122 assert!(err.to_string().contains("tool_rate_limit_burst"));
4123
4124 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4125 .with_extra_route_rate_limit(10)
4126 .with_extra_route_rate_limit_burst(0)
4127 .validate()
4128 .expect_err("zero extra route burst");
4129 assert!(err.to_string().contains("extra_route_rate_limit_burst"));
4130 }
4131
4132 #[test]
4133 fn validate_rejects_orphan_burst_knobs() {
4134 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4135 .with_tool_rate_limit_burst(5)
4136 .validate()
4137 .expect_err("orphan tool burst");
4138 assert!(err.to_string().contains("requires tool_rate_limit"));
4139
4140 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4141 .with_extra_route_rate_limit_burst(5)
4142 .validate()
4143 .expect_err("orphan extra route burst");
4144 assert!(err.to_string().contains("requires extra_route_rate_limit"));
4145 }
4146
4147 #[test]
4148 fn validate_rejects_zero_auth_bursts() {
4149 let auth = AuthConfig::with_keys(vec![])
4150 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_burst(0));
4151 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4152 .with_auth(auth)
4153 .validate()
4154 .expect_err("zero auth burst");
4155 assert!(err.to_string().contains("rate_limit.burst"));
4156
4157 let auth = AuthConfig::with_keys(vec![])
4158 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(0));
4159 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4160 .with_auth(auth)
4161 .validate()
4162 .expect_err("zero pre-auth burst");
4163 assert!(err.to_string().contains("pre_auth_burst"));
4164 }
4165
4166 #[test]
4169 fn validate_accepts_pre_auth_burst_without_explicit_pre_auth_rate() {
4170 let auth = AuthConfig::with_keys(vec![])
4171 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(50));
4172 let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0").with_auth(auth);
4173 assert!(cfg.validate().is_ok(), "pre_auth_burst has no orphan rule");
4174 }
4175
4176 fn forward_resolver(trusted: &[&str], mode: ForwardedHeaderMode) -> Arc<ForwardResolver> {
4179 Arc::new(ForwardResolver {
4180 trusted: trusted.iter().map(|s| s.parse().unwrap()).collect(),
4181 mode,
4182 })
4183 }
4184
4185 fn forwarded_probe_router(resolver: Option<Arc<ForwardResolver>>) -> axum::Router {
4187 async fn probe(req: Request<Body>) -> String {
4188 let pa = req
4189 .extensions()
4190 .get::<PeerAddr>()
4191 .map(|p| p.addr.ip().to_string())
4192 .unwrap_or_default();
4193 let ci = req
4194 .extensions()
4195 .get::<ClientIp>()
4196 .map(|c| c.ip.to_string())
4197 .unwrap_or_default();
4198 format!("{pa}|{ci}")
4199 }
4200 axum::Router::new()
4201 .route("/probe", axum::routing::get(probe))
4202 .layer(axum::middleware::from_fn(move |req, next| {
4203 let r = resolver.clone();
4204 normalize_peer_addr_middleware(r, req, next)
4205 }))
4206 }
4207
4208 fn probe_req(peer: &str, header: Option<(&str, &str)>) -> Request<Body> {
4209 let addr: SocketAddr = peer.parse().unwrap();
4210 let mut builder = Request::builder()
4211 .uri("/probe")
4212 .extension(ConnectInfo(addr));
4213 if let Some((name, value)) = header {
4214 builder = builder.header(name, value);
4215 }
4216 builder.body(Body::empty()).unwrap()
4217 }
4218
4219 #[tokio::test]
4220 async fn client_ip_equals_direct_without_resolver() {
4221 let app = forwarded_probe_router(None);
4222 let resp = app
4223 .oneshot(probe_req(
4224 "10.1.2.3:4444",
4225 Some(("x-forwarded-for", "203.0.113.7")),
4226 ))
4227 .await
4228 .unwrap();
4229 assert_eq!(
4230 body_string(resp).await,
4231 "10.1.2.3|10.1.2.3",
4232 "feature off: header ignored, ClientIp == direct"
4233 );
4234 }
4235
4236 #[tokio::test]
4237 async fn client_ip_resolved_for_trusted_peer() {
4238 let app = forwarded_probe_router(Some(forward_resolver(
4239 &["10.0.0.0/8"],
4240 ForwardedHeaderMode::XForwardedFor,
4241 )));
4242 let resp = app
4243 .oneshot(probe_req(
4244 "10.0.0.1:9999",
4245 Some(("x-forwarded-for", "203.0.113.7")),
4246 ))
4247 .await
4248 .unwrap();
4249 assert_eq!(
4250 body_string(resp).await,
4251 "10.0.0.1|203.0.113.7",
4252 "PeerAddr stays direct while ClientIp resolves"
4253 );
4254 }
4255
4256 #[tokio::test]
4257 async fn client_ip_falls_back_to_direct_on_malformed_header() {
4258 let app = forwarded_probe_router(Some(forward_resolver(
4259 &["10.0.0.0/8"],
4260 ForwardedHeaderMode::XForwardedFor,
4261 )));
4262 let resp = app
4263 .oneshot(probe_req(
4264 "10.0.0.1:9999",
4265 Some(("x-forwarded-for", "not-an-ip")),
4266 ))
4267 .await
4268 .unwrap();
4269 assert_eq!(
4270 body_string(resp).await,
4271 "10.0.0.1|10.0.0.1",
4272 "malformed chain falls back to the direct peer"
4273 );
4274 }
4275
4276 #[test]
4277 fn forwarded_header_mode_deserializes_kebab_case() {
4278 #[derive(serde::Deserialize)]
4279 struct Wrapper {
4280 mode: ForwardedHeaderMode,
4281 }
4282 let w: Wrapper = toml::from_str(r#"mode = "x-forwarded-for""#).unwrap();
4283 assert_eq!(w.mode, ForwardedHeaderMode::XForwardedFor);
4284 let w: Wrapper = toml::from_str(r#"mode = "forwarded""#).unwrap();
4285 assert_eq!(w.mode, ForwardedHeaderMode::Forwarded);
4286 assert!(
4287 toml::from_str::<Wrapper>(r#"mode = "XForwardedFor""#).is_err(),
4288 "PascalCase wire value must be rejected"
4289 );
4290 }
4291
4292 #[test]
4293 fn validate_rejects_bad_trusted_proxy_entry() {
4294 let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4295 .with_trusted_proxies(["not-a-cidr"]);
4296 let err = cfg.validate().expect_err("bad CIDR");
4297 assert!(err.to_string().contains("trusted_proxies"));
4298 }
4299
4300 #[test]
4301 fn validate_accepts_cidr_and_bare_ip_proxy_entries() {
4302 let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0").with_trusted_proxies([
4303 "10.0.0.0/8",
4304 "192.0.2.1",
4305 "2001:db8::1",
4306 ]);
4307 assert!(cfg.validate().is_ok(), "CIDRs and bare IPs are accepted");
4308 }
4309
4310 #[test]
4311 fn validate_rejects_forwarded_header_without_proxies() {
4312 let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4313 .with_forwarded_header(ForwardedHeaderMode::Forwarded);
4314 let err = cfg.validate().expect_err("mode without proxies");
4315 assert!(err.to_string().contains("requires trusted_proxies"));
4316 }
4317
4318 fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
4322 let allowed: Arc<[String]> = Arc::from(origins);
4323 axum::Router::new()
4324 .route("/test", axum::routing::get(|| async { "ok" }))
4325 .layer(axum::middleware::from_fn(move |req, next| {
4326 let a = Arc::clone(&allowed);
4327 origin_check_middleware(a, log_request_headers, req, next)
4328 }))
4329 }
4330
4331 #[tokio::test]
4332 async fn origin_allowed_passes() {
4333 let app = origin_router(vec!["http://localhost:3000".into()], false);
4334 let req = Request::builder()
4335 .uri("/test")
4336 .header(header::ORIGIN, "http://localhost:3000")
4337 .body(Body::empty())
4338 .unwrap();
4339 let resp = app.oneshot(req).await.unwrap();
4340 assert_eq!(resp.status(), StatusCode::OK);
4341 }
4342
4343 #[tokio::test]
4344 async fn origin_rejected_returns_403() {
4345 let app = origin_router(vec!["http://localhost:3000".into()], false);
4346 let req = Request::builder()
4347 .uri("/test")
4348 .header(header::ORIGIN, "http://evil.com")
4349 .body(Body::empty())
4350 .unwrap();
4351 let resp = app.oneshot(req).await.unwrap();
4352 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
4353 }
4354
4355 #[tokio::test]
4356 async fn no_origin_header_passes() {
4357 let app = origin_router(vec!["http://localhost:3000".into()], false);
4358 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4359 let resp = app.oneshot(req).await.unwrap();
4360 assert_eq!(resp.status(), StatusCode::OK);
4361 }
4362
4363 #[tokio::test]
4364 async fn empty_allowlist_rejects_any_origin() {
4365 let app = origin_router(vec![], false);
4366 let req = Request::builder()
4367 .uri("/test")
4368 .header(header::ORIGIN, "http://anything.com")
4369 .body(Body::empty())
4370 .unwrap();
4371 let resp = app.oneshot(req).await.unwrap();
4372 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
4373 }
4374
4375 #[tokio::test]
4376 async fn empty_allowlist_passes_without_origin() {
4377 let app = origin_router(vec![], false);
4378 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4379 let resp = app.oneshot(req).await.unwrap();
4380 assert_eq!(resp.status(), StatusCode::OK);
4381 }
4382
4383 #[test]
4384 fn format_request_headers_redacts_sensitive_values() {
4385 let mut headers = axum::http::HeaderMap::new();
4386 headers.insert("authorization", "Bearer secret-token".parse().unwrap());
4387 headers.insert("cookie", "sid=abc".parse().unwrap());
4388 headers.insert("x-request-id", "req-123".parse().unwrap());
4389
4390 let out = format_request_headers_for_log(&headers);
4391 assert!(out.contains("authorization: [REDACTED]"));
4392 assert!(out.contains("cookie: [REDACTED]"));
4393 assert!(out.contains("x-request-id: req-123"));
4394 assert!(!out.contains("secret-token"));
4395 }
4396
4397 fn security_router(is_tls: bool) -> axum::Router {
4400 security_router_with(is_tls, SecurityHeadersConfig::default())
4401 }
4402
4403 fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
4404 let cfg = Arc::new(cfg);
4405 axum::Router::new()
4406 .route("/test", axum::routing::get(|| async { "ok" }))
4407 .layer(axum::middleware::from_fn(move |req, next| {
4408 let c = Arc::clone(&cfg);
4409 security_headers_middleware(is_tls, c, req, next)
4410 }))
4411 }
4412
4413 #[tokio::test]
4414 async fn security_headers_set_on_response() {
4415 let app = security_router(false);
4416 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4417 let resp = app.oneshot(req).await.unwrap();
4418 assert_eq!(resp.status(), StatusCode::OK);
4419
4420 let h = resp.headers();
4421 assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
4422 assert_eq!(h.get("x-frame-options").unwrap(), "deny");
4423 assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
4424 assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
4425 assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
4426 assert_eq!(
4427 h.get("cross-origin-resource-policy").unwrap(),
4428 "same-origin"
4429 );
4430 assert_eq!(
4431 h.get("cross-origin-embedder-policy").unwrap(),
4432 "require-corp"
4433 );
4434 assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
4435 assert!(
4436 h.get("permissions-policy")
4437 .unwrap()
4438 .to_str()
4439 .unwrap()
4440 .contains("camera=()"),
4441 "permissions-policy must restrict browser features"
4442 );
4443 assert_eq!(
4444 h.get("content-security-policy").unwrap(),
4445 "default-src 'none'; frame-ancestors 'none'"
4446 );
4447 assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
4448 assert!(h.get("strict-transport-security").is_none());
4450 }
4451
4452 #[tokio::test]
4453 async fn hsts_set_when_tls_enabled() {
4454 let app = security_router(true);
4455 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4456 let resp = app.oneshot(req).await.unwrap();
4457
4458 let hsts = resp.headers().get("strict-transport-security").unwrap();
4459 assert!(
4460 hsts.to_str().unwrap().contains("max-age=63072000"),
4461 "HSTS must set 2-year max-age"
4462 );
4463 }
4464
4465 fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
4471 let cfg =
4472 McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
4473 cfg.check()
4474 }
4475
4476 #[test]
4477 fn security_headers_config_default_validates() {
4478 check_with_security_headers(SecurityHeadersConfig::default())
4479 .expect("default SecurityHeadersConfig must validate");
4480 }
4481
4482 #[test]
4483 fn security_headers_config_validate_accepts_empty_string() {
4484 let h = SecurityHeadersConfig {
4486 x_content_type_options: Some(String::new()),
4487 x_frame_options: Some(String::new()),
4488 cache_control: Some(String::new()),
4489 referrer_policy: Some(String::new()),
4490 cross_origin_opener_policy: Some(String::new()),
4491 cross_origin_resource_policy: Some(String::new()),
4492 cross_origin_embedder_policy: Some(String::new()),
4493 permissions_policy: Some(String::new()),
4494 x_permitted_cross_domain_policies: Some(String::new()),
4495 content_security_policy: Some(String::new()),
4496 x_dns_prefetch_control: Some(String::new()),
4497 strict_transport_security: Some(String::new()),
4498 };
4499 check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
4500 }
4501
4502 #[test]
4503 fn security_headers_config_validate_rejects_bad_value() {
4504 let h = SecurityHeadersConfig {
4506 referrer_policy: Some("\u{0007}".into()),
4507 ..SecurityHeadersConfig::default()
4508 };
4509 let err = check_with_security_headers(h)
4510 .expect_err("control char in referrer_policy must reject");
4511 let msg = err.to_string();
4512 assert!(
4513 msg.contains("referrer_policy"),
4514 "error must name the offending field, got: {msg}"
4515 );
4516 }
4517
4518 #[test]
4519 fn security_headers_config_validate_rejects_hsts_preload() {
4520 let h = SecurityHeadersConfig {
4521 strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
4522 ..SecurityHeadersConfig::default()
4523 };
4524 let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
4525 let msg = err.to_string();
4526 assert!(
4527 msg.contains("strict_transport_security"),
4528 "error must name the field, got: {msg}"
4529 );
4530 assert!(
4531 msg.to_lowercase().contains("preload"),
4532 "error must mention `preload`, got: {msg}"
4533 );
4534 }
4535
4536 #[test]
4537 fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
4538 let h = SecurityHeadersConfig {
4540 strict_transport_security: Some("max-age=600; PRELOAD".into()),
4541 ..SecurityHeadersConfig::default()
4542 };
4543 check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
4544 }
4545
4546 #[tokio::test]
4547 async fn security_headers_override_honored() {
4548 let h = SecurityHeadersConfig {
4550 x_frame_options: Some("SAMEORIGIN".into()),
4551 ..SecurityHeadersConfig::default()
4552 };
4553 let app = security_router_with(false, h);
4554 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4555 let resp = app.oneshot(req).await.unwrap();
4556 assert_eq!(resp.status(), StatusCode::OK);
4557
4558 let xfo = resp.headers().get("x-frame-options").unwrap();
4559 assert_eq!(xfo, "SAMEORIGIN");
4560 }
4561
4562 #[tokio::test]
4563 async fn security_headers_empty_string_omits() {
4564 let h = SecurityHeadersConfig {
4566 referrer_policy: Some(String::new()),
4567 ..SecurityHeadersConfig::default()
4568 };
4569 let app = security_router_with(false, h);
4570 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4571 let resp = app.oneshot(req).await.unwrap();
4572 assert_eq!(resp.status(), StatusCode::OK);
4573
4574 assert!(
4575 resp.headers().get("referrer-policy").is_none(),
4576 "Some(\"\") must omit the header"
4577 );
4578 assert_eq!(
4580 resp.headers().get("x-content-type-options").unwrap(),
4581 "nosniff"
4582 );
4583 }
4584
4585 #[tokio::test]
4586 async fn security_headers_hsts_only_when_tls() {
4587 let h = SecurityHeadersConfig {
4589 strict_transport_security: Some("max-age=600".into()),
4590 ..SecurityHeadersConfig::default()
4591 };
4592 let app = security_router_with(false, h);
4593 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4594 let resp = app.oneshot(req).await.unwrap();
4595 assert!(
4596 resp.headers().get("strict-transport-security").is_none(),
4597 "HSTS must remain absent on plaintext deployments even with override"
4598 );
4599 }
4600
4601 #[cfg(feature = "oauth")]
4604 #[tokio::test]
4605 async fn oauth_token_cache_headers_set_pragma_and_vary() {
4606 let app = axum::Router::new()
4607 .route("/token", axum::routing::post(|| async { "{}" }))
4608 .layer(axum::middleware::from_fn(
4609 oauth_token_cache_headers_middleware,
4610 ));
4611 let req = Request::builder()
4612 .method("POST")
4613 .uri("/token")
4614 .body(Body::from("{}"))
4615 .unwrap();
4616 let resp = app.oneshot(req).await.unwrap();
4617 assert_eq!(resp.status(), StatusCode::OK);
4618
4619 let h = resp.headers();
4620 assert_eq!(
4621 h.get("pragma").unwrap(),
4622 "no-cache",
4623 "RFC 6749 §5.1: token responses must set Pragma: no-cache"
4624 );
4625 let vary_values: Vec<String> = h
4626 .get_all("vary")
4627 .iter()
4628 .filter_map(|v| v.to_str().ok().map(str::to_owned))
4629 .collect();
4630 assert!(
4631 vary_values
4632 .iter()
4633 .any(|v| v.eq_ignore_ascii_case("Authorization")),
4634 "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
4635 );
4636 }
4637
4638 #[cfg(feature = "oauth")]
4639 #[tokio::test]
4640 async fn oauth_token_cache_headers_preserve_existing_vary() {
4641 let app = axum::Router::new()
4644 .route(
4645 "/token",
4646 axum::routing::post(|| async {
4647 axum::response::Response::builder()
4648 .header("vary", "Accept-Encoding")
4649 .body(axum::body::Body::from("{}"))
4650 .unwrap()
4651 }),
4652 )
4653 .layer(axum::middleware::from_fn(
4654 oauth_token_cache_headers_middleware,
4655 ));
4656 let req = Request::builder()
4657 .method("POST")
4658 .uri("/token")
4659 .body(Body::empty())
4660 .unwrap();
4661 let resp = app.oneshot(req).await.unwrap();
4662
4663 let vary: Vec<String> = resp
4664 .headers()
4665 .get_all("vary")
4666 .iter()
4667 .filter_map(|v| v.to_str().ok().map(str::to_owned))
4668 .collect();
4669 assert!(
4670 vary.iter().any(|v| v.contains("Accept-Encoding")),
4671 "must preserve pre-existing Vary value, got {vary:?}"
4672 );
4673 assert!(
4674 vary.iter().any(|v| v.contains("Authorization")),
4675 "must append Authorization to Vary, got {vary:?}"
4676 );
4677 }
4678
4679 #[test]
4682 fn version_payload_contains_expected_fields() {
4683 let v = version_payload("my-server", "1.2.3");
4684 assert_eq!(v["name"], "my-server");
4685 assert_eq!(v["version"], "1.2.3");
4686 assert!(v["build_git_sha"].is_string());
4687 assert!(v["build_timestamp"].is_string());
4688 assert!(v["rust_version"].is_string());
4689 assert!(v["mcpx_version"].is_string());
4690 }
4691
4692 #[tokio::test]
4695 async fn concurrency_limit_layer_composes_and_serves() {
4696 let app = axum::Router::new()
4700 .route("/ok", axum::routing::get(|| async { "ok" }))
4701 .layer(
4702 tower::ServiceBuilder::new()
4703 .layer(axum::error_handling::HandleErrorLayer::new(
4704 |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
4705 ))
4706 .layer(tower::load_shed::LoadShedLayer::new())
4707 .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
4708 );
4709 let resp = app
4710 .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
4711 .await
4712 .unwrap();
4713 assert_eq!(resp.status(), StatusCode::OK);
4714 }
4715
4716 #[tokio::test]
4719 async fn compression_layer_gzip_encodes_response() {
4720 use tower_http::compression::Predicate as _;
4721
4722 let big_body = "a".repeat(4096);
4723 let app = axum::Router::new()
4724 .route(
4725 "/big",
4726 axum::routing::get(move || {
4727 let body = big_body.clone();
4728 async move { body }
4729 }),
4730 )
4731 .layer(
4732 tower_http::compression::CompressionLayer::new()
4733 .gzip(true)
4734 .br(true)
4735 .compress_when(
4736 tower_http::compression::DefaultPredicate::new()
4737 .and(tower_http::compression::predicate::SizeAbove::new(1024)),
4738 ),
4739 );
4740
4741 let req = Request::builder()
4742 .uri("/big")
4743 .header(header::ACCEPT_ENCODING, "gzip")
4744 .body(Body::empty())
4745 .unwrap();
4746 let resp = app.oneshot(req).await.unwrap();
4747 assert_eq!(resp.status(), StatusCode::OK);
4748 assert_eq!(
4749 resp.headers().get(header::CONTENT_ENCODING).unwrap(),
4750 "gzip"
4751 );
4752 }
4753
4754 #[tokio::test]
4757 async fn tls_handshake_timeout_reaps_idle_connections() {
4758 use tokio::io::AsyncReadExt as _;
4759
4760 let _ = rustls::crypto::ring::default_provider().install_default();
4761
4762 let key = rcgen::KeyPair::generate().expect("generate key");
4764 let cert = rcgen::CertificateParams::new(vec!["localhost".to_owned()])
4765 .expect("cert params")
4766 .self_signed(&key)
4767 .expect("self-signed cert");
4768 let dir = std::env::temp_dir().join(format!(
4769 "rmcp-server-kit-hs-timeout-{}",
4770 std::time::SystemTime::now()
4771 .duration_since(std::time::UNIX_EPOCH)
4772 .expect("clock after epoch")
4773 .as_nanos()
4774 ));
4775 tokio::fs::create_dir_all(&dir).await.expect("temp dir");
4776 let cert_path = dir.join("server.crt");
4777 let key_path = dir.join("server.key");
4778 tokio::fs::write(&cert_path, cert.pem())
4779 .await
4780 .expect("write cert");
4781 tokio::fs::write(&key_path, key.serialize_pem())
4782 .await
4783 .expect("write key");
4784
4785 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
4786 let tls = TlsListener::new(
4787 listener,
4788 &cert_path,
4789 &key_path,
4790 None,
4791 None,
4792 Duration::from_millis(200),
4793 8, )
4795 .expect("tls listener");
4796 let addr = axum::serve::Listener::local_addr(&tls).expect("local addr");
4797
4798 let mut idle = tokio::net::TcpStream::connect(addr).await.expect("connect");
4802 let mut buf = [0_u8; 16];
4803 let read = tokio::time::timeout(Duration::from_secs(2), idle.read(&mut buf))
4804 .await
4805 .expect("server must reap the idle handshake within its timeout");
4806 match read {
4807 Ok(0) | Err(_) => {} Ok(n) => panic!("unexpected {n} bytes from server during reaped handshake"),
4809 }
4810
4811 drop(tls);
4812 }
4813}