Skip to main content

rmcp_server_kit/
transport.rs

1use std::{
2    future::Future,
3    net::SocketAddr,
4    path::{Path, PathBuf},
5    pin::Pin,
6    sync::Arc,
7    time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{body::Body, extract::Request, middleware::Next, response::IntoResponse};
12use rmcp::{
13    ServerHandler,
14    transport::streamable_http_server::{
15        StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16    },
17};
18use rustls::RootCertStore;
19use tokio::net::TcpListener;
20use tokio_util::sync::CancellationToken;
21
22use crate::{
23    auth::{
24        AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
25        build_rate_limiter, extract_mtls_identity,
26    },
27    error::McpxError,
28    mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
29    rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
30};
31
32/// Map an internal `anyhow::Error` chain into a public [`McpxError::Startup`]
33/// at the public API boundary, flattening the chain via the alternate
34/// formatter so callers see the full causal path.
35#[allow(
36    clippy::needless_pass_by_value,
37    reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
38)]
39fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
40    McpxError::Startup(format!("{e:#}"))
41}
42
43/// Map a `std::io::Error` produced during server startup into a public
44/// [`McpxError::Startup`]. We deliberately do not use the [`McpxError::Io`]
45/// `From` impl here because startup-phase IO errors (bind, listener) are
46/// semantically distinct from request-time IO errors and should surface
47/// the originating operation in the message.
48#[allow(
49    clippy::needless_pass_by_value,
50    reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
51)]
52fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
53    McpxError::Startup(format!("{op}: {e}"))
54}
55
56/// Async readiness check callback for the `/readyz` endpoint.
57///
58/// Returns a JSON object with at least a `"ready"` boolean.
59/// When `ready` is false, the endpoint returns HTTP 503.
60pub type ReadinessCheck =
61    Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
62
63/// Configuration for the MCP server.
64#[allow(
65    missing_debug_implementations,
66    reason = "contains callback/trait objects that don't impl Debug"
67)]
68#[allow(
69    clippy::struct_excessive_bools,
70    reason = "server configuration naturally has many boolean feature flags"
71)]
72#[non_exhaustive]
73pub struct McpServerConfig {
74    /// Socket address the MCP HTTP server binds to.
75    #[deprecated(
76        since = "0.13.0",
77        note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in 1.0"
78    )]
79    pub bind_addr: String,
80    /// Server name advertised via MCP `initialize`.
81    #[deprecated(
82        since = "0.13.0",
83        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in 1.0"
84    )]
85    pub name: String,
86    /// Server version advertised via MCP `initialize`.
87    #[deprecated(
88        since = "0.13.0",
89        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in 1.0"
90    )]
91    pub version: String,
92    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
93    #[deprecated(
94        since = "0.13.0",
95        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in 1.0"
96    )]
97    pub tls_cert_path: Option<PathBuf>,
98    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
99    #[deprecated(
100        since = "0.13.0",
101        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in 1.0"
102    )]
103    pub tls_key_path: Option<PathBuf>,
104    /// Optional authentication config. When `Some` and `enabled`, auth
105    /// is enforced on `/mcp`. `/healthz` is always open.
106    #[deprecated(
107        since = "0.13.0",
108        note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in 1.0"
109    )]
110    pub auth: Option<AuthConfig>,
111    /// Optional RBAC policy. When present and enabled, tool calls are
112    /// checked against the policy after authentication.
113    #[deprecated(
114        since = "0.13.0",
115        note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in 1.0"
116    )]
117    pub rbac: Option<Arc<RbacPolicy>>,
118    /// Allowed Origin values for DNS rebinding protection (MCP spec MUST).
119    /// When empty and `public_url` is set, the origin is auto-derived from
120    /// the public URL. When both are empty, only requests with no Origin
121    /// header are accepted.
122    /// Example entries: `"http://localhost:3000"`, `"https://myapp.example.com"`.
123    #[deprecated(
124        since = "0.13.0",
125        note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in 1.0"
126    )]
127    pub allowed_origins: Vec<String>,
128    /// Maximum tool invocations per source IP per minute.
129    /// When set, enforced on every `tools/call` request.
130    #[deprecated(
131        since = "0.13.0",
132        note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in 1.0"
133    )]
134    pub tool_rate_limit: Option<u32>,
135    /// Optional readiness probe for `/readyz`.
136    /// When `None`, `/readyz` mirrors `/healthz` (always OK).
137    #[deprecated(
138        since = "0.13.0",
139        note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in 1.0"
140    )]
141    pub readiness_check: Option<ReadinessCheck>,
142    /// Maximum request body size in bytes. Default: 1 MiB.
143    /// Protects against oversized payloads causing OOM.
144    #[deprecated(
145        since = "0.13.0",
146        note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in 1.0"
147    )]
148    pub max_request_body: usize,
149    /// Request processing timeout. Default: 120s.
150    /// Requests exceeding this duration receive 408 Request Timeout.
151    #[deprecated(
152        since = "0.13.0",
153        note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in 1.0"
154    )]
155    pub request_timeout: Duration,
156    /// Graceful shutdown timeout. Default: 30s.
157    /// After the shutdown signal, in-flight requests have this long to finish.
158    #[deprecated(
159        since = "0.13.0",
160        note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in 1.0"
161    )]
162    pub shutdown_timeout: Duration,
163    /// Idle timeout for MCP sessions. Sessions with no activity for this
164    /// duration are closed automatically. Default: 20 minutes.
165    #[deprecated(
166        since = "0.13.0",
167        note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in 1.0"
168    )]
169    pub session_idle_timeout: Duration,
170    /// Interval for SSE keep-alive pings. Prevents proxies and load
171    /// balancers from killing idle connections. Default: 15 seconds.
172    #[deprecated(
173        since = "0.13.0",
174        note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in 1.0"
175    )]
176    pub sse_keep_alive: Duration,
177    /// Callback invoked once the server is built, delivering a
178    /// [`ReloadHandle`] for hot-reloading auth keys and RBAC policy
179    /// at runtime (e.g. on SIGHUP). Only useful when auth/RBAC is enabled.
180    #[deprecated(
181        since = "0.13.0",
182        note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in 1.0"
183    )]
184    pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
185    /// Additional application-specific routes merged into the top-level
186    /// router.  These routes **bypass** the MCP auth and RBAC middleware,
187    /// so the application is responsible for its own auth on them.
188    #[deprecated(
189        since = "0.13.0",
190        note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in 1.0"
191    )]
192    pub extra_router: Option<axum::Router>,
193    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
194    /// When set, OAuth metadata endpoints advertise this URL instead of
195    /// the listen address. Required when binding `0.0.0.0` behind a
196    /// reverse proxy or inside a container.
197    #[deprecated(
198        since = "0.13.0",
199        note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in 1.0"
200    )]
201    pub public_url: Option<String>,
202    /// Log inbound HTTP request headers at DEBUG level.
203    /// Sensitive values remain redacted.
204    #[deprecated(
205        since = "0.13.0",
206        note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in 1.0"
207    )]
208    pub log_request_headers: bool,
209    /// Enable gzip/br response compression on MCP responses.
210    /// Defaults to `false` to preserve existing behaviour.
211    #[deprecated(
212        since = "0.13.0",
213        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in 1.0"
214    )]
215    pub compression_enabled: bool,
216    /// Minimum response body size (in bytes) before compression kicks in.
217    /// Only used when `compression_enabled` is true. Default: 1024.
218    #[deprecated(
219        since = "0.13.0",
220        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in 1.0"
221    )]
222    pub compression_min_size: u16,
223    /// Global cap on in-flight HTTP requests across the whole server.
224    /// When `Some`, requests over the cap receive 503 Service Unavailable
225    /// via `tower::load_shed`. Default: `None` (unlimited).
226    #[deprecated(
227        since = "0.13.0",
228        note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in 1.0"
229    )]
230    pub max_concurrent_requests: Option<usize>,
231    /// Enable `/admin/*` diagnostic endpoints. Requires `auth` to be
232    /// configured and `enabled`. Default: `false`.
233    #[deprecated(
234        since = "0.13.0",
235        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in 1.0"
236    )]
237    pub admin_enabled: bool,
238    /// RBAC role required to access admin endpoints. Default: `"admin"`.
239    #[deprecated(
240        since = "0.13.0",
241        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in 1.0"
242    )]
243    pub admin_role: String,
244    /// Enable Prometheus metrics endpoint on a separate listener.
245    /// Requires the `metrics` crate feature.
246    #[cfg(feature = "metrics")]
247    #[deprecated(
248        since = "0.13.0",
249        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in 1.0"
250    )]
251    pub metrics_enabled: bool,
252    /// Bind address for the Prometheus metrics listener. Default: `127.0.0.1:9090`.
253    #[cfg(feature = "metrics")]
254    #[deprecated(
255        since = "0.13.0",
256        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in 1.0"
257    )]
258    pub metrics_bind: String,
259}
260
261/// Marker that wraps a value proven to satisfy its validation
262/// contract.
263///
264/// The only way to obtain `Validated<McpServerConfig>` is by calling
265/// [`McpServerConfig::validate`], which is the contract enforced at
266/// the type level by [`serve`] and [`serve_with_listener`]. The
267/// inner field is private, so downstream code cannot bypass
268/// validation by hand-constructing the wrapper.
269///
270/// `Validated<T>` derefs to `&T` for read-only access. To mutate,
271/// recover the raw value with [`Validated::into_inner`] and
272/// re-validate.
273///
274/// # Example
275///
276/// ```no_run
277/// use rmcp_server_kit::transport::{McpServerConfig, Validated, serve};
278/// use rmcp::handler::server::ServerHandler;
279/// use rmcp::model::{ServerCapabilities, ServerInfo};
280///
281/// #[derive(Clone)]
282/// struct H;
283/// impl ServerHandler for H {
284///     fn get_info(&self) -> ServerInfo {
285///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
286///     }
287/// }
288///
289/// # async fn example() -> rmcp_server_kit::Result<()> {
290/// let config: Validated<McpServerConfig> =
291///     McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0").validate()?;
292/// serve(config, || H).await
293/// # }
294/// ```
295///
296/// Forgetting `.validate()?` is a compile error:
297///
298/// ```compile_fail
299/// use rmcp_server_kit::transport::{McpServerConfig, serve};
300/// use rmcp::handler::server::ServerHandler;
301/// use rmcp::model::{ServerCapabilities, ServerInfo};
302///
303/// #[derive(Clone)]
304/// struct H;
305/// impl ServerHandler for H {
306///     fn get_info(&self) -> ServerInfo {
307///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
308///     }
309/// }
310///
311/// # async fn example() -> rmcp_server_kit::Result<()> {
312/// let config = McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0");
313/// // Missing `.validate()?` -> mismatched types: expected
314/// // `Validated<McpServerConfig>`, found `McpServerConfig`.
315/// serve(config, || H).await
316/// # }
317/// ```
318#[allow(
319    missing_debug_implementations,
320    reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
321)]
322pub struct Validated<T>(T);
323
324impl<T> std::fmt::Debug for Validated<T> {
325    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        f.debug_struct("Validated").finish_non_exhaustive()
327    }
328}
329
330impl<T> Validated<T> {
331    /// Borrow the inner value.
332    #[must_use]
333    pub fn as_inner(&self) -> &T {
334        &self.0
335    }
336
337    /// Recover the raw value, discarding the validation proof.
338    ///
339    /// Re-validate before re-using the value with [`serve`] or
340    /// [`serve_with_listener`].
341    #[must_use]
342    pub fn into_inner(self) -> T {
343        self.0
344    }
345}
346
347impl<T> std::ops::Deref for Validated<T> {
348    type Target = T;
349
350    fn deref(&self) -> &T {
351        &self.0
352    }
353}
354
355#[allow(
356    deprecated,
357    reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
358)]
359impl McpServerConfig {
360    /// Create a new server configuration with the given bind address,
361    /// server name, and version. All other fields use safe defaults.
362    ///
363    /// Use the chainable `with_*` / `enable_*` builder methods to
364    /// customize. Call [`McpServerConfig::validate`] to obtain a
365    /// [`Validated<McpServerConfig>`] proof token, which is required by
366    /// [`serve`] and [`serve_with_listener`].
367    #[must_use]
368    pub fn new(
369        bind_addr: impl Into<String>,
370        name: impl Into<String>,
371        version: impl Into<String>,
372    ) -> Self {
373        Self {
374            bind_addr: bind_addr.into(),
375            name: name.into(),
376            version: version.into(),
377            tls_cert_path: None,
378            tls_key_path: None,
379            auth: None,
380            rbac: None,
381            allowed_origins: Vec::new(),
382            tool_rate_limit: None,
383            readiness_check: None,
384            max_request_body: 1024 * 1024,
385            request_timeout: Duration::from_mins(2),
386            shutdown_timeout: Duration::from_secs(30),
387            session_idle_timeout: Duration::from_mins(20),
388            sse_keep_alive: Duration::from_secs(15),
389            on_reload_ready: None,
390            extra_router: None,
391            public_url: None,
392            log_request_headers: false,
393            compression_enabled: false,
394            compression_min_size: 1024,
395            max_concurrent_requests: None,
396            admin_enabled: false,
397            admin_role: "admin".to_owned(),
398            #[cfg(feature = "metrics")]
399            metrics_enabled: false,
400            #[cfg(feature = "metrics")]
401            metrics_bind: "127.0.0.1:9090".into(),
402        }
403    }
404
405    // ---------------------------------------------------------------
406    // Builder methods (fluent, consume + return self).
407    //
408    // Each method is `#[must_use]` because dropping the returned
409    // `McpServerConfig` discards the configuration change.
410    // ---------------------------------------------------------------
411
412    /// Attach an authentication configuration. Required for
413    /// [`enable_admin`](Self::enable_admin) and any non-public deployment.
414    #[must_use]
415    pub fn with_auth(mut self, auth: AuthConfig) -> Self {
416        self.auth = Some(auth);
417        self
418    }
419
420    /// Override the bind address (e.g. `127.0.0.1:8080`). Useful when the
421    /// final port is only known after pre-binding an ephemeral listener
422    /// (tests, dynamic-port deployments).
423    #[must_use]
424    pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
425        self.bind_addr = addr.into();
426        self
427    }
428
429    /// Attach an RBAC policy. Tool calls are checked against the policy
430    /// after authentication.
431    #[must_use]
432    pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
433        self.rbac = Some(rbac);
434        self
435    }
436
437    /// Configure TLS by providing the certificate and private key paths
438    /// (PEM). Both must be readable at startup. Without this call, the
439    /// server runs plain HTTP.
440    #[must_use]
441    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
442        self.tls_cert_path = Some(cert_path.into());
443        self.tls_key_path = Some(key_path.into());
444        self
445    }
446
447    /// Set the externally reachable base URL (e.g. `https://mcp.example.com`).
448    /// Required when binding `0.0.0.0` behind a reverse proxy or inside
449    /// a container so OAuth metadata and auto-derived origins resolve correctly.
450    #[must_use]
451    pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
452        self.public_url = Some(url.into());
453        self
454    }
455
456    /// Replace the allowed Origin allow-list (DNS-rebinding protection).
457    /// When empty and [`with_public_url`](Self::with_public_url) is set,
458    /// the origin is auto-derived.
459    #[must_use]
460    pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
461    where
462        I: IntoIterator<Item = S>,
463        S: Into<String>,
464    {
465        self.allowed_origins = origins.into_iter().map(Into::into).collect();
466        self
467    }
468
469    /// Merge an additional axum router at the top level. Routes added
470    /// here **bypass** rmcp-server-kit auth and RBAC; the application is responsible
471    /// for its own protection.
472    #[must_use]
473    pub fn with_extra_router(mut self, router: axum::Router) -> Self {
474        self.extra_router = Some(router);
475        self
476    }
477
478    /// Install an async readiness probe for `/readyz`. Without this call,
479    /// `/readyz` mirrors `/healthz` (always 200 OK).
480    #[must_use]
481    pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
482        self.readiness_check = Some(check);
483        self
484    }
485
486    /// Override the maximum request body (bytes). Must be `> 0`.
487    /// Default: 1 MiB.
488    #[must_use]
489    pub fn with_max_request_body(mut self, bytes: usize) -> Self {
490        self.max_request_body = bytes;
491        self
492    }
493
494    /// Override the per-request processing timeout. Default: 2 minutes.
495    #[must_use]
496    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
497        self.request_timeout = timeout;
498        self
499    }
500
501    /// Override the graceful shutdown grace period. Default: 30 seconds.
502    #[must_use]
503    pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
504        self.shutdown_timeout = timeout;
505        self
506    }
507
508    /// Override the MCP session idle timeout. Default: 20 minutes.
509    #[must_use]
510    pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
511        self.session_idle_timeout = timeout;
512        self
513    }
514
515    /// Override the SSE keep-alive interval. Default: 15 seconds.
516    #[must_use]
517    pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
518        self.sse_keep_alive = interval;
519        self
520    }
521
522    /// Cap the global number of in-flight HTTP requests via
523    /// `tower::load_shed`. Excess requests receive 503 Service Unavailable.
524    /// Default: unlimited.
525    #[must_use]
526    pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
527        self.max_concurrent_requests = Some(limit);
528        self
529    }
530
531    /// Cap tool invocations per source IP per minute. Enforced on every
532    /// `tools/call` request.
533    #[must_use]
534    pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
535        self.tool_rate_limit = Some(per_minute);
536        self
537    }
538
539    /// Register a callback that receives the [`ReloadHandle`] after the
540    /// server is built. Use it to wire SIGHUP-style hot reloads of API
541    /// keys and RBAC policy.
542    #[must_use]
543    pub fn with_reload_callback<F>(mut self, callback: F) -> Self
544    where
545        F: FnOnce(ReloadHandle) + Send + 'static,
546    {
547        self.on_reload_ready = Some(Box::new(callback));
548        self
549    }
550
551    /// Enable gzip/brotli response compression on MCP responses.
552    /// `min_size` is the smallest body size (bytes) eligible for
553    /// compression. Default min size: 1024.
554    #[must_use]
555    pub fn enable_compression(mut self, min_size: u16) -> Self {
556        self.compression_enabled = true;
557        self.compression_min_size = min_size;
558        self
559    }
560
561    /// Enable `/admin/*` diagnostic endpoints. Requires
562    /// [`with_auth`](Self::with_auth) to be set and enabled; otherwise
563    /// [`validate`](Self::validate) returns an error. `role` is the RBAC
564    /// role gate (default: `"admin"`).
565    #[must_use]
566    pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
567        self.admin_enabled = true;
568        self.admin_role = role.into();
569        self
570    }
571
572    /// Log inbound HTTP request headers at DEBUG level. Sensitive
573    /// values remain redacted by the logging layer.
574    #[must_use]
575    pub fn enable_request_header_logging(mut self) -> Self {
576        self.log_request_headers = true;
577        self
578    }
579
580    /// Enable the Prometheus metrics listener on `bind` (e.g.
581    /// `127.0.0.1:9090`). Requires the `metrics` crate feature.
582    #[cfg(feature = "metrics")]
583    #[must_use]
584    pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
585        self.metrics_enabled = true;
586        self.metrics_bind = bind.into();
587        self
588    }
589
590    /// Validate the configuration and consume `self`, returning a
591    /// [`Validated<McpServerConfig>`] proof token required by [`serve`]
592    /// and [`serve_with_listener`]. This is the only way to construct
593    /// `Validated<McpServerConfig>`, so the type system guarantees
594    /// validation has run before the server starts.
595    ///
596    /// Checks:
597    ///
598    /// 1. `admin_enabled` requires `auth` to be configured and enabled.
599    /// 2. `tls_cert_path` and `tls_key_path` must both be set or both
600    ///    be unset.
601    /// 3. `bind_addr` must parse as a [`SocketAddr`].
602    /// 4. `public_url`, when set, must start with `http://` or `https://`.
603    /// 5. Each entry in `allowed_origins` must start with `http://` or
604    ///    `https://`.
605    /// 6. `max_request_body` must be greater than zero.
606    /// 7. When the `oauth` feature is enabled and an [`OAuthConfig`] is
607    ///    present, all OAuth URL fields (`jwks_uri`, `proxy.authorize_url`,
608    ///    `proxy.token_url`, `proxy.introspection_url`,
609    ///    `proxy.revocation_url`, `token_exchange.token_url`) must parse
610    ///    and use the `https` scheme. Set
611    ///    [`OAuthConfig::allow_http_oauth_urls`] to permit `http://`
612    ///    targets (strongly discouraged in production - see the field-level
613    ///    docs for the threat model).
614    ///
615    /// [`OAuthConfig`]: crate::oauth::OAuthConfig
616    /// [`OAuthConfig::allow_http_oauth_urls`]: crate::oauth::OAuthConfig::allow_http_oauth_urls
617    ///
618    /// # Errors
619    ///
620    /// Returns [`McpxError::Config`] with a human-readable message on
621    /// the first validation failure.
622    pub fn validate(self) -> Result<Validated<Self>, McpxError> {
623        self.check()?;
624        Ok(Validated(self))
625    }
626
627    /// Run the validation checks without consuming `self`. Used by
628    /// internal call sites (e.g. tests) that need to inspect a config
629    /// without taking ownership.
630    fn check(&self) -> Result<(), McpxError> {
631        // 1. admin <-> auth dependency. Mirrors the runtime check in
632        //    `build_app_router`: admin endpoints require an auth state,
633        //    which is built only when `auth` is `Some` *and* `enabled`.
634        if self.admin_enabled {
635            let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
636            if !auth_enabled {
637                return Err(McpxError::Config(
638                    "admin_enabled=true requires auth to be configured and enabled".into(),
639                ));
640            }
641        }
642
643        // 2. TLS cert / key must be paired
644        match (&self.tls_cert_path, &self.tls_key_path) {
645            (Some(_), None) => {
646                return Err(McpxError::Config(
647                    "tls_cert_path is set but tls_key_path is missing".into(),
648                ));
649            }
650            (None, Some(_)) => {
651                return Err(McpxError::Config(
652                    "tls_key_path is set but tls_cert_path is missing".into(),
653                ));
654            }
655            _ => {}
656        }
657
658        // 3. bind_addr parses
659        if self.bind_addr.parse::<SocketAddr>().is_err() {
660            return Err(McpxError::Config(format!(
661                "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
662                self.bind_addr
663            )));
664        }
665
666        // 4. public_url scheme
667        if let Some(ref url) = self.public_url
668            && !(url.starts_with("http://") || url.starts_with("https://"))
669        {
670            return Err(McpxError::Config(format!(
671                "public_url {url:?} must start with http:// or https://"
672            )));
673        }
674
675        // 5. allowed_origins scheme
676        for origin in &self.allowed_origins {
677            if !(origin.starts_with("http://") || origin.starts_with("https://")) {
678                return Err(McpxError::Config(format!(
679                    "allowed_origins entry {origin:?} must start with http:// or https://"
680                )));
681            }
682        }
683
684        // 6. max_request_body > 0
685        if self.max_request_body == 0 {
686            return Err(McpxError::Config(
687                "max_request_body must be greater than zero".into(),
688            ));
689        }
690
691        // 7. OAuth URL fields enforce HTTPS (unless `allow_http_oauth_urls`)
692        #[cfg(feature = "oauth")]
693        if let Some(auth_cfg) = &self.auth
694            && let Some(oauth_cfg) = &auth_cfg.oauth
695        {
696            oauth_cfg.validate()?;
697        }
698
699        Ok(())
700    }
701}
702
703/// Handle for hot-reloading server configuration without restart.
704///
705/// Obtained via [`McpServerConfig::on_reload_ready`].
706/// All swap operations are lock-free and wait-free -- in-flight requests
707/// finish with the old values while new requests see the update immediately.
708#[allow(
709    missing_debug_implementations,
710    reason = "contains Arc<AuthState> with non-Debug fields"
711)]
712pub struct ReloadHandle {
713    auth: Option<Arc<AuthState>>,
714    rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
715    crl_set: Option<Arc<CrlSet>>,
716}
717
718impl ReloadHandle {
719    /// Atomically replace the API key list used by the auth middleware.
720    pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
721        if let Some(ref auth) = self.auth {
722            auth.reload_keys(keys);
723        }
724    }
725
726    /// Atomically replace the RBAC policy used by the RBAC middleware.
727    pub fn reload_rbac(&self, policy: RbacPolicy) {
728        if let Some(ref rbac) = self.rbac {
729            rbac.store(Arc::new(policy));
730            tracing::info!("RBAC policy reloaded");
731        }
732    }
733
734    /// Force an immediate refresh of all cached mTLS CRLs.
735    ///
736    /// # Errors
737    ///
738    /// Returns an error if CRL refresh is unavailable or verifier rebuild fails.
739    pub async fn refresh_crls(&self) -> Result<(), McpxError> {
740        let Some(ref crl_set) = self.crl_set else {
741            return Err(McpxError::Config(
742                "CRL refresh requested but mTLS CRL support is not configured".into(),
743            ));
744        };
745
746        crl_set.force_refresh().await
747    }
748}
749
750/// Generic MCP HTTP server.
751///
752/// Wraps an axum server with `/healthz` and `/mcp` endpoints.
753/// When `tls_cert_path` and `tls_key_path` are both set, the server binds
754/// with TLS (rustls). Optionally supports mTLS client certificate auth.
755///
756/// # Errors
757///
758/// Returns an error if the TCP listener cannot bind, TLS config is invalid,
759/// or the server fails.
760// TODO(refactor): cognitive complexity reduced from 111/25 to 83/25 by
761// extracting `run_server` (serve-loop tail) and `install_oauth_proxy_routes`.
762// Remaining flow is a linear router builder: middleware layering, feature-
763// gated auth/RBAC wiring, and PRM/metrics installation. Further extraction
764// would require threading many `&mut Router` helpers and hurt readability
765// of the layer order (which is security-relevant and must stay visible).
766#[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
767/// Internal bundle of values produced by [`build_app_router`] and
768/// consumed by [`serve`] / [`serve_with_listener`] when driving the
769/// HTTP listener.
770struct AppRunParams {
771    /// TLS cert/key paths when TLS is configured.
772    tls_paths: Option<(PathBuf, PathBuf)>,
773    /// mTLS configuration when mutual-TLS auth is enabled.
774    mtls_config: Option<MtlsConfig>,
775    /// Graceful shutdown drain window.
776    shutdown_timeout: Duration,
777    /// Shared auth state used by hot-reload callbacks.
778    auth_state: Option<Arc<AuthState>>,
779    /// Hot-reloadable RBAC state used by reload callbacks.
780    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
781    /// Optional callback that receives the final [`ReloadHandle`].
782    on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
783    /// Server-internal cancellation token. Cancelled by [`run_server`]
784    /// once the shutdown trigger fires (so rmcp's child token also
785    /// fires, terminating in-flight MCP sessions).
786    ct: CancellationToken,
787    /// `"http"` or `"https"` -- used only for boot-time logging.
788    scheme: &'static str,
789    /// Server name -- used only for boot-time logging.
790    name: String,
791}
792
793/// Build the full application axum [`axum::Router`] (MCP route +
794/// middleware stack + admin + OAuth + health endpoints + security
795/// headers + CORS + compression + concurrency limit + origin check)
796/// and the [`AppRunParams`] needed to drive it.
797///
798/// This is the shared core of [`serve`] and [`serve_with_listener`].
799/// It performs *no* network I/O: callers are responsible for binding
800/// (or accepting a pre-bound) [`TcpListener`] and invoking
801/// [`run_server`].
802#[allow(
803    clippy::cognitive_complexity,
804    reason = "router assembly is intrinsically sequential; splitting harms readability"
805)]
806#[allow(
807    deprecated,
808    reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
809)]
810fn build_app_router<H, F>(
811    mut config: McpServerConfig,
812    handler_factory: F,
813) -> anyhow::Result<(axum::Router, AppRunParams)>
814where
815    H: ServerHandler + 'static,
816    F: Fn() -> H + Send + Sync + Clone + 'static,
817{
818    let ct = CancellationToken::new();
819
820    let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
821    tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
822
823    let mcp_service = StreamableHttpService::new(
824        move || Ok(handler_factory()),
825        {
826            let mut mgr = LocalSessionManager::default();
827            mgr.session_config.keep_alive = Some(config.session_idle_timeout);
828            mgr.into()
829        },
830        StreamableHttpServerConfig::default()
831            .with_allowed_hosts(allowed_hosts)
832            .with_sse_keep_alive(Some(config.sse_keep_alive))
833            .with_cancellation_token(ct.child_token()),
834    );
835
836    // Build the MCP route, optionally wrapped with auth and RBAC middleware.
837    let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
838
839    // Build auth state eagerly when auth is configured so we can wire both
840    // the auth middleware *and* the optional admin router against the same
841    // state. The middleware itself is installed further down in layer order.
842    let auth_state: Option<Arc<AuthState>> = match config.auth {
843        Some(ref auth_config) if auth_config.enabled => {
844            let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
845            let pre_auth_limiter = auth_config
846                .rate_limit
847                .as_ref()
848                .map(crate::auth::build_pre_auth_limiter);
849
850            #[cfg(feature = "oauth")]
851            let jwks_cache = auth_config
852                .oauth
853                .as_ref()
854                .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
855                .transpose()
856                .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
857
858            Some(Arc::new(AuthState {
859                api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
860                rate_limiter,
861                pre_auth_limiter,
862                #[cfg(feature = "oauth")]
863                jwks_cache,
864                seen_identities: std::sync::Mutex::new(std::collections::HashSet::new()),
865                counters: crate::auth::AuthCounters::default(),
866            }))
867        }
868        _ => None,
869    };
870
871    // Build the RBAC policy swap early so the admin router and the later
872    // RBAC middleware layer share the same hot-reloadable state.
873    let rbac_swap = Arc::new(ArcSwap::new(
874        config
875            .rbac
876            .clone()
877            .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
878    ));
879
880    // Optional /admin/* diagnostic routes. Merged BEFORE the
881    // body-limit/timeout/RBAC/origin/auth layers so all of them apply.
882    if config.admin_enabled {
883        let Some(ref auth_state_ref) = auth_state else {
884            return Err(anyhow::anyhow!(
885                "admin_enabled=true requires auth to be configured and enabled"
886            ));
887        };
888        let admin_state = crate::admin::AdminState {
889            started_at: std::time::Instant::now(),
890            name: config.name.clone(),
891            version: config.version.clone(),
892            auth: Some(Arc::clone(auth_state_ref)),
893            rbac: Arc::clone(&rbac_swap),
894        };
895        let admin_cfg = crate::admin::AdminConfig {
896            role: config.admin_role.clone(),
897        };
898        mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
899        tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
900    }
901
902    // ----- Middleware order (CRITICAL: read carefully) ------------------
903    //
904    // axum/tower applies layers **bottom-up** at runtime: the LAST layer
905    // added is the OUTERMOST (runs first on a request). To achieve a
906    // request-time flow of:
907    //
908    //   body-limit -> timeout -> auth -> rbac -> handler
909    //
910    // we add layers in the REVERSE order:
911    //
912    //   1. RBAC               (innermost, runs last before handler)
913    //   2. auth               (parses identity, sets extension for RBAC)
914    //   3. timeout            (bounds total request time)
915    //   4. body-limit         (outermost on /mcp; caps payload before
916    //                          anything else reads/buffers it)
917    //
918    // Origin validation is installed on the OUTER router (after the
919    // /mcp router is merged in), so it also protects /healthz, /readyz,
920    // /version, and any OAuth proxy endpoints.
921    //
922    // Rationale:
923    // - Body-limit must be outermost on /mcp so RBAC (which reads the
924    //   JSON-RPC body) cannot be DoS'd by a 100MB payload.
925    // - Auth must run before RBAC because RBAC consumes
926    //   `req.extensions().get::<AuthIdentity>()` to enforce per-role
927    //   policy.
928    // - Origin runs before auth so we reject cross-origin requests
929    //   without spending Argon2 cycles on unauthenticated callers.
930
931    // [1] RBAC + tool rate-limit layer (innermost; closest to handler).
932    // Always installed: even when RBAC is disabled, tool rate limiting may
933    // be active (MCP spec: servers MUST rate limit tool invocations).
934    {
935        let tool_limiter: Option<Arc<ToolRateLimiter>> =
936            config.tool_rate_limit.map(build_tool_rate_limiter);
937
938        if rbac_swap.load().is_enabled() {
939            tracing::info!("RBAC enforcement enabled on /mcp");
940        }
941        if let Some(limit) = config.tool_rate_limit {
942            tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
943        }
944
945        let rbac_for_mw = Arc::clone(&rbac_swap);
946        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
947            let p = rbac_for_mw.load_full();
948            let tl = tool_limiter.clone();
949            rbac_middleware(p, tl, req, next)
950        }));
951    }
952
953    // [2] Auth layer (runs before RBAC so AuthIdentity is in extensions).
954    if let Some(ref auth_config) = config.auth
955        && auth_config.enabled
956    {
957        let Some(ref state) = auth_state else {
958            return Err(anyhow::anyhow!("auth state missing despite enabled config"));
959        };
960
961        let methods: Vec<&str> = [
962            auth_config.mtls.is_some().then_some("mTLS"),
963            (!auth_config.api_keys.is_empty()).then_some("bearer"),
964            #[cfg(feature = "oauth")]
965            auth_config.oauth.is_some().then_some("oauth-jwt"),
966        ]
967        .into_iter()
968        .flatten()
969        .collect();
970
971        tracing::info!(
972            methods = %methods.join(", "),
973            api_keys = auth_config.api_keys.len(),
974            "auth enabled on /mcp"
975        );
976
977        let state_for_mw = Arc::clone(state);
978        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
979            let s = Arc::clone(&state_for_mw);
980            auth_middleware(s, req, next)
981        }));
982    }
983
984    // [3] Request timeout (returns 408 on expiry). Bounds total request
985    // duration including auth + handler.
986    mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
987        axum::http::StatusCode::REQUEST_TIMEOUT,
988        config.request_timeout,
989    ));
990
991    // [4] Request body size limit (OUTERMOST on /mcp). Prevents OOM /
992    // DoS from oversized payloads BEFORE any inner layer (auth, RBAC)
993    // attempts to buffer or parse the body.
994    mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
995        config.max_request_body,
996    ));
997
998    // Compute the effective allowed-origins list for the outer
999    // origin-check layer (installed on the merged router below). When
1000    // `allowed_origins` is empty but `public_url` is set, auto-derive
1001    // the origin from the public URL so MCP clients (e.g. Claude Code)
1002    // that send `Origin: <server-url>` are accepted without explicit
1003    // config.
1004    let mut effective_origins = config.allowed_origins.clone();
1005    if effective_origins.is_empty()
1006        && let Some(ref url) = config.public_url
1007    {
1008        // Origin = scheme + "://" + host (+ ":" + port if non-default).
1009        // Strip any path/query from the public URL.
1010        if let Some(scheme_end) = url.find("://") {
1011            let after_scheme = &url[scheme_end + 3..];
1012            let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1013            let origin = format!("{}{}", &url[..scheme_end + 3], &after_scheme[..host_end]);
1014            tracing::info!(
1015                %origin,
1016                "auto-derived allowed origin from public_url"
1017            );
1018            effective_origins.push(origin);
1019        }
1020    }
1021    let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1022    let cors_origins = Arc::clone(&allowed_origins);
1023    let log_request_headers = config.log_request_headers;
1024
1025    let readyz_route = if let Some(check) = config.readiness_check.take() {
1026        axum::routing::get(move || readyz(Arc::clone(&check)))
1027    } else {
1028        axum::routing::get(healthz)
1029    };
1030
1031    #[allow(unused_mut)] // mut needed when oauth feature adds PRM route
1032    let mut router = axum::Router::new()
1033        .route("/healthz", axum::routing::get(healthz))
1034        .route("/readyz", readyz_route)
1035        .route(
1036            "/version",
1037            axum::routing::get({
1038                // Pre-serialize the version payload once at router-build
1039                // time. The handler then serves a cheap `Arc::clone` of the
1040                // immutable bytes per request, avoiding `serde_json::Value`
1041                // allocation + serialization on every `/version` hit.
1042                let payload_bytes: Arc<[u8]> =
1043                    serialize_version_payload(&config.name, &config.version);
1044                move || {
1045                    let p = Arc::clone(&payload_bytes);
1046                    async move {
1047                        (
1048                            [(axum::http::header::CONTENT_TYPE, "application/json")],
1049                            p.to_vec(),
1050                        )
1051                    }
1052                }
1053            }),
1054        )
1055        .merge(mcp_router);
1056
1057    // Merge application-specific routes (bypass MCP auth/RBAC middleware).
1058    if let Some(extra) = config.extra_router.take() {
1059        router = router.merge(extra);
1060    }
1061
1062    // RFC 9728: Protected Resource Metadata endpoint.
1063    // When OAuth is configured, serve full metadata with authorization_servers.
1064    // Otherwise, serve a minimal document with just the resource URL and no
1065    // authorization_servers -- this tells MCP clients (e.g. Claude Code SDK)
1066    // that the server exists but does NOT require OAuth authentication,
1067    // preventing them from gating the connection behind a broken auth flow.
1068    let server_url = if let Some(ref url) = config.public_url {
1069        url.trim_end_matches('/').to_owned()
1070    } else {
1071        let prm_scheme = if config.tls_cert_path.is_some() {
1072            "https"
1073        } else {
1074            "http"
1075        };
1076        format!("{prm_scheme}://{}", config.bind_addr)
1077    };
1078    let resource_url = format!("{server_url}/mcp");
1079
1080    #[cfg(feature = "oauth")]
1081    let prm_metadata = if let Some(ref auth_config) = config.auth
1082        && let Some(ref oauth_config) = auth_config.oauth
1083    {
1084        crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1085    } else {
1086        serde_json::json!({ "resource": resource_url })
1087    };
1088    #[cfg(not(feature = "oauth"))]
1089    let prm_metadata = serde_json::json!({ "resource": resource_url });
1090
1091    router = router.route(
1092        "/.well-known/oauth-protected-resource",
1093        axum::routing::get(move || {
1094            let m = prm_metadata.clone();
1095            async move { axum::Json(m) }
1096        }),
1097    );
1098
1099    // OAuth 2.1 proxy endpoints: when an OAuth proxy is configured, expose
1100    // /authorize, /token, /register, and authorization server metadata so
1101    // MCP clients can perform Authorization Code + PKCE against the upstream
1102    // IdP (e.g. Keycloak) transparently.
1103    #[cfg(feature = "oauth")]
1104    if let Some(ref auth_config) = config.auth
1105        && let Some(ref oauth_config) = auth_config.oauth
1106        && oauth_config.proxy.is_some()
1107    {
1108        router =
1109            install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1110    }
1111
1112    // OWASP security response headers (applied to all responses).
1113    // HSTS is conditional on TLS being configured.
1114    let is_tls = config.tls_cert_path.is_some();
1115    router = router.layer(axum::middleware::from_fn(move |req, next| {
1116        security_headers_middleware(is_tls, req, next)
1117    }));
1118
1119    // CORS preflight layer (required for browser-based MCP clients).
1120    // Uses the same effective origins as the origin check middleware
1121    // (including auto-derived origin from public_url).
1122    if !cors_origins.is_empty() {
1123        let cors = tower_http::cors::CorsLayer::new()
1124            .allow_origin(
1125                cors_origins
1126                    .iter()
1127                    .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1128                    .collect::<Vec<_>>(),
1129            )
1130            .allow_methods([
1131                axum::http::Method::GET,
1132                axum::http::Method::POST,
1133                axum::http::Method::OPTIONS,
1134            ])
1135            .allow_headers([
1136                axum::http::header::CONTENT_TYPE,
1137                axum::http::header::AUTHORIZATION,
1138            ]);
1139        router = router.layer(cors);
1140    }
1141
1142    // Optional response compression (gzip + brotli). Skips small bodies
1143    // to avoid overhead. Applied after CORS so preflight responses remain
1144    // uncompressed.
1145    if config.compression_enabled {
1146        use tower_http::compression::Predicate as _;
1147        let predicate = tower_http::compression::DefaultPredicate::new().and(
1148            tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1149        );
1150        router = router.layer(
1151            tower_http::compression::CompressionLayer::new()
1152                .gzip(true)
1153                .br(true)
1154                .compress_when(predicate),
1155        );
1156        tracing::info!(
1157            min_size = config.compression_min_size,
1158            "response compression enabled (gzip, br)"
1159        );
1160    }
1161
1162    // Optional global concurrency cap. `load_shed` converts the
1163    // `ConcurrencyLimit` back-pressure error into 503 instead of hanging.
1164    if let Some(max) = config.max_concurrent_requests {
1165        let overload_handler = tower::ServiceBuilder::new()
1166            .layer(axum::error_handling::HandleErrorLayer::new(
1167                |_err: tower::BoxError| async {
1168                    (
1169                        axum::http::StatusCode::SERVICE_UNAVAILABLE,
1170                        axum::Json(serde_json::json!({
1171                            "error": "overloaded",
1172                            "error_description": "server is at capacity, retry later"
1173                        })),
1174                    )
1175                },
1176            ))
1177            .layer(tower::load_shed::LoadShedLayer::new())
1178            .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1179        router = router.layer(overload_handler);
1180        tracing::info!(max, "global concurrency limit enabled");
1181    }
1182
1183    // JSON fallback for unmatched routes. Without this, axum returns
1184    // an empty-body 404 that breaks MCP clients (e.g. Claude Code SDK)
1185    // when they probe OAuth endpoints like /authorize or /token.
1186    router = router.fallback(|| async {
1187        (
1188            axum::http::StatusCode::NOT_FOUND,
1189            axum::Json(serde_json::json!({
1190                "error": "not_found",
1191                "error_description": "The requested endpoint does not exist"
1192            })),
1193        )
1194    });
1195
1196    // Prometheus metrics: recording middleware + separate listener.
1197    #[cfg(feature = "metrics")]
1198    if config.metrics_enabled {
1199        let metrics = Arc::new(
1200            crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1201        );
1202        let m = Arc::clone(&metrics);
1203        router = router.layer(axum::middleware::from_fn(
1204            move |req: Request<Body>, next: Next| {
1205                let m = Arc::clone(&m);
1206                metrics_middleware(m, req, next)
1207            },
1208        ));
1209        let metrics_bind = config.metrics_bind.clone();
1210        tokio::spawn(async move {
1211            if let Err(e) = crate::metrics::serve_metrics(metrics_bind, metrics).await {
1212                tracing::error!("metrics listener failed: {e}");
1213            }
1214        });
1215    }
1216
1217    // Origin validation layer (MCP spec: servers MUST validate the
1218    // Origin header to prevent DNS rebinding attacks). Installed as the
1219    // OUTERMOST layer on the OUTER router so it protects ALL routes
1220    // (`/mcp`, `/healthz`, `/readyz`, `/version`, OAuth proxy endpoints,
1221    // admin endpoints, extra_router, etc.) and runs BEFORE auth so we
1222    // reject cross-origin attackers without spending Argon2 cycles.
1223    //
1224    // Origin-less requests (e.g. server-to-server probes, curl, native
1225    // MCP clients) are permitted; only requests with an Origin header
1226    // that does not match `effective_origins` are rejected.
1227    router = router.layer(axum::middleware::from_fn(move |req, next| {
1228        let origins = Arc::clone(&allowed_origins);
1229        origin_check_middleware(origins, log_request_headers, req, next)
1230    }));
1231
1232    let scheme = if config.tls_cert_path.is_some() {
1233        "https"
1234    } else {
1235        "http"
1236    };
1237
1238    let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1239        (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1240        _ => None,
1241    };
1242    let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1243
1244    Ok((
1245        router,
1246        AppRunParams {
1247            tls_paths,
1248            mtls_config,
1249            shutdown_timeout: config.shutdown_timeout,
1250            auth_state,
1251            rbac_swap,
1252            on_reload_ready: config.on_reload_ready.take(),
1253            ct,
1254            scheme,
1255            name: config.name.clone(),
1256        },
1257    ))
1258}
1259
1260/// Run the MCP HTTP server, binding to `config.bind_addr` and serving
1261/// until an OS shutdown signal (Ctrl-C / SIGTERM) is received.
1262///
1263/// This is the standard entry point for production deployments. For
1264/// deterministic shutdown control (e.g. integration tests), see
1265/// [`serve_with_listener`].
1266///
1267/// The configuration must be validated first via
1268/// [`McpServerConfig::validate`], which returns a [`Validated`] proof
1269/// token. This typestate guarantees, at compile time, that the server
1270/// never starts with an invalid configuration.
1271///
1272/// # Errors
1273///
1274/// Returns [`McpxError::Startup`] if binding to `config.bind_addr`
1275/// fails, or if the underlying axum server returns an error.
1276pub async fn serve<H, F>(
1277    config: Validated<McpServerConfig>,
1278    handler_factory: F,
1279) -> Result<(), McpxError>
1280where
1281    H: ServerHandler + 'static,
1282    F: Fn() -> H + Send + Sync + Clone + 'static,
1283{
1284    let config = config.into_inner();
1285    #[allow(
1286        deprecated,
1287        reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1288    )]
1289    let bind_addr = config.bind_addr.clone();
1290    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1291
1292    let listener = TcpListener::bind(&bind_addr)
1293        .await
1294        .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1295    log_listening(&params.name, params.scheme, &bind_addr);
1296
1297    run_server(
1298        router,
1299        listener,
1300        params.tls_paths,
1301        params.mtls_config,
1302        params.shutdown_timeout,
1303        params.auth_state,
1304        params.rbac_swap,
1305        params.on_reload_ready,
1306        params.ct,
1307    )
1308    .await
1309    .map_err(anyhow_to_startup)
1310}
1311
1312/// Run the MCP HTTP server on a pre-bound [`TcpListener`], with optional
1313/// readiness signalling and external shutdown control.
1314///
1315/// This variant is intended for **deterministic integration tests** and
1316/// for embedders that need to bind the listening socket themselves
1317/// (e.g. systemd socket activation). Compared to [`serve`]:
1318///
1319/// * The caller passes a `TcpListener` that is already bound. This
1320///   eliminates the bind race in tests that previously required
1321///   poll-the-`/healthz`-loop start-up detection.
1322/// * `ready_tx`, when `Some`, receives the socket's
1323///   [`SocketAddr`] *after* the router is built and immediately before
1324///   the server starts accepting connections. Tests can `await` the
1325///   matching `oneshot::Receiver` to know exactly when it is safe to
1326///   issue requests.
1327/// * `shutdown`, when `Some`, gives the caller a
1328///   [`CancellationToken`] that triggers the same graceful-shutdown
1329///   path as a real OS signal. This avoids cross-platform issues with
1330///   sending real `SIGTERM` from tests on Windows.
1331///
1332/// All three optional parameters degrade gracefully: if `ready_tx` is
1333/// `None`, no signal is sent; if `shutdown` is `None`, the server only
1334/// stops on an OS signal (just like [`serve`]).
1335///
1336/// # Errors
1337///
1338/// Returns [`McpxError::Startup`] if router construction fails, if reading
1339/// the listener's `local_addr()` fails, or if the underlying axum
1340/// server returns an error.
1341pub async fn serve_with_listener<H, F>(
1342    listener: TcpListener,
1343    config: Validated<McpServerConfig>,
1344    handler_factory: F,
1345    ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1346    shutdown: Option<CancellationToken>,
1347) -> Result<(), McpxError>
1348where
1349    H: ServerHandler + 'static,
1350    F: Fn() -> H + Send + Sync + Clone + 'static,
1351{
1352    let config = config.into_inner();
1353    let local_addr = listener
1354        .local_addr()
1355        .map_err(|e| io_to_startup("listener.local_addr", e))?;
1356    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1357
1358    log_listening(&params.name, params.scheme, &local_addr.to_string());
1359
1360    // Forward external shutdown into the server-internal cancellation
1361    // token so `run_server`'s shutdown trigger picks it up alongside
1362    // any real OS signal.
1363    if let Some(external) = shutdown {
1364        let internal = params.ct.clone();
1365        tokio::spawn(async move {
1366            external.cancelled().await;
1367            internal.cancel();
1368        });
1369    }
1370
1371    // Signal readiness *after* the router is fully built and external
1372    // shutdown is wired, but *before* run_server takes ownership of
1373    // the listener. The receiver can immediately issue requests.
1374    if let Some(tx) = ready_tx {
1375        // Receiver may have been dropped (test gave up). That's fine.
1376        let _ = tx.send(local_addr);
1377    }
1378
1379    run_server(
1380        router,
1381        listener,
1382        params.tls_paths,
1383        params.mtls_config,
1384        params.shutdown_timeout,
1385        params.auth_state,
1386        params.rbac_swap,
1387        params.on_reload_ready,
1388        params.ct,
1389    )
1390    .await
1391    .map_err(anyhow_to_startup)
1392}
1393
1394/// Emit the standard "listening on …" log lines used by both
1395/// [`serve`] and [`serve_with_listener`].
1396#[allow(
1397    clippy::cognitive_complexity,
1398    reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1399)]
1400fn log_listening(name: &str, scheme: &str, addr: &str) {
1401    tracing::info!("{name} listening on {addr}");
1402    tracing::info!("  MCP endpoint: {scheme}://{addr}/mcp");
1403    tracing::info!("  Health check: {scheme}://{addr}/healthz");
1404    tracing::info!("  Readiness:   {scheme}://{addr}/readyz");
1405}
1406
1407/// Drive the chosen axum server variant (TLS or plain) with a graceful
1408/// shutdown window. Consumes the router and listener.
1409///
1410/// # Shutdown semantics
1411///
1412/// A single shutdown trigger (the FIRST of: OS signal via
1413/// `shutdown_signal()`, or external cancellation of `ct`) starts BOTH:
1414///
1415/// 1. axum's `.with_graceful_shutdown(...)` future, which stops
1416///    accepting new connections and waits for in-flight requests to
1417///    drain;
1418/// 2. a `tokio::time::sleep(shutdown_timeout)` race that forces exit if
1419///    drainage exceeds `shutdown_timeout`.
1420///
1421/// Previously this function awaited `shutdown_signal()` independently
1422/// in BOTH branches of a `tokio::select!`. Because `shutdown_signal`
1423/// resolves once per future and consumes one signal, the force-exit
1424/// timer was tied to a SECOND signal (a second SIGTERM the operator
1425/// would never send). Under a single SIGTERM the graceful drain could
1426/// hang indefinitely. The current implementation derives both branches
1427/// from a single shared trigger so the timeout race is anchored to the
1428/// FIRST (and only) signal.
1429#[allow(
1430    clippy::too_many_arguments,
1431    clippy::cognitive_complexity,
1432    reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1433)]
1434async fn run_server(
1435    router: axum::Router,
1436    listener: TcpListener,
1437    tls_paths: Option<(PathBuf, PathBuf)>,
1438    mtls_config: Option<MtlsConfig>,
1439    shutdown_timeout: Duration,
1440    auth_state: Option<Arc<AuthState>>,
1441    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1442    mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1443    ct: CancellationToken,
1444) -> anyhow::Result<()> {
1445    // `shutdown_trigger` fires when the FIRST source resolves: either
1446    // an OS signal (Ctrl-C / SIGTERM) or external cancellation of `ct`
1447    // (which the test harness uses for deterministic shutdown).
1448    let shutdown_trigger = CancellationToken::new();
1449    {
1450        let trigger = shutdown_trigger.clone();
1451        let parent = ct.clone();
1452        tokio::spawn(async move {
1453            tokio::select! {
1454                () = shutdown_signal() => {}
1455                () = parent.cancelled() => {}
1456            }
1457            trigger.cancel();
1458        });
1459    }
1460
1461    let graceful = {
1462        let trigger = shutdown_trigger.clone();
1463        let ct = ct.clone();
1464        async move {
1465            trigger.cancelled().await;
1466            tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1467            ct.cancel();
1468        }
1469    };
1470
1471    let force_exit_timer = {
1472        let trigger = shutdown_trigger.clone();
1473        async move {
1474            trigger.cancelled().await;
1475            tokio::time::sleep(shutdown_timeout).await;
1476        }
1477    };
1478
1479    if let Some((cert_path, key_path)) = tls_paths {
1480        let crl_set = if let Some(mtls) = mtls_config.as_ref()
1481            && mtls.crl_enabled
1482        {
1483            let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1484            let (crl_set, discover_rx) =
1485                mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1486                    .await
1487                    .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1488            tokio::spawn(mtls_revocation::run_crl_refresher(
1489                Arc::clone(&crl_set),
1490                discover_rx,
1491                ct.clone(),
1492            ));
1493            Some(crl_set)
1494        } else {
1495            None
1496        };
1497
1498        if let Some(cb) = on_reload_ready.take() {
1499            cb(ReloadHandle {
1500                auth: auth_state.clone(),
1501                rbac: Some(Arc::clone(&rbac_swap)),
1502                crl_set: crl_set.clone(),
1503            });
1504        }
1505
1506        let tls_listener = TlsListener::new(
1507            listener,
1508            &cert_path,
1509            &key_path,
1510            mtls_config.as_ref(),
1511            crl_set,
1512        )?;
1513        let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1514        tokio::select! {
1515            result = axum::serve(tls_listener, make_svc)
1516                .with_graceful_shutdown(graceful) => { result?; }
1517            () = force_exit_timer => {
1518                tracing::warn!("shutdown timeout exceeded, forcing exit");
1519            }
1520        }
1521    } else {
1522        if let Some(cb) = on_reload_ready.take() {
1523            cb(ReloadHandle {
1524                auth: auth_state,
1525                rbac: Some(rbac_swap),
1526                crl_set: None,
1527            });
1528        }
1529
1530        let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1531        tokio::select! {
1532            result = axum::serve(listener, make_svc)
1533                .with_graceful_shutdown(graceful) => { result?; }
1534            () = force_exit_timer => {
1535                tracing::warn!("shutdown timeout exceeded, forcing exit");
1536            }
1537        }
1538    }
1539
1540    Ok(())
1541}
1542
1543/// Install the OAuth 2.1 proxy endpoints (`/authorize`, `/token`,
1544/// `/register`, and authorization server metadata) on `router`. The
1545/// caller must ensure `oauth_config.proxy` is `Some`.
1546///
1547/// # Errors
1548///
1549/// Returns [`McpxError::Startup`] if the shared
1550/// [`crate::oauth::OauthHttpClient`] cannot be initialized.
1551#[cfg(feature = "oauth")]
1552fn install_oauth_proxy_routes(
1553    router: axum::Router,
1554    server_url: &str,
1555    oauth_config: &crate::oauth::OAuthConfig,
1556    auth_state: Option<&Arc<AuthState>>,
1557) -> Result<axum::Router, McpxError> {
1558    let Some(ref proxy) = oauth_config.proxy else {
1559        return Ok(router);
1560    };
1561
1562    // Single shared HTTP client for all proxy endpoints. Cloning is
1563    // cheap (refcounted) and shares the underlying connection pool.
1564    let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1565
1566    let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1567    let router = router.route(
1568        "/.well-known/oauth-authorization-server",
1569        axum::routing::get(move || {
1570            let m = asm.clone();
1571            async move { axum::Json(m) }
1572        }),
1573    );
1574
1575    let proxy_authorize = proxy.clone();
1576    let router = router.route(
1577        "/authorize",
1578        axum::routing::get(
1579            move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1580                let p = proxy_authorize.clone();
1581                async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1582            },
1583        ),
1584    );
1585
1586    let proxy_token = proxy.clone();
1587    let token_http = http.clone();
1588    let router = router.route(
1589        "/token",
1590        axum::routing::post(move |body: String| {
1591            let p = proxy_token.clone();
1592            let h = token_http.clone();
1593            async move { crate::oauth::handle_token(&h, &p, &body).await }
1594        }),
1595    );
1596
1597    let proxy_register = proxy.clone();
1598    let router = router.route(
1599        "/register",
1600        axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1601            let p = proxy_register;
1602            async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1603        }),
1604    );
1605
1606    let admin_routes_enabled = proxy.expose_admin_endpoints
1607        && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1608    if proxy.expose_admin_endpoints && !proxy.require_auth_on_admin_endpoints {
1609        tracing::warn!(
1610            "OAuth introspect/revoke endpoints are unauthenticated; consider setting require_auth_on_admin_endpoints = true"
1611        );
1612    }
1613
1614    let admin_router = if admin_routes_enabled {
1615        let mut admin_router = axum::Router::new();
1616        if proxy.introspection_url.is_some() {
1617            let proxy_introspect = proxy.clone();
1618            let introspect_http = http.clone();
1619            admin_router = admin_router.route(
1620                "/introspect",
1621                axum::routing::post(move |body: String| {
1622                    let p = proxy_introspect.clone();
1623                    let h = introspect_http.clone();
1624                    async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1625                }),
1626            );
1627        }
1628        if proxy.revocation_url.is_some() {
1629            let proxy_revoke = proxy.clone();
1630            let revoke_http = http;
1631            admin_router = admin_router.route(
1632                "/revoke",
1633                axum::routing::post(move |body: String| {
1634                    let p = proxy_revoke.clone();
1635                    let h = revoke_http.clone();
1636                    async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1637                }),
1638            );
1639        }
1640
1641        if proxy.require_auth_on_admin_endpoints {
1642            let Some(state) = auth_state else {
1643                return Err(McpxError::Startup(
1644                    "oauth proxy admin endpoints require auth state".into(),
1645                ));
1646            };
1647            let state_for_mw = Arc::clone(state);
1648            admin_router.layer(axum::middleware::from_fn(move |req, next| {
1649                let s = Arc::clone(&state_for_mw);
1650                auth_middleware(s, req, next)
1651            }))
1652        } else {
1653            admin_router
1654        }
1655    } else {
1656        axum::Router::new()
1657    };
1658
1659    let router = router.merge(admin_router);
1660
1661    tracing::info!(
1662        introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1663        revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1664        "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1665    );
1666    Ok(router)
1667}
1668
1669/// Build the host allow-list for rmcp's DNS rebinding protection.
1670///
1671/// Includes loopback hosts by default, then augments with host/authority
1672/// derived from `public_url` and the server bind address.
1673fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
1674    let mut hosts = vec![
1675        "localhost".to_owned(),
1676        "127.0.0.1".to_owned(),
1677        "::1".to_owned(),
1678    ];
1679
1680    if let Some(url) = public_url
1681        && let Ok(uri) = url.parse::<axum::http::Uri>()
1682        && let Some(authority) = uri.authority()
1683    {
1684        let host = authority.host().to_owned();
1685        if !hosts.iter().any(|h| h == &host) {
1686            hosts.push(host);
1687        }
1688
1689        let authority = authority.as_str().to_owned();
1690        if !hosts.iter().any(|h| h == &authority) {
1691            hosts.push(authority);
1692        }
1693    }
1694
1695    if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
1696        && let Some(authority) = uri.authority()
1697    {
1698        let host = authority.host().to_owned();
1699        if !hosts.iter().any(|h| h == &host) {
1700            hosts.push(host);
1701        }
1702
1703        let authority = authority.as_str().to_owned();
1704        if !hosts.iter().any(|h| h == &authority) {
1705            hosts.push(authority);
1706        }
1707    }
1708
1709    hosts
1710}
1711
1712// - TLS support -
1713
1714/// Implement axum's `Connected` trait for `TlsConnInfo` so that
1715/// `ConnectInfo<TlsConnInfo>` is available in middleware when serving
1716/// over our custom `TlsListener`.
1717///
1718/// The identity is read directly from the wrapping
1719/// [`AuthenticatedTlsStream`], which guarantees one-to-one correspondence
1720/// between the TLS connection and its mTLS identity. This eliminates the
1721/// previous shared-map approach which was vulnerable to ephemeral-port
1722/// reuse races (an unauthenticated reconnection from the same `(IP, port)`
1723/// pair could alias a stale entry).
1724impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
1725    for TlsConnInfo
1726{
1727    fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
1728        let addr = *target.remote_addr();
1729        let identity = target.io().identity().cloned();
1730        TlsConnInfo::new(addr, identity)
1731    }
1732}
1733
1734/// A TLS-wrapping listener that implements axum's `Listener` trait.
1735///
1736/// When mTLS is configured, verifies client certificates against the
1737/// configured CA and extracts the client identity at handshake time.
1738/// The extracted identity is bound to the connection itself via the
1739/// returned [`AuthenticatedTlsStream`], so it is impossible for an
1740/// unrelated connection to observe it.
1741struct TlsListener {
1742    inner: TcpListener,
1743    acceptor: tokio_rustls::TlsAcceptor,
1744    mtls_default_role: String,
1745}
1746
1747impl TlsListener {
1748    fn new(
1749        inner: TcpListener,
1750        cert_path: &Path,
1751        key_path: &Path,
1752        mtls_config: Option<&MtlsConfig>,
1753        crl_set: Option<Arc<CrlSet>>,
1754    ) -> anyhow::Result<Self> {
1755        // Install the ring crypto provider (ok to call multiple times).
1756        rustls::crypto::ring::default_provider()
1757            .install_default()
1758            .ok();
1759
1760        let certs = load_certs(cert_path)?;
1761        let key = load_key(key_path)?;
1762
1763        let mtls_default_role;
1764
1765        let tls_config = if let Some(mtls) = mtls_config {
1766            mtls_default_role = mtls.default_role.clone();
1767            let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
1768            {
1769                let Some(crl_set) = crl_set else {
1770                    return Err(anyhow::anyhow!(
1771                        "mTLS CRL verifier requested but CRL state was not initialized"
1772                    ));
1773                };
1774                Arc::new(DynamicClientCertVerifier::new(crl_set))
1775            } else {
1776                let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
1777                if mtls.required {
1778                    rustls::server::WebPkiClientVerifier::builder(root_store)
1779                        .build()
1780                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1781                } else {
1782                    rustls::server::WebPkiClientVerifier::builder(root_store)
1783                        .allow_unauthenticated()
1784                        .build()
1785                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1786                }
1787            };
1788
1789            tracing::info!(
1790                ca = %mtls.ca_cert_path.display(),
1791                required = mtls.required,
1792                crl_enabled = mtls.crl_enabled,
1793                "mTLS client auth configured"
1794            );
1795
1796            rustls::ServerConfig::builder_with_protocol_versions(&[
1797                &rustls::version::TLS12,
1798                &rustls::version::TLS13,
1799            ])
1800            .with_client_cert_verifier(verifier)
1801            .with_single_cert(certs, key)?
1802        } else {
1803            mtls_default_role = "viewer".to_owned();
1804            rustls::ServerConfig::builder_with_protocol_versions(&[
1805                &rustls::version::TLS12,
1806                &rustls::version::TLS13,
1807            ])
1808            .with_no_client_auth()
1809            .with_single_cert(certs, key)?
1810        };
1811
1812        let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
1813        tracing::info!(
1814            "TLS enabled (cert: {}, key: {})",
1815            cert_path.display(),
1816            key_path.display()
1817        );
1818        Ok(Self {
1819            inner,
1820            acceptor,
1821            mtls_default_role,
1822        })
1823    }
1824
1825    /// Extract the mTLS client cert identity from a completed TLS handshake.
1826    /// Returns `None` if no client certificate was presented or if the
1827    /// certificate could not be parsed into an [`AuthIdentity`].
1828    fn extract_handshake_identity(
1829        tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1830        default_role: &str,
1831        addr: SocketAddr,
1832    ) -> Option<AuthIdentity> {
1833        let (_, server_conn) = tls_stream.get_ref();
1834        let cert_der = server_conn.peer_certificates()?.first()?;
1835        let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
1836        tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
1837        Some(id)
1838    }
1839}
1840
1841/// A TLS stream paired with the mTLS identity extracted at handshake time.
1842///
1843/// Wraps [`tokio_rustls::server::TlsStream`] so the verified client
1844/// identity travels with the connection itself. This replaces the previous
1845/// shared `MtlsIdentities` map, eliminating the
1846/// `(SocketAddr) -> AuthIdentity` aliasing risk caused by ephemeral-port
1847/// reuse and removing the need for an LRU eviction policy.
1848///
1849/// The wrapper is `Unpin` (its inner stream is `Unpin` because
1850/// [`tokio::net::TcpStream`] is `Unpin`), so `AsyncRead`/`AsyncWrite`
1851/// delegation uses safe pin projection via `Pin::new(&mut self.inner)`.
1852pub(crate) struct AuthenticatedTlsStream {
1853    inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1854    identity: Option<AuthIdentity>,
1855}
1856
1857impl AuthenticatedTlsStream {
1858    /// Returns the verified mTLS client identity, if any.
1859    #[must_use]
1860    pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
1861        self.identity.as_ref()
1862    }
1863}
1864
1865impl std::fmt::Debug for AuthenticatedTlsStream {
1866    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1867        f.debug_struct("AuthenticatedTlsStream")
1868            .field("identity", &self.identity.as_ref().map(|id| &id.name))
1869            .finish_non_exhaustive()
1870    }
1871}
1872
1873impl tokio::io::AsyncRead for AuthenticatedTlsStream {
1874    fn poll_read(
1875        mut self: Pin<&mut Self>,
1876        cx: &mut std::task::Context<'_>,
1877        buf: &mut tokio::io::ReadBuf<'_>,
1878    ) -> std::task::Poll<std::io::Result<()>> {
1879        Pin::new(&mut self.inner).poll_read(cx, buf)
1880    }
1881}
1882
1883impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
1884    fn poll_write(
1885        mut self: Pin<&mut Self>,
1886        cx: &mut std::task::Context<'_>,
1887        buf: &[u8],
1888    ) -> std::task::Poll<std::io::Result<usize>> {
1889        Pin::new(&mut self.inner).poll_write(cx, buf)
1890    }
1891
1892    fn poll_flush(
1893        mut self: Pin<&mut Self>,
1894        cx: &mut std::task::Context<'_>,
1895    ) -> std::task::Poll<std::io::Result<()>> {
1896        Pin::new(&mut self.inner).poll_flush(cx)
1897    }
1898
1899    fn poll_shutdown(
1900        mut self: Pin<&mut Self>,
1901        cx: &mut std::task::Context<'_>,
1902    ) -> std::task::Poll<std::io::Result<()>> {
1903        Pin::new(&mut self.inner).poll_shutdown(cx)
1904    }
1905
1906    fn poll_write_vectored(
1907        mut self: Pin<&mut Self>,
1908        cx: &mut std::task::Context<'_>,
1909        bufs: &[std::io::IoSlice<'_>],
1910    ) -> std::task::Poll<std::io::Result<usize>> {
1911        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
1912    }
1913
1914    fn is_write_vectored(&self) -> bool {
1915        self.inner.is_write_vectored()
1916    }
1917}
1918
1919impl axum::serve::Listener for TlsListener {
1920    type Io = AuthenticatedTlsStream;
1921    type Addr = SocketAddr;
1922
1923    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
1924        loop {
1925            let (stream, addr) = match self.inner.accept().await {
1926                Ok(pair) => pair,
1927                Err(e) => {
1928                    tracing::debug!("TCP accept error: {e}");
1929                    continue;
1930                }
1931            };
1932            let tls_stream = match self.acceptor.accept(stream).await {
1933                Ok(s) => s,
1934                Err(e) => {
1935                    tracing::debug!("TLS handshake failed from {addr}: {e}");
1936                    continue;
1937                }
1938            };
1939            let identity =
1940                Self::extract_handshake_identity(&tls_stream, &self.mtls_default_role, addr);
1941            let wrapped = AuthenticatedTlsStream {
1942                inner: tls_stream,
1943                identity,
1944            };
1945            return (wrapped, addr);
1946        }
1947    }
1948
1949    fn local_addr(&self) -> std::io::Result<Self::Addr> {
1950        self.inner.local_addr()
1951    }
1952}
1953
1954fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
1955    use rustls::pki_types::pem::PemObject;
1956    let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
1957        .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
1958        .collect::<Result<_, _>>()
1959        .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
1960    anyhow::ensure!(
1961        !certs.is_empty(),
1962        "no certificates found in {}",
1963        path.display()
1964    );
1965    Ok(certs)
1966}
1967
1968fn load_client_auth_roots(
1969    path: &Path,
1970) -> anyhow::Result<(
1971    Vec<rustls::pki_types::CertificateDer<'static>>,
1972    Arc<RootCertStore>,
1973)> {
1974    let ca_certs = load_certs(path)?;
1975    let mut root_store = RootCertStore::empty();
1976    for cert in &ca_certs {
1977        root_store
1978            .add(cert.clone())
1979            .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
1980    }
1981
1982    Ok((ca_certs, Arc::new(root_store)))
1983}
1984
1985fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
1986    use rustls::pki_types::pem::PemObject;
1987    rustls::pki_types::PrivateKeyDer::from_pem_file(path)
1988        .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
1989}
1990
1991#[allow(clippy::unused_async)]
1992async fn healthz() -> impl IntoResponse {
1993    axum::Json(serde_json::json!({
1994        "status": "ok",
1995    }))
1996}
1997
1998/// Build the `/version` JSON payload for a given server name and version.
1999///
2000/// Build metadata (`build_git_sha`, `build_timestamp`, `rust_version`) is
2001/// read at compile time from the `MCPX_BUILD_SHA`, `MCPX_BUILD_TIME`, and
2002/// `MCPX_RUSTC_VERSION` env vars. Unset values resolve to `"unknown"`.
2003fn version_payload(name: &str, version: &str) -> serde_json::Value {
2004    serde_json::json!({
2005        "name": name,
2006        "version": version,
2007        "build_git_sha": option_env!("MCPX_BUILD_SHA").unwrap_or("unknown"),
2008        "build_timestamp": option_env!("MCPX_BUILD_TIME").unwrap_or("unknown"),
2009        "rust_version": option_env!("MCPX_RUSTC_VERSION").unwrap_or("unknown"),
2010        "mcpx_version": env!("CARGO_PKG_VERSION"),
2011    })
2012}
2013
2014/// Pre-serialize the `/version` payload to immutable bytes.
2015///
2016/// This is called once at router-build time so per-request handling can
2017/// reuse a cheap `Arc<[u8]>` clone instead of re-serializing a
2018/// [`serde_json::Value`] on every hit.
2019///
2020/// Serialization of a flat `serde_json::Value` of static-string fields
2021/// cannot fail in practice; the fallback to `b"{}"` exists only to
2022/// satisfy the crate-wide `unwrap_used` / `expect_used` lint policy.
2023fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2024    let value = version_payload(name, version);
2025    serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2026}
2027
2028async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2029    let status = check().await;
2030    let ready = status
2031        .get("ready")
2032        .and_then(serde_json::Value::as_bool)
2033        .unwrap_or(false);
2034    let code = if ready {
2035        axum::http::StatusCode::OK
2036    } else {
2037        axum::http::StatusCode::SERVICE_UNAVAILABLE
2038    };
2039    (code, axum::Json(status))
2040}
2041
2042/// Wait for SIGINT (ctrl-c) or SIGTERM (container stop).
2043///
2044/// On non-Unix platforms, only SIGINT is handled.
2045async fn shutdown_signal() {
2046    let ctrl_c = tokio::signal::ctrl_c();
2047
2048    #[cfg(unix)]
2049    {
2050        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2051            Ok(mut term) => {
2052                tokio::select! {
2053                    _ = ctrl_c => {}
2054                    _ = term.recv() => {}
2055                }
2056            }
2057            Err(e) => {
2058                tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2059                ctrl_c.await.ok();
2060            }
2061        }
2062    }
2063
2064    #[cfg(not(unix))]
2065    {
2066        ctrl_c.await.ok();
2067    }
2068}
2069
2070// -- Origin validation (MCP 2025-11-25 spec, section 2.0.1) --
2071
2072/// Middleware that validates the `Origin` header on incoming HTTP requests.
2073///
2074/// Record HTTP request metrics (method, path, status, duration).
2075#[cfg(feature = "metrics")]
2076async fn metrics_middleware(
2077    metrics: Arc<crate::metrics::McpMetrics>,
2078    req: Request<Body>,
2079    next: Next,
2080) -> axum::response::Response {
2081    let method = req.method().to_string();
2082    let path = req.uri().path().to_owned();
2083    let start = std::time::Instant::now();
2084
2085    let response = next.run(req).await;
2086
2087    let status = response.status().as_u16().to_string();
2088    let duration = start.elapsed().as_secs_f64();
2089
2090    metrics
2091        .http_requests_total
2092        .with_label_values(&[&method, &path, &status])
2093        .inc();
2094    metrics
2095        .http_request_duration_seconds
2096        .with_label_values(&[&method, &path])
2097        .observe(duration);
2098
2099    response
2100}
2101
2102/// OWASP security header hardening applied to every response.
2103///
2104/// Sets: `X-Content-Type-Options`, `X-Frame-Options`, `Cache-Control`,
2105/// `Referrer-Policy`, `Cross-Origin-Opener-Policy`, `Cross-Origin-Resource-Policy`,
2106/// `Cross-Origin-Embedder-Policy`, `Permissions-Policy`,
2107/// `X-Permitted-Cross-Domain-Policies`, `Content-Security-Policy`,
2108/// `X-DNS-Prefetch-Control`, and (when TLS is active) `Strict-Transport-Security`.
2109async fn security_headers_middleware(
2110    is_tls: bool,
2111    req: Request<Body>,
2112    next: Next,
2113) -> axum::response::Response {
2114    use axum::http::{HeaderName, HeaderValue, header};
2115
2116    let mut resp = next.run(req).await;
2117    let headers = resp.headers_mut();
2118
2119    // Strip server identity headers to reduce information leakage.
2120    headers.remove(header::SERVER);
2121    headers.remove(HeaderName::from_static("x-powered-by"));
2122
2123    headers.insert(
2124        header::X_CONTENT_TYPE_OPTIONS,
2125        HeaderValue::from_static("nosniff"),
2126    );
2127    headers.insert(header::X_FRAME_OPTIONS, HeaderValue::from_static("deny"));
2128    headers.insert(
2129        header::CACHE_CONTROL,
2130        HeaderValue::from_static("no-store, max-age=0"),
2131    );
2132    headers.insert(
2133        header::REFERRER_POLICY,
2134        HeaderValue::from_static("no-referrer"),
2135    );
2136    headers.insert(
2137        HeaderName::from_static("cross-origin-opener-policy"),
2138        HeaderValue::from_static("same-origin"),
2139    );
2140    headers.insert(
2141        HeaderName::from_static("cross-origin-resource-policy"),
2142        HeaderValue::from_static("same-origin"),
2143    );
2144    headers.insert(
2145        HeaderName::from_static("cross-origin-embedder-policy"),
2146        HeaderValue::from_static("require-corp"),
2147    );
2148    headers.insert(
2149        HeaderName::from_static("permissions-policy"),
2150        HeaderValue::from_static("accelerometer=(), camera=(), geolocation=(), microphone=()"),
2151    );
2152    headers.insert(
2153        HeaderName::from_static("x-permitted-cross-domain-policies"),
2154        HeaderValue::from_static("none"),
2155    );
2156    headers.insert(
2157        HeaderName::from_static("content-security-policy"),
2158        HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'"),
2159    );
2160    headers.insert(
2161        HeaderName::from_static("x-dns-prefetch-control"),
2162        HeaderValue::from_static("off"),
2163    );
2164
2165    if is_tls {
2166        headers.insert(
2167            header::STRICT_TRANSPORT_SECURITY,
2168            HeaderValue::from_static("max-age=63072000; includeSubDomains"),
2169        );
2170    }
2171
2172    resp
2173}
2174
2175/// Per the MCP spec: if the Origin header is present and its value is not in
2176/// the allowed list, respond with 403 Forbidden. Requests without an Origin
2177/// header are allowed through (e.g. non-browser clients like curl, SDKs).
2178async fn origin_check_middleware(
2179    allowed: Arc<[String]>,
2180    log_request_headers: bool,
2181    req: Request<Body>,
2182    next: Next,
2183) -> axum::response::Response {
2184    let method = req.method().clone();
2185    let path = req.uri().path().to_owned();
2186
2187    log_incoming_request(&method, &path, req.headers(), log_request_headers);
2188
2189    if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2190        let origin_str = origin.to_str().unwrap_or("");
2191        if !allowed.iter().any(|a| a == origin_str) {
2192            tracing::warn!(
2193                origin = origin_str,
2194                %method,
2195                %path,
2196                allowed = ?&*allowed,
2197                "rejected request: Origin not allowed"
2198            );
2199            return (
2200                axum::http::StatusCode::FORBIDDEN,
2201                "Forbidden: Origin not allowed",
2202            )
2203                .into_response();
2204        }
2205    }
2206    next.run(req).await
2207}
2208
2209/// Emit a DEBUG log for an incoming request, optionally including the full
2210/// (redacted) header set.
2211fn log_incoming_request(
2212    method: &axum::http::Method,
2213    path: &str,
2214    headers: &axum::http::HeaderMap,
2215    log_request_headers: bool,
2216) {
2217    if log_request_headers {
2218        tracing::debug!(
2219            %method,
2220            %path,
2221            headers = %format_request_headers_for_log(headers),
2222            "incoming request"
2223        );
2224    } else {
2225        tracing::debug!(%method, %path, "incoming request");
2226    }
2227}
2228
2229fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2230    headers
2231        .iter()
2232        .map(|(k, v)| {
2233            let name = k.as_str();
2234            if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2235                format!("{name}: [REDACTED]")
2236            } else {
2237                format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2238            }
2239        })
2240        .collect::<Vec<_>>()
2241        .join(", ")
2242}
2243
2244// -- stdio transport --
2245
2246/// Serve an MCP server over stdin/stdout (stdio transport).
2247///
2248/// # Security warnings
2249///
2250/// - **No authentication**: the parent process has full, unrestricted access.
2251/// - **No RBAC**: all tools are available regardless of policy.
2252/// - **No TLS**: messages travel over OS pipes in plaintext.
2253/// - **Single client**: only the parent process can connect.
2254/// - **No Origin validation**: not applicable to stdio.
2255///
2256/// Use this only when the MCP client spawns the server as a trusted subprocess
2257/// (e.g. Claude Desktop, VS Code Copilot). For network-accessible deployments,
2258/// use `serve()` (Streamable HTTP) instead.
2259///
2260/// # Errors
2261///
2262/// Returns [`McpxError::Startup`] if the handler fails to initialize or the
2263/// transport disconnects unexpectedly.
2264// NOTE: reported complexity 32/25 is driven entirely by `tracing::*!`
2265// macro expansion in this 18-line function (info/warn/info + two matches).
2266// There is nothing meaningful to extract; the allow stays.
2267#[allow(clippy::cognitive_complexity)]
2268pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2269where
2270    H: ServerHandler + 'static,
2271{
2272    use rmcp::ServiceExt as _;
2273
2274    tracing::info!("stdio transport: serving on stdin/stdout");
2275    tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2276
2277    let transport = rmcp::transport::io::stdio();
2278
2279    let service = handler
2280        .serve(transport)
2281        .await
2282        .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2283
2284    if let Err(e) = service.waiting().await {
2285        tracing::warn!(error = %e, "stdio session ended with error");
2286    }
2287    tracing::info!("stdio session ended");
2288    Ok(())
2289}
2290
2291#[cfg(test)]
2292mod tests {
2293    #![allow(
2294        clippy::unwrap_used,
2295        clippy::expect_used,
2296        clippy::panic,
2297        clippy::indexing_slicing,
2298        clippy::unwrap_in_result,
2299        clippy::print_stdout,
2300        clippy::print_stderr,
2301        deprecated,
2302        reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2303    )]
2304    use std::sync::Arc;
2305
2306    use axum::{
2307        body::Body,
2308        http::{Request, StatusCode, header},
2309        response::IntoResponse,
2310    };
2311    use http_body_util::BodyExt;
2312    use tower::ServiceExt as _;
2313
2314    use super::*;
2315
2316    // -- McpServerConfig --
2317
2318    #[test]
2319    fn server_config_new_defaults() {
2320        let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2321        assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2322        assert_eq!(cfg.name, "test-server");
2323        assert_eq!(cfg.version, "1.0.0");
2324        assert!(cfg.tls_cert_path.is_none());
2325        assert!(cfg.tls_key_path.is_none());
2326        assert!(cfg.auth.is_none());
2327        assert!(cfg.rbac.is_none());
2328        assert!(cfg.allowed_origins.is_empty());
2329        assert!(cfg.tool_rate_limit.is_none());
2330        assert!(cfg.readiness_check.is_none());
2331        assert_eq!(cfg.max_request_body, 1024 * 1024);
2332        assert_eq!(cfg.request_timeout, Duration::from_mins(2));
2333        assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
2334        assert!(!cfg.log_request_headers);
2335    }
2336
2337    #[test]
2338    fn validate_consumes_and_proves() {
2339        // Valid config -> Validated wrapper, original is consumed.
2340        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2341        let validated = cfg.validate().expect("valid config");
2342        // Deref gives read-only access to inner fields.
2343        assert_eq!(validated.name, "test-server");
2344        // into_inner recovers the raw value.
2345        let raw = validated.into_inner();
2346        assert_eq!(raw.name, "test-server");
2347
2348        // Invalid config (zero max_request_body) -> Err.
2349        let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2350        bad.max_request_body = 0;
2351        assert!(bad.validate().is_err(), "zero body cap must fail validate");
2352    }
2353
2354    #[test]
2355    fn derive_allowed_hosts_includes_public_host() {
2356        let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
2357        assert!(
2358            hosts.iter().any(|h| h == "mcp.example.com"),
2359            "public_url host must be allowed"
2360        );
2361    }
2362
2363    #[test]
2364    fn derive_allowed_hosts_includes_bind_authority() {
2365        let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
2366        assert!(
2367            hosts.iter().any(|h| h == "127.0.0.1"),
2368            "bind host must be allowed"
2369        );
2370        assert!(
2371            hosts.iter().any(|h| h == "127.0.0.1:8080"),
2372            "bind authority must be allowed"
2373        );
2374    }
2375
2376    // -- healthz --
2377
2378    #[tokio::test]
2379    async fn healthz_returns_ok_json() {
2380        let resp = healthz().await.into_response();
2381        assert_eq!(resp.status(), StatusCode::OK);
2382        let body = resp.into_body().collect().await.unwrap().to_bytes();
2383        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2384        assert_eq!(json["status"], "ok");
2385        assert!(
2386            json.get("name").is_none(),
2387            "healthz must not expose server name"
2388        );
2389        assert!(
2390            json.get("version").is_none(),
2391            "healthz must not expose version"
2392        );
2393    }
2394
2395    // -- readyz --
2396
2397    #[tokio::test]
2398    async fn readyz_returns_ok_when_ready() {
2399        let check: ReadinessCheck =
2400            Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
2401        let resp = readyz(check).await.into_response();
2402        assert_eq!(resp.status(), StatusCode::OK);
2403        let body = resp.into_body().collect().await.unwrap().to_bytes();
2404        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2405        assert_eq!(json["ready"], true);
2406        assert!(
2407            json.get("name").is_none(),
2408            "readyz must not expose server name"
2409        );
2410        assert!(
2411            json.get("version").is_none(),
2412            "readyz must not expose version"
2413        );
2414        assert_eq!(json["db"], "connected");
2415    }
2416
2417    #[tokio::test]
2418    async fn readyz_returns_503_when_not_ready() {
2419        let check: ReadinessCheck =
2420            Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
2421        let resp = readyz(check).await.into_response();
2422        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2423    }
2424
2425    #[tokio::test]
2426    async fn readyz_returns_503_when_ready_missing() {
2427        let check: ReadinessCheck =
2428            Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
2429        let resp = readyz(check).await.into_response();
2430        // Missing "ready" field defaults to false -> 503
2431        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2432    }
2433
2434    // -- origin_check_middleware --
2435
2436    /// Build a test router with origin check middleware and a simple handler.
2437    fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
2438        let allowed: Arc<[String]> = Arc::from(origins);
2439        axum::Router::new()
2440            .route("/test", axum::routing::get(|| async { "ok" }))
2441            .layer(axum::middleware::from_fn(move |req, next| {
2442                let a = Arc::clone(&allowed);
2443                origin_check_middleware(a, log_request_headers, req, next)
2444            }))
2445    }
2446
2447    #[tokio::test]
2448    async fn origin_allowed_passes() {
2449        let app = origin_router(vec!["http://localhost:3000".into()], false);
2450        let req = Request::builder()
2451            .uri("/test")
2452            .header(header::ORIGIN, "http://localhost:3000")
2453            .body(Body::empty())
2454            .unwrap();
2455        let resp = app.oneshot(req).await.unwrap();
2456        assert_eq!(resp.status(), StatusCode::OK);
2457    }
2458
2459    #[tokio::test]
2460    async fn origin_rejected_returns_403() {
2461        let app = origin_router(vec!["http://localhost:3000".into()], false);
2462        let req = Request::builder()
2463            .uri("/test")
2464            .header(header::ORIGIN, "http://evil.com")
2465            .body(Body::empty())
2466            .unwrap();
2467        let resp = app.oneshot(req).await.unwrap();
2468        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2469    }
2470
2471    #[tokio::test]
2472    async fn no_origin_header_passes() {
2473        let app = origin_router(vec!["http://localhost:3000".into()], false);
2474        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2475        let resp = app.oneshot(req).await.unwrap();
2476        assert_eq!(resp.status(), StatusCode::OK);
2477    }
2478
2479    #[tokio::test]
2480    async fn empty_allowlist_rejects_any_origin() {
2481        let app = origin_router(vec![], false);
2482        let req = Request::builder()
2483            .uri("/test")
2484            .header(header::ORIGIN, "http://anything.com")
2485            .body(Body::empty())
2486            .unwrap();
2487        let resp = app.oneshot(req).await.unwrap();
2488        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2489    }
2490
2491    #[tokio::test]
2492    async fn empty_allowlist_passes_without_origin() {
2493        let app = origin_router(vec![], false);
2494        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2495        let resp = app.oneshot(req).await.unwrap();
2496        assert_eq!(resp.status(), StatusCode::OK);
2497    }
2498
2499    #[test]
2500    fn format_request_headers_redacts_sensitive_values() {
2501        let mut headers = axum::http::HeaderMap::new();
2502        headers.insert("authorization", "Bearer secret-token".parse().unwrap());
2503        headers.insert("cookie", "sid=abc".parse().unwrap());
2504        headers.insert("x-request-id", "req-123".parse().unwrap());
2505
2506        let out = format_request_headers_for_log(&headers);
2507        assert!(out.contains("authorization: [REDACTED]"));
2508        assert!(out.contains("cookie: [REDACTED]"));
2509        assert!(out.contains("x-request-id: req-123"));
2510        assert!(!out.contains("secret-token"));
2511    }
2512
2513    // -- security_headers_middleware --
2514
2515    fn security_router(is_tls: bool) -> axum::Router {
2516        axum::Router::new()
2517            .route("/test", axum::routing::get(|| async { "ok" }))
2518            .layer(axum::middleware::from_fn(move |req, next| {
2519                security_headers_middleware(is_tls, req, next)
2520            }))
2521    }
2522
2523    #[tokio::test]
2524    async fn security_headers_set_on_response() {
2525        let app = security_router(false);
2526        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2527        let resp = app.oneshot(req).await.unwrap();
2528        assert_eq!(resp.status(), StatusCode::OK);
2529
2530        let h = resp.headers();
2531        assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
2532        assert_eq!(h.get("x-frame-options").unwrap(), "deny");
2533        assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
2534        assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
2535        assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
2536        assert_eq!(
2537            h.get("cross-origin-resource-policy").unwrap(),
2538            "same-origin"
2539        );
2540        assert_eq!(
2541            h.get("cross-origin-embedder-policy").unwrap(),
2542            "require-corp"
2543        );
2544        assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
2545        assert!(
2546            h.get("permissions-policy")
2547                .unwrap()
2548                .to_str()
2549                .unwrap()
2550                .contains("camera=()"),
2551            "permissions-policy must restrict browser features"
2552        );
2553        assert_eq!(
2554            h.get("content-security-policy").unwrap(),
2555            "default-src 'none'; frame-ancestors 'none'"
2556        );
2557        assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
2558        // No HSTS when TLS is off.
2559        assert!(h.get("strict-transport-security").is_none());
2560    }
2561
2562    #[tokio::test]
2563    async fn hsts_set_when_tls_enabled() {
2564        let app = security_router(true);
2565        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2566        let resp = app.oneshot(req).await.unwrap();
2567
2568        let hsts = resp.headers().get("strict-transport-security").unwrap();
2569        assert!(
2570            hsts.to_str().unwrap().contains("max-age=63072000"),
2571            "HSTS must set 2-year max-age"
2572        );
2573    }
2574
2575    // -- version endpoint --
2576
2577    #[test]
2578    fn version_payload_contains_expected_fields() {
2579        let v = version_payload("my-server", "1.2.3");
2580        assert_eq!(v["name"], "my-server");
2581        assert_eq!(v["version"], "1.2.3");
2582        assert!(v["build_git_sha"].is_string());
2583        assert!(v["build_timestamp"].is_string());
2584        assert!(v["rust_version"].is_string());
2585        assert!(v["mcpx_version"].is_string());
2586    }
2587
2588    // -- concurrency limit layer --
2589
2590    #[tokio::test]
2591    async fn concurrency_limit_layer_composes_and_serves() {
2592        // We only assert the layer stack compiles and a single request
2593        // below the cap still succeeds. True back-pressure behaviour
2594        // requires a live HTTP server and is covered by integration tests.
2595        let app = axum::Router::new()
2596            .route("/ok", axum::routing::get(|| async { "ok" }))
2597            .layer(
2598                tower::ServiceBuilder::new()
2599                    .layer(axum::error_handling::HandleErrorLayer::new(
2600                        |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
2601                    ))
2602                    .layer(tower::load_shed::LoadShedLayer::new())
2603                    .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
2604            );
2605        let resp = app
2606            .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
2607            .await
2608            .unwrap();
2609        assert_eq!(resp.status(), StatusCode::OK);
2610    }
2611
2612    // -- compression layer --
2613
2614    #[tokio::test]
2615    async fn compression_layer_gzip_encodes_response() {
2616        use tower_http::compression::Predicate as _;
2617
2618        let big_body = "a".repeat(4096);
2619        let app = axum::Router::new()
2620            .route(
2621                "/big",
2622                axum::routing::get(move || {
2623                    let body = big_body.clone();
2624                    async move { body }
2625                }),
2626            )
2627            .layer(
2628                tower_http::compression::CompressionLayer::new()
2629                    .gzip(true)
2630                    .br(true)
2631                    .compress_when(
2632                        tower_http::compression::DefaultPredicate::new()
2633                            .and(tower_http::compression::predicate::SizeAbove::new(1024)),
2634                    ),
2635            );
2636
2637        let req = Request::builder()
2638            .uri("/big")
2639            .header(header::ACCEPT_ENCODING, "gzip")
2640            .body(Body::empty())
2641            .unwrap();
2642        let resp = app.oneshot(req).await.unwrap();
2643        assert_eq!(resp.status(), StatusCode::OK);
2644        assert_eq!(
2645            resp.headers().get(header::CONTENT_ENCODING).unwrap(),
2646            "gzip"
2647        );
2648    }
2649}