1use std::{
2 future::Future,
3 net::SocketAddr,
4 path::{Path, PathBuf},
5 pin::Pin,
6 sync::Arc,
7 time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{body::Body, extract::Request, middleware::Next, response::IntoResponse};
12use rmcp::{
13 ServerHandler,
14 transport::streamable_http_server::{
15 StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16 },
17};
18use rustls::RootCertStore;
19use tokio::net::TcpListener;
20use tokio_util::sync::CancellationToken;
21
22use crate::{
23 auth::{
24 AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
25 build_rate_limiter, extract_mtls_identity,
26 },
27 error::McpxError,
28 mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
29 rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
30};
31
32#[allow(
36 clippy::needless_pass_by_value,
37 reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
38)]
39fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
40 McpxError::Startup(format!("{e:#}"))
41}
42
43#[allow(
49 clippy::needless_pass_by_value,
50 reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
51)]
52fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
53 McpxError::Startup(format!("{op}: {e}"))
54}
55
56pub type ReadinessCheck =
61 Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
62
63#[allow(
65 missing_debug_implementations,
66 reason = "contains callback/trait objects that don't impl Debug"
67)]
68#[allow(
69 clippy::struct_excessive_bools,
70 reason = "server configuration naturally has many boolean feature flags"
71)]
72#[non_exhaustive]
73pub struct McpServerConfig {
74 #[deprecated(
76 since = "0.13.0",
77 note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in 1.0"
78 )]
79 pub bind_addr: String,
80 #[deprecated(
82 since = "0.13.0",
83 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in 1.0"
84 )]
85 pub name: String,
86 #[deprecated(
88 since = "0.13.0",
89 note = "set via McpServerConfig::new(); direct field access will become pub(crate) in 1.0"
90 )]
91 pub version: String,
92 #[deprecated(
94 since = "0.13.0",
95 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in 1.0"
96 )]
97 pub tls_cert_path: Option<PathBuf>,
98 #[deprecated(
100 since = "0.13.0",
101 note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in 1.0"
102 )]
103 pub tls_key_path: Option<PathBuf>,
104 #[deprecated(
107 since = "0.13.0",
108 note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in 1.0"
109 )]
110 pub auth: Option<AuthConfig>,
111 #[deprecated(
114 since = "0.13.0",
115 note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in 1.0"
116 )]
117 pub rbac: Option<Arc<RbacPolicy>>,
118 #[deprecated(
124 since = "0.13.0",
125 note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in 1.0"
126 )]
127 pub allowed_origins: Vec<String>,
128 #[deprecated(
131 since = "0.13.0",
132 note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in 1.0"
133 )]
134 pub tool_rate_limit: Option<u32>,
135 #[deprecated(
138 since = "0.13.0",
139 note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in 1.0"
140 )]
141 pub readiness_check: Option<ReadinessCheck>,
142 #[deprecated(
145 since = "0.13.0",
146 note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in 1.0"
147 )]
148 pub max_request_body: usize,
149 #[deprecated(
152 since = "0.13.0",
153 note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in 1.0"
154 )]
155 pub request_timeout: Duration,
156 #[deprecated(
159 since = "0.13.0",
160 note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in 1.0"
161 )]
162 pub shutdown_timeout: Duration,
163 #[deprecated(
166 since = "0.13.0",
167 note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in 1.0"
168 )]
169 pub session_idle_timeout: Duration,
170 #[deprecated(
173 since = "0.13.0",
174 note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in 1.0"
175 )]
176 pub sse_keep_alive: Duration,
177 #[deprecated(
181 since = "0.13.0",
182 note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in 1.0"
183 )]
184 pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
185 #[deprecated(
189 since = "0.13.0",
190 note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in 1.0"
191 )]
192 pub extra_router: Option<axum::Router>,
193 #[deprecated(
198 since = "0.13.0",
199 note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in 1.0"
200 )]
201 pub public_url: Option<String>,
202 #[deprecated(
205 since = "0.13.0",
206 note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in 1.0"
207 )]
208 pub log_request_headers: bool,
209 #[deprecated(
212 since = "0.13.0",
213 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in 1.0"
214 )]
215 pub compression_enabled: bool,
216 #[deprecated(
219 since = "0.13.0",
220 note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in 1.0"
221 )]
222 pub compression_min_size: u16,
223 #[deprecated(
227 since = "0.13.0",
228 note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in 1.0"
229 )]
230 pub max_concurrent_requests: Option<usize>,
231 #[deprecated(
234 since = "0.13.0",
235 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in 1.0"
236 )]
237 pub admin_enabled: bool,
238 #[deprecated(
240 since = "0.13.0",
241 note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in 1.0"
242 )]
243 pub admin_role: String,
244 #[cfg(feature = "metrics")]
247 #[deprecated(
248 since = "0.13.0",
249 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in 1.0"
250 )]
251 pub metrics_enabled: bool,
252 #[cfg(feature = "metrics")]
254 #[deprecated(
255 since = "0.13.0",
256 note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in 1.0"
257 )]
258 pub metrics_bind: String,
259}
260
261#[allow(
319 missing_debug_implementations,
320 reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
321)]
322pub struct Validated<T>(T);
323
324impl<T> std::fmt::Debug for Validated<T> {
325 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 f.debug_struct("Validated").finish_non_exhaustive()
327 }
328}
329
330impl<T> Validated<T> {
331 #[must_use]
333 pub fn as_inner(&self) -> &T {
334 &self.0
335 }
336
337 #[must_use]
342 pub fn into_inner(self) -> T {
343 self.0
344 }
345}
346
347impl<T> std::ops::Deref for Validated<T> {
348 type Target = T;
349
350 fn deref(&self) -> &T {
351 &self.0
352 }
353}
354
355#[allow(
356 deprecated,
357 reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
358)]
359impl McpServerConfig {
360 #[must_use]
368 pub fn new(
369 bind_addr: impl Into<String>,
370 name: impl Into<String>,
371 version: impl Into<String>,
372 ) -> Self {
373 Self {
374 bind_addr: bind_addr.into(),
375 name: name.into(),
376 version: version.into(),
377 tls_cert_path: None,
378 tls_key_path: None,
379 auth: None,
380 rbac: None,
381 allowed_origins: Vec::new(),
382 tool_rate_limit: None,
383 readiness_check: None,
384 max_request_body: 1024 * 1024,
385 request_timeout: Duration::from_mins(2),
386 shutdown_timeout: Duration::from_secs(30),
387 session_idle_timeout: Duration::from_mins(20),
388 sse_keep_alive: Duration::from_secs(15),
389 on_reload_ready: None,
390 extra_router: None,
391 public_url: None,
392 log_request_headers: false,
393 compression_enabled: false,
394 compression_min_size: 1024,
395 max_concurrent_requests: None,
396 admin_enabled: false,
397 admin_role: "admin".to_owned(),
398 #[cfg(feature = "metrics")]
399 metrics_enabled: false,
400 #[cfg(feature = "metrics")]
401 metrics_bind: "127.0.0.1:9090".into(),
402 }
403 }
404
405 #[must_use]
415 pub fn with_auth(mut self, auth: AuthConfig) -> Self {
416 self.auth = Some(auth);
417 self
418 }
419
420 #[must_use]
424 pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
425 self.bind_addr = addr.into();
426 self
427 }
428
429 #[must_use]
432 pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
433 self.rbac = Some(rbac);
434 self
435 }
436
437 #[must_use]
441 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
442 self.tls_cert_path = Some(cert_path.into());
443 self.tls_key_path = Some(key_path.into());
444 self
445 }
446
447 #[must_use]
451 pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
452 self.public_url = Some(url.into());
453 self
454 }
455
456 #[must_use]
460 pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
461 where
462 I: IntoIterator<Item = S>,
463 S: Into<String>,
464 {
465 self.allowed_origins = origins.into_iter().map(Into::into).collect();
466 self
467 }
468
469 #[must_use]
473 pub fn with_extra_router(mut self, router: axum::Router) -> Self {
474 self.extra_router = Some(router);
475 self
476 }
477
478 #[must_use]
481 pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
482 self.readiness_check = Some(check);
483 self
484 }
485
486 #[must_use]
489 pub fn with_max_request_body(mut self, bytes: usize) -> Self {
490 self.max_request_body = bytes;
491 self
492 }
493
494 #[must_use]
496 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
497 self.request_timeout = timeout;
498 self
499 }
500
501 #[must_use]
503 pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
504 self.shutdown_timeout = timeout;
505 self
506 }
507
508 #[must_use]
510 pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
511 self.session_idle_timeout = timeout;
512 self
513 }
514
515 #[must_use]
517 pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
518 self.sse_keep_alive = interval;
519 self
520 }
521
522 #[must_use]
526 pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
527 self.max_concurrent_requests = Some(limit);
528 self
529 }
530
531 #[must_use]
534 pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
535 self.tool_rate_limit = Some(per_minute);
536 self
537 }
538
539 #[must_use]
543 pub fn with_reload_callback<F>(mut self, callback: F) -> Self
544 where
545 F: FnOnce(ReloadHandle) + Send + 'static,
546 {
547 self.on_reload_ready = Some(Box::new(callback));
548 self
549 }
550
551 #[must_use]
555 pub fn enable_compression(mut self, min_size: u16) -> Self {
556 self.compression_enabled = true;
557 self.compression_min_size = min_size;
558 self
559 }
560
561 #[must_use]
566 pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
567 self.admin_enabled = true;
568 self.admin_role = role.into();
569 self
570 }
571
572 #[must_use]
575 pub fn enable_request_header_logging(mut self) -> Self {
576 self.log_request_headers = true;
577 self
578 }
579
580 #[cfg(feature = "metrics")]
583 #[must_use]
584 pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
585 self.metrics_enabled = true;
586 self.metrics_bind = bind.into();
587 self
588 }
589
590 pub fn validate(self) -> Result<Validated<Self>, McpxError> {
623 self.check()?;
624 Ok(Validated(self))
625 }
626
627 fn check(&self) -> Result<(), McpxError> {
631 if self.admin_enabled {
635 let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
636 if !auth_enabled {
637 return Err(McpxError::Config(
638 "admin_enabled=true requires auth to be configured and enabled".into(),
639 ));
640 }
641 }
642
643 match (&self.tls_cert_path, &self.tls_key_path) {
645 (Some(_), None) => {
646 return Err(McpxError::Config(
647 "tls_cert_path is set but tls_key_path is missing".into(),
648 ));
649 }
650 (None, Some(_)) => {
651 return Err(McpxError::Config(
652 "tls_key_path is set but tls_cert_path is missing".into(),
653 ));
654 }
655 _ => {}
656 }
657
658 if self.bind_addr.parse::<SocketAddr>().is_err() {
660 return Err(McpxError::Config(format!(
661 "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
662 self.bind_addr
663 )));
664 }
665
666 if let Some(ref url) = self.public_url
668 && !(url.starts_with("http://") || url.starts_with("https://"))
669 {
670 return Err(McpxError::Config(format!(
671 "public_url {url:?} must start with http:// or https://"
672 )));
673 }
674
675 for origin in &self.allowed_origins {
677 if !(origin.starts_with("http://") || origin.starts_with("https://")) {
678 return Err(McpxError::Config(format!(
679 "allowed_origins entry {origin:?} must start with http:// or https://"
680 )));
681 }
682 }
683
684 if self.max_request_body == 0 {
686 return Err(McpxError::Config(
687 "max_request_body must be greater than zero".into(),
688 ));
689 }
690
691 #[cfg(feature = "oauth")]
693 if let Some(auth_cfg) = &self.auth
694 && let Some(oauth_cfg) = &auth_cfg.oauth
695 {
696 oauth_cfg.validate()?;
697 }
698
699 Ok(())
700 }
701}
702
703#[allow(
709 missing_debug_implementations,
710 reason = "contains Arc<AuthState> with non-Debug fields"
711)]
712pub struct ReloadHandle {
713 auth: Option<Arc<AuthState>>,
714 rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
715 crl_set: Option<Arc<CrlSet>>,
716}
717
718impl ReloadHandle {
719 pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
721 if let Some(ref auth) = self.auth {
722 auth.reload_keys(keys);
723 }
724 }
725
726 pub fn reload_rbac(&self, policy: RbacPolicy) {
728 if let Some(ref rbac) = self.rbac {
729 rbac.store(Arc::new(policy));
730 tracing::info!("RBAC policy reloaded");
731 }
732 }
733
734 pub async fn refresh_crls(&self) -> Result<(), McpxError> {
740 let Some(ref crl_set) = self.crl_set else {
741 return Err(McpxError::Config(
742 "CRL refresh requested but mTLS CRL support is not configured".into(),
743 ));
744 };
745
746 crl_set.force_refresh().await
747 }
748}
749
750#[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
767struct AppRunParams {
771 tls_paths: Option<(PathBuf, PathBuf)>,
773 mtls_config: Option<MtlsConfig>,
775 shutdown_timeout: Duration,
777 auth_state: Option<Arc<AuthState>>,
779 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
781 on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
783 ct: CancellationToken,
787 scheme: &'static str,
789 name: String,
791}
792
793#[allow(
803 clippy::cognitive_complexity,
804 reason = "router assembly is intrinsically sequential; splitting harms readability"
805)]
806#[allow(
807 deprecated,
808 reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
809)]
810fn build_app_router<H, F>(
811 mut config: McpServerConfig,
812 handler_factory: F,
813) -> anyhow::Result<(axum::Router, AppRunParams)>
814where
815 H: ServerHandler + 'static,
816 F: Fn() -> H + Send + Sync + Clone + 'static,
817{
818 let ct = CancellationToken::new();
819
820 let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
821 tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
822
823 let mcp_service = StreamableHttpService::new(
824 move || Ok(handler_factory()),
825 {
826 let mut mgr = LocalSessionManager::default();
827 mgr.session_config.keep_alive = Some(config.session_idle_timeout);
828 mgr.into()
829 },
830 StreamableHttpServerConfig::default()
831 .with_allowed_hosts(allowed_hosts)
832 .with_sse_keep_alive(Some(config.sse_keep_alive))
833 .with_cancellation_token(ct.child_token()),
834 );
835
836 let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
838
839 let auth_state: Option<Arc<AuthState>> = match config.auth {
843 Some(ref auth_config) if auth_config.enabled => {
844 let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
845 let pre_auth_limiter = auth_config
846 .rate_limit
847 .as_ref()
848 .map(crate::auth::build_pre_auth_limiter);
849
850 #[cfg(feature = "oauth")]
851 let jwks_cache = auth_config
852 .oauth
853 .as_ref()
854 .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
855 .transpose()
856 .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
857
858 Some(Arc::new(AuthState {
859 api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
860 rate_limiter,
861 pre_auth_limiter,
862 #[cfg(feature = "oauth")]
863 jwks_cache,
864 seen_identities: std::sync::Mutex::new(std::collections::HashSet::new()),
865 counters: crate::auth::AuthCounters::default(),
866 }))
867 }
868 _ => None,
869 };
870
871 let rbac_swap = Arc::new(ArcSwap::new(
874 config
875 .rbac
876 .clone()
877 .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
878 ));
879
880 if config.admin_enabled {
883 let Some(ref auth_state_ref) = auth_state else {
884 return Err(anyhow::anyhow!(
885 "admin_enabled=true requires auth to be configured and enabled"
886 ));
887 };
888 let admin_state = crate::admin::AdminState {
889 started_at: std::time::Instant::now(),
890 name: config.name.clone(),
891 version: config.version.clone(),
892 auth: Some(Arc::clone(auth_state_ref)),
893 rbac: Arc::clone(&rbac_swap),
894 };
895 let admin_cfg = crate::admin::AdminConfig {
896 role: config.admin_role.clone(),
897 };
898 mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
899 tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
900 }
901
902 {
935 let tool_limiter: Option<Arc<ToolRateLimiter>> =
936 config.tool_rate_limit.map(build_tool_rate_limiter);
937
938 if rbac_swap.load().is_enabled() {
939 tracing::info!("RBAC enforcement enabled on /mcp");
940 }
941 if let Some(limit) = config.tool_rate_limit {
942 tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
943 }
944
945 let rbac_for_mw = Arc::clone(&rbac_swap);
946 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
947 let p = rbac_for_mw.load_full();
948 let tl = tool_limiter.clone();
949 rbac_middleware(p, tl, req, next)
950 }));
951 }
952
953 if let Some(ref auth_config) = config.auth
955 && auth_config.enabled
956 {
957 let Some(ref state) = auth_state else {
958 return Err(anyhow::anyhow!("auth state missing despite enabled config"));
959 };
960
961 let methods: Vec<&str> = [
962 auth_config.mtls.is_some().then_some("mTLS"),
963 (!auth_config.api_keys.is_empty()).then_some("bearer"),
964 #[cfg(feature = "oauth")]
965 auth_config.oauth.is_some().then_some("oauth-jwt"),
966 ]
967 .into_iter()
968 .flatten()
969 .collect();
970
971 tracing::info!(
972 methods = %methods.join(", "),
973 api_keys = auth_config.api_keys.len(),
974 "auth enabled on /mcp"
975 );
976
977 let state_for_mw = Arc::clone(state);
978 mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
979 let s = Arc::clone(&state_for_mw);
980 auth_middleware(s, req, next)
981 }));
982 }
983
984 mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
987 axum::http::StatusCode::REQUEST_TIMEOUT,
988 config.request_timeout,
989 ));
990
991 mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
995 config.max_request_body,
996 ));
997
998 let mut effective_origins = config.allowed_origins.clone();
1005 if effective_origins.is_empty()
1006 && let Some(ref url) = config.public_url
1007 {
1008 if let Some(scheme_end) = url.find("://") {
1011 let after_scheme = &url[scheme_end + 3..];
1012 let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1013 let origin = format!("{}{}", &url[..scheme_end + 3], &after_scheme[..host_end]);
1014 tracing::info!(
1015 %origin,
1016 "auto-derived allowed origin from public_url"
1017 );
1018 effective_origins.push(origin);
1019 }
1020 }
1021 let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1022 let cors_origins = Arc::clone(&allowed_origins);
1023 let log_request_headers = config.log_request_headers;
1024
1025 let readyz_route = if let Some(check) = config.readiness_check.take() {
1026 axum::routing::get(move || readyz(Arc::clone(&check)))
1027 } else {
1028 axum::routing::get(healthz)
1029 };
1030
1031 #[allow(unused_mut)] let mut router = axum::Router::new()
1033 .route("/healthz", axum::routing::get(healthz))
1034 .route("/readyz", readyz_route)
1035 .route(
1036 "/version",
1037 axum::routing::get({
1038 let payload_bytes: Arc<[u8]> =
1043 serialize_version_payload(&config.name, &config.version);
1044 move || {
1045 let p = Arc::clone(&payload_bytes);
1046 async move {
1047 (
1048 [(axum::http::header::CONTENT_TYPE, "application/json")],
1049 p.to_vec(),
1050 )
1051 }
1052 }
1053 }),
1054 )
1055 .merge(mcp_router);
1056
1057 if let Some(extra) = config.extra_router.take() {
1059 router = router.merge(extra);
1060 }
1061
1062 let server_url = if let Some(ref url) = config.public_url {
1069 url.trim_end_matches('/').to_owned()
1070 } else {
1071 let prm_scheme = if config.tls_cert_path.is_some() {
1072 "https"
1073 } else {
1074 "http"
1075 };
1076 format!("{prm_scheme}://{}", config.bind_addr)
1077 };
1078 let resource_url = format!("{server_url}/mcp");
1079
1080 #[cfg(feature = "oauth")]
1081 let prm_metadata = if let Some(ref auth_config) = config.auth
1082 && let Some(ref oauth_config) = auth_config.oauth
1083 {
1084 crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1085 } else {
1086 serde_json::json!({ "resource": resource_url })
1087 };
1088 #[cfg(not(feature = "oauth"))]
1089 let prm_metadata = serde_json::json!({ "resource": resource_url });
1090
1091 router = router.route(
1092 "/.well-known/oauth-protected-resource",
1093 axum::routing::get(move || {
1094 let m = prm_metadata.clone();
1095 async move { axum::Json(m) }
1096 }),
1097 );
1098
1099 #[cfg(feature = "oauth")]
1104 if let Some(ref auth_config) = config.auth
1105 && let Some(ref oauth_config) = auth_config.oauth
1106 && oauth_config.proxy.is_some()
1107 {
1108 router =
1109 install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1110 }
1111
1112 let is_tls = config.tls_cert_path.is_some();
1115 router = router.layer(axum::middleware::from_fn(move |req, next| {
1116 security_headers_middleware(is_tls, req, next)
1117 }));
1118
1119 if !cors_origins.is_empty() {
1123 let cors = tower_http::cors::CorsLayer::new()
1124 .allow_origin(
1125 cors_origins
1126 .iter()
1127 .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1128 .collect::<Vec<_>>(),
1129 )
1130 .allow_methods([
1131 axum::http::Method::GET,
1132 axum::http::Method::POST,
1133 axum::http::Method::OPTIONS,
1134 ])
1135 .allow_headers([
1136 axum::http::header::CONTENT_TYPE,
1137 axum::http::header::AUTHORIZATION,
1138 ]);
1139 router = router.layer(cors);
1140 }
1141
1142 if config.compression_enabled {
1146 use tower_http::compression::Predicate as _;
1147 let predicate = tower_http::compression::DefaultPredicate::new().and(
1148 tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1149 );
1150 router = router.layer(
1151 tower_http::compression::CompressionLayer::new()
1152 .gzip(true)
1153 .br(true)
1154 .compress_when(predicate),
1155 );
1156 tracing::info!(
1157 min_size = config.compression_min_size,
1158 "response compression enabled (gzip, br)"
1159 );
1160 }
1161
1162 if let Some(max) = config.max_concurrent_requests {
1165 let overload_handler = tower::ServiceBuilder::new()
1166 .layer(axum::error_handling::HandleErrorLayer::new(
1167 |_err: tower::BoxError| async {
1168 (
1169 axum::http::StatusCode::SERVICE_UNAVAILABLE,
1170 axum::Json(serde_json::json!({
1171 "error": "overloaded",
1172 "error_description": "server is at capacity, retry later"
1173 })),
1174 )
1175 },
1176 ))
1177 .layer(tower::load_shed::LoadShedLayer::new())
1178 .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1179 router = router.layer(overload_handler);
1180 tracing::info!(max, "global concurrency limit enabled");
1181 }
1182
1183 router = router.fallback(|| async {
1187 (
1188 axum::http::StatusCode::NOT_FOUND,
1189 axum::Json(serde_json::json!({
1190 "error": "not_found",
1191 "error_description": "The requested endpoint does not exist"
1192 })),
1193 )
1194 });
1195
1196 #[cfg(feature = "metrics")]
1198 if config.metrics_enabled {
1199 let metrics = Arc::new(
1200 crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1201 );
1202 let m = Arc::clone(&metrics);
1203 router = router.layer(axum::middleware::from_fn(
1204 move |req: Request<Body>, next: Next| {
1205 let m = Arc::clone(&m);
1206 metrics_middleware(m, req, next)
1207 },
1208 ));
1209 let metrics_bind = config.metrics_bind.clone();
1210 tokio::spawn(async move {
1211 if let Err(e) = crate::metrics::serve_metrics(metrics_bind, metrics).await {
1212 tracing::error!("metrics listener failed: {e}");
1213 }
1214 });
1215 }
1216
1217 router = router.layer(axum::middleware::from_fn(move |req, next| {
1228 let origins = Arc::clone(&allowed_origins);
1229 origin_check_middleware(origins, log_request_headers, req, next)
1230 }));
1231
1232 let scheme = if config.tls_cert_path.is_some() {
1233 "https"
1234 } else {
1235 "http"
1236 };
1237
1238 let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1239 (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1240 _ => None,
1241 };
1242 let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1243
1244 Ok((
1245 router,
1246 AppRunParams {
1247 tls_paths,
1248 mtls_config,
1249 shutdown_timeout: config.shutdown_timeout,
1250 auth_state,
1251 rbac_swap,
1252 on_reload_ready: config.on_reload_ready.take(),
1253 ct,
1254 scheme,
1255 name: config.name.clone(),
1256 },
1257 ))
1258}
1259
1260pub async fn serve<H, F>(
1277 config: Validated<McpServerConfig>,
1278 handler_factory: F,
1279) -> Result<(), McpxError>
1280where
1281 H: ServerHandler + 'static,
1282 F: Fn() -> H + Send + Sync + Clone + 'static,
1283{
1284 let config = config.into_inner();
1285 #[allow(
1286 deprecated,
1287 reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1288 )]
1289 let bind_addr = config.bind_addr.clone();
1290 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1291
1292 let listener = TcpListener::bind(&bind_addr)
1293 .await
1294 .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1295 log_listening(¶ms.name, params.scheme, &bind_addr);
1296
1297 run_server(
1298 router,
1299 listener,
1300 params.tls_paths,
1301 params.mtls_config,
1302 params.shutdown_timeout,
1303 params.auth_state,
1304 params.rbac_swap,
1305 params.on_reload_ready,
1306 params.ct,
1307 )
1308 .await
1309 .map_err(anyhow_to_startup)
1310}
1311
1312pub async fn serve_with_listener<H, F>(
1342 listener: TcpListener,
1343 config: Validated<McpServerConfig>,
1344 handler_factory: F,
1345 ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1346 shutdown: Option<CancellationToken>,
1347) -> Result<(), McpxError>
1348where
1349 H: ServerHandler + 'static,
1350 F: Fn() -> H + Send + Sync + Clone + 'static,
1351{
1352 let config = config.into_inner();
1353 let local_addr = listener
1354 .local_addr()
1355 .map_err(|e| io_to_startup("listener.local_addr", e))?;
1356 let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1357
1358 log_listening(¶ms.name, params.scheme, &local_addr.to_string());
1359
1360 if let Some(external) = shutdown {
1364 let internal = params.ct.clone();
1365 tokio::spawn(async move {
1366 external.cancelled().await;
1367 internal.cancel();
1368 });
1369 }
1370
1371 if let Some(tx) = ready_tx {
1375 let _ = tx.send(local_addr);
1377 }
1378
1379 run_server(
1380 router,
1381 listener,
1382 params.tls_paths,
1383 params.mtls_config,
1384 params.shutdown_timeout,
1385 params.auth_state,
1386 params.rbac_swap,
1387 params.on_reload_ready,
1388 params.ct,
1389 )
1390 .await
1391 .map_err(anyhow_to_startup)
1392}
1393
1394#[allow(
1397 clippy::cognitive_complexity,
1398 reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1399)]
1400fn log_listening(name: &str, scheme: &str, addr: &str) {
1401 tracing::info!("{name} listening on {addr}");
1402 tracing::info!(" MCP endpoint: {scheme}://{addr}/mcp");
1403 tracing::info!(" Health check: {scheme}://{addr}/healthz");
1404 tracing::info!(" Readiness: {scheme}://{addr}/readyz");
1405}
1406
1407#[allow(
1430 clippy::too_many_arguments,
1431 clippy::cognitive_complexity,
1432 reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1433)]
1434async fn run_server(
1435 router: axum::Router,
1436 listener: TcpListener,
1437 tls_paths: Option<(PathBuf, PathBuf)>,
1438 mtls_config: Option<MtlsConfig>,
1439 shutdown_timeout: Duration,
1440 auth_state: Option<Arc<AuthState>>,
1441 rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1442 mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1443 ct: CancellationToken,
1444) -> anyhow::Result<()> {
1445 let shutdown_trigger = CancellationToken::new();
1449 {
1450 let trigger = shutdown_trigger.clone();
1451 let parent = ct.clone();
1452 tokio::spawn(async move {
1453 tokio::select! {
1454 () = shutdown_signal() => {}
1455 () = parent.cancelled() => {}
1456 }
1457 trigger.cancel();
1458 });
1459 }
1460
1461 let graceful = {
1462 let trigger = shutdown_trigger.clone();
1463 let ct = ct.clone();
1464 async move {
1465 trigger.cancelled().await;
1466 tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1467 ct.cancel();
1468 }
1469 };
1470
1471 let force_exit_timer = {
1472 let trigger = shutdown_trigger.clone();
1473 async move {
1474 trigger.cancelled().await;
1475 tokio::time::sleep(shutdown_timeout).await;
1476 }
1477 };
1478
1479 if let Some((cert_path, key_path)) = tls_paths {
1480 let crl_set = if let Some(mtls) = mtls_config.as_ref()
1481 && mtls.crl_enabled
1482 {
1483 let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1484 let (crl_set, discover_rx) =
1485 mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1486 .await
1487 .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1488 tokio::spawn(mtls_revocation::run_crl_refresher(
1489 Arc::clone(&crl_set),
1490 discover_rx,
1491 ct.clone(),
1492 ));
1493 Some(crl_set)
1494 } else {
1495 None
1496 };
1497
1498 if let Some(cb) = on_reload_ready.take() {
1499 cb(ReloadHandle {
1500 auth: auth_state.clone(),
1501 rbac: Some(Arc::clone(&rbac_swap)),
1502 crl_set: crl_set.clone(),
1503 });
1504 }
1505
1506 let tls_listener = TlsListener::new(
1507 listener,
1508 &cert_path,
1509 &key_path,
1510 mtls_config.as_ref(),
1511 crl_set,
1512 )?;
1513 let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1514 tokio::select! {
1515 result = axum::serve(tls_listener, make_svc)
1516 .with_graceful_shutdown(graceful) => { result?; }
1517 () = force_exit_timer => {
1518 tracing::warn!("shutdown timeout exceeded, forcing exit");
1519 }
1520 }
1521 } else {
1522 if let Some(cb) = on_reload_ready.take() {
1523 cb(ReloadHandle {
1524 auth: auth_state,
1525 rbac: Some(rbac_swap),
1526 crl_set: None,
1527 });
1528 }
1529
1530 let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1531 tokio::select! {
1532 result = axum::serve(listener, make_svc)
1533 .with_graceful_shutdown(graceful) => { result?; }
1534 () = force_exit_timer => {
1535 tracing::warn!("shutdown timeout exceeded, forcing exit");
1536 }
1537 }
1538 }
1539
1540 Ok(())
1541}
1542
1543#[cfg(feature = "oauth")]
1552fn install_oauth_proxy_routes(
1553 router: axum::Router,
1554 server_url: &str,
1555 oauth_config: &crate::oauth::OAuthConfig,
1556 auth_state: Option<&Arc<AuthState>>,
1557) -> Result<axum::Router, McpxError> {
1558 let Some(ref proxy) = oauth_config.proxy else {
1559 return Ok(router);
1560 };
1561
1562 let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1565
1566 let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1567 let router = router.route(
1568 "/.well-known/oauth-authorization-server",
1569 axum::routing::get(move || {
1570 let m = asm.clone();
1571 async move { axum::Json(m) }
1572 }),
1573 );
1574
1575 let proxy_authorize = proxy.clone();
1576 let router = router.route(
1577 "/authorize",
1578 axum::routing::get(
1579 move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1580 let p = proxy_authorize.clone();
1581 async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1582 },
1583 ),
1584 );
1585
1586 let proxy_token = proxy.clone();
1587 let token_http = http.clone();
1588 let router = router.route(
1589 "/token",
1590 axum::routing::post(move |body: String| {
1591 let p = proxy_token.clone();
1592 let h = token_http.clone();
1593 async move { crate::oauth::handle_token(&h, &p, &body).await }
1594 }),
1595 );
1596
1597 let proxy_register = proxy.clone();
1598 let router = router.route(
1599 "/register",
1600 axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1601 let p = proxy_register;
1602 async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1603 }),
1604 );
1605
1606 let admin_routes_enabled = proxy.expose_admin_endpoints
1607 && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1608 if proxy.expose_admin_endpoints && !proxy.require_auth_on_admin_endpoints {
1609 tracing::warn!(
1610 "OAuth introspect/revoke endpoints are unauthenticated; consider setting require_auth_on_admin_endpoints = true"
1611 );
1612 }
1613
1614 let admin_router = if admin_routes_enabled {
1615 let mut admin_router = axum::Router::new();
1616 if proxy.introspection_url.is_some() {
1617 let proxy_introspect = proxy.clone();
1618 let introspect_http = http.clone();
1619 admin_router = admin_router.route(
1620 "/introspect",
1621 axum::routing::post(move |body: String| {
1622 let p = proxy_introspect.clone();
1623 let h = introspect_http.clone();
1624 async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1625 }),
1626 );
1627 }
1628 if proxy.revocation_url.is_some() {
1629 let proxy_revoke = proxy.clone();
1630 let revoke_http = http;
1631 admin_router = admin_router.route(
1632 "/revoke",
1633 axum::routing::post(move |body: String| {
1634 let p = proxy_revoke.clone();
1635 let h = revoke_http.clone();
1636 async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1637 }),
1638 );
1639 }
1640
1641 if proxy.require_auth_on_admin_endpoints {
1642 let Some(state) = auth_state else {
1643 return Err(McpxError::Startup(
1644 "oauth proxy admin endpoints require auth state".into(),
1645 ));
1646 };
1647 let state_for_mw = Arc::clone(state);
1648 admin_router.layer(axum::middleware::from_fn(move |req, next| {
1649 let s = Arc::clone(&state_for_mw);
1650 auth_middleware(s, req, next)
1651 }))
1652 } else {
1653 admin_router
1654 }
1655 } else {
1656 axum::Router::new()
1657 };
1658
1659 let router = router.merge(admin_router);
1660
1661 tracing::info!(
1662 introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1663 revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1664 "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1665 );
1666 Ok(router)
1667}
1668
1669fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
1674 let mut hosts = vec![
1675 "localhost".to_owned(),
1676 "127.0.0.1".to_owned(),
1677 "::1".to_owned(),
1678 ];
1679
1680 if let Some(url) = public_url
1681 && let Ok(uri) = url.parse::<axum::http::Uri>()
1682 && let Some(authority) = uri.authority()
1683 {
1684 let host = authority.host().to_owned();
1685 if !hosts.iter().any(|h| h == &host) {
1686 hosts.push(host);
1687 }
1688
1689 let authority = authority.as_str().to_owned();
1690 if !hosts.iter().any(|h| h == &authority) {
1691 hosts.push(authority);
1692 }
1693 }
1694
1695 if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
1696 && let Some(authority) = uri.authority()
1697 {
1698 let host = authority.host().to_owned();
1699 if !hosts.iter().any(|h| h == &host) {
1700 hosts.push(host);
1701 }
1702
1703 let authority = authority.as_str().to_owned();
1704 if !hosts.iter().any(|h| h == &authority) {
1705 hosts.push(authority);
1706 }
1707 }
1708
1709 hosts
1710}
1711
1712impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
1725 for TlsConnInfo
1726{
1727 fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
1728 let addr = *target.remote_addr();
1729 let identity = target.io().identity().cloned();
1730 TlsConnInfo::new(addr, identity)
1731 }
1732}
1733
1734struct TlsListener {
1742 inner: TcpListener,
1743 acceptor: tokio_rustls::TlsAcceptor,
1744 mtls_default_role: String,
1745}
1746
1747impl TlsListener {
1748 fn new(
1749 inner: TcpListener,
1750 cert_path: &Path,
1751 key_path: &Path,
1752 mtls_config: Option<&MtlsConfig>,
1753 crl_set: Option<Arc<CrlSet>>,
1754 ) -> anyhow::Result<Self> {
1755 rustls::crypto::ring::default_provider()
1757 .install_default()
1758 .ok();
1759
1760 let certs = load_certs(cert_path)?;
1761 let key = load_key(key_path)?;
1762
1763 let mtls_default_role;
1764
1765 let tls_config = if let Some(mtls) = mtls_config {
1766 mtls_default_role = mtls.default_role.clone();
1767 let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
1768 {
1769 let Some(crl_set) = crl_set else {
1770 return Err(anyhow::anyhow!(
1771 "mTLS CRL verifier requested but CRL state was not initialized"
1772 ));
1773 };
1774 Arc::new(DynamicClientCertVerifier::new(crl_set))
1775 } else {
1776 let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
1777 if mtls.required {
1778 rustls::server::WebPkiClientVerifier::builder(root_store)
1779 .build()
1780 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1781 } else {
1782 rustls::server::WebPkiClientVerifier::builder(root_store)
1783 .allow_unauthenticated()
1784 .build()
1785 .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1786 }
1787 };
1788
1789 tracing::info!(
1790 ca = %mtls.ca_cert_path.display(),
1791 required = mtls.required,
1792 crl_enabled = mtls.crl_enabled,
1793 "mTLS client auth configured"
1794 );
1795
1796 rustls::ServerConfig::builder_with_protocol_versions(&[
1797 &rustls::version::TLS12,
1798 &rustls::version::TLS13,
1799 ])
1800 .with_client_cert_verifier(verifier)
1801 .with_single_cert(certs, key)?
1802 } else {
1803 mtls_default_role = "viewer".to_owned();
1804 rustls::ServerConfig::builder_with_protocol_versions(&[
1805 &rustls::version::TLS12,
1806 &rustls::version::TLS13,
1807 ])
1808 .with_no_client_auth()
1809 .with_single_cert(certs, key)?
1810 };
1811
1812 let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
1813 tracing::info!(
1814 "TLS enabled (cert: {}, key: {})",
1815 cert_path.display(),
1816 key_path.display()
1817 );
1818 Ok(Self {
1819 inner,
1820 acceptor,
1821 mtls_default_role,
1822 })
1823 }
1824
1825 fn extract_handshake_identity(
1829 tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1830 default_role: &str,
1831 addr: SocketAddr,
1832 ) -> Option<AuthIdentity> {
1833 let (_, server_conn) = tls_stream.get_ref();
1834 let cert_der = server_conn.peer_certificates()?.first()?;
1835 let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
1836 tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
1837 Some(id)
1838 }
1839}
1840
1841pub(crate) struct AuthenticatedTlsStream {
1853 inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1854 identity: Option<AuthIdentity>,
1855}
1856
1857impl AuthenticatedTlsStream {
1858 #[must_use]
1860 pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
1861 self.identity.as_ref()
1862 }
1863}
1864
1865impl std::fmt::Debug for AuthenticatedTlsStream {
1866 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1867 f.debug_struct("AuthenticatedTlsStream")
1868 .field("identity", &self.identity.as_ref().map(|id| &id.name))
1869 .finish_non_exhaustive()
1870 }
1871}
1872
1873impl tokio::io::AsyncRead for AuthenticatedTlsStream {
1874 fn poll_read(
1875 mut self: Pin<&mut Self>,
1876 cx: &mut std::task::Context<'_>,
1877 buf: &mut tokio::io::ReadBuf<'_>,
1878 ) -> std::task::Poll<std::io::Result<()>> {
1879 Pin::new(&mut self.inner).poll_read(cx, buf)
1880 }
1881}
1882
1883impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
1884 fn poll_write(
1885 mut self: Pin<&mut Self>,
1886 cx: &mut std::task::Context<'_>,
1887 buf: &[u8],
1888 ) -> std::task::Poll<std::io::Result<usize>> {
1889 Pin::new(&mut self.inner).poll_write(cx, buf)
1890 }
1891
1892 fn poll_flush(
1893 mut self: Pin<&mut Self>,
1894 cx: &mut std::task::Context<'_>,
1895 ) -> std::task::Poll<std::io::Result<()>> {
1896 Pin::new(&mut self.inner).poll_flush(cx)
1897 }
1898
1899 fn poll_shutdown(
1900 mut self: Pin<&mut Self>,
1901 cx: &mut std::task::Context<'_>,
1902 ) -> std::task::Poll<std::io::Result<()>> {
1903 Pin::new(&mut self.inner).poll_shutdown(cx)
1904 }
1905
1906 fn poll_write_vectored(
1907 mut self: Pin<&mut Self>,
1908 cx: &mut std::task::Context<'_>,
1909 bufs: &[std::io::IoSlice<'_>],
1910 ) -> std::task::Poll<std::io::Result<usize>> {
1911 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
1912 }
1913
1914 fn is_write_vectored(&self) -> bool {
1915 self.inner.is_write_vectored()
1916 }
1917}
1918
1919impl axum::serve::Listener for TlsListener {
1920 type Io = AuthenticatedTlsStream;
1921 type Addr = SocketAddr;
1922
1923 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
1924 loop {
1925 let (stream, addr) = match self.inner.accept().await {
1926 Ok(pair) => pair,
1927 Err(e) => {
1928 tracing::debug!("TCP accept error: {e}");
1929 continue;
1930 }
1931 };
1932 let tls_stream = match self.acceptor.accept(stream).await {
1933 Ok(s) => s,
1934 Err(e) => {
1935 tracing::debug!("TLS handshake failed from {addr}: {e}");
1936 continue;
1937 }
1938 };
1939 let identity =
1940 Self::extract_handshake_identity(&tls_stream, &self.mtls_default_role, addr);
1941 let wrapped = AuthenticatedTlsStream {
1942 inner: tls_stream,
1943 identity,
1944 };
1945 return (wrapped, addr);
1946 }
1947 }
1948
1949 fn local_addr(&self) -> std::io::Result<Self::Addr> {
1950 self.inner.local_addr()
1951 }
1952}
1953
1954fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
1955 use rustls::pki_types::pem::PemObject;
1956 let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
1957 .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
1958 .collect::<Result<_, _>>()
1959 .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
1960 anyhow::ensure!(
1961 !certs.is_empty(),
1962 "no certificates found in {}",
1963 path.display()
1964 );
1965 Ok(certs)
1966}
1967
1968fn load_client_auth_roots(
1969 path: &Path,
1970) -> anyhow::Result<(
1971 Vec<rustls::pki_types::CertificateDer<'static>>,
1972 Arc<RootCertStore>,
1973)> {
1974 let ca_certs = load_certs(path)?;
1975 let mut root_store = RootCertStore::empty();
1976 for cert in &ca_certs {
1977 root_store
1978 .add(cert.clone())
1979 .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
1980 }
1981
1982 Ok((ca_certs, Arc::new(root_store)))
1983}
1984
1985fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
1986 use rustls::pki_types::pem::PemObject;
1987 rustls::pki_types::PrivateKeyDer::from_pem_file(path)
1988 .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
1989}
1990
1991#[allow(clippy::unused_async)]
1992async fn healthz() -> impl IntoResponse {
1993 axum::Json(serde_json::json!({
1994 "status": "ok",
1995 }))
1996}
1997
1998fn version_payload(name: &str, version: &str) -> serde_json::Value {
2004 serde_json::json!({
2005 "name": name,
2006 "version": version,
2007 "build_git_sha": option_env!("MCPX_BUILD_SHA").unwrap_or("unknown"),
2008 "build_timestamp": option_env!("MCPX_BUILD_TIME").unwrap_or("unknown"),
2009 "rust_version": option_env!("MCPX_RUSTC_VERSION").unwrap_or("unknown"),
2010 "mcpx_version": env!("CARGO_PKG_VERSION"),
2011 })
2012}
2013
2014fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2024 let value = version_payload(name, version);
2025 serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2026}
2027
2028async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2029 let status = check().await;
2030 let ready = status
2031 .get("ready")
2032 .and_then(serde_json::Value::as_bool)
2033 .unwrap_or(false);
2034 let code = if ready {
2035 axum::http::StatusCode::OK
2036 } else {
2037 axum::http::StatusCode::SERVICE_UNAVAILABLE
2038 };
2039 (code, axum::Json(status))
2040}
2041
2042async fn shutdown_signal() {
2046 let ctrl_c = tokio::signal::ctrl_c();
2047
2048 #[cfg(unix)]
2049 {
2050 match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2051 Ok(mut term) => {
2052 tokio::select! {
2053 _ = ctrl_c => {}
2054 _ = term.recv() => {}
2055 }
2056 }
2057 Err(e) => {
2058 tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2059 ctrl_c.await.ok();
2060 }
2061 }
2062 }
2063
2064 #[cfg(not(unix))]
2065 {
2066 ctrl_c.await.ok();
2067 }
2068}
2069
2070#[cfg(feature = "metrics")]
2076async fn metrics_middleware(
2077 metrics: Arc<crate::metrics::McpMetrics>,
2078 req: Request<Body>,
2079 next: Next,
2080) -> axum::response::Response {
2081 let method = req.method().to_string();
2082 let path = req.uri().path().to_owned();
2083 let start = std::time::Instant::now();
2084
2085 let response = next.run(req).await;
2086
2087 let status = response.status().as_u16().to_string();
2088 let duration = start.elapsed().as_secs_f64();
2089
2090 metrics
2091 .http_requests_total
2092 .with_label_values(&[&method, &path, &status])
2093 .inc();
2094 metrics
2095 .http_request_duration_seconds
2096 .with_label_values(&[&method, &path])
2097 .observe(duration);
2098
2099 response
2100}
2101
2102async fn security_headers_middleware(
2110 is_tls: bool,
2111 req: Request<Body>,
2112 next: Next,
2113) -> axum::response::Response {
2114 use axum::http::{HeaderName, HeaderValue, header};
2115
2116 let mut resp = next.run(req).await;
2117 let headers = resp.headers_mut();
2118
2119 headers.remove(header::SERVER);
2121 headers.remove(HeaderName::from_static("x-powered-by"));
2122
2123 headers.insert(
2124 header::X_CONTENT_TYPE_OPTIONS,
2125 HeaderValue::from_static("nosniff"),
2126 );
2127 headers.insert(header::X_FRAME_OPTIONS, HeaderValue::from_static("deny"));
2128 headers.insert(
2129 header::CACHE_CONTROL,
2130 HeaderValue::from_static("no-store, max-age=0"),
2131 );
2132 headers.insert(
2133 header::REFERRER_POLICY,
2134 HeaderValue::from_static("no-referrer"),
2135 );
2136 headers.insert(
2137 HeaderName::from_static("cross-origin-opener-policy"),
2138 HeaderValue::from_static("same-origin"),
2139 );
2140 headers.insert(
2141 HeaderName::from_static("cross-origin-resource-policy"),
2142 HeaderValue::from_static("same-origin"),
2143 );
2144 headers.insert(
2145 HeaderName::from_static("cross-origin-embedder-policy"),
2146 HeaderValue::from_static("require-corp"),
2147 );
2148 headers.insert(
2149 HeaderName::from_static("permissions-policy"),
2150 HeaderValue::from_static("accelerometer=(), camera=(), geolocation=(), microphone=()"),
2151 );
2152 headers.insert(
2153 HeaderName::from_static("x-permitted-cross-domain-policies"),
2154 HeaderValue::from_static("none"),
2155 );
2156 headers.insert(
2157 HeaderName::from_static("content-security-policy"),
2158 HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'"),
2159 );
2160 headers.insert(
2161 HeaderName::from_static("x-dns-prefetch-control"),
2162 HeaderValue::from_static("off"),
2163 );
2164
2165 if is_tls {
2166 headers.insert(
2167 header::STRICT_TRANSPORT_SECURITY,
2168 HeaderValue::from_static("max-age=63072000; includeSubDomains"),
2169 );
2170 }
2171
2172 resp
2173}
2174
2175async fn origin_check_middleware(
2179 allowed: Arc<[String]>,
2180 log_request_headers: bool,
2181 req: Request<Body>,
2182 next: Next,
2183) -> axum::response::Response {
2184 let method = req.method().clone();
2185 let path = req.uri().path().to_owned();
2186
2187 log_incoming_request(&method, &path, req.headers(), log_request_headers);
2188
2189 if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2190 let origin_str = origin.to_str().unwrap_or("");
2191 if !allowed.iter().any(|a| a == origin_str) {
2192 tracing::warn!(
2193 origin = origin_str,
2194 %method,
2195 %path,
2196 allowed = ?&*allowed,
2197 "rejected request: Origin not allowed"
2198 );
2199 return (
2200 axum::http::StatusCode::FORBIDDEN,
2201 "Forbidden: Origin not allowed",
2202 )
2203 .into_response();
2204 }
2205 }
2206 next.run(req).await
2207}
2208
2209fn log_incoming_request(
2212 method: &axum::http::Method,
2213 path: &str,
2214 headers: &axum::http::HeaderMap,
2215 log_request_headers: bool,
2216) {
2217 if log_request_headers {
2218 tracing::debug!(
2219 %method,
2220 %path,
2221 headers = %format_request_headers_for_log(headers),
2222 "incoming request"
2223 );
2224 } else {
2225 tracing::debug!(%method, %path, "incoming request");
2226 }
2227}
2228
2229fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2230 headers
2231 .iter()
2232 .map(|(k, v)| {
2233 let name = k.as_str();
2234 if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2235 format!("{name}: [REDACTED]")
2236 } else {
2237 format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2238 }
2239 })
2240 .collect::<Vec<_>>()
2241 .join(", ")
2242}
2243
2244#[allow(clippy::cognitive_complexity)]
2268pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2269where
2270 H: ServerHandler + 'static,
2271{
2272 use rmcp::ServiceExt as _;
2273
2274 tracing::info!("stdio transport: serving on stdin/stdout");
2275 tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2276
2277 let transport = rmcp::transport::io::stdio();
2278
2279 let service = handler
2280 .serve(transport)
2281 .await
2282 .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2283
2284 if let Err(e) = service.waiting().await {
2285 tracing::warn!(error = %e, "stdio session ended with error");
2286 }
2287 tracing::info!("stdio session ended");
2288 Ok(())
2289}
2290
2291#[cfg(test)]
2292mod tests {
2293 #![allow(
2294 clippy::unwrap_used,
2295 clippy::expect_used,
2296 clippy::panic,
2297 clippy::indexing_slicing,
2298 clippy::unwrap_in_result,
2299 clippy::print_stdout,
2300 clippy::print_stderr,
2301 deprecated,
2302 reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2303 )]
2304 use std::sync::Arc;
2305
2306 use axum::{
2307 body::Body,
2308 http::{Request, StatusCode, header},
2309 response::IntoResponse,
2310 };
2311 use http_body_util::BodyExt;
2312 use tower::ServiceExt as _;
2313
2314 use super::*;
2315
2316 #[test]
2319 fn server_config_new_defaults() {
2320 let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2321 assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2322 assert_eq!(cfg.name, "test-server");
2323 assert_eq!(cfg.version, "1.0.0");
2324 assert!(cfg.tls_cert_path.is_none());
2325 assert!(cfg.tls_key_path.is_none());
2326 assert!(cfg.auth.is_none());
2327 assert!(cfg.rbac.is_none());
2328 assert!(cfg.allowed_origins.is_empty());
2329 assert!(cfg.tool_rate_limit.is_none());
2330 assert!(cfg.readiness_check.is_none());
2331 assert_eq!(cfg.max_request_body, 1024 * 1024);
2332 assert_eq!(cfg.request_timeout, Duration::from_mins(2));
2333 assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
2334 assert!(!cfg.log_request_headers);
2335 }
2336
2337 #[test]
2338 fn validate_consumes_and_proves() {
2339 let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2341 let validated = cfg.validate().expect("valid config");
2342 assert_eq!(validated.name, "test-server");
2344 let raw = validated.into_inner();
2346 assert_eq!(raw.name, "test-server");
2347
2348 let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2350 bad.max_request_body = 0;
2351 assert!(bad.validate().is_err(), "zero body cap must fail validate");
2352 }
2353
2354 #[test]
2355 fn derive_allowed_hosts_includes_public_host() {
2356 let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
2357 assert!(
2358 hosts.iter().any(|h| h == "mcp.example.com"),
2359 "public_url host must be allowed"
2360 );
2361 }
2362
2363 #[test]
2364 fn derive_allowed_hosts_includes_bind_authority() {
2365 let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
2366 assert!(
2367 hosts.iter().any(|h| h == "127.0.0.1"),
2368 "bind host must be allowed"
2369 );
2370 assert!(
2371 hosts.iter().any(|h| h == "127.0.0.1:8080"),
2372 "bind authority must be allowed"
2373 );
2374 }
2375
2376 #[tokio::test]
2379 async fn healthz_returns_ok_json() {
2380 let resp = healthz().await.into_response();
2381 assert_eq!(resp.status(), StatusCode::OK);
2382 let body = resp.into_body().collect().await.unwrap().to_bytes();
2383 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2384 assert_eq!(json["status"], "ok");
2385 assert!(
2386 json.get("name").is_none(),
2387 "healthz must not expose server name"
2388 );
2389 assert!(
2390 json.get("version").is_none(),
2391 "healthz must not expose version"
2392 );
2393 }
2394
2395 #[tokio::test]
2398 async fn readyz_returns_ok_when_ready() {
2399 let check: ReadinessCheck =
2400 Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
2401 let resp = readyz(check).await.into_response();
2402 assert_eq!(resp.status(), StatusCode::OK);
2403 let body = resp.into_body().collect().await.unwrap().to_bytes();
2404 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2405 assert_eq!(json["ready"], true);
2406 assert!(
2407 json.get("name").is_none(),
2408 "readyz must not expose server name"
2409 );
2410 assert!(
2411 json.get("version").is_none(),
2412 "readyz must not expose version"
2413 );
2414 assert_eq!(json["db"], "connected");
2415 }
2416
2417 #[tokio::test]
2418 async fn readyz_returns_503_when_not_ready() {
2419 let check: ReadinessCheck =
2420 Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
2421 let resp = readyz(check).await.into_response();
2422 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2423 }
2424
2425 #[tokio::test]
2426 async fn readyz_returns_503_when_ready_missing() {
2427 let check: ReadinessCheck =
2428 Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
2429 let resp = readyz(check).await.into_response();
2430 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2432 }
2433
2434 fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
2438 let allowed: Arc<[String]> = Arc::from(origins);
2439 axum::Router::new()
2440 .route("/test", axum::routing::get(|| async { "ok" }))
2441 .layer(axum::middleware::from_fn(move |req, next| {
2442 let a = Arc::clone(&allowed);
2443 origin_check_middleware(a, log_request_headers, req, next)
2444 }))
2445 }
2446
2447 #[tokio::test]
2448 async fn origin_allowed_passes() {
2449 let app = origin_router(vec!["http://localhost:3000".into()], false);
2450 let req = Request::builder()
2451 .uri("/test")
2452 .header(header::ORIGIN, "http://localhost:3000")
2453 .body(Body::empty())
2454 .unwrap();
2455 let resp = app.oneshot(req).await.unwrap();
2456 assert_eq!(resp.status(), StatusCode::OK);
2457 }
2458
2459 #[tokio::test]
2460 async fn origin_rejected_returns_403() {
2461 let app = origin_router(vec!["http://localhost:3000".into()], false);
2462 let req = Request::builder()
2463 .uri("/test")
2464 .header(header::ORIGIN, "http://evil.com")
2465 .body(Body::empty())
2466 .unwrap();
2467 let resp = app.oneshot(req).await.unwrap();
2468 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2469 }
2470
2471 #[tokio::test]
2472 async fn no_origin_header_passes() {
2473 let app = origin_router(vec!["http://localhost:3000".into()], false);
2474 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2475 let resp = app.oneshot(req).await.unwrap();
2476 assert_eq!(resp.status(), StatusCode::OK);
2477 }
2478
2479 #[tokio::test]
2480 async fn empty_allowlist_rejects_any_origin() {
2481 let app = origin_router(vec![], false);
2482 let req = Request::builder()
2483 .uri("/test")
2484 .header(header::ORIGIN, "http://anything.com")
2485 .body(Body::empty())
2486 .unwrap();
2487 let resp = app.oneshot(req).await.unwrap();
2488 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2489 }
2490
2491 #[tokio::test]
2492 async fn empty_allowlist_passes_without_origin() {
2493 let app = origin_router(vec![], false);
2494 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2495 let resp = app.oneshot(req).await.unwrap();
2496 assert_eq!(resp.status(), StatusCode::OK);
2497 }
2498
2499 #[test]
2500 fn format_request_headers_redacts_sensitive_values() {
2501 let mut headers = axum::http::HeaderMap::new();
2502 headers.insert("authorization", "Bearer secret-token".parse().unwrap());
2503 headers.insert("cookie", "sid=abc".parse().unwrap());
2504 headers.insert("x-request-id", "req-123".parse().unwrap());
2505
2506 let out = format_request_headers_for_log(&headers);
2507 assert!(out.contains("authorization: [REDACTED]"));
2508 assert!(out.contains("cookie: [REDACTED]"));
2509 assert!(out.contains("x-request-id: req-123"));
2510 assert!(!out.contains("secret-token"));
2511 }
2512
2513 fn security_router(is_tls: bool) -> axum::Router {
2516 axum::Router::new()
2517 .route("/test", axum::routing::get(|| async { "ok" }))
2518 .layer(axum::middleware::from_fn(move |req, next| {
2519 security_headers_middleware(is_tls, req, next)
2520 }))
2521 }
2522
2523 #[tokio::test]
2524 async fn security_headers_set_on_response() {
2525 let app = security_router(false);
2526 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2527 let resp = app.oneshot(req).await.unwrap();
2528 assert_eq!(resp.status(), StatusCode::OK);
2529
2530 let h = resp.headers();
2531 assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
2532 assert_eq!(h.get("x-frame-options").unwrap(), "deny");
2533 assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
2534 assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
2535 assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
2536 assert_eq!(
2537 h.get("cross-origin-resource-policy").unwrap(),
2538 "same-origin"
2539 );
2540 assert_eq!(
2541 h.get("cross-origin-embedder-policy").unwrap(),
2542 "require-corp"
2543 );
2544 assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
2545 assert!(
2546 h.get("permissions-policy")
2547 .unwrap()
2548 .to_str()
2549 .unwrap()
2550 .contains("camera=()"),
2551 "permissions-policy must restrict browser features"
2552 );
2553 assert_eq!(
2554 h.get("content-security-policy").unwrap(),
2555 "default-src 'none'; frame-ancestors 'none'"
2556 );
2557 assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
2558 assert!(h.get("strict-transport-security").is_none());
2560 }
2561
2562 #[tokio::test]
2563 async fn hsts_set_when_tls_enabled() {
2564 let app = security_router(true);
2565 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2566 let resp = app.oneshot(req).await.unwrap();
2567
2568 let hsts = resp.headers().get("strict-transport-security").unwrap();
2569 assert!(
2570 hsts.to_str().unwrap().contains("max-age=63072000"),
2571 "HSTS must set 2-year max-age"
2572 );
2573 }
2574
2575 #[test]
2578 fn version_payload_contains_expected_fields() {
2579 let v = version_payload("my-server", "1.2.3");
2580 assert_eq!(v["name"], "my-server");
2581 assert_eq!(v["version"], "1.2.3");
2582 assert!(v["build_git_sha"].is_string());
2583 assert!(v["build_timestamp"].is_string());
2584 assert!(v["rust_version"].is_string());
2585 assert!(v["mcpx_version"].is_string());
2586 }
2587
2588 #[tokio::test]
2591 async fn concurrency_limit_layer_composes_and_serves() {
2592 let app = axum::Router::new()
2596 .route("/ok", axum::routing::get(|| async { "ok" }))
2597 .layer(
2598 tower::ServiceBuilder::new()
2599 .layer(axum::error_handling::HandleErrorLayer::new(
2600 |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
2601 ))
2602 .layer(tower::load_shed::LoadShedLayer::new())
2603 .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
2604 );
2605 let resp = app
2606 .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
2607 .await
2608 .unwrap();
2609 assert_eq!(resp.status(), StatusCode::OK);
2610 }
2611
2612 #[tokio::test]
2615 async fn compression_layer_gzip_encodes_response() {
2616 use tower_http::compression::Predicate as _;
2617
2618 let big_body = "a".repeat(4096);
2619 let app = axum::Router::new()
2620 .route(
2621 "/big",
2622 axum::routing::get(move || {
2623 let body = big_body.clone();
2624 async move { body }
2625 }),
2626 )
2627 .layer(
2628 tower_http::compression::CompressionLayer::new()
2629 .gzip(true)
2630 .br(true)
2631 .compress_when(
2632 tower_http::compression::DefaultPredicate::new()
2633 .and(tower_http::compression::predicate::SizeAbove::new(1024)),
2634 ),
2635 );
2636
2637 let req = Request::builder()
2638 .uri("/big")
2639 .header(header::ACCEPT_ENCODING, "gzip")
2640 .body(Body::empty())
2641 .unwrap();
2642 let resp = app.oneshot(req).await.unwrap();
2643 assert_eq!(resp.status(), StatusCode::OK);
2644 assert_eq!(
2645 resp.headers().get(header::CONTENT_ENCODING).unwrap(),
2646 "gzip"
2647 );
2648 }
2649}