1use std::{
2 future::Future,
3 net::SocketAddr,
4 path::{Path, PathBuf},
5 pin::Pin,
6 sync::Arc,
7 time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{body::Body, extract::Request, middleware::Next, response::IntoResponse};
12use rmcp::{
13 ServerHandler,
14 transport::streamable_http_server::{
15 StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16 },
17};
18use rustls::RootCertStore;
19use tokio::{
20 net::TcpListener,
21 sync::{Semaphore, mpsc},
22};
23use tokio_util::sync::CancellationToken;
24
25use crate::{
26 auth::{
27 AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
28 build_rate_limiter, extract_mtls_identity,
29 },
30 error::McpxError,
31 mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
32 rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
33};
34
35#[allow(
39 clippy::needless_pass_by_value,
40 reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
41)]
42fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
43 McpxError::Startup(format!("{e:#}"))
44}
45
46#[allow(
52 clippy::needless_pass_by_value,
53 reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
54)]
55fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
56 McpxError::Startup(format!("{op}: {e}"))
57}
58
59pub type ReadinessCheck =
64 Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
65
66#[derive(Debug, Clone, Default)]
87#[non_exhaustive]
88pub struct SecurityHeadersConfig {
89 pub x_content_type_options: Option<String>,
91 pub x_frame_options: Option<String>,
93 pub cache_control: Option<String>,
95 pub referrer_policy: Option<String>,
97 pub cross_origin_opener_policy: Option<String>,
99 pub cross_origin_resource_policy: Option<String>,
101 pub cross_origin_embedder_policy: Option<String>,
103 pub permissions_policy: Option<String>,
106 pub x_permitted_cross_domain_policies: Option<String>,
108 pub content_security_policy: Option<String>,
111 pub x_dns_prefetch_control: Option<String>,
113 pub strict_transport_security: Option<String>,
118}
119
120#[allow(
122 missing_debug_implementations,
123 reason = "contains callback/trait objects that don't impl Debug"
124)]
125#[allow(
126 clippy::struct_excessive_bools,
127 reason = "server configuration naturally has many boolean feature flags"
128)]
129#[non_exhaustive]
130pub struct McpServerConfig {
131 #[deprecated(
133 since = "0.13.0",
134 note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
135 )]
136 pub bind_addr: String,
137 #[deprecated(
139 since = "0.13.0",
140 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
141 )]
142 pub name: String,
143 #[deprecated(
145 since = "0.13.0",
146 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
147 )]
148 pub version: String,
149 #[deprecated(
151 since = "0.13.0",
152 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
153 )]
154 pub tls_cert_path: Option<PathBuf>,
155 #[deprecated(
157 since = "0.13.0",
158 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
159 )]
160 pub tls_key_path: Option<PathBuf>,
161 #[deprecated(
164 since = "0.13.0",
165 note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
166 )]
167 pub auth: Option<AuthConfig>,
168 #[deprecated(
171 since = "0.13.0",
172 note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
173 )]
174 pub rbac: Option<Arc<RbacPolicy>>,
175 #[deprecated(
181 since = "0.13.0",
182 note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
183 )]
184 pub allowed_origins: Vec<String>,
185 #[deprecated(
188 since = "0.13.0",
189 note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
190 )]
191 pub tool_rate_limit: Option<u32>,
192 #[deprecated(
195 since = "0.13.0",
196 note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
197 )]
198 pub readiness_check: Option<ReadinessCheck>,
199 #[deprecated(
202 since = "0.13.0",
203 note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
204 )]
205 pub max_request_body: usize,
206 #[deprecated(
209 since = "0.13.0",
210 note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
211 )]
212 pub request_timeout: Duration,
213 #[deprecated(
216 since = "0.13.0",
217 note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
218 )]
219 pub shutdown_timeout: Duration,
220 #[deprecated(
223 since = "0.13.0",
224 note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
225 )]
226 pub session_idle_timeout: Duration,
227 #[deprecated(
230 since = "0.13.0",
231 note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
232 )]
233 pub sse_keep_alive: Duration,
234 #[deprecated(
238 since = "0.13.0",
239 note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
240 )]
241 pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
242 #[deprecated(
246 since = "0.13.0",
247 note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
248 )]
249 pub extra_router: Option<axum::Router>,
250 #[deprecated(
255 since = "0.13.0",
256 note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
257 )]
258 pub public_url: Option<String>,
259 #[deprecated(
262 since = "0.13.0",
263 note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
264 )]
265 pub log_request_headers: bool,
266 #[deprecated(
269 since = "0.13.0",
270 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
271 )]
272 pub compression_enabled: bool,
273 #[deprecated(
276 since = "0.13.0",
277 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
278 )]
279 pub compression_min_size: u16,
280 #[deprecated(
284 since = "0.13.0",
285 note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
286 )]
287 pub max_concurrent_requests: Option<usize>,
288 #[deprecated(
291 since = "0.13.0",
292 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
293 )]
294 pub admin_enabled: bool,
295 #[deprecated(
297 since = "0.13.0",
298 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
299 )]
300 pub admin_role: String,
301 #[cfg(feature = "metrics")]
304 #[deprecated(
305 since = "0.13.0",
306 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
307 )]
308 pub metrics_enabled: bool,
309 #[cfg(feature = "metrics")]
311 #[deprecated(
312 since = "0.13.0",
313 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
314 )]
315 pub metrics_bind: String,
316 #[deprecated(
320 since = "1.5.0",
321 note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
322 )]
323 pub security_headers: SecurityHeadersConfig,
324 #[deprecated(
330 since = "1.9.0",
331 note = "use McpServerConfig::with_tls_handshake_timeout(); direct field access will become pub(crate) in a future major release"
332 )]
333 pub tls_handshake_timeout: Duration,
334 #[deprecated(
341 since = "1.9.0",
342 note = "use McpServerConfig::with_max_concurrent_tls_handshakes(); direct field access will become pub(crate) in a future major release"
343 )]
344 pub max_concurrent_tls_handshakes: usize,
345}
346
347#[allow(
405 missing_debug_implementations,
406 reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
407)]
408pub struct Validated<T>(T);
409
410impl<T> std::fmt::Debug for Validated<T> {
411 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412 f.debug_struct("Validated").finish_non_exhaustive()
413 }
414}
415
416impl<T> Validated<T> {
417 #[must_use]
419 pub fn as_inner(&self) -> &T {
420 &self.0
421 }
422
423 #[must_use]
428 pub fn into_inner(self) -> T {
429 self.0
430 }
431}
432
433#[allow(
434 deprecated,
435 reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
436)]
437impl McpServerConfig {
438 #[must_use]
446 pub fn new(
447 bind_addr: impl Into<String>,
448 name: impl Into<String>,
449 version: impl Into<String>,
450 ) -> Self {
451 Self {
452 bind_addr: bind_addr.into(),
453 name: name.into(),
454 version: version.into(),
455 tls_cert_path: None,
456 tls_key_path: None,
457 auth: None,
458 rbac: None,
459 allowed_origins: Vec::new(),
460 tool_rate_limit: None,
461 readiness_check: None,
462 max_request_body: 1024 * 1024,
463 request_timeout: Duration::from_mins(2),
464 shutdown_timeout: Duration::from_secs(30),
465 session_idle_timeout: Duration::from_mins(20),
466 sse_keep_alive: Duration::from_secs(15),
467 on_reload_ready: None,
468 extra_router: None,
469 public_url: None,
470 log_request_headers: false,
471 compression_enabled: false,
472 compression_min_size: 1024,
473 max_concurrent_requests: None,
474 admin_enabled: false,
475 admin_role: "admin".to_owned(),
476 #[cfg(feature = "metrics")]
477 metrics_enabled: false,
478 #[cfg(feature = "metrics")]
479 metrics_bind: "127.0.0.1:9090".into(),
480 security_headers: SecurityHeadersConfig::default(),
481 tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
482 max_concurrent_tls_handshakes: DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES,
483 }
484 }
485
486 #[must_use]
496 pub fn with_auth(mut self, auth: AuthConfig) -> Self {
497 self.auth = Some(auth);
498 self
499 }
500
501 #[must_use]
506 pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
507 self.security_headers = headers;
508 self
509 }
510
511 #[must_use]
515 pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
516 self.bind_addr = addr.into();
517 self
518 }
519
520 #[must_use]
523 pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
524 self.rbac = Some(rbac);
525 self
526 }
527
528 #[must_use]
532 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
533 self.tls_cert_path = Some(cert_path.into());
534 self.tls_key_path = Some(key_path.into());
535 self
536 }
537
538 #[must_use]
542 pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
543 self.public_url = Some(url.into());
544 self
545 }
546
547 #[must_use]
551 pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
552 where
553 I: IntoIterator<Item = S>,
554 S: Into<String>,
555 {
556 self.allowed_origins = origins.into_iter().map(Into::into).collect();
557 self
558 }
559
560 #[must_use]
564 pub fn with_extra_router(mut self, router: axum::Router) -> Self {
565 self.extra_router = Some(router);
566 self
567 }
568
569 #[must_use]
572 pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
573 self.readiness_check = Some(check);
574 self
575 }
576
577 #[must_use]
580 pub fn with_max_request_body(mut self, bytes: usize) -> Self {
581 self.max_request_body = bytes;
582 self
583 }
584
585 #[must_use]
587 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
588 self.request_timeout = timeout;
589 self
590 }
591
592 #[must_use]
594 pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
595 self.shutdown_timeout = timeout;
596 self
597 }
598
599 #[must_use]
601 pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
602 self.session_idle_timeout = timeout;
603 self
604 }
605
606 #[must_use]
608 pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
609 self.sse_keep_alive = interval;
610 self
611 }
612
613 #[must_use]
617 pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
618 self.max_concurrent_requests = Some(limit);
619 self
620 }
621
622 #[must_use]
630 pub fn with_tls_handshake_timeout(mut self, timeout: Duration) -> Self {
631 self.tls_handshake_timeout = timeout;
632 self
633 }
634
635 #[must_use]
644 pub fn with_max_concurrent_tls_handshakes(mut self, limit: usize) -> Self {
645 self.max_concurrent_tls_handshakes = limit;
646 self
647 }
648
649 #[must_use]
652 pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
653 self.tool_rate_limit = Some(per_minute);
654 self
655 }
656
657 #[must_use]
661 pub fn with_reload_callback<F>(mut self, callback: F) -> Self
662 where
663 F: FnOnce(ReloadHandle) + Send + 'static,
664 {
665 self.on_reload_ready = Some(Box::new(callback));
666 self
667 }
668
669 #[must_use]
673 pub fn enable_compression(mut self, min_size: u16) -> Self {
674 self.compression_enabled = true;
675 self.compression_min_size = min_size;
676 self
677 }
678
679 #[must_use]
684 pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
685 self.admin_enabled = true;
686 self.admin_role = role.into();
687 self
688 }
689
690 #[must_use]
693 pub fn enable_request_header_logging(mut self) -> Self {
694 self.log_request_headers = true;
695 self
696 }
697
698 #[cfg(feature = "metrics")]
701 #[must_use]
702 pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
703 self.metrics_enabled = true;
704 self.metrics_bind = bind.into();
705 self
706 }
707
708 pub fn validate(self) -> Result<Validated<Self>, McpxError> {
741 self.check()?;
742 Ok(Validated(self))
743 }
744
745 fn check(&self) -> Result<(), McpxError> {
749 if self.admin_enabled {
753 let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
754 if !auth_enabled {
755 return Err(McpxError::Config(
756 "admin_enabled=true requires auth to be configured and enabled".into(),
757 ));
758 }
759 }
760
761 match (&self.tls_cert_path, &self.tls_key_path) {
763 (Some(_), None) => {
764 return Err(McpxError::Config(
765 "tls_cert_path is set but tls_key_path is missing".into(),
766 ));
767 }
768 (None, Some(_)) => {
769 return Err(McpxError::Config(
770 "tls_key_path is set but tls_cert_path is missing".into(),
771 ));
772 }
773 _ => {}
774 }
775
776 if self.bind_addr.parse::<SocketAddr>().is_err() {
778 return Err(McpxError::Config(format!(
779 "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
780 self.bind_addr
781 )));
782 }
783
784 if let Some(ref url) = self.public_url
786 && !(url.starts_with("http://") || url.starts_with("https://"))
787 {
788 return Err(McpxError::Config(format!(
789 "public_url {url:?} must start with http:// or https://"
790 )));
791 }
792
793 for origin in &self.allowed_origins {
795 if !(origin.starts_with("http://") || origin.starts_with("https://")) {
796 return Err(McpxError::Config(format!(
797 "allowed_origins entry {origin:?} must start with http:// or https://"
798 )));
799 }
800 }
801
802 if self.max_request_body == 0 {
804 return Err(McpxError::Config(
805 "max_request_body must be greater than zero".into(),
806 ));
807 }
808
809 #[cfg(feature = "oauth")]
811 if let Some(auth_cfg) = &self.auth
812 && let Some(oauth_cfg) = &auth_cfg.oauth
813 {
814 oauth_cfg.validate()?;
815 }
816
817 validate_security_headers(&self.security_headers)?;
820
821 if let Some(0) = self.max_concurrent_requests {
825 return Err(McpxError::Config(
826 "max_concurrent_requests must be greater than zero when set".into(),
827 ));
828 }
829
830 if let Some(auth_cfg) = &self.auth
834 && let Some(rl) = &auth_cfg.rate_limit
835 && rl.max_tracked_keys == 0
836 {
837 return Err(McpxError::Config(
838 "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
839 ));
840 }
841
842 if self.tls_handshake_timeout == Duration::ZERO {
847 return Err(McpxError::Config(
848 "tls_handshake_timeout must be greater than zero".into(),
849 ));
850 }
851
852 if self.max_concurrent_tls_handshakes == 0 {
857 return Err(McpxError::Config(
858 "max_concurrent_tls_handshakes must be greater than zero".into(),
859 ));
860 }
861
862 Ok(())
863 }
864}
865
866#[allow(
872 missing_debug_implementations,
873 reason = "contains Arc<AuthState> with non-Debug fields"
874)]
875pub struct ReloadHandle {
876 auth: Option<Arc<AuthState>>,
877 rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
878 crl_set: Option<Arc<CrlSet>>,
879}
880
881impl ReloadHandle {
882 pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
884 if let Some(ref auth) = self.auth {
885 auth.reload_keys(keys);
886 }
887 }
888
889 pub fn reload_rbac(&self, policy: RbacPolicy) {
891 if let Some(ref rbac) = self.rbac {
892 rbac.store(Arc::new(policy));
893 tracing::info!("RBAC policy reloaded");
894 }
895 }
896
897 pub async fn refresh_crls(&self) -> Result<(), McpxError> {
903 let Some(ref crl_set) = self.crl_set else {
904 return Err(McpxError::Config(
905 "CRL refresh requested but mTLS CRL support is not configured".into(),
906 ));
907 };
908
909 crl_set.force_refresh().await
910 }
911}
912
913#[allow(
930 clippy::too_many_lines,
931 clippy::cognitive_complexity,
932 reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
933)]
934struct AppRunParams {
938 tls_paths: Option<(PathBuf, PathBuf)>,
940 tls_handshake_timeout: Duration,
942 max_concurrent_tls_handshakes: usize,
944 mtls_config: Option<MtlsConfig>,
946 shutdown_timeout: Duration,
948 auth_state: Option<Arc<AuthState>>,
950 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
952 on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
954 ct: CancellationToken,
958 scheme: &'static str,
960 name: String,
962}
963
964#[allow(
974 clippy::cognitive_complexity,
975 reason = "router assembly is intrinsically sequential; splitting harms readability"
976)]
977#[allow(
978 deprecated,
979 reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
980)]
981fn build_app_router<H, F>(
982 mut config: McpServerConfig,
983 handler_factory: F,
984) -> anyhow::Result<(axum::Router, AppRunParams)>
985where
986 H: ServerHandler + 'static,
987 F: Fn() -> H + Send + Sync + Clone + 'static,
988{
989 let ct = CancellationToken::new();
990
991 let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
992 tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
993
994 let mcp_service = StreamableHttpService::new(
995 move || Ok(handler_factory()),
996 {
997 let mut mgr = LocalSessionManager::default();
998 mgr.session_config.keep_alive = Some(config.session_idle_timeout);
999 mgr.into()
1000 },
1001 StreamableHttpServerConfig::default()
1002 .with_allowed_hosts(allowed_hosts)
1003 .with_sse_keep_alive(Some(config.sse_keep_alive))
1004 .with_cancellation_token(ct.child_token()),
1005 );
1006
1007 let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
1009
1010 let auth_state: Option<Arc<AuthState>> = match config.auth {
1014 Some(ref auth_config) if auth_config.enabled => {
1015 let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
1016 let pre_auth_limiter = auth_config
1017 .rate_limit
1018 .as_ref()
1019 .map(crate::auth::build_pre_auth_limiter);
1020
1021 #[cfg(feature = "oauth")]
1022 let jwks_cache = auth_config
1023 .oauth
1024 .as_ref()
1025 .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
1026 .transpose()
1027 .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
1028
1029 Some(Arc::new(AuthState {
1030 api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
1031 rate_limiter,
1032 pre_auth_limiter,
1033 #[cfg(feature = "oauth")]
1034 jwks_cache,
1035 seen_identities: crate::auth::SeenIdentitySet::new(),
1036 counters: crate::auth::AuthCounters::default(),
1037 }))
1038 }
1039 _ => None,
1040 };
1041
1042 let rbac_swap = Arc::new(ArcSwap::new(
1045 config
1046 .rbac
1047 .clone()
1048 .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
1049 ));
1050
1051 if config.admin_enabled {
1054 let Some(ref auth_state_ref) = auth_state else {
1055 return Err(anyhow::anyhow!(
1056 "admin_enabled=true requires auth to be configured and enabled"
1057 ));
1058 };
1059 let admin_state = crate::admin::AdminState {
1060 started_at: std::time::Instant::now(),
1061 name: config.name.clone(),
1062 version: config.version.clone(),
1063 auth: Some(Arc::clone(auth_state_ref)),
1064 rbac: Arc::clone(&rbac_swap),
1065 };
1066 let admin_cfg = crate::admin::AdminConfig {
1067 role: config.admin_role.clone(),
1068 };
1069 mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
1070 tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
1071 }
1072
1073 {
1106 let tool_limiter: Option<Arc<ToolRateLimiter>> =
1107 config.tool_rate_limit.map(build_tool_rate_limiter);
1108
1109 if rbac_swap.load().is_enabled() {
1110 tracing::info!("RBAC enforcement enabled on /mcp");
1111 }
1112 if let Some(limit) = config.tool_rate_limit {
1113 tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1114 }
1115
1116 let rbac_for_mw = Arc::clone(&rbac_swap);
1117 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1118 let p = rbac_for_mw.load_full();
1119 let tl = tool_limiter.clone();
1120 rbac_middleware(p, tl, req, next)
1121 }));
1122 }
1123
1124 if let Some(ref auth_config) = config.auth
1126 && auth_config.enabled
1127 {
1128 let Some(ref state) = auth_state else {
1129 return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1130 };
1131
1132 let methods: Vec<&str> = [
1133 auth_config.mtls.is_some().then_some("mTLS"),
1134 (!auth_config.api_keys.is_empty()).then_some("bearer"),
1135 #[cfg(feature = "oauth")]
1136 auth_config.oauth.is_some().then_some("oauth-jwt"),
1137 ]
1138 .into_iter()
1139 .flatten()
1140 .collect();
1141
1142 tracing::info!(
1143 methods = %methods.join(", "),
1144 api_keys = auth_config.api_keys.len(),
1145 "auth enabled on /mcp"
1146 );
1147
1148 let state_for_mw = Arc::clone(state);
1149 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1150 let s = Arc::clone(&state_for_mw);
1151 auth_middleware(s, req, next)
1152 }));
1153 }
1154
1155 mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1158 axum::http::StatusCode::REQUEST_TIMEOUT,
1159 config.request_timeout,
1160 ));
1161
1162 mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1166 config.max_request_body,
1167 ));
1168
1169 let mut effective_origins = config.allowed_origins.clone();
1176 if effective_origins.is_empty()
1177 && let Some(ref url) = config.public_url
1178 {
1179 if let Some(scheme_end) = url.find("://") {
1184 let scheme_with_sep = url.get(..scheme_end + 3).unwrap_or_default();
1185 let after_scheme = url.get(scheme_end + 3..).unwrap_or_default();
1186 let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1187 let host = after_scheme.get(..host_end).unwrap_or_default();
1188 let origin = format!("{scheme_with_sep}{host}");
1189 tracing::info!(
1190 %origin,
1191 "auto-derived allowed origin from public_url"
1192 );
1193 effective_origins.push(origin);
1194 }
1195 }
1196 let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1197 let cors_origins = Arc::clone(&allowed_origins);
1198 let log_request_headers = config.log_request_headers;
1199
1200 let readyz_route = if let Some(check) = config.readiness_check.take() {
1201 axum::routing::get(move || readyz(Arc::clone(&check)))
1202 } else {
1203 axum::routing::get(healthz)
1204 };
1205
1206 #[allow(unused_mut)] let mut router = axum::Router::new()
1208 .route("/healthz", axum::routing::get(healthz))
1209 .route("/readyz", readyz_route)
1210 .route(
1211 "/version",
1212 axum::routing::get({
1213 let payload_bytes: Arc<[u8]> =
1218 serialize_version_payload(&config.name, &config.version);
1219 move || {
1220 let p = Arc::clone(&payload_bytes);
1221 async move {
1222 (
1223 [(axum::http::header::CONTENT_TYPE, "application/json")],
1224 p.to_vec(),
1225 )
1226 }
1227 }
1228 }),
1229 )
1230 .merge(mcp_router);
1231
1232 if let Some(extra) = config.extra_router.take() {
1234 router = router.merge(extra);
1235 }
1236
1237 let server_url = if let Some(ref url) = config.public_url {
1244 url.trim_end_matches('/').to_owned()
1245 } else {
1246 let prm_scheme = if config.tls_cert_path.is_some() {
1247 "https"
1248 } else {
1249 "http"
1250 };
1251 format!("{prm_scheme}://{}", config.bind_addr)
1252 };
1253 let resource_url = format!("{server_url}/mcp");
1254
1255 #[cfg(feature = "oauth")]
1256 let prm_metadata = if let Some(ref auth_config) = config.auth
1257 && let Some(ref oauth_config) = auth_config.oauth
1258 {
1259 crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1260 } else {
1261 serde_json::json!({ "resource": resource_url })
1262 };
1263 #[cfg(not(feature = "oauth"))]
1264 let prm_metadata = serde_json::json!({ "resource": resource_url });
1265
1266 router = router.route(
1267 "/.well-known/oauth-protected-resource",
1268 axum::routing::get(move || {
1269 let m = prm_metadata.clone();
1270 async move { axum::Json(m) }
1271 }),
1272 );
1273
1274 #[cfg(feature = "oauth")]
1279 if let Some(ref auth_config) = config.auth
1280 && let Some(ref oauth_config) = auth_config.oauth
1281 && oauth_config.proxy.is_some()
1282 {
1283 router =
1284 install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1285 }
1286
1287 let is_tls = config.tls_cert_path.is_some();
1290 let security_headers_cfg = Arc::new(config.security_headers.clone());
1291 router = router.layer(axum::middleware::from_fn(move |req, next| {
1292 let cfg = Arc::clone(&security_headers_cfg);
1293 security_headers_middleware(is_tls, cfg, req, next)
1294 }));
1295
1296 if !cors_origins.is_empty() {
1300 let cors = tower_http::cors::CorsLayer::new()
1301 .allow_origin(
1302 cors_origins
1303 .iter()
1304 .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1305 .collect::<Vec<_>>(),
1306 )
1307 .allow_methods([
1308 axum::http::Method::GET,
1309 axum::http::Method::POST,
1310 axum::http::Method::OPTIONS,
1311 ])
1312 .allow_headers([
1313 axum::http::header::CONTENT_TYPE,
1314 axum::http::header::AUTHORIZATION,
1315 ]);
1316 router = router.layer(cors);
1317 }
1318
1319 if config.compression_enabled {
1323 use tower_http::compression::Predicate as _;
1324 let predicate = tower_http::compression::DefaultPredicate::new().and(
1325 tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1326 );
1327 router = router.layer(
1328 tower_http::compression::CompressionLayer::new()
1329 .gzip(true)
1330 .br(true)
1331 .compress_when(predicate),
1332 );
1333 tracing::info!(
1334 min_size = config.compression_min_size,
1335 "response compression enabled (gzip, br)"
1336 );
1337 }
1338
1339 if let Some(max) = config.max_concurrent_requests {
1342 let overload_handler = tower::ServiceBuilder::new()
1343 .layer(axum::error_handling::HandleErrorLayer::new(
1344 |_err: tower::BoxError| async {
1345 (
1346 axum::http::StatusCode::SERVICE_UNAVAILABLE,
1347 axum::Json(serde_json::json!({
1348 "error": "overloaded",
1349 "error_description": "server is at capacity, retry later"
1350 })),
1351 )
1352 },
1353 ))
1354 .layer(tower::load_shed::LoadShedLayer::new())
1355 .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1356 router = router.layer(overload_handler);
1357 tracing::info!(max, "global concurrency limit enabled");
1358 }
1359
1360 router = router.fallback(|| async {
1364 (
1365 axum::http::StatusCode::NOT_FOUND,
1366 axum::Json(serde_json::json!({
1367 "error": "not_found",
1368 "error_description": "The requested endpoint does not exist"
1369 })),
1370 )
1371 });
1372
1373 #[cfg(feature = "metrics")]
1375 if config.metrics_enabled {
1376 let metrics = Arc::new(
1377 crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1378 );
1379 let m = Arc::clone(&metrics);
1380 router = router.layer(axum::middleware::from_fn(
1381 move |req: Request<Body>, next: Next| {
1382 let m = Arc::clone(&m);
1383 metrics_middleware(m, req, next)
1384 },
1385 ));
1386 let metrics_bind = config.metrics_bind.clone();
1387 let metrics_shutdown = ct.clone();
1388 tokio::spawn(async move {
1389 if let Err(e) =
1390 crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1391 {
1392 tracing::error!("metrics listener failed: {e}");
1393 }
1394 });
1395 }
1396
1397 router = router.layer(axum::middleware::from_fn(move |req, next| {
1408 let origins = Arc::clone(&allowed_origins);
1409 origin_check_middleware(origins, log_request_headers, req, next)
1410 }));
1411
1412 let scheme = if config.tls_cert_path.is_some() {
1413 "https"
1414 } else {
1415 "http"
1416 };
1417
1418 let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1419 (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1420 _ => None,
1421 };
1422 let tls_handshake_timeout = config.tls_handshake_timeout;
1423 let max_concurrent_tls_handshakes = config.max_concurrent_tls_handshakes;
1424 let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1425
1426 Ok((
1427 router,
1428 AppRunParams {
1429 tls_paths,
1430 tls_handshake_timeout,
1431 max_concurrent_tls_handshakes,
1432 mtls_config,
1433 shutdown_timeout: config.shutdown_timeout,
1434 auth_state,
1435 rbac_swap,
1436 on_reload_ready: config.on_reload_ready.take(),
1437 ct,
1438 scheme,
1439 name: config.name.clone(),
1440 },
1441 ))
1442}
1443
1444pub async fn serve<H, F>(
1461 config: Validated<McpServerConfig>,
1462 handler_factory: F,
1463) -> Result<(), McpxError>
1464where
1465 H: ServerHandler + 'static,
1466 F: Fn() -> H + Send + Sync + Clone + 'static,
1467{
1468 let config = config.into_inner();
1469 #[allow(
1470 deprecated,
1471 reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1472 )]
1473 let bind_addr = config.bind_addr.clone();
1474 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1475
1476 let listener = TcpListener::bind(&bind_addr)
1477 .await
1478 .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1479 log_listening(¶ms.name, params.scheme, &bind_addr);
1480
1481 run_server(
1482 router,
1483 listener,
1484 params.tls_paths,
1485 params.tls_handshake_timeout,
1486 params.max_concurrent_tls_handshakes,
1487 params.mtls_config,
1488 params.shutdown_timeout,
1489 params.auth_state,
1490 params.rbac_swap,
1491 params.on_reload_ready,
1492 params.ct,
1493 )
1494 .await
1495 .map_err(anyhow_to_startup)
1496}
1497
1498pub async fn serve_with_listener<H, F>(
1528 listener: TcpListener,
1529 config: Validated<McpServerConfig>,
1530 handler_factory: F,
1531 ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1532 shutdown: Option<CancellationToken>,
1533) -> Result<(), McpxError>
1534where
1535 H: ServerHandler + 'static,
1536 F: Fn() -> H + Send + Sync + Clone + 'static,
1537{
1538 let config = config.into_inner();
1539 let local_addr = listener
1540 .local_addr()
1541 .map_err(|e| io_to_startup("listener.local_addr", e))?;
1542 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1543
1544 log_listening(¶ms.name, params.scheme, &local_addr.to_string());
1545
1546 if let Some(external) = shutdown {
1550 let internal = params.ct.clone();
1551 tokio::spawn(async move {
1552 external.cancelled().await;
1553 internal.cancel();
1554 });
1555 }
1556
1557 if let Some(tx) = ready_tx {
1561 let _ = tx.send(local_addr);
1563 }
1564
1565 run_server(
1566 router,
1567 listener,
1568 params.tls_paths,
1569 params.tls_handshake_timeout,
1570 params.max_concurrent_tls_handshakes,
1571 params.mtls_config,
1572 params.shutdown_timeout,
1573 params.auth_state,
1574 params.rbac_swap,
1575 params.on_reload_ready,
1576 params.ct,
1577 )
1578 .await
1579 .map_err(anyhow_to_startup)
1580}
1581
1582#[allow(
1585 clippy::cognitive_complexity,
1586 reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1587)]
1588fn log_listening(name: &str, scheme: &str, addr: &str) {
1589 tracing::info!("{name} listening on {addr}");
1590 tracing::info!(" MCP endpoint: {scheme}://{addr}/mcp");
1591 tracing::info!(" Health check: {scheme}://{addr}/healthz");
1592 tracing::info!(" Readiness: {scheme}://{addr}/readyz");
1593}
1594
1595#[allow(
1618 clippy::too_many_arguments,
1619 clippy::cognitive_complexity,
1620 reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1621)]
1622async fn run_server(
1623 router: axum::Router,
1624 listener: TcpListener,
1625 tls_paths: Option<(PathBuf, PathBuf)>,
1626 tls_handshake_timeout: Duration,
1627 max_concurrent_tls_handshakes: usize,
1628 mtls_config: Option<MtlsConfig>,
1629 shutdown_timeout: Duration,
1630 auth_state: Option<Arc<AuthState>>,
1631 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1632 mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1633 ct: CancellationToken,
1634) -> anyhow::Result<()> {
1635 let shutdown_trigger = CancellationToken::new();
1639 {
1640 let trigger = shutdown_trigger.clone();
1641 let parent = ct.clone();
1642 tokio::spawn(async move {
1643 tokio::select! {
1644 () = shutdown_signal() => {}
1645 () = parent.cancelled() => {}
1646 }
1647 trigger.cancel();
1648 });
1649 }
1650
1651 let graceful = {
1652 let trigger = shutdown_trigger.clone();
1653 let ct = ct.clone();
1654 async move {
1655 trigger.cancelled().await;
1656 tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1657 ct.cancel();
1658 }
1659 };
1660
1661 let force_exit_timer = {
1662 let trigger = shutdown_trigger.clone();
1663 async move {
1664 trigger.cancelled().await;
1665 tokio::time::sleep(shutdown_timeout).await;
1666 }
1667 };
1668
1669 if let Some((cert_path, key_path)) = tls_paths {
1670 let crl_set = if let Some(mtls) = mtls_config.as_ref()
1671 && mtls.crl_enabled
1672 {
1673 let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1674 let (crl_set, discover_rx) =
1675 mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1676 .await
1677 .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1678 tokio::spawn(mtls_revocation::run_crl_refresher(
1679 Arc::clone(&crl_set),
1680 discover_rx,
1681 ct.clone(),
1682 ));
1683 Some(crl_set)
1684 } else {
1685 None
1686 };
1687
1688 if let Some(cb) = on_reload_ready.take() {
1689 cb(ReloadHandle {
1690 auth: auth_state.clone(),
1691 rbac: Some(Arc::clone(&rbac_swap)),
1692 crl_set: crl_set.clone(),
1693 });
1694 }
1695
1696 let tls_listener = TlsListener::new(
1697 listener,
1698 &cert_path,
1699 &key_path,
1700 mtls_config.as_ref(),
1701 crl_set,
1702 tls_handshake_timeout,
1703 max_concurrent_tls_handshakes,
1704 )?;
1705 let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1706 tokio::select! {
1707 result = axum::serve(tls_listener, make_svc)
1708 .with_graceful_shutdown(graceful) => { result?; }
1709 () = force_exit_timer => {
1710 tracing::warn!("shutdown timeout exceeded, forcing exit");
1711 }
1712 }
1713 } else {
1714 if let Some(cb) = on_reload_ready.take() {
1715 cb(ReloadHandle {
1716 auth: auth_state,
1717 rbac: Some(rbac_swap),
1718 crl_set: None,
1719 });
1720 }
1721
1722 let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1723 tokio::select! {
1724 result = axum::serve(listener, make_svc)
1725 .with_graceful_shutdown(graceful) => { result?; }
1726 () = force_exit_timer => {
1727 tracing::warn!("shutdown timeout exceeded, forcing exit");
1728 }
1729 }
1730 }
1731
1732 Ok(())
1733}
1734
1735#[cfg(feature = "oauth")]
1744fn install_oauth_proxy_routes(
1745 router: axum::Router,
1746 server_url: &str,
1747 oauth_config: &crate::oauth::OAuthConfig,
1748 auth_state: Option<&Arc<AuthState>>,
1749) -> Result<axum::Router, McpxError> {
1750 let Some(ref proxy) = oauth_config.proxy else {
1751 return Ok(router);
1752 };
1753
1754 let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1757
1758 let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1759 let router = router.route(
1760 "/.well-known/oauth-authorization-server",
1761 axum::routing::get(move || {
1762 let m = asm.clone();
1763 async move { axum::Json(m) }
1764 }),
1765 );
1766
1767 let proxy_authorize = proxy.clone();
1768 let router = router.route(
1769 "/authorize",
1770 axum::routing::get(
1771 move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1772 let p = proxy_authorize.clone();
1773 async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1774 },
1775 ),
1776 );
1777
1778 let proxy_token = proxy.clone();
1779 let token_http = http.clone();
1780 let router = router.route(
1781 "/token",
1782 axum::routing::post(move |body: String| {
1783 let p = proxy_token.clone();
1784 let h = token_http.clone();
1785 async move { crate::oauth::handle_token(&h, &p, &body).await }
1786 })
1787 .layer(axum::middleware::from_fn(
1788 oauth_token_cache_headers_middleware,
1789 )),
1790 );
1791
1792 let proxy_register = proxy.clone();
1793 let router = router.route(
1794 "/register",
1795 axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1796 let p = proxy_register;
1797 async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1798 })
1799 .layer(axum::middleware::from_fn(
1800 oauth_token_cache_headers_middleware,
1801 )),
1802 );
1803
1804 let admin_routes_enabled = proxy.expose_admin_endpoints
1805 && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1806 if proxy.expose_admin_endpoints
1807 && !proxy.require_auth_on_admin_endpoints
1808 && proxy.allow_unauthenticated_admin_endpoints
1809 {
1810 tracing::warn!(
1814 "OAuth introspect/revoke endpoints are unauthenticated by explicit \
1815 allow_unauthenticated_admin_endpoints opt-out; ensure an \
1816 authenticated reverse proxy fronts these routes"
1817 );
1818 }
1819
1820 let admin_router = if admin_routes_enabled {
1821 build_oauth_admin_router(proxy, http, auth_state)?
1822 } else {
1823 axum::Router::new()
1824 };
1825
1826 let router = router.merge(admin_router);
1827
1828 tracing::info!(
1829 introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1830 revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1831 "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1832 );
1833 Ok(router)
1834}
1835
1836#[cfg(feature = "oauth")]
1842fn build_oauth_admin_router(
1843 proxy: &crate::oauth::OAuthProxyConfig,
1844 http: crate::oauth::OauthHttpClient,
1845 auth_state: Option<&Arc<AuthState>>,
1846) -> Result<axum::Router, McpxError> {
1847 let mut admin_router = axum::Router::new();
1848 if proxy.introspection_url.is_some() {
1849 let proxy_introspect = proxy.clone();
1850 let introspect_http = http.clone();
1851 admin_router = admin_router.route(
1852 "/introspect",
1853 axum::routing::post(move |body: String| {
1854 let p = proxy_introspect.clone();
1855 let h = introspect_http.clone();
1856 async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1857 }),
1858 );
1859 }
1860 if proxy.revocation_url.is_some() {
1861 let proxy_revoke = proxy.clone();
1862 let revoke_http = http;
1863 admin_router = admin_router.route(
1864 "/revoke",
1865 axum::routing::post(move |body: String| {
1866 let p = proxy_revoke.clone();
1867 let h = revoke_http.clone();
1868 async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1869 }),
1870 );
1871 }
1872
1873 let admin_router = admin_router.layer(axum::middleware::from_fn(
1874 oauth_token_cache_headers_middleware,
1875 ));
1876
1877 if proxy.require_auth_on_admin_endpoints {
1878 let Some(state) = auth_state else {
1879 return Err(McpxError::Startup(
1880 "oauth proxy admin endpoints require auth state".into(),
1881 ));
1882 };
1883 let state_for_mw = Arc::clone(state);
1884 Ok(
1885 admin_router.layer(axum::middleware::from_fn(move |req, next| {
1886 let s = Arc::clone(&state_for_mw);
1887 auth_middleware(s, req, next)
1888 })),
1889 )
1890 } else {
1891 Ok(admin_router)
1892 }
1893}
1894
1895fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
1900 let mut hosts = vec![
1901 "localhost".to_owned(),
1902 "127.0.0.1".to_owned(),
1903 "::1".to_owned(),
1904 ];
1905
1906 if let Some(url) = public_url
1907 && let Ok(uri) = url.parse::<axum::http::Uri>()
1908 && let Some(authority) = uri.authority()
1909 {
1910 let host = authority.host().to_owned();
1911 if !hosts.iter().any(|h| h == &host) {
1912 hosts.push(host);
1913 }
1914
1915 let authority = authority.as_str().to_owned();
1916 if !hosts.iter().any(|h| h == &authority) {
1917 hosts.push(authority);
1918 }
1919 }
1920
1921 if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
1922 && let Some(authority) = uri.authority()
1923 {
1924 let host = authority.host().to_owned();
1925 if !hosts.iter().any(|h| h == &host) {
1926 hosts.push(host);
1927 }
1928
1929 let authority = authority.as_str().to_owned();
1930 if !hosts.iter().any(|h| h == &authority) {
1931 hosts.push(authority);
1932 }
1933 }
1934
1935 hosts
1936}
1937
1938impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
1951 for TlsConnInfo
1952{
1953 fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
1954 let addr = *target.remote_addr();
1955 let identity = target.io().identity().cloned();
1956 TlsConnInfo::new(addr, identity)
1957 }
1958}
1959
1960const DEFAULT_TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
1967
1968const DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES: usize = 256;
1976
1977const TLS_ACCEPT_CHANNEL_CAPACITY: usize = 32;
1982
1983struct TlsListener {
1999 local_addr: SocketAddr,
2002 rx: mpsc::Receiver<(AuthenticatedTlsStream, SocketAddr)>,
2004 acceptor_task: tokio::task::JoinHandle<()>,
2007}
2008
2009impl TlsListener {
2010 fn new(
2011 inner: TcpListener,
2012 cert_path: &Path,
2013 key_path: &Path,
2014 mtls_config: Option<&MtlsConfig>,
2015 crl_set: Option<Arc<CrlSet>>,
2016 handshake_timeout: Duration,
2017 max_concurrent_handshakes: usize,
2018 ) -> anyhow::Result<Self> {
2019 rustls::crypto::ring::default_provider()
2021 .install_default()
2022 .ok();
2023
2024 let certs = load_certs(cert_path)?;
2025 let key = load_key(key_path)?;
2026
2027 let mtls_default_role;
2028
2029 let tls_config = if let Some(mtls) = mtls_config {
2030 mtls_default_role = mtls.default_role.clone();
2031 let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
2032 {
2033 let Some(crl_set) = crl_set else {
2034 return Err(anyhow::anyhow!(
2035 "mTLS CRL verifier requested but CRL state was not initialized"
2036 ));
2037 };
2038 Arc::new(DynamicClientCertVerifier::new(crl_set))
2039 } else {
2040 let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
2041 if mtls.required {
2042 rustls::server::WebPkiClientVerifier::builder(root_store)
2043 .build()
2044 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2045 } else {
2046 rustls::server::WebPkiClientVerifier::builder(root_store)
2047 .allow_unauthenticated()
2048 .build()
2049 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2050 }
2051 };
2052
2053 tracing::info!(
2054 ca = %mtls.ca_cert_path.display(),
2055 required = mtls.required,
2056 crl_enabled = mtls.crl_enabled,
2057 "mTLS client auth configured"
2058 );
2059
2060 rustls::ServerConfig::builder_with_protocol_versions(&[
2061 &rustls::version::TLS12,
2062 &rustls::version::TLS13,
2063 ])
2064 .with_client_cert_verifier(verifier)
2065 .with_single_cert(certs, key)?
2066 } else {
2067 mtls_default_role = "viewer".to_owned();
2068 rustls::ServerConfig::builder_with_protocol_versions(&[
2069 &rustls::version::TLS12,
2070 &rustls::version::TLS13,
2071 ])
2072 .with_no_client_auth()
2073 .with_single_cert(certs, key)?
2074 };
2075
2076 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
2077 tracing::info!(
2078 "TLS enabled (cert: {}, key: {})",
2079 cert_path.display(),
2080 key_path.display()
2081 );
2082 let local_addr = inner.local_addr()?;
2083 let (tx, rx) = mpsc::channel(TLS_ACCEPT_CHANNEL_CAPACITY);
2084 let acceptor_task = tokio::spawn(run_tls_acceptor(
2085 inner,
2086 acceptor,
2087 mtls_default_role,
2088 tx,
2089 handshake_timeout,
2090 max_concurrent_handshakes,
2091 ));
2092 Ok(Self {
2093 local_addr,
2094 rx,
2095 acceptor_task,
2096 })
2097 }
2098
2099 fn extract_handshake_identity(
2103 tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2104 default_role: &str,
2105 addr: SocketAddr,
2106 ) -> Option<AuthIdentity> {
2107 let (_, server_conn) = tls_stream.get_ref();
2108 let cert_der = server_conn.peer_certificates()?.first()?;
2109 let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
2110 tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
2111 Some(id)
2112 }
2113}
2114
2115async fn run_tls_acceptor(
2123 listener: TcpListener,
2124 acceptor: tokio_rustls::TlsAcceptor,
2125 default_role: String,
2126 tx: mpsc::Sender<(AuthenticatedTlsStream, SocketAddr)>,
2127 handshake_timeout: Duration,
2128 max_concurrent_handshakes: usize,
2129) {
2130 let inflight = Arc::new(Semaphore::new(max_concurrent_handshakes));
2131 loop {
2132 let Ok(permit) = Arc::clone(&inflight).acquire_owned().await else {
2136 return;
2138 };
2139 let (stream, addr) = match listener.accept().await {
2140 Ok(pair) => pair,
2141 Err(e) => {
2142 tracing::debug!("TCP accept error: {e}");
2143 continue;
2144 }
2145 };
2146 if tx.is_closed() {
2147 return;
2149 }
2150 let acceptor = acceptor.clone();
2151 let default_role = default_role.clone();
2152 let tx = tx.clone();
2153 tokio::spawn(async move {
2154 let _permit = permit;
2155 match tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
2156 Ok(Ok(tls_stream)) => {
2157 let identity =
2158 TlsListener::extract_handshake_identity(&tls_stream, &default_role, addr);
2159 let wrapped = AuthenticatedTlsStream {
2160 inner: tls_stream,
2161 identity,
2162 };
2163 let _ = tx.send((wrapped, addr)).await;
2166 }
2167 Ok(Err(e)) => {
2168 tracing::debug!("TLS handshake failed from {addr}: {e}");
2169 }
2170 Err(_elapsed) => {
2171 tracing::debug!(
2172 "TLS handshake timed out from {addr} after {handshake_timeout:?}"
2173 );
2174 }
2175 }
2176 });
2177 }
2178}
2179
2180pub(crate) struct AuthenticatedTlsStream {
2192 inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2193 identity: Option<AuthIdentity>,
2194}
2195
2196impl AuthenticatedTlsStream {
2197 #[must_use]
2199 pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
2200 self.identity.as_ref()
2201 }
2202}
2203
2204impl std::fmt::Debug for AuthenticatedTlsStream {
2205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2206 f.debug_struct("AuthenticatedTlsStream")
2207 .field("identity", &self.identity.as_ref().map(|id| &id.name))
2208 .finish_non_exhaustive()
2209 }
2210}
2211
2212impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2213 fn poll_read(
2214 mut self: Pin<&mut Self>,
2215 cx: &mut std::task::Context<'_>,
2216 buf: &mut tokio::io::ReadBuf<'_>,
2217 ) -> std::task::Poll<std::io::Result<()>> {
2218 Pin::new(&mut self.inner).poll_read(cx, buf)
2219 }
2220}
2221
2222impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2223 fn poll_write(
2224 mut self: Pin<&mut Self>,
2225 cx: &mut std::task::Context<'_>,
2226 buf: &[u8],
2227 ) -> std::task::Poll<std::io::Result<usize>> {
2228 Pin::new(&mut self.inner).poll_write(cx, buf)
2229 }
2230
2231 fn poll_flush(
2232 mut self: Pin<&mut Self>,
2233 cx: &mut std::task::Context<'_>,
2234 ) -> std::task::Poll<std::io::Result<()>> {
2235 Pin::new(&mut self.inner).poll_flush(cx)
2236 }
2237
2238 fn poll_shutdown(
2239 mut self: Pin<&mut Self>,
2240 cx: &mut std::task::Context<'_>,
2241 ) -> std::task::Poll<std::io::Result<()>> {
2242 Pin::new(&mut self.inner).poll_shutdown(cx)
2243 }
2244
2245 fn poll_write_vectored(
2246 mut self: Pin<&mut Self>,
2247 cx: &mut std::task::Context<'_>,
2248 bufs: &[std::io::IoSlice<'_>],
2249 ) -> std::task::Poll<std::io::Result<usize>> {
2250 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2251 }
2252
2253 fn is_write_vectored(&self) -> bool {
2254 self.inner.is_write_vectored()
2255 }
2256}
2257
2258impl axum::serve::Listener for TlsListener {
2259 type Io = AuthenticatedTlsStream;
2260 type Addr = SocketAddr;
2261
2262 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2268 if let Some(pair) = self.rx.recv().await {
2269 return pair;
2270 }
2271 tracing::error!("TLS acceptor task terminated; no further connections will be accepted");
2277 std::future::pending().await
2278 }
2279
2280 fn local_addr(&self) -> std::io::Result<Self::Addr> {
2281 Ok(self.local_addr)
2282 }
2283}
2284
2285impl Drop for TlsListener {
2286 fn drop(&mut self) {
2287 self.acceptor_task.abort();
2290 }
2291}
2292
2293fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2294 use rustls::pki_types::pem::PemObject;
2295 let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2296 .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2297 .collect::<Result<_, _>>()
2298 .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2299 anyhow::ensure!(
2300 !certs.is_empty(),
2301 "no certificates found in {}",
2302 path.display()
2303 );
2304 Ok(certs)
2305}
2306
2307fn load_client_auth_roots(
2308 path: &Path,
2309) -> anyhow::Result<(
2310 Vec<rustls::pki_types::CertificateDer<'static>>,
2311 Arc<RootCertStore>,
2312)> {
2313 let ca_certs = load_certs(path)?;
2314 let mut root_store = RootCertStore::empty();
2315 for cert in &ca_certs {
2316 root_store
2317 .add(cert.clone())
2318 .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2319 }
2320
2321 Ok((ca_certs, Arc::new(root_store)))
2322}
2323
2324fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2325 use rustls::pki_types::pem::PemObject;
2326 rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2327 .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2328}
2329
2330#[allow(
2331 clippy::unused_async,
2332 reason = "axum route handler signature requires `async fn` even when the body is synchronous"
2333)]
2334async fn healthz() -> impl IntoResponse {
2335 axum::Json(serde_json::json!({
2336 "status": "ok",
2337 }))
2338}
2339
2340fn version_payload(name: &str, version: &str) -> serde_json::Value {
2347 serde_json::json!({
2348 "name": name,
2349 "version": version,
2350 "build_git_sha": option_env!("RMCP_SERVER_KIT_BUILD_SHA").unwrap_or("unknown"),
2351 "build_timestamp": option_env!("RMCP_SERVER_KIT_BUILD_TIME").unwrap_or("unknown"),
2352 "rust_version": option_env!("RMCP_SERVER_KIT_RUSTC_VERSION").unwrap_or("unknown"),
2353 "mcpx_version": env!("CARGO_PKG_VERSION"),
2354 })
2355}
2356
2357fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2367 let value = version_payload(name, version);
2368 serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2369}
2370
2371async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2372 let status = check().await;
2373 let ready = status
2374 .get("ready")
2375 .and_then(serde_json::Value::as_bool)
2376 .unwrap_or(false);
2377 let code = if ready {
2378 axum::http::StatusCode::OK
2379 } else {
2380 axum::http::StatusCode::SERVICE_UNAVAILABLE
2381 };
2382 (code, axum::Json(status))
2383}
2384
2385async fn shutdown_signal() {
2389 let ctrl_c = tokio::signal::ctrl_c();
2390
2391 #[cfg(unix)]
2392 {
2393 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2394 Ok(mut term) => {
2395 tokio::select! {
2396 _ = ctrl_c => {}
2397 _ = term.recv() => {}
2398 }
2399 }
2400 Err(e) => {
2401 tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2402 ctrl_c.await.ok();
2403 }
2404 }
2405 }
2406
2407 #[cfg(not(unix))]
2408 {
2409 ctrl_c.await.ok();
2410 }
2411}
2412
2413#[cfg(feature = "metrics")]
2419async fn metrics_middleware(
2420 metrics: Arc<crate::metrics::McpMetrics>,
2421 req: Request<Body>,
2422 next: Next,
2423) -> axum::response::Response {
2424 let method = req.method().to_string();
2425 let path = req.uri().path().to_owned();
2426 let start = std::time::Instant::now();
2427
2428 let response = next.run(req).await;
2429
2430 let status = response.status().as_u16().to_string();
2431 let duration = start.elapsed().as_secs_f64();
2432
2433 metrics
2434 .http_requests_total
2435 .with_label_values(&[&method, &path, &status])
2436 .inc();
2437 metrics
2438 .http_request_duration_seconds
2439 .with_label_values(&[&method, &path])
2440 .observe(duration);
2441
2442 response
2443}
2444
2445async fn security_headers_middleware(
2457 is_tls: bool,
2458 cfg: Arc<SecurityHeadersConfig>,
2459 req: Request<Body>,
2460 next: Next,
2461) -> axum::response::Response {
2462 use axum::http::{HeaderName, header};
2463
2464 let mut resp = next.run(req).await;
2465 let headers = resp.headers_mut();
2466
2467 headers.remove(header::SERVER);
2469 headers.remove(HeaderName::from_static("x-powered-by"));
2470
2471 apply_security_header(
2472 headers,
2473 header::X_CONTENT_TYPE_OPTIONS,
2474 cfg.x_content_type_options.as_deref(),
2475 "nosniff",
2476 );
2477 apply_security_header(
2478 headers,
2479 header::X_FRAME_OPTIONS,
2480 cfg.x_frame_options.as_deref(),
2481 "deny",
2482 );
2483 apply_security_header(
2484 headers,
2485 header::CACHE_CONTROL,
2486 cfg.cache_control.as_deref(),
2487 "no-store, max-age=0",
2488 );
2489 apply_security_header(
2490 headers,
2491 header::REFERRER_POLICY,
2492 cfg.referrer_policy.as_deref(),
2493 "no-referrer",
2494 );
2495 apply_security_header(
2496 headers,
2497 HeaderName::from_static("cross-origin-opener-policy"),
2498 cfg.cross_origin_opener_policy.as_deref(),
2499 "same-origin",
2500 );
2501 apply_security_header(
2502 headers,
2503 HeaderName::from_static("cross-origin-resource-policy"),
2504 cfg.cross_origin_resource_policy.as_deref(),
2505 "same-origin",
2506 );
2507 apply_security_header(
2508 headers,
2509 HeaderName::from_static("cross-origin-embedder-policy"),
2510 cfg.cross_origin_embedder_policy.as_deref(),
2511 "require-corp",
2512 );
2513 apply_security_header(
2514 headers,
2515 HeaderName::from_static("permissions-policy"),
2516 cfg.permissions_policy.as_deref(),
2517 "accelerometer=(), camera=(), geolocation=(), microphone=()",
2518 );
2519 apply_security_header(
2520 headers,
2521 HeaderName::from_static("x-permitted-cross-domain-policies"),
2522 cfg.x_permitted_cross_domain_policies.as_deref(),
2523 "none",
2524 );
2525 apply_security_header(
2526 headers,
2527 HeaderName::from_static("content-security-policy"),
2528 cfg.content_security_policy.as_deref(),
2529 "default-src 'none'; frame-ancestors 'none'",
2530 );
2531 apply_security_header(
2532 headers,
2533 HeaderName::from_static("x-dns-prefetch-control"),
2534 cfg.x_dns_prefetch_control.as_deref(),
2535 "off",
2536 );
2537
2538 if is_tls {
2539 apply_security_header(
2540 headers,
2541 header::STRICT_TRANSPORT_SECURITY,
2542 cfg.strict_transport_security.as_deref(),
2543 "max-age=63072000; includeSubDomains",
2544 );
2545 }
2546
2547 resp
2548}
2549
2550fn apply_security_header(
2561 headers: &mut axum::http::HeaderMap,
2562 name: axum::http::HeaderName,
2563 override_value: Option<&str>,
2564 default: &'static str,
2565) {
2566 use axum::http::HeaderValue;
2567
2568 match override_value {
2569 None => {
2570 headers.insert(name, HeaderValue::from_static(default));
2571 }
2572 Some("") => {
2573 }
2575 Some(v) => match HeaderValue::from_str(v) {
2576 Ok(hv) => {
2577 headers.insert(name, hv);
2578 }
2579 Err(err) => {
2580 tracing::error!(
2581 header = %name,
2582 error = %err,
2583 "invalid security header override reached middleware; using default"
2584 );
2585 headers.insert(name, HeaderValue::from_static(default));
2586 }
2587 },
2588 }
2589}
2590
2591fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2602 use axum::http::HeaderValue;
2603
2604 let fields: &[(&str, Option<&str>)] = &[
2605 (
2606 "x_content_type_options",
2607 cfg.x_content_type_options.as_deref(),
2608 ),
2609 ("x_frame_options", cfg.x_frame_options.as_deref()),
2610 ("cache_control", cfg.cache_control.as_deref()),
2611 ("referrer_policy", cfg.referrer_policy.as_deref()),
2612 (
2613 "cross_origin_opener_policy",
2614 cfg.cross_origin_opener_policy.as_deref(),
2615 ),
2616 (
2617 "cross_origin_resource_policy",
2618 cfg.cross_origin_resource_policy.as_deref(),
2619 ),
2620 (
2621 "cross_origin_embedder_policy",
2622 cfg.cross_origin_embedder_policy.as_deref(),
2623 ),
2624 ("permissions_policy", cfg.permissions_policy.as_deref()),
2625 (
2626 "x_permitted_cross_domain_policies",
2627 cfg.x_permitted_cross_domain_policies.as_deref(),
2628 ),
2629 (
2630 "content_security_policy",
2631 cfg.content_security_policy.as_deref(),
2632 ),
2633 (
2634 "x_dns_prefetch_control",
2635 cfg.x_dns_prefetch_control.as_deref(),
2636 ),
2637 (
2638 "strict_transport_security",
2639 cfg.strict_transport_security.as_deref(),
2640 ),
2641 ];
2642
2643 for (field, value) in fields {
2644 let Some(v) = value else { continue };
2645 if v.is_empty() {
2646 continue;
2647 }
2648 if let Err(err) = HeaderValue::from_str(v) {
2649 return Err(McpxError::Config(format!(
2650 "invalid security_headers.{field}: {err}"
2651 )));
2652 }
2653 }
2654
2655 if let Some(v) = cfg.strict_transport_security.as_deref()
2656 && !v.is_empty()
2657 && v.to_ascii_lowercase().contains("preload")
2658 {
2659 return Err(McpxError::Config(format!(
2660 "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2661 HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2662 )));
2663 }
2664
2665 Ok(())
2666}
2667
2668#[cfg(feature = "oauth")]
2683async fn oauth_token_cache_headers_middleware(
2684 req: Request<Body>,
2685 next: Next,
2686) -> axum::response::Response {
2687 use axum::http::{HeaderValue, header};
2688
2689 let mut resp = next.run(req).await;
2690 let headers = resp.headers_mut();
2691 headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2692 headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2693 resp
2694}
2695
2696async fn origin_check_middleware(
2700 allowed: Arc<[String]>,
2701 log_request_headers: bool,
2702 req: Request<Body>,
2703 next: Next,
2704) -> axum::response::Response {
2705 let method = req.method().clone();
2706 let path = req.uri().path().to_owned();
2707
2708 log_incoming_request(&method, &path, req.headers(), log_request_headers);
2709
2710 if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2711 let origin_str = origin.to_str().unwrap_or("");
2712 if !allowed.iter().any(|a| a == origin_str) {
2713 tracing::warn!(
2714 origin = origin_str,
2715 %method,
2716 %path,
2717 allowed = ?&*allowed,
2718 "rejected request: Origin not allowed"
2719 );
2720 return (
2721 axum::http::StatusCode::FORBIDDEN,
2722 "Forbidden: Origin not allowed",
2723 )
2724 .into_response();
2725 }
2726 }
2727 next.run(req).await
2728}
2729
2730fn log_incoming_request(
2733 method: &axum::http::Method,
2734 path: &str,
2735 headers: &axum::http::HeaderMap,
2736 log_request_headers: bool,
2737) {
2738 if log_request_headers {
2739 tracing::debug!(
2740 %method,
2741 %path,
2742 headers = %format_request_headers_for_log(headers),
2743 "incoming request"
2744 );
2745 } else {
2746 tracing::debug!(%method, %path, "incoming request");
2747 }
2748}
2749
2750fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2751 headers
2752 .iter()
2753 .map(|(k, v)| {
2754 let name = k.as_str();
2755 if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2756 format!("{name}: [REDACTED]")
2757 } else {
2758 format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2759 }
2760 })
2761 .collect::<Vec<_>>()
2762 .join(", ")
2763}
2764
2765#[allow(
2789 clippy::cognitive_complexity,
2790 reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
2791)]
2792pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2793where
2794 H: ServerHandler + 'static,
2795{
2796 use rmcp::ServiceExt as _;
2797
2798 tracing::info!("stdio transport: serving on stdin/stdout");
2799 tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2800
2801 let transport = rmcp::transport::io::stdio();
2802
2803 let service = handler
2804 .serve(transport)
2805 .await
2806 .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2807
2808 if let Err(e) = service.waiting().await {
2809 tracing::warn!(error = %e, "stdio session ended with error");
2810 }
2811 tracing::info!("stdio session ended");
2812 Ok(())
2813}
2814
2815#[cfg(test)]
2816mod tests {
2817 #![allow(
2818 clippy::unwrap_used,
2819 clippy::expect_used,
2820 clippy::panic,
2821 clippy::indexing_slicing,
2822 clippy::unwrap_in_result,
2823 clippy::print_stdout,
2824 clippy::print_stderr,
2825 deprecated,
2826 reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2827 )]
2828 use std::{sync::Arc, time::Duration};
2829
2830 use axum::{
2831 body::Body,
2832 http::{Request, StatusCode, header},
2833 response::IntoResponse,
2834 };
2835 use http_body_util::BodyExt;
2836 use tower::ServiceExt as _;
2837
2838 use super::*;
2839
2840 #[test]
2843 fn server_config_new_defaults() {
2844 let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2845 assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2846 assert_eq!(cfg.name, "test-server");
2847 assert_eq!(cfg.version, "1.0.0");
2848 assert!(cfg.tls_cert_path.is_none());
2849 assert!(cfg.tls_key_path.is_none());
2850 assert!(cfg.auth.is_none());
2851 assert!(cfg.rbac.is_none());
2852 assert!(cfg.allowed_origins.is_empty());
2853 assert!(cfg.tool_rate_limit.is_none());
2854 assert!(cfg.readiness_check.is_none());
2855 assert_eq!(cfg.max_request_body, 1024 * 1024);
2856 assert_eq!(cfg.request_timeout, Duration::from_mins(2));
2857 assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
2858 assert!(!cfg.log_request_headers);
2859 assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(10));
2860 assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
2861 }
2862
2863 #[test]
2864 fn tls_handshake_builders_set_fields() {
2865 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
2866 .with_tls_handshake_timeout(Duration::from_secs(3))
2867 .with_max_concurrent_tls_handshakes(64);
2868 assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(3));
2869 assert_eq!(cfg.max_concurrent_tls_handshakes, 64);
2870 }
2871
2872 #[test]
2873 fn validate_rejects_zero_tls_handshake_timeout() {
2874 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
2875 .with_tls_handshake_timeout(Duration::ZERO);
2876 let err = cfg.validate().expect_err("zero handshake timeout");
2877 assert!(err.to_string().contains("tls_handshake_timeout"));
2878 }
2879
2880 #[test]
2881 fn validate_rejects_zero_max_concurrent_tls_handshakes() {
2882 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
2883 .with_max_concurrent_tls_handshakes(0);
2884 let err = cfg.validate().expect_err("zero handshake concurrency");
2885 assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
2886 }
2887
2888 #[test]
2889 fn validate_consumes_and_proves() {
2890 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2892 let validated = cfg.validate().expect("valid config");
2893 assert_eq!(validated.as_inner().name, "test-server");
2895 let raw = validated.into_inner();
2897 assert_eq!(raw.name, "test-server");
2898
2899 let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2901 bad.max_request_body = 0;
2902 assert!(bad.validate().is_err(), "zero body cap must fail validate");
2903 }
2904
2905 #[test]
2906 fn validate_rejects_zero_max_concurrent_requests() {
2907 let cfg =
2908 McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
2909 let err = cfg.validate().expect_err("zero concurrency cap must fail");
2910 assert!(
2911 format!("{err}").contains("max_concurrent_requests"),
2912 "error should mention max_concurrent_requests, got: {err}"
2913 );
2914 }
2915
2916 #[test]
2917 fn validate_rejects_zero_max_tracked_keys() {
2918 let rl = crate::auth::RateLimitConfig {
2921 max_attempts_per_minute: 30,
2922 pre_auth_max_per_minute: None,
2923 max_tracked_keys: 0,
2924 idle_eviction: Duration::from_secs(15 * 60),
2925 };
2926 let auth_cfg = AuthConfig {
2927 enabled: true,
2928 api_keys: Vec::new(),
2929 mtls: None,
2930 rate_limit: Some(rl),
2931 #[cfg(feature = "oauth")]
2932 oauth: None,
2933 };
2934 let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
2935 let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
2936 assert!(
2937 format!("{err}").contains("max_tracked_keys"),
2938 "error should mention max_tracked_keys, got: {err}"
2939 );
2940 }
2941
2942 #[test]
2943 fn derive_allowed_hosts_includes_public_host() {
2944 let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
2945 assert!(
2946 hosts.iter().any(|h| h == "mcp.example.com"),
2947 "public_url host must be allowed"
2948 );
2949 }
2950
2951 #[test]
2952 fn derive_allowed_hosts_includes_bind_authority() {
2953 let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
2954 assert!(
2955 hosts.iter().any(|h| h == "127.0.0.1"),
2956 "bind host must be allowed"
2957 );
2958 assert!(
2959 hosts.iter().any(|h| h == "127.0.0.1:8080"),
2960 "bind authority must be allowed"
2961 );
2962 }
2963
2964 #[tokio::test]
2967 async fn healthz_returns_ok_json() {
2968 let resp = healthz().await.into_response();
2969 assert_eq!(resp.status(), StatusCode::OK);
2970 let body = resp.into_body().collect().await.unwrap().to_bytes();
2971 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2972 assert_eq!(json["status"], "ok");
2973 assert!(
2974 json.get("name").is_none(),
2975 "healthz must not expose server name"
2976 );
2977 assert!(
2978 json.get("version").is_none(),
2979 "healthz must not expose version"
2980 );
2981 }
2982
2983 #[tokio::test]
2986 async fn readyz_returns_ok_when_ready() {
2987 let check: ReadinessCheck =
2988 Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
2989 let resp = readyz(check).await.into_response();
2990 assert_eq!(resp.status(), StatusCode::OK);
2991 let body = resp.into_body().collect().await.unwrap().to_bytes();
2992 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2993 assert_eq!(json["ready"], true);
2994 assert!(
2995 json.get("name").is_none(),
2996 "readyz must not expose server name"
2997 );
2998 assert!(
2999 json.get("version").is_none(),
3000 "readyz must not expose version"
3001 );
3002 assert_eq!(json["db"], "connected");
3003 }
3004
3005 #[tokio::test]
3006 async fn readyz_returns_503_when_not_ready() {
3007 let check: ReadinessCheck =
3008 Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
3009 let resp = readyz(check).await.into_response();
3010 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3011 }
3012
3013 #[tokio::test]
3014 async fn readyz_returns_503_when_ready_missing() {
3015 let check: ReadinessCheck =
3016 Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
3017 let resp = readyz(check).await.into_response();
3018 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3020 }
3021
3022 fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
3026 let allowed: Arc<[String]> = Arc::from(origins);
3027 axum::Router::new()
3028 .route("/test", axum::routing::get(|| async { "ok" }))
3029 .layer(axum::middleware::from_fn(move |req, next| {
3030 let a = Arc::clone(&allowed);
3031 origin_check_middleware(a, log_request_headers, req, next)
3032 }))
3033 }
3034
3035 #[tokio::test]
3036 async fn origin_allowed_passes() {
3037 let app = origin_router(vec!["http://localhost:3000".into()], false);
3038 let req = Request::builder()
3039 .uri("/test")
3040 .header(header::ORIGIN, "http://localhost:3000")
3041 .body(Body::empty())
3042 .unwrap();
3043 let resp = app.oneshot(req).await.unwrap();
3044 assert_eq!(resp.status(), StatusCode::OK);
3045 }
3046
3047 #[tokio::test]
3048 async fn origin_rejected_returns_403() {
3049 let app = origin_router(vec!["http://localhost:3000".into()], false);
3050 let req = Request::builder()
3051 .uri("/test")
3052 .header(header::ORIGIN, "http://evil.com")
3053 .body(Body::empty())
3054 .unwrap();
3055 let resp = app.oneshot(req).await.unwrap();
3056 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
3057 }
3058
3059 #[tokio::test]
3060 async fn no_origin_header_passes() {
3061 let app = origin_router(vec!["http://localhost:3000".into()], false);
3062 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3063 let resp = app.oneshot(req).await.unwrap();
3064 assert_eq!(resp.status(), StatusCode::OK);
3065 }
3066
3067 #[tokio::test]
3068 async fn empty_allowlist_rejects_any_origin() {
3069 let app = origin_router(vec![], false);
3070 let req = Request::builder()
3071 .uri("/test")
3072 .header(header::ORIGIN, "http://anything.com")
3073 .body(Body::empty())
3074 .unwrap();
3075 let resp = app.oneshot(req).await.unwrap();
3076 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
3077 }
3078
3079 #[tokio::test]
3080 async fn empty_allowlist_passes_without_origin() {
3081 let app = origin_router(vec![], false);
3082 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3083 let resp = app.oneshot(req).await.unwrap();
3084 assert_eq!(resp.status(), StatusCode::OK);
3085 }
3086
3087 #[test]
3088 fn format_request_headers_redacts_sensitive_values() {
3089 let mut headers = axum::http::HeaderMap::new();
3090 headers.insert("authorization", "Bearer secret-token".parse().unwrap());
3091 headers.insert("cookie", "sid=abc".parse().unwrap());
3092 headers.insert("x-request-id", "req-123".parse().unwrap());
3093
3094 let out = format_request_headers_for_log(&headers);
3095 assert!(out.contains("authorization: [REDACTED]"));
3096 assert!(out.contains("cookie: [REDACTED]"));
3097 assert!(out.contains("x-request-id: req-123"));
3098 assert!(!out.contains("secret-token"));
3099 }
3100
3101 fn security_router(is_tls: bool) -> axum::Router {
3104 security_router_with(is_tls, SecurityHeadersConfig::default())
3105 }
3106
3107 fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
3108 let cfg = Arc::new(cfg);
3109 axum::Router::new()
3110 .route("/test", axum::routing::get(|| async { "ok" }))
3111 .layer(axum::middleware::from_fn(move |req, next| {
3112 let c = Arc::clone(&cfg);
3113 security_headers_middleware(is_tls, c, req, next)
3114 }))
3115 }
3116
3117 #[tokio::test]
3118 async fn security_headers_set_on_response() {
3119 let app = security_router(false);
3120 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3121 let resp = app.oneshot(req).await.unwrap();
3122 assert_eq!(resp.status(), StatusCode::OK);
3123
3124 let h = resp.headers();
3125 assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
3126 assert_eq!(h.get("x-frame-options").unwrap(), "deny");
3127 assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
3128 assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
3129 assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
3130 assert_eq!(
3131 h.get("cross-origin-resource-policy").unwrap(),
3132 "same-origin"
3133 );
3134 assert_eq!(
3135 h.get("cross-origin-embedder-policy").unwrap(),
3136 "require-corp"
3137 );
3138 assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
3139 assert!(
3140 h.get("permissions-policy")
3141 .unwrap()
3142 .to_str()
3143 .unwrap()
3144 .contains("camera=()"),
3145 "permissions-policy must restrict browser features"
3146 );
3147 assert_eq!(
3148 h.get("content-security-policy").unwrap(),
3149 "default-src 'none'; frame-ancestors 'none'"
3150 );
3151 assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
3152 assert!(h.get("strict-transport-security").is_none());
3154 }
3155
3156 #[tokio::test]
3157 async fn hsts_set_when_tls_enabled() {
3158 let app = security_router(true);
3159 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3160 let resp = app.oneshot(req).await.unwrap();
3161
3162 let hsts = resp.headers().get("strict-transport-security").unwrap();
3163 assert!(
3164 hsts.to_str().unwrap().contains("max-age=63072000"),
3165 "HSTS must set 2-year max-age"
3166 );
3167 }
3168
3169 fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
3175 let cfg =
3176 McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
3177 cfg.check()
3178 }
3179
3180 #[test]
3181 fn security_headers_config_default_validates() {
3182 check_with_security_headers(SecurityHeadersConfig::default())
3183 .expect("default SecurityHeadersConfig must validate");
3184 }
3185
3186 #[test]
3187 fn security_headers_config_validate_accepts_empty_string() {
3188 let h = SecurityHeadersConfig {
3190 x_content_type_options: Some(String::new()),
3191 x_frame_options: Some(String::new()),
3192 cache_control: Some(String::new()),
3193 referrer_policy: Some(String::new()),
3194 cross_origin_opener_policy: Some(String::new()),
3195 cross_origin_resource_policy: Some(String::new()),
3196 cross_origin_embedder_policy: Some(String::new()),
3197 permissions_policy: Some(String::new()),
3198 x_permitted_cross_domain_policies: Some(String::new()),
3199 content_security_policy: Some(String::new()),
3200 x_dns_prefetch_control: Some(String::new()),
3201 strict_transport_security: Some(String::new()),
3202 };
3203 check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
3204 }
3205
3206 #[test]
3207 fn security_headers_config_validate_rejects_bad_value() {
3208 let h = SecurityHeadersConfig {
3210 referrer_policy: Some("\u{0007}".into()),
3211 ..SecurityHeadersConfig::default()
3212 };
3213 let err = check_with_security_headers(h)
3214 .expect_err("control char in referrer_policy must reject");
3215 let msg = err.to_string();
3216 assert!(
3217 msg.contains("referrer_policy"),
3218 "error must name the offending field, got: {msg}"
3219 );
3220 }
3221
3222 #[test]
3223 fn security_headers_config_validate_rejects_hsts_preload() {
3224 let h = SecurityHeadersConfig {
3225 strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
3226 ..SecurityHeadersConfig::default()
3227 };
3228 let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
3229 let msg = err.to_string();
3230 assert!(
3231 msg.contains("strict_transport_security"),
3232 "error must name the field, got: {msg}"
3233 );
3234 assert!(
3235 msg.to_lowercase().contains("preload"),
3236 "error must mention `preload`, got: {msg}"
3237 );
3238 }
3239
3240 #[test]
3241 fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
3242 let h = SecurityHeadersConfig {
3244 strict_transport_security: Some("max-age=600; PRELOAD".into()),
3245 ..SecurityHeadersConfig::default()
3246 };
3247 check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
3248 }
3249
3250 #[tokio::test]
3251 async fn security_headers_override_honored() {
3252 let h = SecurityHeadersConfig {
3254 x_frame_options: Some("SAMEORIGIN".into()),
3255 ..SecurityHeadersConfig::default()
3256 };
3257 let app = security_router_with(false, h);
3258 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3259 let resp = app.oneshot(req).await.unwrap();
3260 assert_eq!(resp.status(), StatusCode::OK);
3261
3262 let xfo = resp.headers().get("x-frame-options").unwrap();
3263 assert_eq!(xfo, "SAMEORIGIN");
3264 }
3265
3266 #[tokio::test]
3267 async fn security_headers_empty_string_omits() {
3268 let h = SecurityHeadersConfig {
3270 referrer_policy: Some(String::new()),
3271 ..SecurityHeadersConfig::default()
3272 };
3273 let app = security_router_with(false, h);
3274 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3275 let resp = app.oneshot(req).await.unwrap();
3276 assert_eq!(resp.status(), StatusCode::OK);
3277
3278 assert!(
3279 resp.headers().get("referrer-policy").is_none(),
3280 "Some(\"\") must omit the header"
3281 );
3282 assert_eq!(
3284 resp.headers().get("x-content-type-options").unwrap(),
3285 "nosniff"
3286 );
3287 }
3288
3289 #[tokio::test]
3290 async fn security_headers_hsts_only_when_tls() {
3291 let h = SecurityHeadersConfig {
3293 strict_transport_security: Some("max-age=600".into()),
3294 ..SecurityHeadersConfig::default()
3295 };
3296 let app = security_router_with(false, h);
3297 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3298 let resp = app.oneshot(req).await.unwrap();
3299 assert!(
3300 resp.headers().get("strict-transport-security").is_none(),
3301 "HSTS must remain absent on plaintext deployments even with override"
3302 );
3303 }
3304
3305 #[cfg(feature = "oauth")]
3308 #[tokio::test]
3309 async fn oauth_token_cache_headers_set_pragma_and_vary() {
3310 let app = axum::Router::new()
3311 .route("/token", axum::routing::post(|| async { "{}" }))
3312 .layer(axum::middleware::from_fn(
3313 oauth_token_cache_headers_middleware,
3314 ));
3315 let req = Request::builder()
3316 .method("POST")
3317 .uri("/token")
3318 .body(Body::from("{}"))
3319 .unwrap();
3320 let resp = app.oneshot(req).await.unwrap();
3321 assert_eq!(resp.status(), StatusCode::OK);
3322
3323 let h = resp.headers();
3324 assert_eq!(
3325 h.get("pragma").unwrap(),
3326 "no-cache",
3327 "RFC 6749 §5.1: token responses must set Pragma: no-cache"
3328 );
3329 let vary_values: Vec<String> = h
3330 .get_all("vary")
3331 .iter()
3332 .filter_map(|v| v.to_str().ok().map(str::to_owned))
3333 .collect();
3334 assert!(
3335 vary_values
3336 .iter()
3337 .any(|v| v.eq_ignore_ascii_case("Authorization")),
3338 "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
3339 );
3340 }
3341
3342 #[cfg(feature = "oauth")]
3343 #[tokio::test]
3344 async fn oauth_token_cache_headers_preserve_existing_vary() {
3345 let app = axum::Router::new()
3348 .route(
3349 "/token",
3350 axum::routing::post(|| async {
3351 axum::response::Response::builder()
3352 .header("vary", "Accept-Encoding")
3353 .body(axum::body::Body::from("{}"))
3354 .unwrap()
3355 }),
3356 )
3357 .layer(axum::middleware::from_fn(
3358 oauth_token_cache_headers_middleware,
3359 ));
3360 let req = Request::builder()
3361 .method("POST")
3362 .uri("/token")
3363 .body(Body::empty())
3364 .unwrap();
3365 let resp = app.oneshot(req).await.unwrap();
3366
3367 let vary: Vec<String> = resp
3368 .headers()
3369 .get_all("vary")
3370 .iter()
3371 .filter_map(|v| v.to_str().ok().map(str::to_owned))
3372 .collect();
3373 assert!(
3374 vary.iter().any(|v| v.contains("Accept-Encoding")),
3375 "must preserve pre-existing Vary value, got {vary:?}"
3376 );
3377 assert!(
3378 vary.iter().any(|v| v.contains("Authorization")),
3379 "must append Authorization to Vary, got {vary:?}"
3380 );
3381 }
3382
3383 #[test]
3386 fn version_payload_contains_expected_fields() {
3387 let v = version_payload("my-server", "1.2.3");
3388 assert_eq!(v["name"], "my-server");
3389 assert_eq!(v["version"], "1.2.3");
3390 assert!(v["build_git_sha"].is_string());
3391 assert!(v["build_timestamp"].is_string());
3392 assert!(v["rust_version"].is_string());
3393 assert!(v["mcpx_version"].is_string());
3394 }
3395
3396 #[tokio::test]
3399 async fn concurrency_limit_layer_composes_and_serves() {
3400 let app = axum::Router::new()
3404 .route("/ok", axum::routing::get(|| async { "ok" }))
3405 .layer(
3406 tower::ServiceBuilder::new()
3407 .layer(axum::error_handling::HandleErrorLayer::new(
3408 |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
3409 ))
3410 .layer(tower::load_shed::LoadShedLayer::new())
3411 .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
3412 );
3413 let resp = app
3414 .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
3415 .await
3416 .unwrap();
3417 assert_eq!(resp.status(), StatusCode::OK);
3418 }
3419
3420 #[tokio::test]
3423 async fn compression_layer_gzip_encodes_response() {
3424 use tower_http::compression::Predicate as _;
3425
3426 let big_body = "a".repeat(4096);
3427 let app = axum::Router::new()
3428 .route(
3429 "/big",
3430 axum::routing::get(move || {
3431 let body = big_body.clone();
3432 async move { body }
3433 }),
3434 )
3435 .layer(
3436 tower_http::compression::CompressionLayer::new()
3437 .gzip(true)
3438 .br(true)
3439 .compress_when(
3440 tower_http::compression::DefaultPredicate::new()
3441 .and(tower_http::compression::predicate::SizeAbove::new(1024)),
3442 ),
3443 );
3444
3445 let req = Request::builder()
3446 .uri("/big")
3447 .header(header::ACCEPT_ENCODING, "gzip")
3448 .body(Body::empty())
3449 .unwrap();
3450 let resp = app.oneshot(req).await.unwrap();
3451 assert_eq!(resp.status(), StatusCode::OK);
3452 assert_eq!(
3453 resp.headers().get(header::CONTENT_ENCODING).unwrap(),
3454 "gzip"
3455 );
3456 }
3457
3458 #[tokio::test]
3461 async fn tls_handshake_timeout_reaps_idle_connections() {
3462 use tokio::io::AsyncReadExt as _;
3463
3464 let _ = rustls::crypto::ring::default_provider().install_default();
3465
3466 let key = rcgen::KeyPair::generate().expect("generate key");
3468 let cert = rcgen::CertificateParams::new(vec!["localhost".to_owned()])
3469 .expect("cert params")
3470 .self_signed(&key)
3471 .expect("self-signed cert");
3472 let dir = std::env::temp_dir().join(format!(
3473 "rmcp-server-kit-hs-timeout-{}",
3474 std::time::SystemTime::now()
3475 .duration_since(std::time::UNIX_EPOCH)
3476 .expect("clock after epoch")
3477 .as_nanos()
3478 ));
3479 tokio::fs::create_dir_all(&dir).await.expect("temp dir");
3480 let cert_path = dir.join("server.crt");
3481 let key_path = dir.join("server.key");
3482 tokio::fs::write(&cert_path, cert.pem())
3483 .await
3484 .expect("write cert");
3485 tokio::fs::write(&key_path, key.serialize_pem())
3486 .await
3487 .expect("write key");
3488
3489 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
3490 let tls = TlsListener::new(
3491 listener,
3492 &cert_path,
3493 &key_path,
3494 None,
3495 None,
3496 Duration::from_millis(200),
3497 8, )
3499 .expect("tls listener");
3500 let addr = axum::serve::Listener::local_addr(&tls).expect("local addr");
3501
3502 let mut idle = tokio::net::TcpStream::connect(addr).await.expect("connect");
3506 let mut buf = [0_u8; 16];
3507 let read = tokio::time::timeout(Duration::from_secs(2), idle.read(&mut buf))
3508 .await
3509 .expect("server must reap the idle handshake within its timeout");
3510 match read {
3511 Ok(0) | Err(_) => {} Ok(n) => panic!("unexpected {n} bytes from server during reaped handshake"),
3513 }
3514
3515 drop(tls);
3516 }
3517}