1use std::{
2 future::Future,
3 net::{IpAddr, SocketAddr},
4 path::{Path, PathBuf},
5 pin::Pin,
6 sync::Arc,
7 time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{
12 body::Body,
13 extract::{ConnectInfo, Request},
14 middleware::Next,
15 response::IntoResponse,
16};
17use rmcp::{
18 ServerHandler,
19 transport::streamable_http_server::{
20 StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
21 },
22};
23use rustls::RootCertStore;
24use tokio::{
25 net::TcpListener,
26 sync::{Semaphore, mpsc},
27};
28use tokio_util::sync::CancellationToken;
29
30use crate::{
31 auth::{
32 AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
33 build_rate_limiter, extract_mtls_identity,
34 },
35 bounded_limiter::BoundedKeyedLimiter,
36 error::McpxError,
37 mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
38 rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
39};
40
41#[allow(
45 clippy::needless_pass_by_value,
46 reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
47)]
48fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
49 McpxError::Startup(format!("{e:#}"))
50}
51
52#[allow(
58 clippy::needless_pass_by_value,
59 reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
60)]
61fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
62 McpxError::Startup(format!("{op}: {e}"))
63}
64
65pub type ReadinessCheck =
70 Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
71
72#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
117#[non_exhaustive]
118pub struct PeerAddr {
119 pub addr: SocketAddr,
121}
122
123impl PeerAddr {
124 #[must_use]
127 pub(crate) const fn new(addr: SocketAddr) -> Self {
128 Self { addr }
129 }
130}
131
132impl<S: Send + Sync> axum::extract::FromRequestParts<S> for PeerAddr {
141 type Rejection = (axum::http::StatusCode, &'static str);
142
143 async fn from_request_parts(
144 parts: &mut axum::http::request::Parts,
145 _state: &S,
146 ) -> Result<Self, Self::Rejection> {
147 parts.extensions.get::<Self>().copied().ok_or((
148 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
149 "peer address unavailable: not running under rmcp-server-kit serve()",
150 ))
151 }
152}
153
154#[derive(Debug, Clone, Default)]
175#[non_exhaustive]
176pub struct SecurityHeadersConfig {
177 pub x_content_type_options: Option<String>,
179 pub x_frame_options: Option<String>,
181 pub cache_control: Option<String>,
183 pub referrer_policy: Option<String>,
185 pub cross_origin_opener_policy: Option<String>,
187 pub cross_origin_resource_policy: Option<String>,
189 pub cross_origin_embedder_policy: Option<String>,
191 pub permissions_policy: Option<String>,
194 pub x_permitted_cross_domain_policies: Option<String>,
196 pub content_security_policy: Option<String>,
199 pub x_dns_prefetch_control: Option<String>,
201 pub strict_transport_security: Option<String>,
206}
207
208#[allow(
210 missing_debug_implementations,
211 reason = "contains callback/trait objects that don't impl Debug"
212)]
213#[allow(
214 clippy::struct_excessive_bools,
215 reason = "server configuration naturally has many boolean feature flags"
216)]
217#[non_exhaustive]
218pub struct McpServerConfig {
219 #[deprecated(
221 since = "0.13.0",
222 note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
223 )]
224 pub bind_addr: String,
225 #[deprecated(
227 since = "0.13.0",
228 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
229 )]
230 pub name: String,
231 #[deprecated(
233 since = "0.13.0",
234 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
235 )]
236 pub version: String,
237 #[deprecated(
239 since = "0.13.0",
240 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
241 )]
242 pub tls_cert_path: Option<PathBuf>,
243 #[deprecated(
245 since = "0.13.0",
246 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
247 )]
248 pub tls_key_path: Option<PathBuf>,
249 #[deprecated(
252 since = "0.13.0",
253 note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
254 )]
255 pub auth: Option<AuthConfig>,
256 #[deprecated(
259 since = "0.13.0",
260 note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
261 )]
262 pub rbac: Option<Arc<RbacPolicy>>,
263 #[deprecated(
269 since = "0.13.0",
270 note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
271 )]
272 pub allowed_origins: Vec<String>,
273 #[deprecated(
276 since = "0.13.0",
277 note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
278 )]
279 pub tool_rate_limit: Option<u32>,
280 #[deprecated(
286 since = "1.12.0",
287 note = "use McpServerConfig::with_tool_rate_limit_burst(); direct field access will become pub(crate) in a future major release"
288 )]
289 pub tool_rate_limit_burst: Option<u32>,
290 #[deprecated(
303 since = "1.11.0",
304 note = "use McpServerConfig::with_extra_route_rate_limit(); direct field access will become pub(crate) in a future major release"
305 )]
306 pub extra_route_rate_limit: Option<u32>,
307 #[deprecated(
314 since = "1.12.0",
315 note = "use McpServerConfig::with_extra_route_rate_limit_burst(); direct field access will become pub(crate) in a future major release"
316 )]
317 pub extra_route_rate_limit_burst: Option<u32>,
318 #[deprecated(
321 since = "0.13.0",
322 note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
323 )]
324 pub readiness_check: Option<ReadinessCheck>,
325 #[deprecated(
328 since = "0.13.0",
329 note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
330 )]
331 pub max_request_body: usize,
332 #[deprecated(
335 since = "0.13.0",
336 note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
337 )]
338 pub request_timeout: Duration,
339 #[deprecated(
342 since = "0.13.0",
343 note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
344 )]
345 pub shutdown_timeout: Duration,
346 #[deprecated(
349 since = "0.13.0",
350 note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
351 )]
352 pub session_idle_timeout: Duration,
353 #[deprecated(
356 since = "0.13.0",
357 note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
358 )]
359 pub sse_keep_alive: Duration,
360 #[deprecated(
364 since = "0.13.0",
365 note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
366 )]
367 pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
368 #[deprecated(
375 since = "0.13.0",
376 note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
377 )]
378 pub extra_router: Option<axum::Router>,
379 #[deprecated(
384 since = "0.13.0",
385 note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
386 )]
387 pub public_url: Option<String>,
388 #[deprecated(
391 since = "0.13.0",
392 note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
393 )]
394 pub log_request_headers: bool,
395 #[deprecated(
398 since = "0.13.0",
399 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
400 )]
401 pub compression_enabled: bool,
402 #[deprecated(
405 since = "0.13.0",
406 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
407 )]
408 pub compression_min_size: u16,
409 #[deprecated(
413 since = "0.13.0",
414 note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
415 )]
416 pub max_concurrent_requests: Option<usize>,
417 #[deprecated(
420 since = "0.13.0",
421 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
422 )]
423 pub admin_enabled: bool,
424 #[deprecated(
426 since = "0.13.0",
427 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
428 )]
429 pub admin_role: String,
430 #[cfg(feature = "metrics")]
433 #[deprecated(
434 since = "0.13.0",
435 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
436 )]
437 pub metrics_enabled: bool,
438 #[cfg(feature = "metrics")]
440 #[deprecated(
441 since = "0.13.0",
442 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
443 )]
444 pub metrics_bind: String,
445 #[deprecated(
449 since = "1.5.0",
450 note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
451 )]
452 pub security_headers: SecurityHeadersConfig,
453 #[deprecated(
459 since = "1.9.0",
460 note = "use McpServerConfig::with_tls_handshake_timeout(); direct field access will become pub(crate) in a future major release"
461 )]
462 pub tls_handshake_timeout: Duration,
463 #[deprecated(
470 since = "1.9.0",
471 note = "use McpServerConfig::with_max_concurrent_tls_handshakes(); direct field access will become pub(crate) in a future major release"
472 )]
473 pub max_concurrent_tls_handshakes: usize,
474}
475
476#[allow(
534 missing_debug_implementations,
535 reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
536)]
537pub struct Validated<T>(T);
538
539impl<T> std::fmt::Debug for Validated<T> {
540 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
541 f.debug_struct("Validated").finish_non_exhaustive()
542 }
543}
544
545impl<T> Validated<T> {
546 #[must_use]
548 pub fn as_inner(&self) -> &T {
549 &self.0
550 }
551
552 #[must_use]
557 pub fn into_inner(self) -> T {
558 self.0
559 }
560}
561
562#[allow(
563 deprecated,
564 reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
565)]
566impl McpServerConfig {
567 #[must_use]
575 pub fn new(
576 bind_addr: impl Into<String>,
577 name: impl Into<String>,
578 version: impl Into<String>,
579 ) -> Self {
580 Self {
581 bind_addr: bind_addr.into(),
582 name: name.into(),
583 version: version.into(),
584 tls_cert_path: None,
585 tls_key_path: None,
586 auth: None,
587 rbac: None,
588 allowed_origins: Vec::new(),
589 tool_rate_limit: None,
590 readiness_check: None,
591 max_request_body: 1024 * 1024,
592 request_timeout: Duration::from_mins(2),
593 shutdown_timeout: Duration::from_secs(30),
594 session_idle_timeout: Duration::from_mins(20),
595 sse_keep_alive: Duration::from_secs(15),
596 on_reload_ready: None,
597 extra_router: None,
598 public_url: None,
599 log_request_headers: false,
600 compression_enabled: false,
601 compression_min_size: 1024,
602 max_concurrent_requests: None,
603 admin_enabled: false,
604 admin_role: "admin".to_owned(),
605 #[cfg(feature = "metrics")]
606 metrics_enabled: false,
607 #[cfg(feature = "metrics")]
608 metrics_bind: "127.0.0.1:9090".into(),
609 security_headers: SecurityHeadersConfig::default(),
610 tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
611 max_concurrent_tls_handshakes: DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES,
612 extra_route_rate_limit: None,
613 tool_rate_limit_burst: None,
614 extra_route_rate_limit_burst: None,
615 }
616 }
617
618 #[must_use]
628 pub fn with_auth(mut self, auth: AuthConfig) -> Self {
629 self.auth = Some(auth);
630 self
631 }
632
633 #[must_use]
638 pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
639 self.security_headers = headers;
640 self
641 }
642
643 #[must_use]
647 pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
648 self.bind_addr = addr.into();
649 self
650 }
651
652 #[must_use]
655 pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
656 self.rbac = Some(rbac);
657 self
658 }
659
660 #[must_use]
664 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
665 self.tls_cert_path = Some(cert_path.into());
666 self.tls_key_path = Some(key_path.into());
667 self
668 }
669
670 #[must_use]
674 pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
675 self.public_url = Some(url.into());
676 self
677 }
678
679 #[must_use]
683 pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
684 where
685 I: IntoIterator<Item = S>,
686 S: Into<String>,
687 {
688 self.allowed_origins = origins.into_iter().map(Into::into).collect();
689 self
690 }
691
692 #[must_use]
705 pub fn with_extra_router(mut self, router: axum::Router) -> Self {
706 self.extra_router = Some(router);
707 self
708 }
709
710 #[must_use]
713 pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
714 self.readiness_check = Some(check);
715 self
716 }
717
718 #[must_use]
721 pub fn with_max_request_body(mut self, bytes: usize) -> Self {
722 self.max_request_body = bytes;
723 self
724 }
725
726 #[must_use]
728 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
729 self.request_timeout = timeout;
730 self
731 }
732
733 #[must_use]
735 pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
736 self.shutdown_timeout = timeout;
737 self
738 }
739
740 #[must_use]
742 pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
743 self.session_idle_timeout = timeout;
744 self
745 }
746
747 #[must_use]
749 pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
750 self.sse_keep_alive = interval;
751 self
752 }
753
754 #[must_use]
758 pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
759 self.max_concurrent_requests = Some(limit);
760 self
761 }
762
763 #[must_use]
771 pub fn with_tls_handshake_timeout(mut self, timeout: Duration) -> Self {
772 self.tls_handshake_timeout = timeout;
773 self
774 }
775
776 #[must_use]
785 pub fn with_max_concurrent_tls_handshakes(mut self, limit: usize) -> Self {
786 self.max_concurrent_tls_handshakes = limit;
787 self
788 }
789
790 #[must_use]
793 pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
794 self.tool_rate_limit = Some(per_minute);
795 self
796 }
797
798 #[must_use]
809 pub fn with_extra_route_rate_limit(mut self, per_minute: u32) -> Self {
810 self.extra_route_rate_limit = Some(per_minute);
811 self
812 }
813
814 #[must_use]
819 pub fn with_tool_rate_limit_burst(mut self, burst: u32) -> Self {
820 self.tool_rate_limit_burst = Some(burst);
821 self
822 }
823
824 #[must_use]
830 pub fn with_extra_route_rate_limit_burst(mut self, burst: u32) -> Self {
831 self.extra_route_rate_limit_burst = Some(burst);
832 self
833 }
834
835 #[must_use]
839 pub fn with_reload_callback<F>(mut self, callback: F) -> Self
840 where
841 F: FnOnce(ReloadHandle) + Send + 'static,
842 {
843 self.on_reload_ready = Some(Box::new(callback));
844 self
845 }
846
847 #[must_use]
851 pub fn enable_compression(mut self, min_size: u16) -> Self {
852 self.compression_enabled = true;
853 self.compression_min_size = min_size;
854 self
855 }
856
857 #[must_use]
862 pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
863 self.admin_enabled = true;
864 self.admin_role = role.into();
865 self
866 }
867
868 #[must_use]
871 pub fn enable_request_header_logging(mut self) -> Self {
872 self.log_request_headers = true;
873 self
874 }
875
876 #[cfg(feature = "metrics")]
879 #[must_use]
880 pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
881 self.metrics_enabled = true;
882 self.metrics_bind = bind.into();
883 self
884 }
885
886 pub fn validate(self) -> Result<Validated<Self>, McpxError> {
919 self.check()?;
920 Ok(Validated(self))
921 }
922
923 fn check_burst_knobs(&self) -> Result<(), McpxError> {
930 if self.tool_rate_limit_burst == Some(0) {
931 return Err(McpxError::Config(
932 "tool_rate_limit_burst must be greater than zero".into(),
933 ));
934 }
935 if self.extra_route_rate_limit_burst == Some(0) {
936 return Err(McpxError::Config(
937 "extra_route_rate_limit_burst must be greater than zero".into(),
938 ));
939 }
940 if self.tool_rate_limit_burst.is_some() && self.tool_rate_limit.is_none() {
941 return Err(McpxError::Config(
942 "tool_rate_limit_burst requires tool_rate_limit to be set".into(),
943 ));
944 }
945 if self.extra_route_rate_limit_burst.is_some() && self.extra_route_rate_limit.is_none() {
946 return Err(McpxError::Config(
947 "extra_route_rate_limit_burst requires extra_route_rate_limit to be set".into(),
948 ));
949 }
950 if let Some(rl) = self.auth.as_ref().and_then(|a| a.rate_limit.as_ref()) {
951 if rl.burst == Some(0) {
952 return Err(McpxError::Config(
953 "auth rate_limit.burst must be greater than zero".into(),
954 ));
955 }
956 if rl.pre_auth_burst == Some(0) {
957 return Err(McpxError::Config(
958 "auth rate_limit.pre_auth_burst must be greater than zero".into(),
959 ));
960 }
961 }
962 Ok(())
963 }
964
965 fn check(&self) -> Result<(), McpxError> {
969 if self.admin_enabled {
973 let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
974 if !auth_enabled {
975 return Err(McpxError::Config(
976 "admin_enabled=true requires auth to be configured and enabled".into(),
977 ));
978 }
979 }
980
981 match (&self.tls_cert_path, &self.tls_key_path) {
983 (Some(_), None) => {
984 return Err(McpxError::Config(
985 "tls_cert_path is set but tls_key_path is missing".into(),
986 ));
987 }
988 (None, Some(_)) => {
989 return Err(McpxError::Config(
990 "tls_key_path is set but tls_cert_path is missing".into(),
991 ));
992 }
993 _ => {}
994 }
995
996 if self.bind_addr.parse::<SocketAddr>().is_err() {
998 return Err(McpxError::Config(format!(
999 "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
1000 self.bind_addr
1001 )));
1002 }
1003
1004 if let Some(ref url) = self.public_url
1006 && !(url.starts_with("http://") || url.starts_with("https://"))
1007 {
1008 return Err(McpxError::Config(format!(
1009 "public_url {url:?} must start with http:// or https://"
1010 )));
1011 }
1012
1013 for origin in &self.allowed_origins {
1015 if !(origin.starts_with("http://") || origin.starts_with("https://")) {
1016 return Err(McpxError::Config(format!(
1017 "allowed_origins entry {origin:?} must start with http:// or https://"
1018 )));
1019 }
1020 }
1021
1022 if self.max_request_body == 0 {
1024 return Err(McpxError::Config(
1025 "max_request_body must be greater than zero".into(),
1026 ));
1027 }
1028
1029 if self.extra_route_rate_limit == Some(0) {
1033 return Err(McpxError::Config(
1034 "extra_route_rate_limit must be greater than zero".into(),
1035 ));
1036 }
1037
1038 self.check_burst_knobs()?;
1040
1041 #[cfg(feature = "oauth")]
1043 if let Some(auth_cfg) = &self.auth
1044 && let Some(oauth_cfg) = &auth_cfg.oauth
1045 {
1046 oauth_cfg.validate()?;
1047 }
1048
1049 validate_security_headers(&self.security_headers)?;
1052
1053 if let Some(0) = self.max_concurrent_requests {
1057 return Err(McpxError::Config(
1058 "max_concurrent_requests must be greater than zero when set".into(),
1059 ));
1060 }
1061
1062 if let Some(auth_cfg) = &self.auth
1066 && let Some(rl) = &auth_cfg.rate_limit
1067 && rl.max_tracked_keys == 0
1068 {
1069 return Err(McpxError::Config(
1070 "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
1071 ));
1072 }
1073
1074 if self.tls_handshake_timeout == Duration::ZERO {
1079 return Err(McpxError::Config(
1080 "tls_handshake_timeout must be greater than zero".into(),
1081 ));
1082 }
1083
1084 if self.max_concurrent_tls_handshakes == 0 {
1089 return Err(McpxError::Config(
1090 "max_concurrent_tls_handshakes must be greater than zero".into(),
1091 ));
1092 }
1093
1094 Ok(())
1095 }
1096}
1097
1098#[allow(
1104 missing_debug_implementations,
1105 reason = "contains Arc<AuthState> with non-Debug fields"
1106)]
1107pub struct ReloadHandle {
1108 auth: Option<Arc<AuthState>>,
1109 rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
1110 crl_set: Option<Arc<CrlSet>>,
1111}
1112
1113impl ReloadHandle {
1114 pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
1116 if let Some(ref auth) = self.auth {
1117 auth.reload_keys(keys);
1118 }
1119 }
1120
1121 pub fn reload_rbac(&self, policy: RbacPolicy) {
1123 if let Some(ref rbac) = self.rbac {
1124 rbac.store(Arc::new(policy));
1125 tracing::info!("RBAC policy reloaded");
1126 }
1127 }
1128
1129 pub async fn refresh_crls(&self) -> Result<(), McpxError> {
1135 let Some(ref crl_set) = self.crl_set else {
1136 return Err(McpxError::Config(
1137 "CRL refresh requested but mTLS CRL support is not configured".into(),
1138 ));
1139 };
1140
1141 crl_set.force_refresh().await
1142 }
1143}
1144
1145#[allow(
1162 clippy::too_many_lines,
1163 clippy::cognitive_complexity,
1164 reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
1165)]
1166struct AppRunParams {
1170 tls_paths: Option<(PathBuf, PathBuf)>,
1172 tls_handshake_timeout: Duration,
1174 max_concurrent_tls_handshakes: usize,
1176 mtls_config: Option<MtlsConfig>,
1178 shutdown_timeout: Duration,
1180 auth_state: Option<Arc<AuthState>>,
1182 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1184 on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1186 ct: CancellationToken,
1190 scheme: &'static str,
1192 name: String,
1194}
1195
1196#[allow(
1206 clippy::cognitive_complexity,
1207 reason = "router assembly is intrinsically sequential; splitting harms readability"
1208)]
1209#[allow(
1210 deprecated,
1211 reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
1212)]
1213fn build_app_router<H, F>(
1214 mut config: McpServerConfig,
1215 handler_factory: F,
1216) -> anyhow::Result<(axum::Router, AppRunParams)>
1217where
1218 H: ServerHandler + 'static,
1219 F: Fn() -> H + Send + Sync + Clone + 'static,
1220{
1221 let ct = CancellationToken::new();
1222
1223 let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
1224 tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
1225
1226 let mcp_service = StreamableHttpService::new(
1227 move || Ok(handler_factory()),
1228 {
1229 let mut mgr = LocalSessionManager::default();
1230 mgr.session_config.keep_alive = Some(config.session_idle_timeout);
1231 mgr.into()
1232 },
1233 StreamableHttpServerConfig::default()
1234 .with_allowed_hosts(allowed_hosts)
1235 .with_sse_keep_alive(Some(config.sse_keep_alive))
1236 .with_cancellation_token(ct.child_token()),
1237 );
1238
1239 let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
1241
1242 let auth_state: Option<Arc<AuthState>> = match config.auth {
1246 Some(ref auth_config) if auth_config.enabled => {
1247 let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
1248 let pre_auth_limiter = auth_config
1249 .rate_limit
1250 .as_ref()
1251 .map(crate::auth::build_pre_auth_limiter);
1252
1253 #[cfg(feature = "oauth")]
1254 let jwks_cache = auth_config
1255 .oauth
1256 .as_ref()
1257 .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
1258 .transpose()
1259 .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
1260
1261 Some(Arc::new(AuthState {
1262 api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
1263 rate_limiter,
1264 pre_auth_limiter,
1265 #[cfg(feature = "oauth")]
1266 jwks_cache,
1267 seen_identities: crate::auth::SeenIdentitySet::new(),
1268 counters: crate::auth::AuthCounters::default(),
1269 }))
1270 }
1271 _ => None,
1272 };
1273
1274 let rbac_swap = Arc::new(ArcSwap::new(
1277 config
1278 .rbac
1279 .clone()
1280 .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
1281 ));
1282
1283 if config.admin_enabled {
1286 let Some(ref auth_state_ref) = auth_state else {
1287 return Err(anyhow::anyhow!(
1288 "admin_enabled=true requires auth to be configured and enabled"
1289 ));
1290 };
1291 let admin_state = crate::admin::AdminState {
1292 started_at: std::time::Instant::now(),
1293 name: config.name.clone(),
1294 version: config.version.clone(),
1295 auth: Some(Arc::clone(auth_state_ref)),
1296 rbac: Arc::clone(&rbac_swap),
1297 };
1298 let admin_cfg = crate::admin::AdminConfig {
1299 role: config.admin_role.clone(),
1300 };
1301 mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
1302 tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
1303 }
1304
1305 {
1338 let tool_limiter: Option<Arc<ToolRateLimiter>> = config
1339 .tool_rate_limit
1340 .map(|per_minute| build_tool_rate_limiter(per_minute, config.tool_rate_limit_burst));
1341
1342 if rbac_swap.load().is_enabled() {
1343 tracing::info!("RBAC enforcement enabled on /mcp");
1344 }
1345 if let Some(limit) = config.tool_rate_limit {
1346 tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1347 }
1348
1349 let rbac_for_mw = Arc::clone(&rbac_swap);
1350 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1351 let p = rbac_for_mw.load_full();
1352 let tl = tool_limiter.clone();
1353 rbac_middleware(p, tl, req, next)
1354 }));
1355 }
1356
1357 if let Some(ref auth_config) = config.auth
1359 && auth_config.enabled
1360 {
1361 let Some(ref state) = auth_state else {
1362 return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1363 };
1364
1365 let methods: Vec<&str> = [
1366 auth_config.mtls.is_some().then_some("mTLS"),
1367 (!auth_config.api_keys.is_empty()).then_some("bearer"),
1368 #[cfg(feature = "oauth")]
1369 auth_config.oauth.is_some().then_some("oauth-jwt"),
1370 ]
1371 .into_iter()
1372 .flatten()
1373 .collect();
1374
1375 tracing::info!(
1376 methods = %methods.join(", "),
1377 api_keys = auth_config.api_keys.len(),
1378 "auth enabled on /mcp"
1379 );
1380
1381 let state_for_mw = Arc::clone(state);
1382 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1383 let s = Arc::clone(&state_for_mw);
1384 auth_middleware(s, req, next)
1385 }));
1386 }
1387
1388 mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1391 axum::http::StatusCode::REQUEST_TIMEOUT,
1392 config.request_timeout,
1393 ));
1394
1395 mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1399 config.max_request_body,
1400 ));
1401
1402 let mut effective_origins = config.allowed_origins.clone();
1409 if effective_origins.is_empty()
1410 && let Some(ref url) = config.public_url
1411 {
1412 if let Some(scheme_end) = url.find("://") {
1417 let scheme_with_sep = url.get(..scheme_end + 3).unwrap_or_default();
1418 let after_scheme = url.get(scheme_end + 3..).unwrap_or_default();
1419 let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1420 let host = after_scheme.get(..host_end).unwrap_or_default();
1421 let origin = format!("{scheme_with_sep}{host}");
1422 tracing::info!(
1423 %origin,
1424 "auto-derived allowed origin from public_url"
1425 );
1426 effective_origins.push(origin);
1427 }
1428 }
1429 let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1430 let cors_origins = Arc::clone(&allowed_origins);
1431 let log_request_headers = config.log_request_headers;
1432
1433 let readyz_route = if let Some(check) = config.readiness_check.take() {
1434 axum::routing::get(move || readyz(Arc::clone(&check)))
1435 } else {
1436 axum::routing::get(healthz)
1437 };
1438
1439 #[allow(unused_mut)] let mut router = axum::Router::new()
1441 .route("/healthz", axum::routing::get(healthz))
1442 .route("/readyz", readyz_route)
1443 .route(
1444 "/version",
1445 axum::routing::get({
1446 let payload_bytes: Arc<[u8]> =
1451 serialize_version_payload(&config.name, &config.version);
1452 move || {
1453 let p = Arc::clone(&payload_bytes);
1454 async move {
1455 (
1456 [(axum::http::header::CONTENT_TYPE, "application/json")],
1457 p.to_vec(),
1458 )
1459 }
1460 }
1461 }),
1462 )
1463 .merge(mcp_router);
1464
1465 if let Some(extra) = config.extra_router.take() {
1472 let extra = match config.extra_route_rate_limit {
1473 Some(per_minute) => {
1474 let limiter =
1475 build_extra_route_rate_limiter(per_minute, config.extra_route_rate_limit_burst);
1476 tracing::info!(per_minute, "extra-route per-IP rate limit enabled");
1477 extra.layer(axum::middleware::from_fn(move |req, next| {
1478 let l = Arc::clone(&limiter);
1479 extra_route_rate_limit_middleware(l, req, next)
1480 }))
1481 }
1482 None => extra,
1483 };
1484 router = router.merge(extra);
1485 }
1486
1487 let server_url = if let Some(ref url) = config.public_url {
1494 url.trim_end_matches('/').to_owned()
1495 } else {
1496 let prm_scheme = if config.tls_cert_path.is_some() {
1497 "https"
1498 } else {
1499 "http"
1500 };
1501 format!("{prm_scheme}://{}", config.bind_addr)
1502 };
1503 let resource_url = format!("{server_url}/mcp");
1504
1505 #[cfg(feature = "oauth")]
1506 let prm_metadata = if let Some(ref auth_config) = config.auth
1507 && let Some(ref oauth_config) = auth_config.oauth
1508 {
1509 crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1510 } else {
1511 serde_json::json!({ "resource": resource_url })
1512 };
1513 #[cfg(not(feature = "oauth"))]
1514 let prm_metadata = serde_json::json!({ "resource": resource_url });
1515
1516 router = router.route(
1517 "/.well-known/oauth-protected-resource",
1518 axum::routing::get(move || {
1519 let m = prm_metadata.clone();
1520 async move { axum::Json(m) }
1521 }),
1522 );
1523
1524 #[cfg(feature = "oauth")]
1529 if let Some(ref auth_config) = config.auth
1530 && let Some(ref oauth_config) = auth_config.oauth
1531 && oauth_config.proxy.is_some()
1532 {
1533 router =
1534 install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1535 }
1536
1537 let is_tls = config.tls_cert_path.is_some();
1540 let security_headers_cfg = Arc::new(config.security_headers.clone());
1541 router = router.layer(axum::middleware::from_fn(move |req, next| {
1542 let cfg = Arc::clone(&security_headers_cfg);
1543 security_headers_middleware(is_tls, cfg, req, next)
1544 }));
1545
1546 if !cors_origins.is_empty() {
1550 let cors = tower_http::cors::CorsLayer::new()
1551 .allow_origin(
1552 cors_origins
1553 .iter()
1554 .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1555 .collect::<Vec<_>>(),
1556 )
1557 .allow_methods([
1558 axum::http::Method::GET,
1559 axum::http::Method::POST,
1560 axum::http::Method::OPTIONS,
1561 ])
1562 .allow_headers([
1563 axum::http::header::CONTENT_TYPE,
1564 axum::http::header::AUTHORIZATION,
1565 ]);
1566 router = router.layer(cors);
1567 }
1568
1569 if config.compression_enabled {
1573 use tower_http::compression::Predicate as _;
1574 let predicate = tower_http::compression::DefaultPredicate::new().and(
1575 tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1576 );
1577 router = router.layer(
1578 tower_http::compression::CompressionLayer::new()
1579 .gzip(true)
1580 .br(true)
1581 .compress_when(predicate),
1582 );
1583 tracing::info!(
1584 min_size = config.compression_min_size,
1585 "response compression enabled (gzip, br)"
1586 );
1587 }
1588
1589 if let Some(max) = config.max_concurrent_requests {
1592 let overload_handler = tower::ServiceBuilder::new()
1593 .layer(axum::error_handling::HandleErrorLayer::new(
1594 |_err: tower::BoxError| async {
1595 (
1596 axum::http::StatusCode::SERVICE_UNAVAILABLE,
1597 axum::Json(serde_json::json!({
1598 "error": "overloaded",
1599 "error_description": "server is at capacity, retry later"
1600 })),
1601 )
1602 },
1603 ))
1604 .layer(tower::load_shed::LoadShedLayer::new())
1605 .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1606 router = router.layer(overload_handler);
1607 tracing::info!(max, "global concurrency limit enabled");
1608 }
1609
1610 router = router.fallback(|| async {
1614 (
1615 axum::http::StatusCode::NOT_FOUND,
1616 axum::Json(serde_json::json!({
1617 "error": "not_found",
1618 "error_description": "The requested endpoint does not exist"
1619 })),
1620 )
1621 });
1622
1623 #[cfg(feature = "metrics")]
1625 if config.metrics_enabled {
1626 let metrics = Arc::new(
1627 crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1628 );
1629 let m = Arc::clone(&metrics);
1630 router = router.layer(axum::middleware::from_fn(
1631 move |req: Request<Body>, next: Next| {
1632 let m = Arc::clone(&m);
1633 metrics_middleware(m, req, next)
1634 },
1635 ));
1636 let metrics_bind = config.metrics_bind.clone();
1637 let metrics_shutdown = ct.clone();
1638 tokio::spawn(async move {
1639 if let Err(e) =
1640 crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1641 {
1642 tracing::error!("metrics listener failed: {e}");
1643 }
1644 });
1645 }
1646
1647 router = router.layer(axum::middleware::from_fn(normalize_peer_addr_middleware));
1655
1656 router = router.layer(axum::middleware::from_fn(move |req, next| {
1667 let origins = Arc::clone(&allowed_origins);
1668 origin_check_middleware(origins, log_request_headers, req, next)
1669 }));
1670
1671 let scheme = if config.tls_cert_path.is_some() {
1672 "https"
1673 } else {
1674 "http"
1675 };
1676
1677 let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1678 (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1679 _ => None,
1680 };
1681 let tls_handshake_timeout = config.tls_handshake_timeout;
1682 let max_concurrent_tls_handshakes = config.max_concurrent_tls_handshakes;
1683 let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1684
1685 Ok((
1686 router,
1687 AppRunParams {
1688 tls_paths,
1689 tls_handshake_timeout,
1690 max_concurrent_tls_handshakes,
1691 mtls_config,
1692 shutdown_timeout: config.shutdown_timeout,
1693 auth_state,
1694 rbac_swap,
1695 on_reload_ready: config.on_reload_ready.take(),
1696 ct,
1697 scheme,
1698 name: config.name.clone(),
1699 },
1700 ))
1701}
1702
1703pub async fn serve<H, F>(
1720 config: Validated<McpServerConfig>,
1721 handler_factory: F,
1722) -> Result<(), McpxError>
1723where
1724 H: ServerHandler + 'static,
1725 F: Fn() -> H + Send + Sync + Clone + 'static,
1726{
1727 let config = config.into_inner();
1728 #[allow(
1729 deprecated,
1730 reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1731 )]
1732 let bind_addr = config.bind_addr.clone();
1733 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1734
1735 let listener = TcpListener::bind(&bind_addr)
1736 .await
1737 .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1738 log_listening(¶ms.name, params.scheme, &bind_addr);
1739
1740 run_server(
1741 router,
1742 listener,
1743 params.tls_paths,
1744 params.tls_handshake_timeout,
1745 params.max_concurrent_tls_handshakes,
1746 params.mtls_config,
1747 params.shutdown_timeout,
1748 params.auth_state,
1749 params.rbac_swap,
1750 params.on_reload_ready,
1751 params.ct,
1752 )
1753 .await
1754 .map_err(anyhow_to_startup)
1755}
1756
1757pub async fn serve_with_listener<H, F>(
1787 listener: TcpListener,
1788 config: Validated<McpServerConfig>,
1789 handler_factory: F,
1790 ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1791 shutdown: Option<CancellationToken>,
1792) -> Result<(), McpxError>
1793where
1794 H: ServerHandler + 'static,
1795 F: Fn() -> H + Send + Sync + Clone + 'static,
1796{
1797 let config = config.into_inner();
1798 let local_addr = listener
1799 .local_addr()
1800 .map_err(|e| io_to_startup("listener.local_addr", e))?;
1801 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1802
1803 log_listening(¶ms.name, params.scheme, &local_addr.to_string());
1804
1805 if let Some(external) = shutdown {
1809 let internal = params.ct.clone();
1810 tokio::spawn(async move {
1811 external.cancelled().await;
1812 internal.cancel();
1813 });
1814 }
1815
1816 if let Some(tx) = ready_tx {
1820 let _ = tx.send(local_addr);
1822 }
1823
1824 run_server(
1825 router,
1826 listener,
1827 params.tls_paths,
1828 params.tls_handshake_timeout,
1829 params.max_concurrent_tls_handshakes,
1830 params.mtls_config,
1831 params.shutdown_timeout,
1832 params.auth_state,
1833 params.rbac_swap,
1834 params.on_reload_ready,
1835 params.ct,
1836 )
1837 .await
1838 .map_err(anyhow_to_startup)
1839}
1840
1841#[allow(
1844 clippy::cognitive_complexity,
1845 reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1846)]
1847fn log_listening(name: &str, scheme: &str, addr: &str) {
1848 tracing::info!("{name} listening on {addr}");
1849 tracing::info!(" MCP endpoint: {scheme}://{addr}/mcp");
1850 tracing::info!(" Health check: {scheme}://{addr}/healthz");
1851 tracing::info!(" Readiness: {scheme}://{addr}/readyz");
1852}
1853
1854#[allow(
1877 clippy::too_many_arguments,
1878 clippy::cognitive_complexity,
1879 reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1880)]
1881async fn run_server(
1882 router: axum::Router,
1883 listener: TcpListener,
1884 tls_paths: Option<(PathBuf, PathBuf)>,
1885 tls_handshake_timeout: Duration,
1886 max_concurrent_tls_handshakes: usize,
1887 mtls_config: Option<MtlsConfig>,
1888 shutdown_timeout: Duration,
1889 auth_state: Option<Arc<AuthState>>,
1890 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1891 mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1892 ct: CancellationToken,
1893) -> anyhow::Result<()> {
1894 let shutdown_trigger = CancellationToken::new();
1898 {
1899 let trigger = shutdown_trigger.clone();
1900 let parent = ct.clone();
1901 tokio::spawn(async move {
1902 tokio::select! {
1903 () = shutdown_signal() => {}
1904 () = parent.cancelled() => {}
1905 }
1906 trigger.cancel();
1907 });
1908 }
1909
1910 let graceful = {
1911 let trigger = shutdown_trigger.clone();
1912 let ct = ct.clone();
1913 async move {
1914 trigger.cancelled().await;
1915 tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1916 ct.cancel();
1917 }
1918 };
1919
1920 let force_exit_timer = {
1921 let trigger = shutdown_trigger.clone();
1922 async move {
1923 trigger.cancelled().await;
1924 tokio::time::sleep(shutdown_timeout).await;
1925 }
1926 };
1927
1928 if let Some((cert_path, key_path)) = tls_paths {
1929 let crl_set = if let Some(mtls) = mtls_config.as_ref()
1930 && mtls.crl_enabled
1931 {
1932 let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1933 let (crl_set, discover_rx) =
1934 mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1935 .await
1936 .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1937 tokio::spawn(mtls_revocation::run_crl_refresher(
1938 Arc::clone(&crl_set),
1939 discover_rx,
1940 ct.clone(),
1941 ));
1942 Some(crl_set)
1943 } else {
1944 None
1945 };
1946
1947 if let Some(cb) = on_reload_ready.take() {
1948 cb(ReloadHandle {
1949 auth: auth_state.clone(),
1950 rbac: Some(Arc::clone(&rbac_swap)),
1951 crl_set: crl_set.clone(),
1952 });
1953 }
1954
1955 let tls_listener = TlsListener::new(
1956 listener,
1957 &cert_path,
1958 &key_path,
1959 mtls_config.as_ref(),
1960 crl_set,
1961 tls_handshake_timeout,
1962 max_concurrent_tls_handshakes,
1963 )?;
1964 let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1965 tokio::select! {
1966 result = axum::serve(tls_listener, make_svc)
1967 .with_graceful_shutdown(graceful) => { result?; }
1968 () = force_exit_timer => {
1969 tracing::warn!("shutdown timeout exceeded, forcing exit");
1970 }
1971 }
1972 } else {
1973 if let Some(cb) = on_reload_ready.take() {
1974 cb(ReloadHandle {
1975 auth: auth_state,
1976 rbac: Some(rbac_swap),
1977 crl_set: None,
1978 });
1979 }
1980
1981 let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1982 tokio::select! {
1983 result = axum::serve(listener, make_svc)
1984 .with_graceful_shutdown(graceful) => { result?; }
1985 () = force_exit_timer => {
1986 tracing::warn!("shutdown timeout exceeded, forcing exit");
1987 }
1988 }
1989 }
1990
1991 Ok(())
1992}
1993
1994#[cfg(feature = "oauth")]
2003fn install_oauth_proxy_routes(
2004 router: axum::Router,
2005 server_url: &str,
2006 oauth_config: &crate::oauth::OAuthConfig,
2007 auth_state: Option<&Arc<AuthState>>,
2008) -> Result<axum::Router, McpxError> {
2009 let Some(ref proxy) = oauth_config.proxy else {
2010 return Ok(router);
2011 };
2012
2013 let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
2016
2017 let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
2018 let router = router.route(
2019 "/.well-known/oauth-authorization-server",
2020 axum::routing::get(move || {
2021 let m = asm.clone();
2022 async move { axum::Json(m) }
2023 }),
2024 );
2025
2026 let proxy_authorize = proxy.clone();
2027 let router = router.route(
2028 "/authorize",
2029 axum::routing::get(
2030 move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
2031 let p = proxy_authorize.clone();
2032 async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
2033 },
2034 ),
2035 );
2036
2037 let proxy_token = proxy.clone();
2038 let token_http = http.clone();
2039 let router = router.route(
2040 "/token",
2041 axum::routing::post(move |body: String| {
2042 let p = proxy_token.clone();
2043 let h = token_http.clone();
2044 async move { crate::oauth::handle_token(&h, &p, &body).await }
2045 })
2046 .layer(axum::middleware::from_fn(
2047 oauth_token_cache_headers_middleware,
2048 )),
2049 );
2050
2051 let proxy_register = proxy.clone();
2052 let router = router.route(
2053 "/register",
2054 axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
2055 let p = proxy_register;
2056 async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
2057 })
2058 .layer(axum::middleware::from_fn(
2059 oauth_token_cache_headers_middleware,
2060 )),
2061 );
2062
2063 let admin_routes_enabled = proxy.expose_admin_endpoints
2064 && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
2065 if proxy.expose_admin_endpoints
2066 && !proxy.require_auth_on_admin_endpoints
2067 && proxy.allow_unauthenticated_admin_endpoints
2068 {
2069 tracing::warn!(
2073 "OAuth introspect/revoke endpoints are unauthenticated by explicit \
2074 allow_unauthenticated_admin_endpoints opt-out; ensure an \
2075 authenticated reverse proxy fronts these routes"
2076 );
2077 }
2078
2079 let admin_router = if admin_routes_enabled {
2080 build_oauth_admin_router(proxy, http, auth_state)?
2081 } else {
2082 axum::Router::new()
2083 };
2084
2085 let router = router.merge(admin_router);
2086
2087 tracing::info!(
2088 introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
2089 revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
2090 "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
2091 );
2092 Ok(router)
2093}
2094
2095#[cfg(feature = "oauth")]
2101fn build_oauth_admin_router(
2102 proxy: &crate::oauth::OAuthProxyConfig,
2103 http: crate::oauth::OauthHttpClient,
2104 auth_state: Option<&Arc<AuthState>>,
2105) -> Result<axum::Router, McpxError> {
2106 let mut admin_router = axum::Router::new();
2107 if proxy.introspection_url.is_some() {
2108 let proxy_introspect = proxy.clone();
2109 let introspect_http = http.clone();
2110 admin_router = admin_router.route(
2111 "/introspect",
2112 axum::routing::post(move |body: String| {
2113 let p = proxy_introspect.clone();
2114 let h = introspect_http.clone();
2115 async move { crate::oauth::handle_introspect(&h, &p, &body).await }
2116 }),
2117 );
2118 }
2119 if proxy.revocation_url.is_some() {
2120 let proxy_revoke = proxy.clone();
2121 let revoke_http = http;
2122 admin_router = admin_router.route(
2123 "/revoke",
2124 axum::routing::post(move |body: String| {
2125 let p = proxy_revoke.clone();
2126 let h = revoke_http.clone();
2127 async move { crate::oauth::handle_revoke(&h, &p, &body).await }
2128 }),
2129 );
2130 }
2131
2132 let admin_router = admin_router.layer(axum::middleware::from_fn(
2133 oauth_token_cache_headers_middleware,
2134 ));
2135
2136 if proxy.require_auth_on_admin_endpoints {
2137 let Some(state) = auth_state else {
2138 return Err(McpxError::Startup(
2139 "oauth proxy admin endpoints require auth state".into(),
2140 ));
2141 };
2142 let state_for_mw = Arc::clone(state);
2143 Ok(
2144 admin_router.layer(axum::middleware::from_fn(move |req, next| {
2145 let s = Arc::clone(&state_for_mw);
2146 auth_middleware(s, req, next)
2147 })),
2148 )
2149 } else {
2150 Ok(admin_router)
2151 }
2152}
2153
2154fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
2159 let mut hosts = vec![
2160 "localhost".to_owned(),
2161 "127.0.0.1".to_owned(),
2162 "::1".to_owned(),
2163 ];
2164
2165 if let Some(url) = public_url
2166 && let Ok(uri) = url.parse::<axum::http::Uri>()
2167 && let Some(authority) = uri.authority()
2168 {
2169 let host = authority.host().to_owned();
2170 if !hosts.iter().any(|h| h == &host) {
2171 hosts.push(host);
2172 }
2173
2174 let authority = authority.as_str().to_owned();
2175 if !hosts.iter().any(|h| h == &authority) {
2176 hosts.push(authority);
2177 }
2178 }
2179
2180 if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
2181 && let Some(authority) = uri.authority()
2182 {
2183 let host = authority.host().to_owned();
2184 if !hosts.iter().any(|h| h == &host) {
2185 hosts.push(host);
2186 }
2187
2188 let authority = authority.as_str().to_owned();
2189 if !hosts.iter().any(|h| h == &authority) {
2190 hosts.push(authority);
2191 }
2192 }
2193
2194 hosts
2195}
2196
2197impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
2210 for TlsConnInfo
2211{
2212 fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
2213 let addr = *target.remote_addr();
2214 let identity = target.io().identity().cloned();
2215 TlsConnInfo::new(addr, identity)
2216 }
2217}
2218
2219const DEFAULT_TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
2226
2227const DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES: usize = 256;
2235
2236const TLS_ACCEPT_CHANNEL_CAPACITY: usize = 32;
2241
2242struct TlsListener {
2258 local_addr: SocketAddr,
2261 rx: mpsc::Receiver<(AuthenticatedTlsStream, SocketAddr)>,
2263 acceptor_task: tokio::task::JoinHandle<()>,
2266}
2267
2268impl TlsListener {
2269 fn new(
2270 inner: TcpListener,
2271 cert_path: &Path,
2272 key_path: &Path,
2273 mtls_config: Option<&MtlsConfig>,
2274 crl_set: Option<Arc<CrlSet>>,
2275 handshake_timeout: Duration,
2276 max_concurrent_handshakes: usize,
2277 ) -> anyhow::Result<Self> {
2278 rustls::crypto::ring::default_provider()
2280 .install_default()
2281 .ok();
2282
2283 let certs = load_certs(cert_path)?;
2284 let key = load_key(key_path)?;
2285
2286 let mtls_default_role;
2287
2288 let tls_config = if let Some(mtls) = mtls_config {
2289 mtls_default_role = mtls.default_role.clone();
2290 let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
2291 {
2292 let Some(crl_set) = crl_set else {
2293 return Err(anyhow::anyhow!(
2294 "mTLS CRL verifier requested but CRL state was not initialized"
2295 ));
2296 };
2297 Arc::new(DynamicClientCertVerifier::new(crl_set))
2298 } else {
2299 let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
2300 if mtls.required {
2301 rustls::server::WebPkiClientVerifier::builder(root_store)
2302 .build()
2303 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2304 } else {
2305 rustls::server::WebPkiClientVerifier::builder(root_store)
2306 .allow_unauthenticated()
2307 .build()
2308 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2309 }
2310 };
2311
2312 tracing::info!(
2313 ca = %mtls.ca_cert_path.display(),
2314 required = mtls.required,
2315 crl_enabled = mtls.crl_enabled,
2316 "mTLS client auth configured"
2317 );
2318
2319 rustls::ServerConfig::builder_with_protocol_versions(&[
2320 &rustls::version::TLS12,
2321 &rustls::version::TLS13,
2322 ])
2323 .with_client_cert_verifier(verifier)
2324 .with_single_cert(certs, key)?
2325 } else {
2326 mtls_default_role = "viewer".to_owned();
2327 rustls::ServerConfig::builder_with_protocol_versions(&[
2328 &rustls::version::TLS12,
2329 &rustls::version::TLS13,
2330 ])
2331 .with_no_client_auth()
2332 .with_single_cert(certs, key)?
2333 };
2334
2335 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
2336 tracing::info!(
2337 "TLS enabled (cert: {}, key: {})",
2338 cert_path.display(),
2339 key_path.display()
2340 );
2341 let local_addr = inner.local_addr()?;
2342 let (tx, rx) = mpsc::channel(TLS_ACCEPT_CHANNEL_CAPACITY);
2343 let acceptor_task = tokio::spawn(run_tls_acceptor(
2344 inner,
2345 acceptor,
2346 mtls_default_role,
2347 tx,
2348 handshake_timeout,
2349 max_concurrent_handshakes,
2350 ));
2351 Ok(Self {
2352 local_addr,
2353 rx,
2354 acceptor_task,
2355 })
2356 }
2357
2358 fn extract_handshake_identity(
2362 tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2363 default_role: &str,
2364 addr: SocketAddr,
2365 ) -> Option<AuthIdentity> {
2366 let (_, server_conn) = tls_stream.get_ref();
2367 let cert_der = server_conn.peer_certificates()?.first()?;
2368 let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
2369 tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
2370 Some(id)
2371 }
2372}
2373
2374async fn run_tls_acceptor(
2382 listener: TcpListener,
2383 acceptor: tokio_rustls::TlsAcceptor,
2384 default_role: String,
2385 tx: mpsc::Sender<(AuthenticatedTlsStream, SocketAddr)>,
2386 handshake_timeout: Duration,
2387 max_concurrent_handshakes: usize,
2388) {
2389 let inflight = Arc::new(Semaphore::new(max_concurrent_handshakes));
2390 loop {
2391 let Ok(permit) = Arc::clone(&inflight).acquire_owned().await else {
2395 return;
2397 };
2398 let (stream, addr) = match listener.accept().await {
2399 Ok(pair) => pair,
2400 Err(e) => {
2401 tracing::debug!("TCP accept error: {e}");
2402 continue;
2403 }
2404 };
2405 if tx.is_closed() {
2406 return;
2408 }
2409 let acceptor = acceptor.clone();
2410 let default_role = default_role.clone();
2411 let tx = tx.clone();
2412 tokio::spawn(async move {
2413 let _permit = permit;
2414 match tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
2415 Ok(Ok(tls_stream)) => {
2416 let identity =
2417 TlsListener::extract_handshake_identity(&tls_stream, &default_role, addr);
2418 let wrapped = AuthenticatedTlsStream {
2419 inner: tls_stream,
2420 identity,
2421 };
2422 let _ = tx.send((wrapped, addr)).await;
2425 }
2426 Ok(Err(e)) => {
2427 tracing::debug!("TLS handshake failed from {addr}: {e}");
2428 }
2429 Err(_elapsed) => {
2430 tracing::debug!(
2431 "TLS handshake timed out from {addr} after {handshake_timeout:?}"
2432 );
2433 }
2434 }
2435 });
2436 }
2437}
2438
2439pub(crate) struct AuthenticatedTlsStream {
2451 inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2452 identity: Option<AuthIdentity>,
2453}
2454
2455impl AuthenticatedTlsStream {
2456 #[must_use]
2458 pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
2459 self.identity.as_ref()
2460 }
2461}
2462
2463impl std::fmt::Debug for AuthenticatedTlsStream {
2464 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2465 f.debug_struct("AuthenticatedTlsStream")
2466 .field("identity", &self.identity.as_ref().map(|id| &id.name))
2467 .finish_non_exhaustive()
2468 }
2469}
2470
2471impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2472 fn poll_read(
2473 mut self: Pin<&mut Self>,
2474 cx: &mut std::task::Context<'_>,
2475 buf: &mut tokio::io::ReadBuf<'_>,
2476 ) -> std::task::Poll<std::io::Result<()>> {
2477 Pin::new(&mut self.inner).poll_read(cx, buf)
2478 }
2479}
2480
2481impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2482 fn poll_write(
2483 mut self: Pin<&mut Self>,
2484 cx: &mut std::task::Context<'_>,
2485 buf: &[u8],
2486 ) -> std::task::Poll<std::io::Result<usize>> {
2487 Pin::new(&mut self.inner).poll_write(cx, buf)
2488 }
2489
2490 fn poll_flush(
2491 mut self: Pin<&mut Self>,
2492 cx: &mut std::task::Context<'_>,
2493 ) -> std::task::Poll<std::io::Result<()>> {
2494 Pin::new(&mut self.inner).poll_flush(cx)
2495 }
2496
2497 fn poll_shutdown(
2498 mut self: Pin<&mut Self>,
2499 cx: &mut std::task::Context<'_>,
2500 ) -> std::task::Poll<std::io::Result<()>> {
2501 Pin::new(&mut self.inner).poll_shutdown(cx)
2502 }
2503
2504 fn poll_write_vectored(
2505 mut self: Pin<&mut Self>,
2506 cx: &mut std::task::Context<'_>,
2507 bufs: &[std::io::IoSlice<'_>],
2508 ) -> std::task::Poll<std::io::Result<usize>> {
2509 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2510 }
2511
2512 fn is_write_vectored(&self) -> bool {
2513 self.inner.is_write_vectored()
2514 }
2515}
2516
2517impl axum::serve::Listener for TlsListener {
2518 type Io = AuthenticatedTlsStream;
2519 type Addr = SocketAddr;
2520
2521 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2527 if let Some(pair) = self.rx.recv().await {
2528 return pair;
2529 }
2530 tracing::error!("TLS acceptor task terminated; no further connections will be accepted");
2536 std::future::pending().await
2537 }
2538
2539 fn local_addr(&self) -> std::io::Result<Self::Addr> {
2540 Ok(self.local_addr)
2541 }
2542}
2543
2544impl Drop for TlsListener {
2545 fn drop(&mut self) {
2546 self.acceptor_task.abort();
2549 }
2550}
2551
2552fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2553 use rustls::pki_types::pem::PemObject;
2554 let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2555 .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2556 .collect::<Result<_, _>>()
2557 .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2558 anyhow::ensure!(
2559 !certs.is_empty(),
2560 "no certificates found in {}",
2561 path.display()
2562 );
2563 Ok(certs)
2564}
2565
2566fn load_client_auth_roots(
2567 path: &Path,
2568) -> anyhow::Result<(
2569 Vec<rustls::pki_types::CertificateDer<'static>>,
2570 Arc<RootCertStore>,
2571)> {
2572 let ca_certs = load_certs(path)?;
2573 let mut root_store = RootCertStore::empty();
2574 for cert in &ca_certs {
2575 root_store
2576 .add(cert.clone())
2577 .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2578 }
2579
2580 Ok((ca_certs, Arc::new(root_store)))
2581}
2582
2583fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2584 use rustls::pki_types::pem::PemObject;
2585 rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2586 .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2587}
2588
2589#[allow(
2590 clippy::unused_async,
2591 reason = "axum route handler signature requires `async fn` even when the body is synchronous"
2592)]
2593async fn healthz() -> impl IntoResponse {
2594 axum::Json(serde_json::json!({
2595 "status": "ok",
2596 }))
2597}
2598
2599fn version_payload(name: &str, version: &str) -> serde_json::Value {
2606 serde_json::json!({
2607 "name": name,
2608 "version": version,
2609 "build_git_sha": option_env!("RMCP_SERVER_KIT_BUILD_SHA").unwrap_or("unknown"),
2610 "build_timestamp": option_env!("RMCP_SERVER_KIT_BUILD_TIME").unwrap_or("unknown"),
2611 "rust_version": option_env!("RMCP_SERVER_KIT_RUSTC_VERSION").unwrap_or("unknown"),
2612 "mcpx_version": env!("CARGO_PKG_VERSION"),
2613 })
2614}
2615
2616fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2626 let value = version_payload(name, version);
2627 serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2628}
2629
2630async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2631 let status = check().await;
2632 let ready = status
2633 .get("ready")
2634 .and_then(serde_json::Value::as_bool)
2635 .unwrap_or(false);
2636 let code = if ready {
2637 axum::http::StatusCode::OK
2638 } else {
2639 axum::http::StatusCode::SERVICE_UNAVAILABLE
2640 };
2641 (code, axum::Json(status))
2642}
2643
2644async fn shutdown_signal() {
2648 let ctrl_c = tokio::signal::ctrl_c();
2649
2650 #[cfg(unix)]
2651 {
2652 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2653 Ok(mut term) => {
2654 tokio::select! {
2655 _ = ctrl_c => {}
2656 _ = term.recv() => {}
2657 }
2658 }
2659 Err(e) => {
2660 tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2661 ctrl_c.await.ok();
2662 }
2663 }
2664 }
2665
2666 #[cfg(not(unix))]
2667 {
2668 ctrl_c.await.ok();
2669 }
2670}
2671
2672#[cfg(feature = "metrics")]
2678async fn metrics_middleware(
2679 metrics: Arc<crate::metrics::McpMetrics>,
2680 req: Request<Body>,
2681 next: Next,
2682) -> axum::response::Response {
2683 let method = req.method().to_string();
2684 let path = req.uri().path().to_owned();
2685 let start = std::time::Instant::now();
2686
2687 let response = next.run(req).await;
2688
2689 let status = response.status().as_u16().to_string();
2690 let duration = start.elapsed().as_secs_f64();
2691
2692 metrics
2693 .http_requests_total
2694 .with_label_values(&[&method, &path, &status])
2695 .inc();
2696 metrics
2697 .http_request_duration_seconds
2698 .with_label_values(&[&method, &path])
2699 .observe(duration);
2700
2701 response
2702}
2703
2704async fn security_headers_middleware(
2716 is_tls: bool,
2717 cfg: Arc<SecurityHeadersConfig>,
2718 req: Request<Body>,
2719 next: Next,
2720) -> axum::response::Response {
2721 use axum::http::{HeaderName, header};
2722
2723 let mut resp = next.run(req).await;
2724 let headers = resp.headers_mut();
2725
2726 headers.remove(header::SERVER);
2728 headers.remove(HeaderName::from_static("x-powered-by"));
2729
2730 apply_security_header(
2731 headers,
2732 header::X_CONTENT_TYPE_OPTIONS,
2733 cfg.x_content_type_options.as_deref(),
2734 "nosniff",
2735 );
2736 apply_security_header(
2737 headers,
2738 header::X_FRAME_OPTIONS,
2739 cfg.x_frame_options.as_deref(),
2740 "deny",
2741 );
2742 apply_security_header(
2743 headers,
2744 header::CACHE_CONTROL,
2745 cfg.cache_control.as_deref(),
2746 "no-store, max-age=0",
2747 );
2748 apply_security_header(
2749 headers,
2750 header::REFERRER_POLICY,
2751 cfg.referrer_policy.as_deref(),
2752 "no-referrer",
2753 );
2754 apply_security_header(
2755 headers,
2756 HeaderName::from_static("cross-origin-opener-policy"),
2757 cfg.cross_origin_opener_policy.as_deref(),
2758 "same-origin",
2759 );
2760 apply_security_header(
2761 headers,
2762 HeaderName::from_static("cross-origin-resource-policy"),
2763 cfg.cross_origin_resource_policy.as_deref(),
2764 "same-origin",
2765 );
2766 apply_security_header(
2767 headers,
2768 HeaderName::from_static("cross-origin-embedder-policy"),
2769 cfg.cross_origin_embedder_policy.as_deref(),
2770 "require-corp",
2771 );
2772 apply_security_header(
2773 headers,
2774 HeaderName::from_static("permissions-policy"),
2775 cfg.permissions_policy.as_deref(),
2776 "accelerometer=(), camera=(), geolocation=(), microphone=()",
2777 );
2778 apply_security_header(
2779 headers,
2780 HeaderName::from_static("x-permitted-cross-domain-policies"),
2781 cfg.x_permitted_cross_domain_policies.as_deref(),
2782 "none",
2783 );
2784 apply_security_header(
2785 headers,
2786 HeaderName::from_static("content-security-policy"),
2787 cfg.content_security_policy.as_deref(),
2788 "default-src 'none'; frame-ancestors 'none'",
2789 );
2790 apply_security_header(
2791 headers,
2792 HeaderName::from_static("x-dns-prefetch-control"),
2793 cfg.x_dns_prefetch_control.as_deref(),
2794 "off",
2795 );
2796
2797 if is_tls {
2798 apply_security_header(
2799 headers,
2800 header::STRICT_TRANSPORT_SECURITY,
2801 cfg.strict_transport_security.as_deref(),
2802 "max-age=63072000; includeSubDomains",
2803 );
2804 }
2805
2806 resp
2807}
2808
2809fn apply_security_header(
2820 headers: &mut axum::http::HeaderMap,
2821 name: axum::http::HeaderName,
2822 override_value: Option<&str>,
2823 default: &'static str,
2824) {
2825 use axum::http::HeaderValue;
2826
2827 match override_value {
2828 None => {
2829 headers.insert(name, HeaderValue::from_static(default));
2830 }
2831 Some("") => {
2832 }
2834 Some(v) => match HeaderValue::from_str(v) {
2835 Ok(hv) => {
2836 headers.insert(name, hv);
2837 }
2838 Err(err) => {
2839 tracing::error!(
2840 header = %name,
2841 error = %err,
2842 "invalid security header override reached middleware; using default"
2843 );
2844 headers.insert(name, HeaderValue::from_static(default));
2845 }
2846 },
2847 }
2848}
2849
2850fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2861 use axum::http::HeaderValue;
2862
2863 let fields: &[(&str, Option<&str>)] = &[
2864 (
2865 "x_content_type_options",
2866 cfg.x_content_type_options.as_deref(),
2867 ),
2868 ("x_frame_options", cfg.x_frame_options.as_deref()),
2869 ("cache_control", cfg.cache_control.as_deref()),
2870 ("referrer_policy", cfg.referrer_policy.as_deref()),
2871 (
2872 "cross_origin_opener_policy",
2873 cfg.cross_origin_opener_policy.as_deref(),
2874 ),
2875 (
2876 "cross_origin_resource_policy",
2877 cfg.cross_origin_resource_policy.as_deref(),
2878 ),
2879 (
2880 "cross_origin_embedder_policy",
2881 cfg.cross_origin_embedder_policy.as_deref(),
2882 ),
2883 ("permissions_policy", cfg.permissions_policy.as_deref()),
2884 (
2885 "x_permitted_cross_domain_policies",
2886 cfg.x_permitted_cross_domain_policies.as_deref(),
2887 ),
2888 (
2889 "content_security_policy",
2890 cfg.content_security_policy.as_deref(),
2891 ),
2892 (
2893 "x_dns_prefetch_control",
2894 cfg.x_dns_prefetch_control.as_deref(),
2895 ),
2896 (
2897 "strict_transport_security",
2898 cfg.strict_transport_security.as_deref(),
2899 ),
2900 ];
2901
2902 for (field, value) in fields {
2903 let Some(v) = value else { continue };
2904 if v.is_empty() {
2905 continue;
2906 }
2907 if let Err(err) = HeaderValue::from_str(v) {
2908 return Err(McpxError::Config(format!(
2909 "invalid security_headers.{field}: {err}"
2910 )));
2911 }
2912 }
2913
2914 if let Some(v) = cfg.strict_transport_security.as_deref()
2915 && !v.is_empty()
2916 && v.to_ascii_lowercase().contains("preload")
2917 {
2918 return Err(McpxError::Config(format!(
2919 "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2920 HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2921 )));
2922 }
2923
2924 Ok(())
2925}
2926
2927#[cfg(feature = "oauth")]
2942async fn oauth_token_cache_headers_middleware(
2943 req: Request<Body>,
2944 next: Next,
2945) -> axum::response::Response {
2946 use axum::http::{HeaderValue, header};
2947
2948 let mut resp = next.run(req).await;
2949 let headers = resp.headers_mut();
2950 headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2951 headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2952 resp
2953}
2954
2955async fn normalize_peer_addr_middleware(
2979 mut req: Request<Body>,
2980 next: Next,
2981) -> axum::response::Response {
2982 let direct = req
2983 .extensions()
2984 .get::<ConnectInfo<SocketAddr>>()
2985 .map(|ci| ci.0);
2986 let from_tls = req
2987 .extensions()
2988 .get::<ConnectInfo<TlsConnInfo>>()
2989 .map(|ci| ci.0.addr);
2990 if let Some(addr) = direct.or(from_tls) {
2991 if direct.is_none() {
2992 req.extensions_mut().insert(ConnectInfo(addr));
2993 }
2994 req.extensions_mut().insert(PeerAddr::new(addr));
2995 }
2996 next.run(req).await
2997}
2998
2999pub(crate) type ExtraRouteRateLimiter = BoundedKeyedLimiter<IpAddr>;
3003
3004const EXTRA_ROUTE_MAX_TRACKED_KEYS: usize = 10_000;
3010
3011const EXTRA_ROUTE_IDLE_EVICTION: Duration = Duration::from_mins(15);
3014
3015fn build_extra_route_rate_limiter(
3022 per_minute: u32,
3023 burst: Option<u32>,
3024) -> Arc<ExtraRouteRateLimiter> {
3025 let rate = std::num::NonZeroU32::new(per_minute.max(1)).unwrap_or(std::num::NonZeroU32::MIN);
3026 let mut quota = governor::Quota::per_minute(rate);
3027 if let Some(b) = burst.and_then(std::num::NonZeroU32::new) {
3028 quota = quota.allow_burst(b);
3029 }
3030 Arc::new(BoundedKeyedLimiter::new(
3031 quota,
3032 EXTRA_ROUTE_MAX_TRACKED_KEYS,
3033 EXTRA_ROUTE_IDLE_EVICTION,
3034 ))
3035}
3036
3037async fn extra_route_rate_limit_middleware(
3054 limiter: Arc<ExtraRouteRateLimiter>,
3055 req: Request<Body>,
3056 next: Next,
3057) -> axum::response::Response {
3058 let peer_ip: Option<IpAddr> = req
3059 .extensions()
3060 .get::<ConnectInfo<SocketAddr>>()
3061 .map(|ci| ci.0.ip())
3062 .or_else(|| {
3063 req.extensions()
3064 .get::<ConnectInfo<TlsConnInfo>>()
3065 .map(|ci| ci.0.addr.ip())
3066 });
3067 if let Some(ip) = peer_ip
3068 && let Err(wait) = limiter.check_key_wait(&ip)
3069 {
3070 tracing::warn!(%ip, "extra route request rate limited");
3071 return McpxError::RateLimitedFor {
3072 message: "too many requests to application routes from this source".into(),
3073 retry_after: wait,
3074 }
3075 .into_response();
3076 }
3077 next.run(req).await
3078}
3079
3080async fn origin_check_middleware(
3084 allowed: Arc<[String]>,
3085 log_request_headers: bool,
3086 req: Request<Body>,
3087 next: Next,
3088) -> axum::response::Response {
3089 let method = req.method().clone();
3090 let path = req.uri().path().to_owned();
3091
3092 log_incoming_request(&method, &path, req.headers(), log_request_headers);
3093
3094 if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
3095 let origin_str = origin.to_str().unwrap_or("");
3096 if !allowed.iter().any(|a| a == origin_str) {
3097 tracing::warn!(
3098 origin = origin_str,
3099 %method,
3100 %path,
3101 allowed = ?&*allowed,
3102 "rejected request: Origin not allowed"
3103 );
3104 return (
3105 axum::http::StatusCode::FORBIDDEN,
3106 "Forbidden: Origin not allowed",
3107 )
3108 .into_response();
3109 }
3110 }
3111 next.run(req).await
3112}
3113
3114fn log_incoming_request(
3117 method: &axum::http::Method,
3118 path: &str,
3119 headers: &axum::http::HeaderMap,
3120 log_request_headers: bool,
3121) {
3122 if log_request_headers {
3123 tracing::debug!(
3124 %method,
3125 %path,
3126 headers = %format_request_headers_for_log(headers),
3127 "incoming request"
3128 );
3129 } else {
3130 tracing::debug!(%method, %path, "incoming request");
3131 }
3132}
3133
3134fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
3135 headers
3136 .iter()
3137 .map(|(k, v)| {
3138 let name = k.as_str();
3139 if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
3140 format!("{name}: [REDACTED]")
3141 } else {
3142 format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
3143 }
3144 })
3145 .collect::<Vec<_>>()
3146 .join(", ")
3147}
3148
3149#[allow(
3173 clippy::cognitive_complexity,
3174 reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
3175)]
3176pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
3177where
3178 H: ServerHandler + 'static,
3179{
3180 use rmcp::ServiceExt as _;
3181
3182 tracing::info!("stdio transport: serving on stdin/stdout");
3183 tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
3184
3185 let transport = rmcp::transport::io::stdio();
3186
3187 let service = handler
3188 .serve(transport)
3189 .await
3190 .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
3191
3192 if let Err(e) = service.waiting().await {
3193 tracing::warn!(error = %e, "stdio session ended with error");
3194 }
3195 tracing::info!("stdio session ended");
3196 Ok(())
3197}
3198
3199#[cfg(test)]
3200mod tests {
3201 #![allow(
3202 clippy::unwrap_used,
3203 clippy::expect_used,
3204 clippy::panic,
3205 clippy::indexing_slicing,
3206 clippy::unwrap_in_result,
3207 clippy::print_stdout,
3208 clippy::print_stderr,
3209 deprecated,
3210 reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
3211 )]
3212 use std::{sync::Arc, time::Duration};
3213
3214 use axum::{
3215 body::Body,
3216 http::{Request, StatusCode, header},
3217 response::IntoResponse,
3218 };
3219 use http_body_util::BodyExt;
3220 use tower::ServiceExt as _;
3221
3222 use super::*;
3223
3224 #[test]
3227 fn server_config_new_defaults() {
3228 let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
3229 assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
3230 assert_eq!(cfg.name, "test-server");
3231 assert_eq!(cfg.version, "1.0.0");
3232 assert!(cfg.tls_cert_path.is_none());
3233 assert!(cfg.tls_key_path.is_none());
3234 assert!(cfg.auth.is_none());
3235 assert!(cfg.rbac.is_none());
3236 assert!(cfg.allowed_origins.is_empty());
3237 assert!(cfg.tool_rate_limit.is_none());
3238 assert!(cfg.readiness_check.is_none());
3239 assert_eq!(cfg.max_request_body, 1024 * 1024);
3240 assert_eq!(cfg.request_timeout, Duration::from_mins(2));
3241 assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
3242 assert!(!cfg.log_request_headers);
3243 assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(10));
3244 assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
3245 }
3246
3247 #[test]
3248 fn tls_handshake_builders_set_fields() {
3249 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3250 .with_tls_handshake_timeout(Duration::from_secs(3))
3251 .with_max_concurrent_tls_handshakes(64);
3252 assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(3));
3253 assert_eq!(cfg.max_concurrent_tls_handshakes, 64);
3254 }
3255
3256 #[test]
3257 fn validate_rejects_zero_tls_handshake_timeout() {
3258 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3259 .with_tls_handshake_timeout(Duration::ZERO);
3260 let err = cfg.validate().expect_err("zero handshake timeout");
3261 assert!(err.to_string().contains("tls_handshake_timeout"));
3262 }
3263
3264 #[test]
3265 fn validate_rejects_zero_max_concurrent_tls_handshakes() {
3266 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3267 .with_max_concurrent_tls_handshakes(0);
3268 let err = cfg.validate().expect_err("zero handshake concurrency");
3269 assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
3270 }
3271
3272 #[test]
3273 fn validate_consumes_and_proves() {
3274 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3276 let validated = cfg.validate().expect("valid config");
3277 assert_eq!(validated.as_inner().name, "test-server");
3279 let raw = validated.into_inner();
3281 assert_eq!(raw.name, "test-server");
3282
3283 let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3285 bad.max_request_body = 0;
3286 assert!(bad.validate().is_err(), "zero body cap must fail validate");
3287 }
3288
3289 #[test]
3290 fn validate_rejects_zero_max_concurrent_requests() {
3291 let cfg =
3292 McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
3293 let err = cfg.validate().expect_err("zero concurrency cap must fail");
3294 assert!(
3295 format!("{err}").contains("max_concurrent_requests"),
3296 "error should mention max_concurrent_requests, got: {err}"
3297 );
3298 }
3299
3300 #[test]
3301 fn validate_rejects_zero_max_tracked_keys() {
3302 let rl = crate::auth::RateLimitConfig {
3305 max_attempts_per_minute: 30,
3306 pre_auth_max_per_minute: None,
3307 max_tracked_keys: 0,
3308 idle_eviction: Duration::from_secs(15 * 60),
3309 burst: None,
3310 pre_auth_burst: None,
3311 };
3312 let auth_cfg = AuthConfig {
3313 enabled: true,
3314 api_keys: Vec::new(),
3315 mtls: None,
3316 rate_limit: Some(rl),
3317 #[cfg(feature = "oauth")]
3318 oauth: None,
3319 };
3320 let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
3321 let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
3322 assert!(
3323 format!("{err}").contains("max_tracked_keys"),
3324 "error should mention max_tracked_keys, got: {err}"
3325 );
3326 }
3327
3328 #[test]
3329 fn derive_allowed_hosts_includes_public_host() {
3330 let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
3331 assert!(
3332 hosts.iter().any(|h| h == "mcp.example.com"),
3333 "public_url host must be allowed"
3334 );
3335 }
3336
3337 #[test]
3338 fn derive_allowed_hosts_includes_bind_authority() {
3339 let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
3340 assert!(
3341 hosts.iter().any(|h| h == "127.0.0.1"),
3342 "bind host must be allowed"
3343 );
3344 assert!(
3345 hosts.iter().any(|h| h == "127.0.0.1:8080"),
3346 "bind authority must be allowed"
3347 );
3348 }
3349
3350 #[tokio::test]
3353 async fn healthz_returns_ok_json() {
3354 let resp = healthz().await.into_response();
3355 assert_eq!(resp.status(), StatusCode::OK);
3356 let body = resp.into_body().collect().await.unwrap().to_bytes();
3357 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3358 assert_eq!(json["status"], "ok");
3359 assert!(
3360 json.get("name").is_none(),
3361 "healthz must not expose server name"
3362 );
3363 assert!(
3364 json.get("version").is_none(),
3365 "healthz must not expose version"
3366 );
3367 }
3368
3369 #[tokio::test]
3372 async fn readyz_returns_ok_when_ready() {
3373 let check: ReadinessCheck =
3374 Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
3375 let resp = readyz(check).await.into_response();
3376 assert_eq!(resp.status(), StatusCode::OK);
3377 let body = resp.into_body().collect().await.unwrap().to_bytes();
3378 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3379 assert_eq!(json["ready"], true);
3380 assert!(
3381 json.get("name").is_none(),
3382 "readyz must not expose server name"
3383 );
3384 assert!(
3385 json.get("version").is_none(),
3386 "readyz must not expose version"
3387 );
3388 assert_eq!(json["db"], "connected");
3389 }
3390
3391 #[tokio::test]
3392 async fn readyz_returns_503_when_not_ready() {
3393 let check: ReadinessCheck =
3394 Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
3395 let resp = readyz(check).await.into_response();
3396 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3397 }
3398
3399 #[tokio::test]
3400 async fn readyz_returns_503_when_ready_missing() {
3401 let check: ReadinessCheck =
3402 Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
3403 let resp = readyz(check).await.into_response();
3404 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3406 }
3407
3408 fn peer_probe_router() -> axum::Router {
3413 async fn probe(req: Request<Body>) -> String {
3414 let ci = req
3415 .extensions()
3416 .get::<ConnectInfo<SocketAddr>>()
3417 .map(|c| c.0.to_string())
3418 .unwrap_or_default();
3419 let pa = req
3420 .extensions()
3421 .get::<PeerAddr>()
3422 .map(|p| p.addr.to_string())
3423 .unwrap_or_default();
3424 format!("{ci}|{pa}")
3425 }
3426 axum::Router::new()
3427 .route("/probe", axum::routing::get(probe))
3428 .layer(axum::middleware::from_fn(normalize_peer_addr_middleware))
3429 }
3430
3431 async fn body_string(resp: axum::response::Response) -> String {
3432 let bytes = resp.into_body().collect().await.unwrap().to_bytes();
3433 String::from_utf8(bytes.to_vec()).unwrap()
3434 }
3435
3436 #[tokio::test]
3437 async fn normalize_preserves_existing_connect_info_and_mirrors_peer_addr() {
3438 let plain: SocketAddr = "10.0.0.1:1111".parse().unwrap();
3441 let tls: SocketAddr = "10.0.0.2:2222".parse().unwrap();
3442 let req = Request::builder()
3443 .uri("/probe")
3444 .extension(ConnectInfo(plain))
3445 .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3446 .body(Body::empty())
3447 .unwrap();
3448 let resp = peer_probe_router().oneshot(req).await.unwrap();
3449 assert_eq!(resp.status(), StatusCode::OK);
3450 assert_eq!(body_string(resp).await, format!("{plain}|{plain}"));
3451 }
3452
3453 #[tokio::test]
3454 async fn normalize_inserts_connect_info_and_peer_addr_from_tls() {
3455 let tls: SocketAddr = "192.168.1.7:50443".parse().unwrap();
3456 let req = Request::builder()
3457 .uri("/probe")
3458 .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3459 .body(Body::empty())
3460 .unwrap();
3461 let resp = peer_probe_router().oneshot(req).await.unwrap();
3462 assert_eq!(resp.status(), StatusCode::OK);
3463 assert_eq!(body_string(resp).await, format!("{tls}|{tls}"));
3464 }
3465
3466 #[tokio::test]
3467 async fn normalize_no_op_without_any_connect_info() {
3468 let req = Request::builder()
3469 .uri("/probe")
3470 .body(Body::empty())
3471 .unwrap();
3472 let resp = peer_probe_router().oneshot(req).await.unwrap();
3473 assert_eq!(resp.status(), StatusCode::OK);
3474 assert_eq!(body_string(resp).await, "|");
3475 }
3476
3477 #[tokio::test]
3478 async fn peer_addr_extractor_rejects_when_absent() {
3479 async fn h(peer: PeerAddr) -> String {
3480 peer.addr.to_string()
3481 }
3482 let app = axum::Router::new().route("/p", axum::routing::get(h));
3483 let req = Request::builder().uri("/p").body(Body::empty()).unwrap();
3484 let resp = app.oneshot(req).await.unwrap();
3485 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
3486 }
3487
3488 #[tokio::test]
3489 async fn peer_addr_extractor_returns_value_when_present() {
3490 async fn h(peer: PeerAddr) -> String {
3491 peer.addr.to_string()
3492 }
3493 let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap();
3494 let app = axum::Router::new().route("/p", axum::routing::get(h));
3495 let req = Request::builder()
3496 .uri("/p")
3497 .extension(PeerAddr::new(addr))
3498 .body(Body::empty())
3499 .unwrap();
3500 let resp = app.oneshot(req).await.unwrap();
3501 assert_eq!(resp.status(), StatusCode::OK);
3502 assert_eq!(body_string(resp).await, addr.to_string());
3503 }
3504
3505 #[tokio::test]
3506 async fn peer_addr_via_extension_extractor() {
3507 async fn h(axum::Extension(peer): axum::Extension<PeerAddr>) -> String {
3508 peer.addr.to_string()
3509 }
3510 let addr: SocketAddr = "127.0.0.1:4242".parse().unwrap();
3511 let app = axum::Router::new().route("/p", axum::routing::get(h));
3512 let req = Request::builder()
3513 .uri("/p")
3514 .extension(PeerAddr::new(addr))
3515 .body(Body::empty())
3516 .unwrap();
3517 let resp = app.oneshot(req).await.unwrap();
3518 assert_eq!(resp.status(), StatusCode::OK);
3519 assert_eq!(body_string(resp).await, addr.to_string());
3520 }
3521
3522 fn limited_router(per_minute: u32) -> axum::Router {
3527 limited_router_with_burst(per_minute, None)
3528 }
3529
3530 fn limited_router_with_burst(per_minute: u32, burst: Option<u32>) -> axum::Router {
3532 let limiter = build_extra_route_rate_limiter(per_minute, burst);
3533 axum::Router::new()
3534 .route("/limited", axum::routing::get(|| async { "ok" }))
3535 .layer(axum::middleware::from_fn(move |req, next| {
3536 let l = Arc::clone(&limiter);
3537 extra_route_rate_limit_middleware(l, req, next)
3538 }))
3539 }
3540
3541 fn limited_req(ip: &str) -> Request<Body> {
3542 let addr: SocketAddr = format!("{ip}:40000").parse().unwrap();
3543 Request::builder()
3544 .uri("/limited")
3545 .extension(ConnectInfo(addr))
3546 .body(Body::empty())
3547 .unwrap()
3548 }
3549
3550 #[tokio::test]
3551 async fn extra_route_limiter_denies_over_quota() {
3552 let app = limited_router(2);
3553 for i in 0..2 {
3554 let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3555 assert_eq!(resp.status(), StatusCode::OK, "request {i} should pass");
3556 }
3557 let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3558 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3559 let body = body_string(resp).await;
3560 assert!(
3561 body.contains("too many requests to application routes"),
3562 "deny body should match the limiter message, got: {body}"
3563 );
3564 }
3565
3566 #[tokio::test]
3567 async fn extra_route_limiter_isolates_keys() {
3568 let app = limited_router(2);
3569 for _ in 0..2 {
3570 let resp = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3571 assert_eq!(resp.status(), StatusCode::OK);
3572 }
3573 let exhausted = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3574 assert_eq!(exhausted.status(), StatusCode::TOO_MANY_REQUESTS);
3575 let other = app.clone().oneshot(limited_req("10.3.3.3")).await.unwrap();
3577 assert_eq!(other.status(), StatusCode::OK);
3578 }
3579
3580 #[tokio::test]
3581 async fn extra_route_limiter_fails_open_without_peer() {
3582 let app = limited_router(1);
3583 for i in 0..3 {
3584 let req = Request::builder()
3585 .uri("/limited")
3586 .body(Body::empty())
3587 .unwrap();
3588 let resp = app.clone().oneshot(req).await.unwrap();
3589 assert_eq!(
3590 resp.status(),
3591 StatusCode::OK,
3592 "request {i} should fail open"
3593 );
3594 }
3595 }
3596
3597 #[tokio::test]
3598 async fn extra_route_limiter_extracts_tls_conn_info() {
3599 let app = limited_router(2);
3600 let mk = || {
3601 let addr: SocketAddr = "192.168.9.9:55555".parse().unwrap();
3602 Request::builder()
3603 .uri("/limited")
3604 .extension(ConnectInfo(TlsConnInfo::new(addr, None)))
3605 .body(Body::empty())
3606 .unwrap()
3607 };
3608 for _ in 0..2 {
3609 assert_eq!(
3610 app.clone().oneshot(mk()).await.unwrap().status(),
3611 StatusCode::OK
3612 );
3613 }
3614 let resp = app.clone().oneshot(mk()).await.unwrap();
3615 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3616 }
3617
3618 #[test]
3619 fn validate_rejects_zero_extra_route_rate_limit() {
3620 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3621 .with_extra_route_rate_limit(0);
3622 let err = cfg.validate().expect_err("zero extra route rate limit");
3623 assert!(err.to_string().contains("extra_route_rate_limit"));
3624 }
3625
3626 #[tokio::test]
3627 async fn extra_route_limiter_burst_allows_initial_spike() {
3628 let app = limited_router_with_burst(1, Some(3));
3629 for i in 0..3 {
3630 let resp = app.clone().oneshot(limited_req("10.4.4.4")).await.unwrap();
3631 assert_eq!(resp.status(), StatusCode::OK, "burst request {i}");
3632 }
3633 let resp = app.clone().oneshot(limited_req("10.4.4.4")).await.unwrap();
3634 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3635 }
3636
3637 #[tokio::test]
3638 async fn extra_route_limiter_deny_sets_retry_after() {
3639 let app = limited_router(1);
3640 let ok = app.clone().oneshot(limited_req("10.5.5.5")).await.unwrap();
3641 assert_eq!(ok.status(), StatusCode::OK);
3642 let denied = app.clone().oneshot(limited_req("10.5.5.5")).await.unwrap();
3643 assert_eq!(denied.status(), StatusCode::TOO_MANY_REQUESTS);
3644 let retry_after = denied
3645 .headers()
3646 .get(header::RETRY_AFTER)
3647 .expect("Retry-After present")
3648 .to_str()
3649 .unwrap()
3650 .parse::<u64>()
3651 .unwrap();
3652 assert!(retry_after >= 1, "delta-seconds must be >= 1");
3653 }
3654
3655 #[test]
3656 fn validate_rejects_zero_burst_knobs() {
3657 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3658 .with_tool_rate_limit(10)
3659 .with_tool_rate_limit_burst(0)
3660 .validate()
3661 .expect_err("zero tool burst");
3662 assert!(err.to_string().contains("tool_rate_limit_burst"));
3663
3664 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3665 .with_extra_route_rate_limit(10)
3666 .with_extra_route_rate_limit_burst(0)
3667 .validate()
3668 .expect_err("zero extra route burst");
3669 assert!(err.to_string().contains("extra_route_rate_limit_burst"));
3670 }
3671
3672 #[test]
3673 fn validate_rejects_orphan_burst_knobs() {
3674 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3675 .with_tool_rate_limit_burst(5)
3676 .validate()
3677 .expect_err("orphan tool burst");
3678 assert!(err.to_string().contains("requires tool_rate_limit"));
3679
3680 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3681 .with_extra_route_rate_limit_burst(5)
3682 .validate()
3683 .expect_err("orphan extra route burst");
3684 assert!(err.to_string().contains("requires extra_route_rate_limit"));
3685 }
3686
3687 #[test]
3688 fn validate_rejects_zero_auth_bursts() {
3689 let auth = AuthConfig::with_keys(vec![])
3690 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_burst(0));
3691 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3692 .with_auth(auth)
3693 .validate()
3694 .expect_err("zero auth burst");
3695 assert!(err.to_string().contains("rate_limit.burst"));
3696
3697 let auth = AuthConfig::with_keys(vec![])
3698 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(0));
3699 let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3700 .with_auth(auth)
3701 .validate()
3702 .expect_err("zero pre-auth burst");
3703 assert!(err.to_string().contains("pre_auth_burst"));
3704 }
3705
3706 #[test]
3709 fn validate_accepts_pre_auth_burst_without_explicit_pre_auth_rate() {
3710 let auth = AuthConfig::with_keys(vec![])
3711 .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(50));
3712 let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0").with_auth(auth);
3713 assert!(cfg.validate().is_ok(), "pre_auth_burst has no orphan rule");
3714 }
3715
3716 fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
3720 let allowed: Arc<[String]> = Arc::from(origins);
3721 axum::Router::new()
3722 .route("/test", axum::routing::get(|| async { "ok" }))
3723 .layer(axum::middleware::from_fn(move |req, next| {
3724 let a = Arc::clone(&allowed);
3725 origin_check_middleware(a, log_request_headers, req, next)
3726 }))
3727 }
3728
3729 #[tokio::test]
3730 async fn origin_allowed_passes() {
3731 let app = origin_router(vec!["http://localhost:3000".into()], false);
3732 let req = Request::builder()
3733 .uri("/test")
3734 .header(header::ORIGIN, "http://localhost:3000")
3735 .body(Body::empty())
3736 .unwrap();
3737 let resp = app.oneshot(req).await.unwrap();
3738 assert_eq!(resp.status(), StatusCode::OK);
3739 }
3740
3741 #[tokio::test]
3742 async fn origin_rejected_returns_403() {
3743 let app = origin_router(vec!["http://localhost:3000".into()], false);
3744 let req = Request::builder()
3745 .uri("/test")
3746 .header(header::ORIGIN, "http://evil.com")
3747 .body(Body::empty())
3748 .unwrap();
3749 let resp = app.oneshot(req).await.unwrap();
3750 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
3751 }
3752
3753 #[tokio::test]
3754 async fn no_origin_header_passes() {
3755 let app = origin_router(vec!["http://localhost:3000".into()], false);
3756 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3757 let resp = app.oneshot(req).await.unwrap();
3758 assert_eq!(resp.status(), StatusCode::OK);
3759 }
3760
3761 #[tokio::test]
3762 async fn empty_allowlist_rejects_any_origin() {
3763 let app = origin_router(vec![], false);
3764 let req = Request::builder()
3765 .uri("/test")
3766 .header(header::ORIGIN, "http://anything.com")
3767 .body(Body::empty())
3768 .unwrap();
3769 let resp = app.oneshot(req).await.unwrap();
3770 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
3771 }
3772
3773 #[tokio::test]
3774 async fn empty_allowlist_passes_without_origin() {
3775 let app = origin_router(vec![], false);
3776 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3777 let resp = app.oneshot(req).await.unwrap();
3778 assert_eq!(resp.status(), StatusCode::OK);
3779 }
3780
3781 #[test]
3782 fn format_request_headers_redacts_sensitive_values() {
3783 let mut headers = axum::http::HeaderMap::new();
3784 headers.insert("authorization", "Bearer secret-token".parse().unwrap());
3785 headers.insert("cookie", "sid=abc".parse().unwrap());
3786 headers.insert("x-request-id", "req-123".parse().unwrap());
3787
3788 let out = format_request_headers_for_log(&headers);
3789 assert!(out.contains("authorization: [REDACTED]"));
3790 assert!(out.contains("cookie: [REDACTED]"));
3791 assert!(out.contains("x-request-id: req-123"));
3792 assert!(!out.contains("secret-token"));
3793 }
3794
3795 fn security_router(is_tls: bool) -> axum::Router {
3798 security_router_with(is_tls, SecurityHeadersConfig::default())
3799 }
3800
3801 fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
3802 let cfg = Arc::new(cfg);
3803 axum::Router::new()
3804 .route("/test", axum::routing::get(|| async { "ok" }))
3805 .layer(axum::middleware::from_fn(move |req, next| {
3806 let c = Arc::clone(&cfg);
3807 security_headers_middleware(is_tls, c, req, next)
3808 }))
3809 }
3810
3811 #[tokio::test]
3812 async fn security_headers_set_on_response() {
3813 let app = security_router(false);
3814 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3815 let resp = app.oneshot(req).await.unwrap();
3816 assert_eq!(resp.status(), StatusCode::OK);
3817
3818 let h = resp.headers();
3819 assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
3820 assert_eq!(h.get("x-frame-options").unwrap(), "deny");
3821 assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
3822 assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
3823 assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
3824 assert_eq!(
3825 h.get("cross-origin-resource-policy").unwrap(),
3826 "same-origin"
3827 );
3828 assert_eq!(
3829 h.get("cross-origin-embedder-policy").unwrap(),
3830 "require-corp"
3831 );
3832 assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
3833 assert!(
3834 h.get("permissions-policy")
3835 .unwrap()
3836 .to_str()
3837 .unwrap()
3838 .contains("camera=()"),
3839 "permissions-policy must restrict browser features"
3840 );
3841 assert_eq!(
3842 h.get("content-security-policy").unwrap(),
3843 "default-src 'none'; frame-ancestors 'none'"
3844 );
3845 assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
3846 assert!(h.get("strict-transport-security").is_none());
3848 }
3849
3850 #[tokio::test]
3851 async fn hsts_set_when_tls_enabled() {
3852 let app = security_router(true);
3853 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3854 let resp = app.oneshot(req).await.unwrap();
3855
3856 let hsts = resp.headers().get("strict-transport-security").unwrap();
3857 assert!(
3858 hsts.to_str().unwrap().contains("max-age=63072000"),
3859 "HSTS must set 2-year max-age"
3860 );
3861 }
3862
3863 fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
3869 let cfg =
3870 McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
3871 cfg.check()
3872 }
3873
3874 #[test]
3875 fn security_headers_config_default_validates() {
3876 check_with_security_headers(SecurityHeadersConfig::default())
3877 .expect("default SecurityHeadersConfig must validate");
3878 }
3879
3880 #[test]
3881 fn security_headers_config_validate_accepts_empty_string() {
3882 let h = SecurityHeadersConfig {
3884 x_content_type_options: Some(String::new()),
3885 x_frame_options: Some(String::new()),
3886 cache_control: Some(String::new()),
3887 referrer_policy: Some(String::new()),
3888 cross_origin_opener_policy: Some(String::new()),
3889 cross_origin_resource_policy: Some(String::new()),
3890 cross_origin_embedder_policy: Some(String::new()),
3891 permissions_policy: Some(String::new()),
3892 x_permitted_cross_domain_policies: Some(String::new()),
3893 content_security_policy: Some(String::new()),
3894 x_dns_prefetch_control: Some(String::new()),
3895 strict_transport_security: Some(String::new()),
3896 };
3897 check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
3898 }
3899
3900 #[test]
3901 fn security_headers_config_validate_rejects_bad_value() {
3902 let h = SecurityHeadersConfig {
3904 referrer_policy: Some("\u{0007}".into()),
3905 ..SecurityHeadersConfig::default()
3906 };
3907 let err = check_with_security_headers(h)
3908 .expect_err("control char in referrer_policy must reject");
3909 let msg = err.to_string();
3910 assert!(
3911 msg.contains("referrer_policy"),
3912 "error must name the offending field, got: {msg}"
3913 );
3914 }
3915
3916 #[test]
3917 fn security_headers_config_validate_rejects_hsts_preload() {
3918 let h = SecurityHeadersConfig {
3919 strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
3920 ..SecurityHeadersConfig::default()
3921 };
3922 let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
3923 let msg = err.to_string();
3924 assert!(
3925 msg.contains("strict_transport_security"),
3926 "error must name the field, got: {msg}"
3927 );
3928 assert!(
3929 msg.to_lowercase().contains("preload"),
3930 "error must mention `preload`, got: {msg}"
3931 );
3932 }
3933
3934 #[test]
3935 fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
3936 let h = SecurityHeadersConfig {
3938 strict_transport_security: Some("max-age=600; PRELOAD".into()),
3939 ..SecurityHeadersConfig::default()
3940 };
3941 check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
3942 }
3943
3944 #[tokio::test]
3945 async fn security_headers_override_honored() {
3946 let h = SecurityHeadersConfig {
3948 x_frame_options: Some("SAMEORIGIN".into()),
3949 ..SecurityHeadersConfig::default()
3950 };
3951 let app = security_router_with(false, h);
3952 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3953 let resp = app.oneshot(req).await.unwrap();
3954 assert_eq!(resp.status(), StatusCode::OK);
3955
3956 let xfo = resp.headers().get("x-frame-options").unwrap();
3957 assert_eq!(xfo, "SAMEORIGIN");
3958 }
3959
3960 #[tokio::test]
3961 async fn security_headers_empty_string_omits() {
3962 let h = SecurityHeadersConfig {
3964 referrer_policy: Some(String::new()),
3965 ..SecurityHeadersConfig::default()
3966 };
3967 let app = security_router_with(false, h);
3968 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3969 let resp = app.oneshot(req).await.unwrap();
3970 assert_eq!(resp.status(), StatusCode::OK);
3971
3972 assert!(
3973 resp.headers().get("referrer-policy").is_none(),
3974 "Some(\"\") must omit the header"
3975 );
3976 assert_eq!(
3978 resp.headers().get("x-content-type-options").unwrap(),
3979 "nosniff"
3980 );
3981 }
3982
3983 #[tokio::test]
3984 async fn security_headers_hsts_only_when_tls() {
3985 let h = SecurityHeadersConfig {
3987 strict_transport_security: Some("max-age=600".into()),
3988 ..SecurityHeadersConfig::default()
3989 };
3990 let app = security_router_with(false, h);
3991 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3992 let resp = app.oneshot(req).await.unwrap();
3993 assert!(
3994 resp.headers().get("strict-transport-security").is_none(),
3995 "HSTS must remain absent on plaintext deployments even with override"
3996 );
3997 }
3998
3999 #[cfg(feature = "oauth")]
4002 #[tokio::test]
4003 async fn oauth_token_cache_headers_set_pragma_and_vary() {
4004 let app = axum::Router::new()
4005 .route("/token", axum::routing::post(|| async { "{}" }))
4006 .layer(axum::middleware::from_fn(
4007 oauth_token_cache_headers_middleware,
4008 ));
4009 let req = Request::builder()
4010 .method("POST")
4011 .uri("/token")
4012 .body(Body::from("{}"))
4013 .unwrap();
4014 let resp = app.oneshot(req).await.unwrap();
4015 assert_eq!(resp.status(), StatusCode::OK);
4016
4017 let h = resp.headers();
4018 assert_eq!(
4019 h.get("pragma").unwrap(),
4020 "no-cache",
4021 "RFC 6749 §5.1: token responses must set Pragma: no-cache"
4022 );
4023 let vary_values: Vec<String> = h
4024 .get_all("vary")
4025 .iter()
4026 .filter_map(|v| v.to_str().ok().map(str::to_owned))
4027 .collect();
4028 assert!(
4029 vary_values
4030 .iter()
4031 .any(|v| v.eq_ignore_ascii_case("Authorization")),
4032 "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
4033 );
4034 }
4035
4036 #[cfg(feature = "oauth")]
4037 #[tokio::test]
4038 async fn oauth_token_cache_headers_preserve_existing_vary() {
4039 let app = axum::Router::new()
4042 .route(
4043 "/token",
4044 axum::routing::post(|| async {
4045 axum::response::Response::builder()
4046 .header("vary", "Accept-Encoding")
4047 .body(axum::body::Body::from("{}"))
4048 .unwrap()
4049 }),
4050 )
4051 .layer(axum::middleware::from_fn(
4052 oauth_token_cache_headers_middleware,
4053 ));
4054 let req = Request::builder()
4055 .method("POST")
4056 .uri("/token")
4057 .body(Body::empty())
4058 .unwrap();
4059 let resp = app.oneshot(req).await.unwrap();
4060
4061 let vary: Vec<String> = resp
4062 .headers()
4063 .get_all("vary")
4064 .iter()
4065 .filter_map(|v| v.to_str().ok().map(str::to_owned))
4066 .collect();
4067 assert!(
4068 vary.iter().any(|v| v.contains("Accept-Encoding")),
4069 "must preserve pre-existing Vary value, got {vary:?}"
4070 );
4071 assert!(
4072 vary.iter().any(|v| v.contains("Authorization")),
4073 "must append Authorization to Vary, got {vary:?}"
4074 );
4075 }
4076
4077 #[test]
4080 fn version_payload_contains_expected_fields() {
4081 let v = version_payload("my-server", "1.2.3");
4082 assert_eq!(v["name"], "my-server");
4083 assert_eq!(v["version"], "1.2.3");
4084 assert!(v["build_git_sha"].is_string());
4085 assert!(v["build_timestamp"].is_string());
4086 assert!(v["rust_version"].is_string());
4087 assert!(v["mcpx_version"].is_string());
4088 }
4089
4090 #[tokio::test]
4093 async fn concurrency_limit_layer_composes_and_serves() {
4094 let app = axum::Router::new()
4098 .route("/ok", axum::routing::get(|| async { "ok" }))
4099 .layer(
4100 tower::ServiceBuilder::new()
4101 .layer(axum::error_handling::HandleErrorLayer::new(
4102 |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
4103 ))
4104 .layer(tower::load_shed::LoadShedLayer::new())
4105 .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
4106 );
4107 let resp = app
4108 .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
4109 .await
4110 .unwrap();
4111 assert_eq!(resp.status(), StatusCode::OK);
4112 }
4113
4114 #[tokio::test]
4117 async fn compression_layer_gzip_encodes_response() {
4118 use tower_http::compression::Predicate as _;
4119
4120 let big_body = "a".repeat(4096);
4121 let app = axum::Router::new()
4122 .route(
4123 "/big",
4124 axum::routing::get(move || {
4125 let body = big_body.clone();
4126 async move { body }
4127 }),
4128 )
4129 .layer(
4130 tower_http::compression::CompressionLayer::new()
4131 .gzip(true)
4132 .br(true)
4133 .compress_when(
4134 tower_http::compression::DefaultPredicate::new()
4135 .and(tower_http::compression::predicate::SizeAbove::new(1024)),
4136 ),
4137 );
4138
4139 let req = Request::builder()
4140 .uri("/big")
4141 .header(header::ACCEPT_ENCODING, "gzip")
4142 .body(Body::empty())
4143 .unwrap();
4144 let resp = app.oneshot(req).await.unwrap();
4145 assert_eq!(resp.status(), StatusCode::OK);
4146 assert_eq!(
4147 resp.headers().get(header::CONTENT_ENCODING).unwrap(),
4148 "gzip"
4149 );
4150 }
4151
4152 #[tokio::test]
4155 async fn tls_handshake_timeout_reaps_idle_connections() {
4156 use tokio::io::AsyncReadExt as _;
4157
4158 let _ = rustls::crypto::ring::default_provider().install_default();
4159
4160 let key = rcgen::KeyPair::generate().expect("generate key");
4162 let cert = rcgen::CertificateParams::new(vec!["localhost".to_owned()])
4163 .expect("cert params")
4164 .self_signed(&key)
4165 .expect("self-signed cert");
4166 let dir = std::env::temp_dir().join(format!(
4167 "rmcp-server-kit-hs-timeout-{}",
4168 std::time::SystemTime::now()
4169 .duration_since(std::time::UNIX_EPOCH)
4170 .expect("clock after epoch")
4171 .as_nanos()
4172 ));
4173 tokio::fs::create_dir_all(&dir).await.expect("temp dir");
4174 let cert_path = dir.join("server.crt");
4175 let key_path = dir.join("server.key");
4176 tokio::fs::write(&cert_path, cert.pem())
4177 .await
4178 .expect("write cert");
4179 tokio::fs::write(&key_path, key.serialize_pem())
4180 .await
4181 .expect("write key");
4182
4183 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
4184 let tls = TlsListener::new(
4185 listener,
4186 &cert_path,
4187 &key_path,
4188 None,
4189 None,
4190 Duration::from_millis(200),
4191 8, )
4193 .expect("tls listener");
4194 let addr = axum::serve::Listener::local_addr(&tls).expect("local addr");
4195
4196 let mut idle = tokio::net::TcpStream::connect(addr).await.expect("connect");
4200 let mut buf = [0_u8; 16];
4201 let read = tokio::time::timeout(Duration::from_secs(2), idle.read(&mut buf))
4202 .await
4203 .expect("server must reap the idle handshake within its timeout");
4204 match read {
4205 Ok(0) | Err(_) => {} Ok(n) => panic!("unexpected {n} bytes from server during reaped handshake"),
4207 }
4208
4209 drop(tls);
4210 }
4211}