Skip to main content

rmcp_server_kit/
transport.rs

1use std::{
2    future::Future,
3    net::SocketAddr,
4    path::{Path, PathBuf},
5    pin::Pin,
6    sync::Arc,
7    time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{body::Body, extract::Request, middleware::Next, response::IntoResponse};
12use rmcp::{
13    ServerHandler,
14    transport::streamable_http_server::{
15        StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16    },
17};
18use rustls::RootCertStore;
19use tokio::net::TcpListener;
20use tokio_util::sync::CancellationToken;
21
22use crate::{
23    auth::{
24        AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
25        build_rate_limiter, extract_mtls_identity,
26    },
27    error::McpxError,
28    mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
29    rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
30};
31
32/// Map an internal `anyhow::Error` chain into a public [`McpxError::Startup`]
33/// at the public API boundary, flattening the chain via the alternate
34/// formatter so callers see the full causal path.
35#[allow(
36    clippy::needless_pass_by_value,
37    reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
38)]
39fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
40    McpxError::Startup(format!("{e:#}"))
41}
42
43/// Map a `std::io::Error` produced during server startup into a public
44/// [`McpxError::Startup`]. We deliberately do not use the [`McpxError::Io`]
45/// `From` impl here because startup-phase IO errors (bind, listener) are
46/// semantically distinct from request-time IO errors and should surface
47/// the originating operation in the message.
48#[allow(
49    clippy::needless_pass_by_value,
50    reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
51)]
52fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
53    McpxError::Startup(format!("{op}: {e}"))
54}
55
56/// Async readiness check callback for the `/readyz` endpoint.
57///
58/// Returns a JSON object with at least a `"ready"` boolean.
59/// When `ready` is false, the endpoint returns HTTP 503.
60pub type ReadinessCheck =
61    Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
62
63/// Configuration for the MCP server.
64#[allow(
65    missing_debug_implementations,
66    reason = "contains callback/trait objects that don't impl Debug"
67)]
68#[allow(
69    clippy::struct_excessive_bools,
70    reason = "server configuration naturally has many boolean feature flags"
71)]
72#[non_exhaustive]
73pub struct McpServerConfig {
74    /// Socket address the MCP HTTP server binds to.
75    #[deprecated(
76        since = "0.13.0",
77        note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in 1.0"
78    )]
79    pub bind_addr: String,
80    /// Server name advertised via MCP `initialize`.
81    #[deprecated(
82        since = "0.13.0",
83        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in 1.0"
84    )]
85    pub name: String,
86    /// Server version advertised via MCP `initialize`.
87    #[deprecated(
88        since = "0.13.0",
89        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in 1.0"
90    )]
91    pub version: String,
92    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
93    #[deprecated(
94        since = "0.13.0",
95        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in 1.0"
96    )]
97    pub tls_cert_path: Option<PathBuf>,
98    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
99    #[deprecated(
100        since = "0.13.0",
101        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in 1.0"
102    )]
103    pub tls_key_path: Option<PathBuf>,
104    /// Optional authentication config. When `Some` and `enabled`, auth
105    /// is enforced on `/mcp`. `/healthz` is always open.
106    #[deprecated(
107        since = "0.13.0",
108        note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in 1.0"
109    )]
110    pub auth: Option<AuthConfig>,
111    /// Optional RBAC policy. When present and enabled, tool calls are
112    /// checked against the policy after authentication.
113    #[deprecated(
114        since = "0.13.0",
115        note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in 1.0"
116    )]
117    pub rbac: Option<Arc<RbacPolicy>>,
118    /// Allowed Origin values for DNS rebinding protection (MCP spec MUST).
119    /// When empty and `public_url` is set, the origin is auto-derived from
120    /// the public URL. When both are empty, only requests with no Origin
121    /// header are accepted.
122    /// Example entries: `"http://localhost:3000"`, `"https://myapp.example.com"`.
123    #[deprecated(
124        since = "0.13.0",
125        note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in 1.0"
126    )]
127    pub allowed_origins: Vec<String>,
128    /// Maximum tool invocations per source IP per minute.
129    /// When set, enforced on every `tools/call` request.
130    #[deprecated(
131        since = "0.13.0",
132        note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in 1.0"
133    )]
134    pub tool_rate_limit: Option<u32>,
135    /// Optional readiness probe for `/readyz`.
136    /// When `None`, `/readyz` mirrors `/healthz` (always OK).
137    #[deprecated(
138        since = "0.13.0",
139        note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in 1.0"
140    )]
141    pub readiness_check: Option<ReadinessCheck>,
142    /// Maximum request body size in bytes. Default: 1 MiB.
143    /// Protects against oversized payloads causing OOM.
144    #[deprecated(
145        since = "0.13.0",
146        note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in 1.0"
147    )]
148    pub max_request_body: usize,
149    /// Request processing timeout. Default: 120s.
150    /// Requests exceeding this duration receive 408 Request Timeout.
151    #[deprecated(
152        since = "0.13.0",
153        note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in 1.0"
154    )]
155    pub request_timeout: Duration,
156    /// Graceful shutdown timeout. Default: 30s.
157    /// After the shutdown signal, in-flight requests have this long to finish.
158    #[deprecated(
159        since = "0.13.0",
160        note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in 1.0"
161    )]
162    pub shutdown_timeout: Duration,
163    /// Idle timeout for MCP sessions. Sessions with no activity for this
164    /// duration are closed automatically. Default: 20 minutes.
165    #[deprecated(
166        since = "0.13.0",
167        note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in 1.0"
168    )]
169    pub session_idle_timeout: Duration,
170    /// Interval for SSE keep-alive pings. Prevents proxies and load
171    /// balancers from killing idle connections. Default: 15 seconds.
172    #[deprecated(
173        since = "0.13.0",
174        note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in 1.0"
175    )]
176    pub sse_keep_alive: Duration,
177    /// Callback invoked once the server is built, delivering a
178    /// [`ReloadHandle`] for hot-reloading auth keys and RBAC policy
179    /// at runtime (e.g. on SIGHUP). Only useful when auth/RBAC is enabled.
180    #[deprecated(
181        since = "0.13.0",
182        note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in 1.0"
183    )]
184    pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
185    /// Additional application-specific routes merged into the top-level
186    /// router.  These routes **bypass** the MCP auth and RBAC middleware,
187    /// so the application is responsible for its own auth on them.
188    #[deprecated(
189        since = "0.13.0",
190        note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in 1.0"
191    )]
192    pub extra_router: Option<axum::Router>,
193    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
194    /// When set, OAuth metadata endpoints advertise this URL instead of
195    /// the listen address. Required when binding `0.0.0.0` behind a
196    /// reverse proxy or inside a container.
197    #[deprecated(
198        since = "0.13.0",
199        note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in 1.0"
200    )]
201    pub public_url: Option<String>,
202    /// Log inbound HTTP request headers at DEBUG level.
203    /// Sensitive values remain redacted.
204    #[deprecated(
205        since = "0.13.0",
206        note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in 1.0"
207    )]
208    pub log_request_headers: bool,
209    /// Enable gzip/br response compression on MCP responses.
210    /// Defaults to `false` to preserve existing behaviour.
211    #[deprecated(
212        since = "0.13.0",
213        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in 1.0"
214    )]
215    pub compression_enabled: bool,
216    /// Minimum response body size (in bytes) before compression kicks in.
217    /// Only used when `compression_enabled` is true. Default: 1024.
218    #[deprecated(
219        since = "0.13.0",
220        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in 1.0"
221    )]
222    pub compression_min_size: u16,
223    /// Global cap on in-flight HTTP requests across the whole server.
224    /// When `Some`, requests over the cap receive 503 Service Unavailable
225    /// via `tower::load_shed`. Default: `None` (unlimited).
226    #[deprecated(
227        since = "0.13.0",
228        note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in 1.0"
229    )]
230    pub max_concurrent_requests: Option<usize>,
231    /// Enable `/admin/*` diagnostic endpoints. Requires `auth` to be
232    /// configured and `enabled`. Default: `false`.
233    #[deprecated(
234        since = "0.13.0",
235        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in 1.0"
236    )]
237    pub admin_enabled: bool,
238    /// RBAC role required to access admin endpoints. Default: `"admin"`.
239    #[deprecated(
240        since = "0.13.0",
241        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in 1.0"
242    )]
243    pub admin_role: String,
244    /// Enable Prometheus metrics endpoint on a separate listener.
245    /// Requires the `metrics` crate feature.
246    #[cfg(feature = "metrics")]
247    #[deprecated(
248        since = "0.13.0",
249        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in 1.0"
250    )]
251    pub metrics_enabled: bool,
252    /// Bind address for the Prometheus metrics listener. Default: `127.0.0.1:9090`.
253    #[cfg(feature = "metrics")]
254    #[deprecated(
255        since = "0.13.0",
256        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in 1.0"
257    )]
258    pub metrics_bind: String,
259}
260
261/// Marker that wraps a value proven to satisfy its validation
262/// contract.
263///
264/// The only way to obtain `Validated<McpServerConfig>` is by calling
265/// [`McpServerConfig::validate`], which is the contract enforced at
266/// the type level by [`serve`] and [`serve_with_listener`]. The
267/// inner field is private, so downstream code cannot bypass
268/// validation by hand-constructing the wrapper.
269///
270/// `Validated<T>` derefs to `&T` for read-only access. To mutate,
271/// recover the raw value with [`Validated::into_inner`] and
272/// re-validate.
273///
274/// # Example
275///
276/// ```no_run
277/// use rmcp_server_kit::transport::{McpServerConfig, Validated, serve};
278/// use rmcp::handler::server::ServerHandler;
279/// use rmcp::model::{ServerCapabilities, ServerInfo};
280///
281/// #[derive(Clone)]
282/// struct H;
283/// impl ServerHandler for H {
284///     fn get_info(&self) -> ServerInfo {
285///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
286///     }
287/// }
288///
289/// # async fn example() -> rmcp_server_kit::Result<()> {
290/// let config: Validated<McpServerConfig> =
291///     McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0").validate()?;
292/// serve(config, || H).await
293/// # }
294/// ```
295///
296/// Forgetting `.validate()?` is a compile error:
297///
298/// ```compile_fail
299/// use rmcp_server_kit::transport::{McpServerConfig, serve};
300/// use rmcp::handler::server::ServerHandler;
301/// use rmcp::model::{ServerCapabilities, ServerInfo};
302///
303/// #[derive(Clone)]
304/// struct H;
305/// impl ServerHandler for H {
306///     fn get_info(&self) -> ServerInfo {
307///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
308///     }
309/// }
310///
311/// # async fn example() -> rmcp_server_kit::Result<()> {
312/// let config = McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0");
313/// // Missing `.validate()?` -> mismatched types: expected
314/// // `Validated<McpServerConfig>`, found `McpServerConfig`.
315/// serve(config, || H).await
316/// # }
317/// ```
318#[allow(
319    missing_debug_implementations,
320    reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
321)]
322pub struct Validated<T>(T);
323
324impl<T> std::fmt::Debug for Validated<T> {
325    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        f.debug_struct("Validated").finish_non_exhaustive()
327    }
328}
329
330impl<T> Validated<T> {
331    /// Borrow the inner value.
332    #[must_use]
333    pub fn as_inner(&self) -> &T {
334        &self.0
335    }
336
337    /// Recover the raw value, discarding the validation proof.
338    ///
339    /// Re-validate before re-using the value with [`serve`] or
340    /// [`serve_with_listener`].
341    #[must_use]
342    pub fn into_inner(self) -> T {
343        self.0
344    }
345}
346
347impl<T> std::ops::Deref for Validated<T> {
348    type Target = T;
349
350    fn deref(&self) -> &T {
351        &self.0
352    }
353}
354
355#[allow(
356    deprecated,
357    reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
358)]
359impl McpServerConfig {
360    /// Create a new server configuration with the given bind address,
361    /// server name, and version. All other fields use safe defaults.
362    ///
363    /// Use the chainable `with_*` / `enable_*` builder methods to
364    /// customize. Call [`McpServerConfig::validate`] to obtain a
365    /// [`Validated<McpServerConfig>`] proof token, which is required by
366    /// [`serve`] and [`serve_with_listener`].
367    #[must_use]
368    pub fn new(
369        bind_addr: impl Into<String>,
370        name: impl Into<String>,
371        version: impl Into<String>,
372    ) -> Self {
373        Self {
374            bind_addr: bind_addr.into(),
375            name: name.into(),
376            version: version.into(),
377            tls_cert_path: None,
378            tls_key_path: None,
379            auth: None,
380            rbac: None,
381            allowed_origins: Vec::new(),
382            tool_rate_limit: None,
383            readiness_check: None,
384            max_request_body: 1024 * 1024,
385            request_timeout: Duration::from_mins(2),
386            shutdown_timeout: Duration::from_secs(30),
387            session_idle_timeout: Duration::from_mins(20),
388            sse_keep_alive: Duration::from_secs(15),
389            on_reload_ready: None,
390            extra_router: None,
391            public_url: None,
392            log_request_headers: false,
393            compression_enabled: false,
394            compression_min_size: 1024,
395            max_concurrent_requests: None,
396            admin_enabled: false,
397            admin_role: "admin".to_owned(),
398            #[cfg(feature = "metrics")]
399            metrics_enabled: false,
400            #[cfg(feature = "metrics")]
401            metrics_bind: "127.0.0.1:9090".into(),
402        }
403    }
404
405    // ---------------------------------------------------------------
406    // Builder methods (fluent, consume + return self).
407    //
408    // Each method is `#[must_use]` because dropping the returned
409    // `McpServerConfig` discards the configuration change.
410    // ---------------------------------------------------------------
411
412    /// Attach an authentication configuration. Required for
413    /// [`enable_admin`](Self::enable_admin) and any non-public deployment.
414    #[must_use]
415    pub fn with_auth(mut self, auth: AuthConfig) -> Self {
416        self.auth = Some(auth);
417        self
418    }
419
420    /// Override the bind address (e.g. `127.0.0.1:8080`). Useful when the
421    /// final port is only known after pre-binding an ephemeral listener
422    /// (tests, dynamic-port deployments).
423    #[must_use]
424    pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
425        self.bind_addr = addr.into();
426        self
427    }
428
429    /// Attach an RBAC policy. Tool calls are checked against the policy
430    /// after authentication.
431    #[must_use]
432    pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
433        self.rbac = Some(rbac);
434        self
435    }
436
437    /// Configure TLS by providing the certificate and private key paths
438    /// (PEM). Both must be readable at startup. Without this call, the
439    /// server runs plain HTTP.
440    #[must_use]
441    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
442        self.tls_cert_path = Some(cert_path.into());
443        self.tls_key_path = Some(key_path.into());
444        self
445    }
446
447    /// Set the externally reachable base URL (e.g. `https://mcp.example.com`).
448    /// Required when binding `0.0.0.0` behind a reverse proxy or inside
449    /// a container so OAuth metadata and auto-derived origins resolve correctly.
450    #[must_use]
451    pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
452        self.public_url = Some(url.into());
453        self
454    }
455
456    /// Replace the allowed Origin allow-list (DNS-rebinding protection).
457    /// When empty and [`with_public_url`](Self::with_public_url) is set,
458    /// the origin is auto-derived.
459    #[must_use]
460    pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
461    where
462        I: IntoIterator<Item = S>,
463        S: Into<String>,
464    {
465        self.allowed_origins = origins.into_iter().map(Into::into).collect();
466        self
467    }
468
469    /// Merge an additional axum router at the top level. Routes added
470    /// here **bypass** rmcp-server-kit auth and RBAC; the application is responsible
471    /// for its own protection.
472    #[must_use]
473    pub fn with_extra_router(mut self, router: axum::Router) -> Self {
474        self.extra_router = Some(router);
475        self
476    }
477
478    /// Install an async readiness probe for `/readyz`. Without this call,
479    /// `/readyz` mirrors `/healthz` (always 200 OK).
480    #[must_use]
481    pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
482        self.readiness_check = Some(check);
483        self
484    }
485
486    /// Override the maximum request body (bytes). Must be `> 0`.
487    /// Default: 1 MiB.
488    #[must_use]
489    pub fn with_max_request_body(mut self, bytes: usize) -> Self {
490        self.max_request_body = bytes;
491        self
492    }
493
494    /// Override the per-request processing timeout. Default: 2 minutes.
495    #[must_use]
496    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
497        self.request_timeout = timeout;
498        self
499    }
500
501    /// Override the graceful shutdown grace period. Default: 30 seconds.
502    #[must_use]
503    pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
504        self.shutdown_timeout = timeout;
505        self
506    }
507
508    /// Override the MCP session idle timeout. Default: 20 minutes.
509    #[must_use]
510    pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
511        self.session_idle_timeout = timeout;
512        self
513    }
514
515    /// Override the SSE keep-alive interval. Default: 15 seconds.
516    #[must_use]
517    pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
518        self.sse_keep_alive = interval;
519        self
520    }
521
522    /// Cap the global number of in-flight HTTP requests via
523    /// `tower::load_shed`. Excess requests receive 503 Service Unavailable.
524    /// Default: unlimited.
525    #[must_use]
526    pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
527        self.max_concurrent_requests = Some(limit);
528        self
529    }
530
531    /// Cap tool invocations per source IP per minute. Enforced on every
532    /// `tools/call` request.
533    #[must_use]
534    pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
535        self.tool_rate_limit = Some(per_minute);
536        self
537    }
538
539    /// Register a callback that receives the [`ReloadHandle`] after the
540    /// server is built. Use it to wire SIGHUP-style hot reloads of API
541    /// keys and RBAC policy.
542    #[must_use]
543    pub fn with_reload_callback<F>(mut self, callback: F) -> Self
544    where
545        F: FnOnce(ReloadHandle) + Send + 'static,
546    {
547        self.on_reload_ready = Some(Box::new(callback));
548        self
549    }
550
551    /// Enable gzip/brotli response compression on MCP responses.
552    /// `min_size` is the smallest body size (bytes) eligible for
553    /// compression. Default min size: 1024.
554    #[must_use]
555    pub fn enable_compression(mut self, min_size: u16) -> Self {
556        self.compression_enabled = true;
557        self.compression_min_size = min_size;
558        self
559    }
560
561    /// Enable `/admin/*` diagnostic endpoints. Requires
562    /// [`with_auth`](Self::with_auth) to be set and enabled; otherwise
563    /// [`validate`](Self::validate) returns an error. `role` is the RBAC
564    /// role gate (default: `"admin"`).
565    #[must_use]
566    pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
567        self.admin_enabled = true;
568        self.admin_role = role.into();
569        self
570    }
571
572    /// Log inbound HTTP request headers at DEBUG level. Sensitive
573    /// values remain redacted by the logging layer.
574    #[must_use]
575    pub fn enable_request_header_logging(mut self) -> Self {
576        self.log_request_headers = true;
577        self
578    }
579
580    /// Enable the Prometheus metrics listener on `bind` (e.g.
581    /// `127.0.0.1:9090`). Requires the `metrics` crate feature.
582    #[cfg(feature = "metrics")]
583    #[must_use]
584    pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
585        self.metrics_enabled = true;
586        self.metrics_bind = bind.into();
587        self
588    }
589
590    /// Validate the configuration and consume `self`, returning a
591    /// [`Validated<McpServerConfig>`] proof token required by [`serve`]
592    /// and [`serve_with_listener`]. This is the only way to construct
593    /// `Validated<McpServerConfig>`, so the type system guarantees
594    /// validation has run before the server starts.
595    ///
596    /// Checks:
597    ///
598    /// 1. `admin_enabled` requires `auth` to be configured and enabled.
599    /// 2. `tls_cert_path` and `tls_key_path` must both be set or both
600    ///    be unset.
601    /// 3. `bind_addr` must parse as a [`SocketAddr`].
602    /// 4. `public_url`, when set, must start with `http://` or `https://`.
603    /// 5. Each entry in `allowed_origins` must start with `http://` or
604    ///    `https://`.
605    /// 6. `max_request_body` must be greater than zero.
606    /// 7. When the `oauth` feature is enabled and an [`OAuthConfig`] is
607    ///    present, all OAuth URL fields (`jwks_uri`, `proxy.authorize_url`,
608    ///    `proxy.token_url`, `proxy.introspection_url`,
609    ///    `proxy.revocation_url`, `token_exchange.token_url`) must parse
610    ///    and use the `https` scheme. Set
611    ///    [`OAuthConfig::allow_http_oauth_urls`] to permit `http://`
612    ///    targets (strongly discouraged in production - see the field-level
613    ///    docs for the threat model).
614    ///
615    /// [`OAuthConfig`]: crate::oauth::OAuthConfig
616    /// [`OAuthConfig::allow_http_oauth_urls`]: crate::oauth::OAuthConfig::allow_http_oauth_urls
617    ///
618    /// # Errors
619    ///
620    /// Returns [`McpxError::Config`] with a human-readable message on
621    /// the first validation failure.
622    pub fn validate(self) -> Result<Validated<Self>, McpxError> {
623        self.check()?;
624        Ok(Validated(self))
625    }
626
627    /// Run the validation checks without consuming `self`. Used by
628    /// internal call sites (e.g. tests) that need to inspect a config
629    /// without taking ownership.
630    fn check(&self) -> Result<(), McpxError> {
631        // 1. admin <-> auth dependency. Mirrors the runtime check in
632        //    `build_app_router`: admin endpoints require an auth state,
633        //    which is built only when `auth` is `Some` *and* `enabled`.
634        if self.admin_enabled {
635            let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
636            if !auth_enabled {
637                return Err(McpxError::Config(
638                    "admin_enabled=true requires auth to be configured and enabled".into(),
639                ));
640            }
641        }
642
643        // 2. TLS cert / key must be paired
644        match (&self.tls_cert_path, &self.tls_key_path) {
645            (Some(_), None) => {
646                return Err(McpxError::Config(
647                    "tls_cert_path is set but tls_key_path is missing".into(),
648                ));
649            }
650            (None, Some(_)) => {
651                return Err(McpxError::Config(
652                    "tls_key_path is set but tls_cert_path is missing".into(),
653                ));
654            }
655            _ => {}
656        }
657
658        // 3. bind_addr parses
659        if self.bind_addr.parse::<SocketAddr>().is_err() {
660            return Err(McpxError::Config(format!(
661                "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
662                self.bind_addr
663            )));
664        }
665
666        // 4. public_url scheme
667        if let Some(ref url) = self.public_url
668            && !(url.starts_with("http://") || url.starts_with("https://"))
669        {
670            return Err(McpxError::Config(format!(
671                "public_url {url:?} must start with http:// or https://"
672            )));
673        }
674
675        // 5. allowed_origins scheme
676        for origin in &self.allowed_origins {
677            if !(origin.starts_with("http://") || origin.starts_with("https://")) {
678                return Err(McpxError::Config(format!(
679                    "allowed_origins entry {origin:?} must start with http:// or https://"
680                )));
681            }
682        }
683
684        // 6. max_request_body > 0
685        if self.max_request_body == 0 {
686            return Err(McpxError::Config(
687                "max_request_body must be greater than zero".into(),
688            ));
689        }
690
691        // 7. OAuth URL fields enforce HTTPS (unless `allow_http_oauth_urls`)
692        #[cfg(feature = "oauth")]
693        if let Some(auth_cfg) = &self.auth
694            && let Some(oauth_cfg) = &auth_cfg.oauth
695        {
696            oauth_cfg.validate()?;
697        }
698
699        Ok(())
700    }
701}
702
703/// Handle for hot-reloading server configuration without restart.
704///
705/// Obtained via [`McpServerConfig::on_reload_ready`].
706/// All swap operations are lock-free and wait-free -- in-flight requests
707/// finish with the old values while new requests see the update immediately.
708#[allow(
709    missing_debug_implementations,
710    reason = "contains Arc<AuthState> with non-Debug fields"
711)]
712pub struct ReloadHandle {
713    auth: Option<Arc<AuthState>>,
714    rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
715    crl_set: Option<Arc<CrlSet>>,
716}
717
718impl ReloadHandle {
719    /// Atomically replace the API key list used by the auth middleware.
720    pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
721        if let Some(ref auth) = self.auth {
722            auth.reload_keys(keys);
723        }
724    }
725
726    /// Atomically replace the RBAC policy used by the RBAC middleware.
727    pub fn reload_rbac(&self, policy: RbacPolicy) {
728        if let Some(ref rbac) = self.rbac {
729            rbac.store(Arc::new(policy));
730            tracing::info!("RBAC policy reloaded");
731        }
732    }
733
734    /// Force an immediate refresh of all cached mTLS CRLs.
735    ///
736    /// # Errors
737    ///
738    /// Returns an error if CRL refresh is unavailable or verifier rebuild fails.
739    pub async fn refresh_crls(&self) -> Result<(), McpxError> {
740        let Some(ref crl_set) = self.crl_set else {
741            return Err(McpxError::Config(
742                "CRL refresh requested but mTLS CRL support is not configured".into(),
743            ));
744        };
745
746        crl_set.force_refresh().await
747    }
748}
749
750/// Generic MCP HTTP server.
751///
752/// Wraps an axum server with `/healthz` and `/mcp` endpoints.
753/// When `tls_cert_path` and `tls_key_path` are both set, the server binds
754/// with TLS (rustls). Optionally supports mTLS client certificate auth.
755///
756/// # Errors
757///
758/// Returns an error if the TCP listener cannot bind, TLS config is invalid,
759/// or the server fails.
760// TODO(refactor): cognitive complexity reduced from 111/25 to 83/25 by
761// extracting `run_server` (serve-loop tail) and `install_oauth_proxy_routes`.
762// Remaining flow is a linear router builder: middleware layering, feature-
763// gated auth/RBAC wiring, and PRM/metrics installation. Further extraction
764// would require threading many `&mut Router` helpers and hurt readability
765// of the layer order (which is security-relevant and must stay visible).
766#[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
767/// Internal bundle of values produced by [`build_app_router`] and
768/// consumed by [`serve`] / [`serve_with_listener`] when driving the
769/// HTTP listener.
770struct AppRunParams {
771    /// TLS cert/key paths when TLS is configured.
772    tls_paths: Option<(PathBuf, PathBuf)>,
773    /// mTLS configuration when mutual-TLS auth is enabled.
774    mtls_config: Option<MtlsConfig>,
775    /// Graceful shutdown drain window.
776    shutdown_timeout: Duration,
777    /// Shared auth state used by hot-reload callbacks.
778    auth_state: Option<Arc<AuthState>>,
779    /// Hot-reloadable RBAC state used by reload callbacks.
780    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
781    /// Optional callback that receives the final [`ReloadHandle`].
782    on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
783    /// Server-internal cancellation token. Cancelled by [`run_server`]
784    /// once the shutdown trigger fires (so rmcp's child token also
785    /// fires, terminating in-flight MCP sessions).
786    ct: CancellationToken,
787    /// `"http"` or `"https"` -- used only for boot-time logging.
788    scheme: &'static str,
789    /// Server name -- used only for boot-time logging.
790    name: String,
791}
792
793/// Build the full application axum [`axum::Router`] (MCP route +
794/// middleware stack + admin + OAuth + health endpoints + security
795/// headers + CORS + compression + concurrency limit + origin check)
796/// and the [`AppRunParams`] needed to drive it.
797///
798/// This is the shared core of [`serve`] and [`serve_with_listener`].
799/// It performs *no* network I/O: callers are responsible for binding
800/// (or accepting a pre-bound) [`TcpListener`] and invoking
801/// [`run_server`].
802#[allow(
803    clippy::cognitive_complexity,
804    reason = "router assembly is intrinsically sequential; splitting harms readability"
805)]
806#[allow(
807    deprecated,
808    reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
809)]
810fn build_app_router<H, F>(
811    mut config: McpServerConfig,
812    handler_factory: F,
813) -> anyhow::Result<(axum::Router, AppRunParams)>
814where
815    H: ServerHandler + 'static,
816    F: Fn() -> H + Send + Sync + Clone + 'static,
817{
818    let ct = CancellationToken::new();
819
820    let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
821    tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
822
823    let mcp_service = StreamableHttpService::new(
824        move || Ok(handler_factory()),
825        {
826            let mut mgr = LocalSessionManager::default();
827            mgr.session_config.keep_alive = Some(config.session_idle_timeout);
828            mgr.into()
829        },
830        StreamableHttpServerConfig::default()
831            .with_allowed_hosts(allowed_hosts)
832            .with_sse_keep_alive(Some(config.sse_keep_alive))
833            .with_cancellation_token(ct.child_token()),
834    );
835
836    // Build the MCP route, optionally wrapped with auth and RBAC middleware.
837    let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
838
839    // Build auth state eagerly when auth is configured so we can wire both
840    // the auth middleware *and* the optional admin router against the same
841    // state. The middleware itself is installed further down in layer order.
842    let auth_state: Option<Arc<AuthState>> = match config.auth {
843        Some(ref auth_config) if auth_config.enabled => {
844            let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
845            let pre_auth_limiter = auth_config
846                .rate_limit
847                .as_ref()
848                .map(crate::auth::build_pre_auth_limiter);
849
850            #[cfg(feature = "oauth")]
851            let jwks_cache = auth_config
852                .oauth
853                .as_ref()
854                .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
855                .transpose()
856                .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
857
858            Some(Arc::new(AuthState {
859                api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
860                rate_limiter,
861                pre_auth_limiter,
862                #[cfg(feature = "oauth")]
863                jwks_cache,
864                seen_identities: std::sync::Mutex::new(std::collections::HashSet::new()),
865                counters: crate::auth::AuthCounters::default(),
866            }))
867        }
868        _ => None,
869    };
870
871    // Build the RBAC policy swap early so the admin router and the later
872    // RBAC middleware layer share the same hot-reloadable state.
873    let rbac_swap = Arc::new(ArcSwap::new(
874        config
875            .rbac
876            .clone()
877            .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
878    ));
879
880    // Optional /admin/* diagnostic routes. Merged BEFORE the
881    // body-limit/timeout/RBAC/origin/auth layers so all of them apply.
882    if config.admin_enabled {
883        let Some(ref auth_state_ref) = auth_state else {
884            return Err(anyhow::anyhow!(
885                "admin_enabled=true requires auth to be configured and enabled"
886            ));
887        };
888        let admin_state = crate::admin::AdminState {
889            started_at: std::time::Instant::now(),
890            name: config.name.clone(),
891            version: config.version.clone(),
892            auth: Some(Arc::clone(auth_state_ref)),
893            rbac: Arc::clone(&rbac_swap),
894        };
895        let admin_cfg = crate::admin::AdminConfig {
896            role: config.admin_role.clone(),
897        };
898        mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
899        tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
900    }
901
902    // ----- Middleware order (CRITICAL: read carefully) ------------------
903    //
904    // axum/tower applies layers **bottom-up** at runtime: the LAST layer
905    // added is the OUTERMOST (runs first on a request). To achieve a
906    // request-time flow of:
907    //
908    //   body-limit -> timeout -> auth -> rbac -> handler
909    //
910    // we add layers in the REVERSE order:
911    //
912    //   1. RBAC               (innermost, runs last before handler)
913    //   2. auth               (parses identity, sets extension for RBAC)
914    //   3. timeout            (bounds total request time)
915    //   4. body-limit         (outermost on /mcp; caps payload before
916    //                          anything else reads/buffers it)
917    //
918    // Origin validation is installed on the OUTER router (after the
919    // /mcp router is merged in), so it also protects /healthz, /readyz,
920    // /version, and any OAuth proxy endpoints.
921    //
922    // Rationale:
923    // - Body-limit must be outermost on /mcp so RBAC (which reads the
924    //   JSON-RPC body) cannot be DoS'd by a 100MB payload.
925    // - Auth must run before RBAC because RBAC consumes
926    //   `req.extensions().get::<AuthIdentity>()` to enforce per-role
927    //   policy.
928    // - Origin runs before auth so we reject cross-origin requests
929    //   without spending Argon2 cycles on unauthenticated callers.
930
931    // [1] RBAC + tool rate-limit layer (innermost; closest to handler).
932    // Always installed: even when RBAC is disabled, tool rate limiting may
933    // be active (MCP spec: servers MUST rate limit tool invocations).
934    {
935        let tool_limiter: Option<Arc<ToolRateLimiter>> =
936            config.tool_rate_limit.map(build_tool_rate_limiter);
937
938        if rbac_swap.load().is_enabled() {
939            tracing::info!("RBAC enforcement enabled on /mcp");
940        }
941        if let Some(limit) = config.tool_rate_limit {
942            tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
943        }
944
945        let rbac_for_mw = Arc::clone(&rbac_swap);
946        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
947            let p = rbac_for_mw.load_full();
948            let tl = tool_limiter.clone();
949            rbac_middleware(p, tl, req, next)
950        }));
951    }
952
953    // [2] Auth layer (runs before RBAC so AuthIdentity is in extensions).
954    if let Some(ref auth_config) = config.auth
955        && auth_config.enabled
956    {
957        let Some(ref state) = auth_state else {
958            return Err(anyhow::anyhow!("auth state missing despite enabled config"));
959        };
960
961        let methods: Vec<&str> = [
962            auth_config.mtls.is_some().then_some("mTLS"),
963            (!auth_config.api_keys.is_empty()).then_some("bearer"),
964            #[cfg(feature = "oauth")]
965            auth_config.oauth.is_some().then_some("oauth-jwt"),
966        ]
967        .into_iter()
968        .flatten()
969        .collect();
970
971        tracing::info!(
972            methods = %methods.join(", "),
973            api_keys = auth_config.api_keys.len(),
974            "auth enabled on /mcp"
975        );
976
977        let state_for_mw = Arc::clone(state);
978        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
979            let s = Arc::clone(&state_for_mw);
980            auth_middleware(s, req, next)
981        }));
982    }
983
984    // [3] Request timeout (returns 408 on expiry). Bounds total request
985    // duration including auth + handler.
986    mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
987        axum::http::StatusCode::REQUEST_TIMEOUT,
988        config.request_timeout,
989    ));
990
991    // [4] Request body size limit (OUTERMOST on /mcp). Prevents OOM /
992    // DoS from oversized payloads BEFORE any inner layer (auth, RBAC)
993    // attempts to buffer or parse the body.
994    mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
995        config.max_request_body,
996    ));
997
998    // Compute the effective allowed-origins list for the outer
999    // origin-check layer (installed on the merged router below). When
1000    // `allowed_origins` is empty but `public_url` is set, auto-derive
1001    // the origin from the public URL so MCP clients (e.g. Claude Code)
1002    // that send `Origin: <server-url>` are accepted without explicit
1003    // config.
1004    let mut effective_origins = config.allowed_origins.clone();
1005    if effective_origins.is_empty()
1006        && let Some(ref url) = config.public_url
1007    {
1008        // Origin = scheme + "://" + host (+ ":" + port if non-default).
1009        // Strip any path/query from the public URL.
1010        if let Some(scheme_end) = url.find("://") {
1011            let after_scheme = &url[scheme_end + 3..];
1012            let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1013            let origin = format!("{}{}", &url[..scheme_end + 3], &after_scheme[..host_end]);
1014            tracing::info!(
1015                %origin,
1016                "auto-derived allowed origin from public_url"
1017            );
1018            effective_origins.push(origin);
1019        }
1020    }
1021    let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1022    let cors_origins = Arc::clone(&allowed_origins);
1023    let log_request_headers = config.log_request_headers;
1024
1025    let readyz_route = if let Some(check) = config.readiness_check.take() {
1026        axum::routing::get(move || readyz(Arc::clone(&check)))
1027    } else {
1028        axum::routing::get(healthz)
1029    };
1030
1031    #[allow(unused_mut)] // mut needed when oauth feature adds PRM route
1032    let mut router = axum::Router::new()
1033        .route("/healthz", axum::routing::get(healthz))
1034        .route("/readyz", readyz_route)
1035        .route(
1036            "/version",
1037            axum::routing::get({
1038                // Pre-serialize the version payload once at router-build
1039                // time. The handler then serves a cheap `Arc::clone` of the
1040                // immutable bytes per request, avoiding `serde_json::Value`
1041                // allocation + serialization on every `/version` hit.
1042                let payload_bytes: Arc<[u8]> =
1043                    serialize_version_payload(&config.name, &config.version);
1044                move || {
1045                    let p = Arc::clone(&payload_bytes);
1046                    async move {
1047                        (
1048                            [(axum::http::header::CONTENT_TYPE, "application/json")],
1049                            p.to_vec(),
1050                        )
1051                    }
1052                }
1053            }),
1054        )
1055        .merge(mcp_router);
1056
1057    // Merge application-specific routes (bypass MCP auth/RBAC middleware).
1058    if let Some(extra) = config.extra_router.take() {
1059        router = router.merge(extra);
1060    }
1061
1062    // RFC 9728: Protected Resource Metadata endpoint.
1063    // When OAuth is configured, serve full metadata with authorization_servers.
1064    // Otherwise, serve a minimal document with just the resource URL and no
1065    // authorization_servers -- this tells MCP clients (e.g. Claude Code SDK)
1066    // that the server exists but does NOT require OAuth authentication,
1067    // preventing them from gating the connection behind a broken auth flow.
1068    let server_url = if let Some(ref url) = config.public_url {
1069        url.trim_end_matches('/').to_owned()
1070    } else {
1071        let prm_scheme = if config.tls_cert_path.is_some() {
1072            "https"
1073        } else {
1074            "http"
1075        };
1076        format!("{prm_scheme}://{}", config.bind_addr)
1077    };
1078    let resource_url = format!("{server_url}/mcp");
1079
1080    #[cfg(feature = "oauth")]
1081    let prm_metadata = if let Some(ref auth_config) = config.auth
1082        && let Some(ref oauth_config) = auth_config.oauth
1083    {
1084        crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1085    } else {
1086        serde_json::json!({ "resource": resource_url })
1087    };
1088    #[cfg(not(feature = "oauth"))]
1089    let prm_metadata = serde_json::json!({ "resource": resource_url });
1090
1091    router = router.route(
1092        "/.well-known/oauth-protected-resource",
1093        axum::routing::get(move || {
1094            let m = prm_metadata.clone();
1095            async move { axum::Json(m) }
1096        }),
1097    );
1098
1099    // OAuth 2.1 proxy endpoints: when an OAuth proxy is configured, expose
1100    // /authorize, /token, /register, and authorization server metadata so
1101    // MCP clients can perform Authorization Code + PKCE against the upstream
1102    // IdP (e.g. Keycloak) transparently.
1103    #[cfg(feature = "oauth")]
1104    if let Some(ref auth_config) = config.auth
1105        && let Some(ref oauth_config) = auth_config.oauth
1106        && oauth_config.proxy.is_some()
1107    {
1108        router = install_oauth_proxy_routes(router, &server_url, oauth_config)?;
1109    }
1110
1111    // OWASP security response headers (applied to all responses).
1112    // HSTS is conditional on TLS being configured.
1113    let is_tls = config.tls_cert_path.is_some();
1114    router = router.layer(axum::middleware::from_fn(move |req, next| {
1115        security_headers_middleware(is_tls, req, next)
1116    }));
1117
1118    // CORS preflight layer (required for browser-based MCP clients).
1119    // Uses the same effective origins as the origin check middleware
1120    // (including auto-derived origin from public_url).
1121    if !cors_origins.is_empty() {
1122        let cors = tower_http::cors::CorsLayer::new()
1123            .allow_origin(
1124                cors_origins
1125                    .iter()
1126                    .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1127                    .collect::<Vec<_>>(),
1128            )
1129            .allow_methods([
1130                axum::http::Method::GET,
1131                axum::http::Method::POST,
1132                axum::http::Method::OPTIONS,
1133            ])
1134            .allow_headers([
1135                axum::http::header::CONTENT_TYPE,
1136                axum::http::header::AUTHORIZATION,
1137            ]);
1138        router = router.layer(cors);
1139    }
1140
1141    // Optional response compression (gzip + brotli). Skips small bodies
1142    // to avoid overhead. Applied after CORS so preflight responses remain
1143    // uncompressed.
1144    if config.compression_enabled {
1145        use tower_http::compression::Predicate as _;
1146        let predicate = tower_http::compression::DefaultPredicate::new().and(
1147            tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1148        );
1149        router = router.layer(
1150            tower_http::compression::CompressionLayer::new()
1151                .gzip(true)
1152                .br(true)
1153                .compress_when(predicate),
1154        );
1155        tracing::info!(
1156            min_size = config.compression_min_size,
1157            "response compression enabled (gzip, br)"
1158        );
1159    }
1160
1161    // Optional global concurrency cap. `load_shed` converts the
1162    // `ConcurrencyLimit` back-pressure error into 503 instead of hanging.
1163    if let Some(max) = config.max_concurrent_requests {
1164        let overload_handler = tower::ServiceBuilder::new()
1165            .layer(axum::error_handling::HandleErrorLayer::new(
1166                |_err: tower::BoxError| async {
1167                    (
1168                        axum::http::StatusCode::SERVICE_UNAVAILABLE,
1169                        axum::Json(serde_json::json!({
1170                            "error": "overloaded",
1171                            "error_description": "server is at capacity, retry later"
1172                        })),
1173                    )
1174                },
1175            ))
1176            .layer(tower::load_shed::LoadShedLayer::new())
1177            .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1178        router = router.layer(overload_handler);
1179        tracing::info!(max, "global concurrency limit enabled");
1180    }
1181
1182    // JSON fallback for unmatched routes. Without this, axum returns
1183    // an empty-body 404 that breaks MCP clients (e.g. Claude Code SDK)
1184    // when they probe OAuth endpoints like /authorize or /token.
1185    router = router.fallback(|| async {
1186        (
1187            axum::http::StatusCode::NOT_FOUND,
1188            axum::Json(serde_json::json!({
1189                "error": "not_found",
1190                "error_description": "The requested endpoint does not exist"
1191            })),
1192        )
1193    });
1194
1195    // Prometheus metrics: recording middleware + separate listener.
1196    #[cfg(feature = "metrics")]
1197    if config.metrics_enabled {
1198        let metrics = Arc::new(
1199            crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1200        );
1201        let m = Arc::clone(&metrics);
1202        router = router.layer(axum::middleware::from_fn(
1203            move |req: Request<Body>, next: Next| {
1204                let m = Arc::clone(&m);
1205                metrics_middleware(m, req, next)
1206            },
1207        ));
1208        let metrics_bind = config.metrics_bind.clone();
1209        tokio::spawn(async move {
1210            if let Err(e) = crate::metrics::serve_metrics(metrics_bind, metrics).await {
1211                tracing::error!("metrics listener failed: {e}");
1212            }
1213        });
1214    }
1215
1216    // Origin validation layer (MCP spec: servers MUST validate the
1217    // Origin header to prevent DNS rebinding attacks). Installed as the
1218    // OUTERMOST layer on the OUTER router so it protects ALL routes
1219    // (`/mcp`, `/healthz`, `/readyz`, `/version`, OAuth proxy endpoints,
1220    // admin endpoints, extra_router, etc.) and runs BEFORE auth so we
1221    // reject cross-origin attackers without spending Argon2 cycles.
1222    //
1223    // Origin-less requests (e.g. server-to-server probes, curl, native
1224    // MCP clients) are permitted; only requests with an Origin header
1225    // that does not match `effective_origins` are rejected.
1226    router = router.layer(axum::middleware::from_fn(move |req, next| {
1227        let origins = Arc::clone(&allowed_origins);
1228        origin_check_middleware(origins, log_request_headers, req, next)
1229    }));
1230
1231    let scheme = if config.tls_cert_path.is_some() {
1232        "https"
1233    } else {
1234        "http"
1235    };
1236
1237    let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1238        (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1239        _ => None,
1240    };
1241    let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1242
1243    Ok((
1244        router,
1245        AppRunParams {
1246            tls_paths,
1247            mtls_config,
1248            shutdown_timeout: config.shutdown_timeout,
1249            auth_state,
1250            rbac_swap,
1251            on_reload_ready: config.on_reload_ready.take(),
1252            ct,
1253            scheme,
1254            name: config.name.clone(),
1255        },
1256    ))
1257}
1258
1259/// Run the MCP HTTP server, binding to `config.bind_addr` and serving
1260/// until an OS shutdown signal (Ctrl-C / SIGTERM) is received.
1261///
1262/// This is the standard entry point for production deployments. For
1263/// deterministic shutdown control (e.g. integration tests), see
1264/// [`serve_with_listener`].
1265///
1266/// The configuration must be validated first via
1267/// [`McpServerConfig::validate`], which returns a [`Validated`] proof
1268/// token. This typestate guarantees, at compile time, that the server
1269/// never starts with an invalid configuration.
1270///
1271/// # Errors
1272///
1273/// Returns [`McpxError::Startup`] if binding to `config.bind_addr`
1274/// fails, or if the underlying axum server returns an error.
1275pub async fn serve<H, F>(
1276    config: Validated<McpServerConfig>,
1277    handler_factory: F,
1278) -> Result<(), McpxError>
1279where
1280    H: ServerHandler + 'static,
1281    F: Fn() -> H + Send + Sync + Clone + 'static,
1282{
1283    let config = config.into_inner();
1284    #[allow(
1285        deprecated,
1286        reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1287    )]
1288    let bind_addr = config.bind_addr.clone();
1289    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1290
1291    let listener = TcpListener::bind(&bind_addr)
1292        .await
1293        .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1294    log_listening(&params.name, params.scheme, &bind_addr);
1295
1296    run_server(
1297        router,
1298        listener,
1299        params.tls_paths,
1300        params.mtls_config,
1301        params.shutdown_timeout,
1302        params.auth_state,
1303        params.rbac_swap,
1304        params.on_reload_ready,
1305        params.ct,
1306    )
1307    .await
1308    .map_err(anyhow_to_startup)
1309}
1310
1311/// Run the MCP HTTP server on a pre-bound [`TcpListener`], with optional
1312/// readiness signalling and external shutdown control.
1313///
1314/// This variant is intended for **deterministic integration tests** and
1315/// for embedders that need to bind the listening socket themselves
1316/// (e.g. systemd socket activation). Compared to [`serve`]:
1317///
1318/// * The caller passes a `TcpListener` that is already bound. This
1319///   eliminates the bind race in tests that previously required
1320///   poll-the-`/healthz`-loop start-up detection.
1321/// * `ready_tx`, when `Some`, receives the socket's
1322///   [`SocketAddr`] *after* the router is built and immediately before
1323///   the server starts accepting connections. Tests can `await` the
1324///   matching `oneshot::Receiver` to know exactly when it is safe to
1325///   issue requests.
1326/// * `shutdown`, when `Some`, gives the caller a
1327///   [`CancellationToken`] that triggers the same graceful-shutdown
1328///   path as a real OS signal. This avoids cross-platform issues with
1329///   sending real `SIGTERM` from tests on Windows.
1330///
1331/// All three optional parameters degrade gracefully: if `ready_tx` is
1332/// `None`, no signal is sent; if `shutdown` is `None`, the server only
1333/// stops on an OS signal (just like [`serve`]).
1334///
1335/// # Errors
1336///
1337/// Returns [`McpxError::Startup`] if router construction fails, if reading
1338/// the listener's `local_addr()` fails, or if the underlying axum
1339/// server returns an error.
1340pub async fn serve_with_listener<H, F>(
1341    listener: TcpListener,
1342    config: Validated<McpServerConfig>,
1343    handler_factory: F,
1344    ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1345    shutdown: Option<CancellationToken>,
1346) -> Result<(), McpxError>
1347where
1348    H: ServerHandler + 'static,
1349    F: Fn() -> H + Send + Sync + Clone + 'static,
1350{
1351    let config = config.into_inner();
1352    let local_addr = listener
1353        .local_addr()
1354        .map_err(|e| io_to_startup("listener.local_addr", e))?;
1355    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1356
1357    log_listening(&params.name, params.scheme, &local_addr.to_string());
1358
1359    // Forward external shutdown into the server-internal cancellation
1360    // token so `run_server`'s shutdown trigger picks it up alongside
1361    // any real OS signal.
1362    if let Some(external) = shutdown {
1363        let internal = params.ct.clone();
1364        tokio::spawn(async move {
1365            external.cancelled().await;
1366            internal.cancel();
1367        });
1368    }
1369
1370    // Signal readiness *after* the router is fully built and external
1371    // shutdown is wired, but *before* run_server takes ownership of
1372    // the listener. The receiver can immediately issue requests.
1373    if let Some(tx) = ready_tx {
1374        // Receiver may have been dropped (test gave up). That's fine.
1375        let _ = tx.send(local_addr);
1376    }
1377
1378    run_server(
1379        router,
1380        listener,
1381        params.tls_paths,
1382        params.mtls_config,
1383        params.shutdown_timeout,
1384        params.auth_state,
1385        params.rbac_swap,
1386        params.on_reload_ready,
1387        params.ct,
1388    )
1389    .await
1390    .map_err(anyhow_to_startup)
1391}
1392
1393/// Emit the standard "listening on …" log lines used by both
1394/// [`serve`] and [`serve_with_listener`].
1395#[allow(
1396    clippy::cognitive_complexity,
1397    reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1398)]
1399fn log_listening(name: &str, scheme: &str, addr: &str) {
1400    tracing::info!("{name} listening on {addr}");
1401    tracing::info!("  MCP endpoint: {scheme}://{addr}/mcp");
1402    tracing::info!("  Health check: {scheme}://{addr}/healthz");
1403    tracing::info!("  Readiness:   {scheme}://{addr}/readyz");
1404}
1405
1406/// Drive the chosen axum server variant (TLS or plain) with a graceful
1407/// shutdown window. Consumes the router and listener.
1408///
1409/// # Shutdown semantics
1410///
1411/// A single shutdown trigger (the FIRST of: OS signal via
1412/// `shutdown_signal()`, or external cancellation of `ct`) starts BOTH:
1413///
1414/// 1. axum's `.with_graceful_shutdown(...)` future, which stops
1415///    accepting new connections and waits for in-flight requests to
1416///    drain;
1417/// 2. a `tokio::time::sleep(shutdown_timeout)` race that forces exit if
1418///    drainage exceeds `shutdown_timeout`.
1419///
1420/// Previously this function awaited `shutdown_signal()` independently
1421/// in BOTH branches of a `tokio::select!`. Because `shutdown_signal`
1422/// resolves once per future and consumes one signal, the force-exit
1423/// timer was tied to a SECOND signal (a second SIGTERM the operator
1424/// would never send). Under a single SIGTERM the graceful drain could
1425/// hang indefinitely. The current implementation derives both branches
1426/// from a single shared trigger so the timeout race is anchored to the
1427/// FIRST (and only) signal.
1428#[allow(
1429    clippy::too_many_arguments,
1430    clippy::cognitive_complexity,
1431    reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1432)]
1433async fn run_server(
1434    router: axum::Router,
1435    listener: TcpListener,
1436    tls_paths: Option<(PathBuf, PathBuf)>,
1437    mtls_config: Option<MtlsConfig>,
1438    shutdown_timeout: Duration,
1439    auth_state: Option<Arc<AuthState>>,
1440    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1441    mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1442    ct: CancellationToken,
1443) -> anyhow::Result<()> {
1444    // `shutdown_trigger` fires when the FIRST source resolves: either
1445    // an OS signal (Ctrl-C / SIGTERM) or external cancellation of `ct`
1446    // (which the test harness uses for deterministic shutdown).
1447    let shutdown_trigger = CancellationToken::new();
1448    {
1449        let trigger = shutdown_trigger.clone();
1450        let parent = ct.clone();
1451        tokio::spawn(async move {
1452            tokio::select! {
1453                () = shutdown_signal() => {}
1454                () = parent.cancelled() => {}
1455            }
1456            trigger.cancel();
1457        });
1458    }
1459
1460    let graceful = {
1461        let trigger = shutdown_trigger.clone();
1462        let ct = ct.clone();
1463        async move {
1464            trigger.cancelled().await;
1465            tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1466            ct.cancel();
1467        }
1468    };
1469
1470    let force_exit_timer = {
1471        let trigger = shutdown_trigger.clone();
1472        async move {
1473            trigger.cancelled().await;
1474            tokio::time::sleep(shutdown_timeout).await;
1475        }
1476    };
1477
1478    if let Some((cert_path, key_path)) = tls_paths {
1479        let crl_set = if let Some(mtls) = mtls_config.as_ref()
1480            && mtls.crl_enabled
1481        {
1482            let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1483            let (crl_set, discover_rx) =
1484                mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1485                    .await
1486                    .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1487            tokio::spawn(mtls_revocation::run_crl_refresher(
1488                Arc::clone(&crl_set),
1489                discover_rx,
1490                ct.clone(),
1491            ));
1492            Some(crl_set)
1493        } else {
1494            None
1495        };
1496
1497        if let Some(cb) = on_reload_ready.take() {
1498            cb(ReloadHandle {
1499                auth: auth_state.clone(),
1500                rbac: Some(Arc::clone(&rbac_swap)),
1501                crl_set: crl_set.clone(),
1502            });
1503        }
1504
1505        let tls_listener = TlsListener::new(
1506            listener,
1507            &cert_path,
1508            &key_path,
1509            mtls_config.as_ref(),
1510            crl_set,
1511        )?;
1512        let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1513        tokio::select! {
1514            result = axum::serve(tls_listener, make_svc)
1515                .with_graceful_shutdown(graceful) => { result?; }
1516            () = force_exit_timer => {
1517                tracing::warn!("shutdown timeout exceeded, forcing exit");
1518            }
1519        }
1520    } else {
1521        if let Some(cb) = on_reload_ready.take() {
1522            cb(ReloadHandle {
1523                auth: auth_state,
1524                rbac: Some(rbac_swap),
1525                crl_set: None,
1526            });
1527        }
1528
1529        let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1530        tokio::select! {
1531            result = axum::serve(listener, make_svc)
1532                .with_graceful_shutdown(graceful) => { result?; }
1533            () = force_exit_timer => {
1534                tracing::warn!("shutdown timeout exceeded, forcing exit");
1535            }
1536        }
1537    }
1538
1539    Ok(())
1540}
1541
1542/// Install the OAuth 2.1 proxy endpoints (`/authorize`, `/token`,
1543/// `/register`, and authorization server metadata) on `router`. The
1544/// caller must ensure `oauth_config.proxy` is `Some`.
1545///
1546/// # Errors
1547///
1548/// Returns [`McpxError::Startup`] if the shared
1549/// [`crate::oauth::OauthHttpClient`] cannot be initialized.
1550#[cfg(feature = "oauth")]
1551fn install_oauth_proxy_routes(
1552    router: axum::Router,
1553    server_url: &str,
1554    oauth_config: &crate::oauth::OAuthConfig,
1555) -> Result<axum::Router, McpxError> {
1556    let Some(ref proxy) = oauth_config.proxy else {
1557        return Ok(router);
1558    };
1559
1560    // Single shared HTTP client for all proxy endpoints. Cloning is
1561    // cheap (refcounted) and shares the underlying connection pool.
1562    let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1563
1564    let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1565    let router = router.route(
1566        "/.well-known/oauth-authorization-server",
1567        axum::routing::get(move || {
1568            let m = asm.clone();
1569            async move { axum::Json(m) }
1570        }),
1571    );
1572
1573    let proxy_authorize = proxy.clone();
1574    let router = router.route(
1575        "/authorize",
1576        axum::routing::get(
1577            move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1578                let p = proxy_authorize.clone();
1579                async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1580            },
1581        ),
1582    );
1583
1584    let proxy_token = proxy.clone();
1585    let token_http = http.clone();
1586    let router = router.route(
1587        "/token",
1588        axum::routing::post(move |body: String| {
1589            let p = proxy_token.clone();
1590            let h = token_http.clone();
1591            async move { crate::oauth::handle_token(&h, &p, &body).await }
1592        }),
1593    );
1594
1595    let proxy_register = proxy.clone();
1596    let router = router.route(
1597        "/register",
1598        axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1599            let p = proxy_register;
1600            async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1601        }),
1602    );
1603
1604    let router = if proxy.expose_admin_endpoints && proxy.introspection_url.is_some() {
1605        let proxy_introspect = proxy.clone();
1606        let introspect_http = http.clone();
1607        router.route(
1608            "/introspect",
1609            axum::routing::post(move |body: String| {
1610                let p = proxy_introspect.clone();
1611                let h = introspect_http.clone();
1612                async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1613            }),
1614        )
1615    } else {
1616        router
1617    };
1618
1619    let router = if proxy.expose_admin_endpoints && proxy.revocation_url.is_some() {
1620        let proxy_revoke = proxy.clone();
1621        let revoke_http = http;
1622        router.route(
1623            "/revoke",
1624            axum::routing::post(move |body: String| {
1625                let p = proxy_revoke.clone();
1626                let h = revoke_http.clone();
1627                async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1628            }),
1629        )
1630    } else {
1631        router
1632    };
1633
1634    tracing::info!(
1635        introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1636        revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1637        "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1638    );
1639    Ok(router)
1640}
1641
1642/// Build the host allow-list for rmcp's DNS rebinding protection.
1643///
1644/// Includes loopback hosts by default, then augments with host/authority
1645/// derived from `public_url` and the server bind address.
1646fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
1647    let mut hosts = vec![
1648        "localhost".to_owned(),
1649        "127.0.0.1".to_owned(),
1650        "::1".to_owned(),
1651    ];
1652
1653    if let Some(url) = public_url
1654        && let Ok(uri) = url.parse::<axum::http::Uri>()
1655        && let Some(authority) = uri.authority()
1656    {
1657        let host = authority.host().to_owned();
1658        if !hosts.iter().any(|h| h == &host) {
1659            hosts.push(host);
1660        }
1661
1662        let authority = authority.as_str().to_owned();
1663        if !hosts.iter().any(|h| h == &authority) {
1664            hosts.push(authority);
1665        }
1666    }
1667
1668    if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
1669        && let Some(authority) = uri.authority()
1670    {
1671        let host = authority.host().to_owned();
1672        if !hosts.iter().any(|h| h == &host) {
1673            hosts.push(host);
1674        }
1675
1676        let authority = authority.as_str().to_owned();
1677        if !hosts.iter().any(|h| h == &authority) {
1678            hosts.push(authority);
1679        }
1680    }
1681
1682    hosts
1683}
1684
1685// - TLS support -
1686
1687/// Implement axum's `Connected` trait for `TlsConnInfo` so that
1688/// `ConnectInfo<TlsConnInfo>` is available in middleware when serving
1689/// over our custom `TlsListener`.
1690///
1691/// The identity is read directly from the wrapping
1692/// [`AuthenticatedTlsStream`], which guarantees one-to-one correspondence
1693/// between the TLS connection and its mTLS identity. This eliminates the
1694/// previous shared-map approach which was vulnerable to ephemeral-port
1695/// reuse races (an unauthenticated reconnection from the same `(IP, port)`
1696/// pair could alias a stale entry).
1697impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
1698    for TlsConnInfo
1699{
1700    fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
1701        let addr = *target.remote_addr();
1702        let identity = target.io().identity().cloned();
1703        TlsConnInfo::new(addr, identity)
1704    }
1705}
1706
1707/// A TLS-wrapping listener that implements axum's `Listener` trait.
1708///
1709/// When mTLS is configured, verifies client certificates against the
1710/// configured CA and extracts the client identity at handshake time.
1711/// The extracted identity is bound to the connection itself via the
1712/// returned [`AuthenticatedTlsStream`], so it is impossible for an
1713/// unrelated connection to observe it.
1714struct TlsListener {
1715    inner: TcpListener,
1716    acceptor: tokio_rustls::TlsAcceptor,
1717    mtls_default_role: String,
1718}
1719
1720impl TlsListener {
1721    fn new(
1722        inner: TcpListener,
1723        cert_path: &Path,
1724        key_path: &Path,
1725        mtls_config: Option<&MtlsConfig>,
1726        crl_set: Option<Arc<CrlSet>>,
1727    ) -> anyhow::Result<Self> {
1728        // Install the ring crypto provider (ok to call multiple times).
1729        rustls::crypto::ring::default_provider()
1730            .install_default()
1731            .ok();
1732
1733        let certs = load_certs(cert_path)?;
1734        let key = load_key(key_path)?;
1735
1736        let mtls_default_role;
1737
1738        let tls_config = if let Some(mtls) = mtls_config {
1739            mtls_default_role = mtls.default_role.clone();
1740            let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
1741            {
1742                let Some(crl_set) = crl_set else {
1743                    return Err(anyhow::anyhow!(
1744                        "mTLS CRL verifier requested but CRL state was not initialized"
1745                    ));
1746                };
1747                Arc::new(DynamicClientCertVerifier::new(crl_set))
1748            } else {
1749                let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
1750                if mtls.required {
1751                    rustls::server::WebPkiClientVerifier::builder(root_store)
1752                        .build()
1753                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1754                } else {
1755                    rustls::server::WebPkiClientVerifier::builder(root_store)
1756                        .allow_unauthenticated()
1757                        .build()
1758                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1759                }
1760            };
1761
1762            tracing::info!(
1763                ca = %mtls.ca_cert_path.display(),
1764                required = mtls.required,
1765                crl_enabled = mtls.crl_enabled,
1766                "mTLS client auth configured"
1767            );
1768
1769            rustls::ServerConfig::builder_with_protocol_versions(&[
1770                &rustls::version::TLS12,
1771                &rustls::version::TLS13,
1772            ])
1773            .with_client_cert_verifier(verifier)
1774            .with_single_cert(certs, key)?
1775        } else {
1776            mtls_default_role = "viewer".to_owned();
1777            rustls::ServerConfig::builder_with_protocol_versions(&[
1778                &rustls::version::TLS12,
1779                &rustls::version::TLS13,
1780            ])
1781            .with_no_client_auth()
1782            .with_single_cert(certs, key)?
1783        };
1784
1785        let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
1786        tracing::info!(
1787            "TLS enabled (cert: {}, key: {})",
1788            cert_path.display(),
1789            key_path.display()
1790        );
1791        Ok(Self {
1792            inner,
1793            acceptor,
1794            mtls_default_role,
1795        })
1796    }
1797
1798    /// Extract the mTLS client cert identity from a completed TLS handshake.
1799    /// Returns `None` if no client certificate was presented or if the
1800    /// certificate could not be parsed into an [`AuthIdentity`].
1801    fn extract_handshake_identity(
1802        tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1803        default_role: &str,
1804        addr: SocketAddr,
1805    ) -> Option<AuthIdentity> {
1806        let (_, server_conn) = tls_stream.get_ref();
1807        let cert_der = server_conn.peer_certificates()?.first()?;
1808        let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
1809        tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
1810        Some(id)
1811    }
1812}
1813
1814/// A TLS stream paired with the mTLS identity extracted at handshake time.
1815///
1816/// Wraps [`tokio_rustls::server::TlsStream`] so the verified client
1817/// identity travels with the connection itself. This replaces the previous
1818/// shared `MtlsIdentities` map, eliminating the
1819/// `(SocketAddr) -> AuthIdentity` aliasing risk caused by ephemeral-port
1820/// reuse and removing the need for an LRU eviction policy.
1821///
1822/// The wrapper is `Unpin` (its inner stream is `Unpin` because
1823/// [`tokio::net::TcpStream`] is `Unpin`), so `AsyncRead`/`AsyncWrite`
1824/// delegation uses safe pin projection via `Pin::new(&mut self.inner)`.
1825pub(crate) struct AuthenticatedTlsStream {
1826    inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1827    identity: Option<AuthIdentity>,
1828}
1829
1830impl AuthenticatedTlsStream {
1831    /// Returns the verified mTLS client identity, if any.
1832    #[must_use]
1833    pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
1834        self.identity.as_ref()
1835    }
1836}
1837
1838impl std::fmt::Debug for AuthenticatedTlsStream {
1839    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1840        f.debug_struct("AuthenticatedTlsStream")
1841            .field("identity", &self.identity.as_ref().map(|id| &id.name))
1842            .finish_non_exhaustive()
1843    }
1844}
1845
1846impl tokio::io::AsyncRead for AuthenticatedTlsStream {
1847    fn poll_read(
1848        mut self: Pin<&mut Self>,
1849        cx: &mut std::task::Context<'_>,
1850        buf: &mut tokio::io::ReadBuf<'_>,
1851    ) -> std::task::Poll<std::io::Result<()>> {
1852        Pin::new(&mut self.inner).poll_read(cx, buf)
1853    }
1854}
1855
1856impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
1857    fn poll_write(
1858        mut self: Pin<&mut Self>,
1859        cx: &mut std::task::Context<'_>,
1860        buf: &[u8],
1861    ) -> std::task::Poll<std::io::Result<usize>> {
1862        Pin::new(&mut self.inner).poll_write(cx, buf)
1863    }
1864
1865    fn poll_flush(
1866        mut self: Pin<&mut Self>,
1867        cx: &mut std::task::Context<'_>,
1868    ) -> std::task::Poll<std::io::Result<()>> {
1869        Pin::new(&mut self.inner).poll_flush(cx)
1870    }
1871
1872    fn poll_shutdown(
1873        mut self: Pin<&mut Self>,
1874        cx: &mut std::task::Context<'_>,
1875    ) -> std::task::Poll<std::io::Result<()>> {
1876        Pin::new(&mut self.inner).poll_shutdown(cx)
1877    }
1878
1879    fn poll_write_vectored(
1880        mut self: Pin<&mut Self>,
1881        cx: &mut std::task::Context<'_>,
1882        bufs: &[std::io::IoSlice<'_>],
1883    ) -> std::task::Poll<std::io::Result<usize>> {
1884        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
1885    }
1886
1887    fn is_write_vectored(&self) -> bool {
1888        self.inner.is_write_vectored()
1889    }
1890}
1891
1892impl axum::serve::Listener for TlsListener {
1893    type Io = AuthenticatedTlsStream;
1894    type Addr = SocketAddr;
1895
1896    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
1897        loop {
1898            let (stream, addr) = match self.inner.accept().await {
1899                Ok(pair) => pair,
1900                Err(e) => {
1901                    tracing::debug!("TCP accept error: {e}");
1902                    continue;
1903                }
1904            };
1905            let tls_stream = match self.acceptor.accept(stream).await {
1906                Ok(s) => s,
1907                Err(e) => {
1908                    tracing::debug!("TLS handshake failed from {addr}: {e}");
1909                    continue;
1910                }
1911            };
1912            let identity =
1913                Self::extract_handshake_identity(&tls_stream, &self.mtls_default_role, addr);
1914            let wrapped = AuthenticatedTlsStream {
1915                inner: tls_stream,
1916                identity,
1917            };
1918            return (wrapped, addr);
1919        }
1920    }
1921
1922    fn local_addr(&self) -> std::io::Result<Self::Addr> {
1923        self.inner.local_addr()
1924    }
1925}
1926
1927fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
1928    use rustls::pki_types::pem::PemObject;
1929    let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
1930        .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
1931        .collect::<Result<_, _>>()
1932        .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
1933    anyhow::ensure!(
1934        !certs.is_empty(),
1935        "no certificates found in {}",
1936        path.display()
1937    );
1938    Ok(certs)
1939}
1940
1941fn load_client_auth_roots(
1942    path: &Path,
1943) -> anyhow::Result<(
1944    Vec<rustls::pki_types::CertificateDer<'static>>,
1945    Arc<RootCertStore>,
1946)> {
1947    let ca_certs = load_certs(path)?;
1948    let mut root_store = RootCertStore::empty();
1949    for cert in &ca_certs {
1950        root_store
1951            .add(cert.clone())
1952            .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
1953    }
1954
1955    Ok((ca_certs, Arc::new(root_store)))
1956}
1957
1958fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
1959    use rustls::pki_types::pem::PemObject;
1960    rustls::pki_types::PrivateKeyDer::from_pem_file(path)
1961        .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
1962}
1963
1964#[allow(clippy::unused_async)]
1965async fn healthz() -> impl IntoResponse {
1966    axum::Json(serde_json::json!({
1967        "status": "ok",
1968    }))
1969}
1970
1971/// Build the `/version` JSON payload for a given server name and version.
1972///
1973/// Build metadata (`build_git_sha`, `build_timestamp`, `rust_version`) is
1974/// read at compile time from the `MCPX_BUILD_SHA`, `MCPX_BUILD_TIME`, and
1975/// `MCPX_RUSTC_VERSION` env vars. Unset values resolve to `"unknown"`.
1976fn version_payload(name: &str, version: &str) -> serde_json::Value {
1977    serde_json::json!({
1978        "name": name,
1979        "version": version,
1980        "build_git_sha": option_env!("MCPX_BUILD_SHA").unwrap_or("unknown"),
1981        "build_timestamp": option_env!("MCPX_BUILD_TIME").unwrap_or("unknown"),
1982        "rust_version": option_env!("MCPX_RUSTC_VERSION").unwrap_or("unknown"),
1983        "mcpx_version": env!("CARGO_PKG_VERSION"),
1984    })
1985}
1986
1987/// Pre-serialize the `/version` payload to immutable bytes.
1988///
1989/// This is called once at router-build time so per-request handling can
1990/// reuse a cheap `Arc<[u8]>` clone instead of re-serializing a
1991/// [`serde_json::Value`] on every hit.
1992///
1993/// Serialization of a flat `serde_json::Value` of static-string fields
1994/// cannot fail in practice; the fallback to `b"{}"` exists only to
1995/// satisfy the crate-wide `unwrap_used` / `expect_used` lint policy.
1996fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
1997    let value = version_payload(name, version);
1998    serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
1999}
2000
2001async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2002    let status = check().await;
2003    let ready = status
2004        .get("ready")
2005        .and_then(serde_json::Value::as_bool)
2006        .unwrap_or(false);
2007    let code = if ready {
2008        axum::http::StatusCode::OK
2009    } else {
2010        axum::http::StatusCode::SERVICE_UNAVAILABLE
2011    };
2012    (code, axum::Json(status))
2013}
2014
2015/// Wait for SIGINT (ctrl-c) or SIGTERM (container stop).
2016///
2017/// On non-Unix platforms, only SIGINT is handled.
2018async fn shutdown_signal() {
2019    let ctrl_c = tokio::signal::ctrl_c();
2020
2021    #[cfg(unix)]
2022    {
2023        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2024            Ok(mut term) => {
2025                tokio::select! {
2026                    _ = ctrl_c => {}
2027                    _ = term.recv() => {}
2028                }
2029            }
2030            Err(e) => {
2031                tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2032                ctrl_c.await.ok();
2033            }
2034        }
2035    }
2036
2037    #[cfg(not(unix))]
2038    {
2039        ctrl_c.await.ok();
2040    }
2041}
2042
2043// -- Origin validation (MCP 2025-11-25 spec, section 2.0.1) --
2044
2045/// Middleware that validates the `Origin` header on incoming HTTP requests.
2046///
2047/// Record HTTP request metrics (method, path, status, duration).
2048#[cfg(feature = "metrics")]
2049async fn metrics_middleware(
2050    metrics: Arc<crate::metrics::McpMetrics>,
2051    req: Request<Body>,
2052    next: Next,
2053) -> axum::response::Response {
2054    let method = req.method().to_string();
2055    let path = req.uri().path().to_owned();
2056    let start = std::time::Instant::now();
2057
2058    let response = next.run(req).await;
2059
2060    let status = response.status().as_u16().to_string();
2061    let duration = start.elapsed().as_secs_f64();
2062
2063    metrics
2064        .http_requests_total
2065        .with_label_values(&[&method, &path, &status])
2066        .inc();
2067    metrics
2068        .http_request_duration_seconds
2069        .with_label_values(&[&method, &path])
2070        .observe(duration);
2071
2072    response
2073}
2074
2075/// OWASP security header hardening applied to every response.
2076///
2077/// Sets: `X-Content-Type-Options`, `X-Frame-Options`, `Cache-Control`,
2078/// `Referrer-Policy`, `Cross-Origin-Opener-Policy`, `Cross-Origin-Resource-Policy`,
2079/// `Cross-Origin-Embedder-Policy`, `Permissions-Policy`,
2080/// `X-Permitted-Cross-Domain-Policies`, `Content-Security-Policy`,
2081/// `X-DNS-Prefetch-Control`, and (when TLS is active) `Strict-Transport-Security`.
2082async fn security_headers_middleware(
2083    is_tls: bool,
2084    req: Request<Body>,
2085    next: Next,
2086) -> axum::response::Response {
2087    use axum::http::{HeaderName, HeaderValue, header};
2088
2089    let mut resp = next.run(req).await;
2090    let headers = resp.headers_mut();
2091
2092    // Strip server identity headers to reduce information leakage.
2093    headers.remove(header::SERVER);
2094    headers.remove(HeaderName::from_static("x-powered-by"));
2095
2096    headers.insert(
2097        header::X_CONTENT_TYPE_OPTIONS,
2098        HeaderValue::from_static("nosniff"),
2099    );
2100    headers.insert(header::X_FRAME_OPTIONS, HeaderValue::from_static("deny"));
2101    headers.insert(
2102        header::CACHE_CONTROL,
2103        HeaderValue::from_static("no-store, max-age=0"),
2104    );
2105    headers.insert(
2106        header::REFERRER_POLICY,
2107        HeaderValue::from_static("no-referrer"),
2108    );
2109    headers.insert(
2110        HeaderName::from_static("cross-origin-opener-policy"),
2111        HeaderValue::from_static("same-origin"),
2112    );
2113    headers.insert(
2114        HeaderName::from_static("cross-origin-resource-policy"),
2115        HeaderValue::from_static("same-origin"),
2116    );
2117    headers.insert(
2118        HeaderName::from_static("cross-origin-embedder-policy"),
2119        HeaderValue::from_static("require-corp"),
2120    );
2121    headers.insert(
2122        HeaderName::from_static("permissions-policy"),
2123        HeaderValue::from_static("accelerometer=(), camera=(), geolocation=(), microphone=()"),
2124    );
2125    headers.insert(
2126        HeaderName::from_static("x-permitted-cross-domain-policies"),
2127        HeaderValue::from_static("none"),
2128    );
2129    headers.insert(
2130        HeaderName::from_static("content-security-policy"),
2131        HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'"),
2132    );
2133    headers.insert(
2134        HeaderName::from_static("x-dns-prefetch-control"),
2135        HeaderValue::from_static("off"),
2136    );
2137
2138    if is_tls {
2139        headers.insert(
2140            header::STRICT_TRANSPORT_SECURITY,
2141            HeaderValue::from_static("max-age=63072000; includeSubDomains"),
2142        );
2143    }
2144
2145    resp
2146}
2147
2148/// Per the MCP spec: if the Origin header is present and its value is not in
2149/// the allowed list, respond with 403 Forbidden. Requests without an Origin
2150/// header are allowed through (e.g. non-browser clients like curl, SDKs).
2151async fn origin_check_middleware(
2152    allowed: Arc<[String]>,
2153    log_request_headers: bool,
2154    req: Request<Body>,
2155    next: Next,
2156) -> axum::response::Response {
2157    let method = req.method().clone();
2158    let path = req.uri().path().to_owned();
2159
2160    log_incoming_request(&method, &path, req.headers(), log_request_headers);
2161
2162    if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2163        let origin_str = origin.to_str().unwrap_or("");
2164        if !allowed.iter().any(|a| a == origin_str) {
2165            tracing::warn!(
2166                origin = origin_str,
2167                %method,
2168                %path,
2169                allowed = ?&*allowed,
2170                "rejected request: Origin not allowed"
2171            );
2172            return (
2173                axum::http::StatusCode::FORBIDDEN,
2174                "Forbidden: Origin not allowed",
2175            )
2176                .into_response();
2177        }
2178    }
2179    next.run(req).await
2180}
2181
2182/// Emit a DEBUG log for an incoming request, optionally including the full
2183/// (redacted) header set.
2184fn log_incoming_request(
2185    method: &axum::http::Method,
2186    path: &str,
2187    headers: &axum::http::HeaderMap,
2188    log_request_headers: bool,
2189) {
2190    if log_request_headers {
2191        tracing::debug!(
2192            %method,
2193            %path,
2194            headers = %format_request_headers_for_log(headers),
2195            "incoming request"
2196        );
2197    } else {
2198        tracing::debug!(%method, %path, "incoming request");
2199    }
2200}
2201
2202fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2203    headers
2204        .iter()
2205        .map(|(k, v)| {
2206            let name = k.as_str();
2207            if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2208                format!("{name}: [REDACTED]")
2209            } else {
2210                format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2211            }
2212        })
2213        .collect::<Vec<_>>()
2214        .join(", ")
2215}
2216
2217// -- stdio transport --
2218
2219/// Serve an MCP server over stdin/stdout (stdio transport).
2220///
2221/// # Security warnings
2222///
2223/// - **No authentication**: the parent process has full, unrestricted access.
2224/// - **No RBAC**: all tools are available regardless of policy.
2225/// - **No TLS**: messages travel over OS pipes in plaintext.
2226/// - **Single client**: only the parent process can connect.
2227/// - **No Origin validation**: not applicable to stdio.
2228///
2229/// Use this only when the MCP client spawns the server as a trusted subprocess
2230/// (e.g. Claude Desktop, VS Code Copilot). For network-accessible deployments,
2231/// use `serve()` (Streamable HTTP) instead.
2232///
2233/// # Errors
2234///
2235/// Returns [`McpxError::Startup`] if the handler fails to initialize or the
2236/// transport disconnects unexpectedly.
2237// NOTE: reported complexity 32/25 is driven entirely by `tracing::*!`
2238// macro expansion in this 18-line function (info/warn/info + two matches).
2239// There is nothing meaningful to extract; the allow stays.
2240#[allow(clippy::cognitive_complexity)]
2241pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2242where
2243    H: ServerHandler + 'static,
2244{
2245    use rmcp::ServiceExt as _;
2246
2247    tracing::info!("stdio transport: serving on stdin/stdout");
2248    tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2249
2250    let transport = rmcp::transport::io::stdio();
2251
2252    let service = handler
2253        .serve(transport)
2254        .await
2255        .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2256
2257    if let Err(e) = service.waiting().await {
2258        tracing::warn!(error = %e, "stdio session ended with error");
2259    }
2260    tracing::info!("stdio session ended");
2261    Ok(())
2262}
2263
2264#[cfg(test)]
2265mod tests {
2266    #![allow(
2267        clippy::unwrap_used,
2268        clippy::expect_used,
2269        clippy::panic,
2270        clippy::indexing_slicing,
2271        clippy::unwrap_in_result,
2272        clippy::print_stdout,
2273        clippy::print_stderr,
2274        deprecated,
2275        reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2276    )]
2277    use std::sync::Arc;
2278
2279    use axum::{
2280        body::Body,
2281        http::{Request, StatusCode, header},
2282        response::IntoResponse,
2283    };
2284    use http_body_util::BodyExt;
2285    use tower::ServiceExt as _;
2286
2287    use super::*;
2288
2289    // -- McpServerConfig --
2290
2291    #[test]
2292    fn server_config_new_defaults() {
2293        let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2294        assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2295        assert_eq!(cfg.name, "test-server");
2296        assert_eq!(cfg.version, "1.0.0");
2297        assert!(cfg.tls_cert_path.is_none());
2298        assert!(cfg.tls_key_path.is_none());
2299        assert!(cfg.auth.is_none());
2300        assert!(cfg.rbac.is_none());
2301        assert!(cfg.allowed_origins.is_empty());
2302        assert!(cfg.tool_rate_limit.is_none());
2303        assert!(cfg.readiness_check.is_none());
2304        assert_eq!(cfg.max_request_body, 1024 * 1024);
2305        assert_eq!(cfg.request_timeout, Duration::from_mins(2));
2306        assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
2307        assert!(!cfg.log_request_headers);
2308    }
2309
2310    #[test]
2311    fn validate_consumes_and_proves() {
2312        // Valid config -> Validated wrapper, original is consumed.
2313        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2314        let validated = cfg.validate().expect("valid config");
2315        // Deref gives read-only access to inner fields.
2316        assert_eq!(validated.name, "test-server");
2317        // into_inner recovers the raw value.
2318        let raw = validated.into_inner();
2319        assert_eq!(raw.name, "test-server");
2320
2321        // Invalid config (zero max_request_body) -> Err.
2322        let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2323        bad.max_request_body = 0;
2324        assert!(bad.validate().is_err(), "zero body cap must fail validate");
2325    }
2326
2327    #[test]
2328    fn derive_allowed_hosts_includes_public_host() {
2329        let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
2330        assert!(
2331            hosts.iter().any(|h| h == "mcp.example.com"),
2332            "public_url host must be allowed"
2333        );
2334    }
2335
2336    #[test]
2337    fn derive_allowed_hosts_includes_bind_authority() {
2338        let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
2339        assert!(
2340            hosts.iter().any(|h| h == "127.0.0.1"),
2341            "bind host must be allowed"
2342        );
2343        assert!(
2344            hosts.iter().any(|h| h == "127.0.0.1:8080"),
2345            "bind authority must be allowed"
2346        );
2347    }
2348
2349    // -- healthz --
2350
2351    #[tokio::test]
2352    async fn healthz_returns_ok_json() {
2353        let resp = healthz().await.into_response();
2354        assert_eq!(resp.status(), StatusCode::OK);
2355        let body = resp.into_body().collect().await.unwrap().to_bytes();
2356        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2357        assert_eq!(json["status"], "ok");
2358        assert!(
2359            json.get("name").is_none(),
2360            "healthz must not expose server name"
2361        );
2362        assert!(
2363            json.get("version").is_none(),
2364            "healthz must not expose version"
2365        );
2366    }
2367
2368    // -- readyz --
2369
2370    #[tokio::test]
2371    async fn readyz_returns_ok_when_ready() {
2372        let check: ReadinessCheck =
2373            Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
2374        let resp = readyz(check).await.into_response();
2375        assert_eq!(resp.status(), StatusCode::OK);
2376        let body = resp.into_body().collect().await.unwrap().to_bytes();
2377        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2378        assert_eq!(json["ready"], true);
2379        assert!(
2380            json.get("name").is_none(),
2381            "readyz must not expose server name"
2382        );
2383        assert!(
2384            json.get("version").is_none(),
2385            "readyz must not expose version"
2386        );
2387        assert_eq!(json["db"], "connected");
2388    }
2389
2390    #[tokio::test]
2391    async fn readyz_returns_503_when_not_ready() {
2392        let check: ReadinessCheck =
2393            Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
2394        let resp = readyz(check).await.into_response();
2395        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2396    }
2397
2398    #[tokio::test]
2399    async fn readyz_returns_503_when_ready_missing() {
2400        let check: ReadinessCheck =
2401            Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
2402        let resp = readyz(check).await.into_response();
2403        // Missing "ready" field defaults to false -> 503
2404        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2405    }
2406
2407    // -- origin_check_middleware --
2408
2409    /// Build a test router with origin check middleware and a simple handler.
2410    fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
2411        let allowed: Arc<[String]> = Arc::from(origins);
2412        axum::Router::new()
2413            .route("/test", axum::routing::get(|| async { "ok" }))
2414            .layer(axum::middleware::from_fn(move |req, next| {
2415                let a = Arc::clone(&allowed);
2416                origin_check_middleware(a, log_request_headers, req, next)
2417            }))
2418    }
2419
2420    #[tokio::test]
2421    async fn origin_allowed_passes() {
2422        let app = origin_router(vec!["http://localhost:3000".into()], false);
2423        let req = Request::builder()
2424            .uri("/test")
2425            .header(header::ORIGIN, "http://localhost:3000")
2426            .body(Body::empty())
2427            .unwrap();
2428        let resp = app.oneshot(req).await.unwrap();
2429        assert_eq!(resp.status(), StatusCode::OK);
2430    }
2431
2432    #[tokio::test]
2433    async fn origin_rejected_returns_403() {
2434        let app = origin_router(vec!["http://localhost:3000".into()], false);
2435        let req = Request::builder()
2436            .uri("/test")
2437            .header(header::ORIGIN, "http://evil.com")
2438            .body(Body::empty())
2439            .unwrap();
2440        let resp = app.oneshot(req).await.unwrap();
2441        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2442    }
2443
2444    #[tokio::test]
2445    async fn no_origin_header_passes() {
2446        let app = origin_router(vec!["http://localhost:3000".into()], false);
2447        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2448        let resp = app.oneshot(req).await.unwrap();
2449        assert_eq!(resp.status(), StatusCode::OK);
2450    }
2451
2452    #[tokio::test]
2453    async fn empty_allowlist_rejects_any_origin() {
2454        let app = origin_router(vec![], false);
2455        let req = Request::builder()
2456            .uri("/test")
2457            .header(header::ORIGIN, "http://anything.com")
2458            .body(Body::empty())
2459            .unwrap();
2460        let resp = app.oneshot(req).await.unwrap();
2461        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2462    }
2463
2464    #[tokio::test]
2465    async fn empty_allowlist_passes_without_origin() {
2466        let app = origin_router(vec![], false);
2467        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2468        let resp = app.oneshot(req).await.unwrap();
2469        assert_eq!(resp.status(), StatusCode::OK);
2470    }
2471
2472    #[test]
2473    fn format_request_headers_redacts_sensitive_values() {
2474        let mut headers = axum::http::HeaderMap::new();
2475        headers.insert("authorization", "Bearer secret-token".parse().unwrap());
2476        headers.insert("cookie", "sid=abc".parse().unwrap());
2477        headers.insert("x-request-id", "req-123".parse().unwrap());
2478
2479        let out = format_request_headers_for_log(&headers);
2480        assert!(out.contains("authorization: [REDACTED]"));
2481        assert!(out.contains("cookie: [REDACTED]"));
2482        assert!(out.contains("x-request-id: req-123"));
2483        assert!(!out.contains("secret-token"));
2484    }
2485
2486    // -- security_headers_middleware --
2487
2488    fn security_router(is_tls: bool) -> axum::Router {
2489        axum::Router::new()
2490            .route("/test", axum::routing::get(|| async { "ok" }))
2491            .layer(axum::middleware::from_fn(move |req, next| {
2492                security_headers_middleware(is_tls, req, next)
2493            }))
2494    }
2495
2496    #[tokio::test]
2497    async fn security_headers_set_on_response() {
2498        let app = security_router(false);
2499        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2500        let resp = app.oneshot(req).await.unwrap();
2501        assert_eq!(resp.status(), StatusCode::OK);
2502
2503        let h = resp.headers();
2504        assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
2505        assert_eq!(h.get("x-frame-options").unwrap(), "deny");
2506        assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
2507        assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
2508        assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
2509        assert_eq!(
2510            h.get("cross-origin-resource-policy").unwrap(),
2511            "same-origin"
2512        );
2513        assert_eq!(
2514            h.get("cross-origin-embedder-policy").unwrap(),
2515            "require-corp"
2516        );
2517        assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
2518        assert!(
2519            h.get("permissions-policy")
2520                .unwrap()
2521                .to_str()
2522                .unwrap()
2523                .contains("camera=()"),
2524            "permissions-policy must restrict browser features"
2525        );
2526        assert_eq!(
2527            h.get("content-security-policy").unwrap(),
2528            "default-src 'none'; frame-ancestors 'none'"
2529        );
2530        assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
2531        // No HSTS when TLS is off.
2532        assert!(h.get("strict-transport-security").is_none());
2533    }
2534
2535    #[tokio::test]
2536    async fn hsts_set_when_tls_enabled() {
2537        let app = security_router(true);
2538        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2539        let resp = app.oneshot(req).await.unwrap();
2540
2541        let hsts = resp.headers().get("strict-transport-security").unwrap();
2542        assert!(
2543            hsts.to_str().unwrap().contains("max-age=63072000"),
2544            "HSTS must set 2-year max-age"
2545        );
2546    }
2547
2548    // -- version endpoint --
2549
2550    #[test]
2551    fn version_payload_contains_expected_fields() {
2552        let v = version_payload("my-server", "1.2.3");
2553        assert_eq!(v["name"], "my-server");
2554        assert_eq!(v["version"], "1.2.3");
2555        assert!(v["build_git_sha"].is_string());
2556        assert!(v["build_timestamp"].is_string());
2557        assert!(v["rust_version"].is_string());
2558        assert!(v["mcpx_version"].is_string());
2559    }
2560
2561    // -- concurrency limit layer --
2562
2563    #[tokio::test]
2564    async fn concurrency_limit_layer_composes_and_serves() {
2565        // We only assert the layer stack compiles and a single request
2566        // below the cap still succeeds. True back-pressure behaviour
2567        // requires a live HTTP server and is covered by integration tests.
2568        let app = axum::Router::new()
2569            .route("/ok", axum::routing::get(|| async { "ok" }))
2570            .layer(
2571                tower::ServiceBuilder::new()
2572                    .layer(axum::error_handling::HandleErrorLayer::new(
2573                        |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
2574                    ))
2575                    .layer(tower::load_shed::LoadShedLayer::new())
2576                    .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
2577            );
2578        let resp = app
2579            .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
2580            .await
2581            .unwrap();
2582        assert_eq!(resp.status(), StatusCode::OK);
2583    }
2584
2585    // -- compression layer --
2586
2587    #[tokio::test]
2588    async fn compression_layer_gzip_encodes_response() {
2589        use tower_http::compression::Predicate as _;
2590
2591        let big_body = "a".repeat(4096);
2592        let app = axum::Router::new()
2593            .route(
2594                "/big",
2595                axum::routing::get(move || {
2596                    let body = big_body.clone();
2597                    async move { body }
2598                }),
2599            )
2600            .layer(
2601                tower_http::compression::CompressionLayer::new()
2602                    .gzip(true)
2603                    .br(true)
2604                    .compress_when(
2605                        tower_http::compression::DefaultPredicate::new()
2606                            .and(tower_http::compression::predicate::SizeAbove::new(1024)),
2607                    ),
2608            );
2609
2610        let req = Request::builder()
2611            .uri("/big")
2612            .header(header::ACCEPT_ENCODING, "gzip")
2613            .body(Body::empty())
2614            .unwrap();
2615        let resp = app.oneshot(req).await.unwrap();
2616        assert_eq!(resp.status(), StatusCode::OK);
2617        assert_eq!(
2618            resp.headers().get(header::CONTENT_ENCODING).unwrap(),
2619            "gzip"
2620        );
2621    }
2622}