Skip to main content

rmcp_server_kit/
transport.rs

1use std::{
2    future::Future,
3    net::SocketAddr,
4    path::{Path, PathBuf},
5    pin::Pin,
6    sync::Arc,
7    time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{body::Body, extract::Request, middleware::Next, response::IntoResponse};
12use rmcp::{
13    ServerHandler,
14    transport::streamable_http_server::{
15        StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16    },
17};
18use rustls::RootCertStore;
19use tokio::net::TcpListener;
20use tokio_util::sync::CancellationToken;
21
22use crate::{
23    auth::{
24        AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
25        build_rate_limiter, extract_mtls_identity,
26    },
27    error::McpxError,
28    mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
29    rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
30};
31
32/// Map an internal `anyhow::Error` chain into a public [`McpxError::Startup`]
33/// at the public API boundary, flattening the chain via the alternate
34/// formatter so callers see the full causal path.
35#[allow(
36    clippy::needless_pass_by_value,
37    reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
38)]
39fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
40    McpxError::Startup(format!("{e:#}"))
41}
42
43/// Map a `std::io::Error` produced during server startup into a public
44/// [`McpxError::Startup`]. We deliberately do not use the [`McpxError::Io`]
45/// `From` impl here because startup-phase IO errors (bind, listener) are
46/// semantically distinct from request-time IO errors and should surface
47/// the originating operation in the message.
48#[allow(
49    clippy::needless_pass_by_value,
50    reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
51)]
52fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
53    McpxError::Startup(format!("{op}: {e}"))
54}
55
56/// Async readiness check callback for the `/readyz` endpoint.
57///
58/// Returns a JSON object with at least a `"ready"` boolean.
59/// When `ready` is false, the endpoint returns HTTP 503.
60pub type ReadinessCheck =
61    Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
62
63/// Per-header overrides for the OWASP security headers emitted by the
64/// global response middleware.
65///
66/// Each field follows a three-state semantic:
67///
68/// | Value         | Behaviour                                                |
69/// |---------------|----------------------------------------------------------|
70/// | `None`        | Use the built-in default (current behaviour).            |
71/// | `Some("")`    | **Omit** the header entirely from responses.             |
72/// | `Some(value)` | Emit `header: value`. Validated at config-load time.     |
73///
74/// All non-empty values are validated via
75/// [`axum::http::HeaderValue::from_str`] inside
76/// [`McpServerConfig::validate`]; invalid values fail fast before the
77/// server starts accepting traffic.
78///
79/// `Strict-Transport-Security` has an additional rule: the substring
80/// `preload` (case-insensitive) is rejected. Operators who want to
81/// commit to the HSTS preload list must do so via a future explicit
82/// builder method, not by smuggling it through this knob.
83#[derive(Debug, Clone, Default)]
84#[non_exhaustive]
85pub struct SecurityHeadersConfig {
86    /// Override for `X-Content-Type-Options`. Default: `nosniff`.
87    pub x_content_type_options: Option<String>,
88    /// Override for `X-Frame-Options`. Default: `deny`.
89    pub x_frame_options: Option<String>,
90    /// Override for `Cache-Control`. Default: `no-store, max-age=0`.
91    pub cache_control: Option<String>,
92    /// Override for `Referrer-Policy`. Default: `no-referrer`.
93    pub referrer_policy: Option<String>,
94    /// Override for `Cross-Origin-Opener-Policy`. Default: `same-origin`.
95    pub cross_origin_opener_policy: Option<String>,
96    /// Override for `Cross-Origin-Resource-Policy`. Default: `same-origin`.
97    pub cross_origin_resource_policy: Option<String>,
98    /// Override for `Cross-Origin-Embedder-Policy`. Default: `require-corp`.
99    pub cross_origin_embedder_policy: Option<String>,
100    /// Override for `Permissions-Policy`. Default:
101    /// `accelerometer=(), camera=(), geolocation=(), microphone=()`.
102    pub permissions_policy: Option<String>,
103    /// Override for `X-Permitted-Cross-Domain-Policies`. Default: `none`.
104    pub x_permitted_cross_domain_policies: Option<String>,
105    /// Override for `Content-Security-Policy`. Default:
106    /// `default-src 'none'; frame-ancestors 'none'`.
107    pub content_security_policy: Option<String>,
108    /// Override for `X-DNS-Prefetch-Control`. Default: `off`.
109    pub x_dns_prefetch_control: Option<String>,
110    /// Override for `Strict-Transport-Security`. Default (TLS only):
111    /// `max-age=63072000; includeSubDomains`. Only emitted when TLS is
112    /// active; the override is ignored on plaintext deployments. The
113    /// substring `preload` (any case) is rejected by the validator.
114    pub strict_transport_security: Option<String>,
115}
116
117/// Configuration for the MCP server.
118#[allow(
119    missing_debug_implementations,
120    reason = "contains callback/trait objects that don't impl Debug"
121)]
122#[allow(
123    clippy::struct_excessive_bools,
124    reason = "server configuration naturally has many boolean feature flags"
125)]
126#[non_exhaustive]
127pub struct McpServerConfig {
128    /// Socket address the MCP HTTP server binds to.
129    #[deprecated(
130        since = "0.13.0",
131        note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in 1.0"
132    )]
133    pub bind_addr: String,
134    /// Server name advertised via MCP `initialize`.
135    #[deprecated(
136        since = "0.13.0",
137        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in 1.0"
138    )]
139    pub name: String,
140    /// Server version advertised via MCP `initialize`.
141    #[deprecated(
142        since = "0.13.0",
143        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in 1.0"
144    )]
145    pub version: String,
146    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
147    #[deprecated(
148        since = "0.13.0",
149        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in 1.0"
150    )]
151    pub tls_cert_path: Option<PathBuf>,
152    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
153    #[deprecated(
154        since = "0.13.0",
155        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in 1.0"
156    )]
157    pub tls_key_path: Option<PathBuf>,
158    /// Optional authentication config. When `Some` and `enabled`, auth
159    /// is enforced on `/mcp`. `/healthz` is always open.
160    #[deprecated(
161        since = "0.13.0",
162        note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in 1.0"
163    )]
164    pub auth: Option<AuthConfig>,
165    /// Optional RBAC policy. When present and enabled, tool calls are
166    /// checked against the policy after authentication.
167    #[deprecated(
168        since = "0.13.0",
169        note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in 1.0"
170    )]
171    pub rbac: Option<Arc<RbacPolicy>>,
172    /// Allowed Origin values for DNS rebinding protection (MCP spec MUST).
173    /// When empty and `public_url` is set, the origin is auto-derived from
174    /// the public URL. When both are empty, only requests with no Origin
175    /// header are accepted.
176    /// Example entries: `"http://localhost:3000"`, `"https://myapp.example.com"`.
177    #[deprecated(
178        since = "0.13.0",
179        note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in 1.0"
180    )]
181    pub allowed_origins: Vec<String>,
182    /// Maximum tool invocations per source IP per minute.
183    /// When set, enforced on every `tools/call` request.
184    #[deprecated(
185        since = "0.13.0",
186        note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in 1.0"
187    )]
188    pub tool_rate_limit: Option<u32>,
189    /// Optional readiness probe for `/readyz`.
190    /// When `None`, `/readyz` mirrors `/healthz` (always OK).
191    #[deprecated(
192        since = "0.13.0",
193        note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in 1.0"
194    )]
195    pub readiness_check: Option<ReadinessCheck>,
196    /// Maximum request body size in bytes. Default: 1 MiB.
197    /// Protects against oversized payloads causing OOM.
198    #[deprecated(
199        since = "0.13.0",
200        note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in 1.0"
201    )]
202    pub max_request_body: usize,
203    /// Request processing timeout. Default: 120s.
204    /// Requests exceeding this duration receive 408 Request Timeout.
205    #[deprecated(
206        since = "0.13.0",
207        note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in 1.0"
208    )]
209    pub request_timeout: Duration,
210    /// Graceful shutdown timeout. Default: 30s.
211    /// After the shutdown signal, in-flight requests have this long to finish.
212    #[deprecated(
213        since = "0.13.0",
214        note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in 1.0"
215    )]
216    pub shutdown_timeout: Duration,
217    /// Idle timeout for MCP sessions. Sessions with no activity for this
218    /// duration are closed automatically. Default: 20 minutes.
219    #[deprecated(
220        since = "0.13.0",
221        note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in 1.0"
222    )]
223    pub session_idle_timeout: Duration,
224    /// Interval for SSE keep-alive pings. Prevents proxies and load
225    /// balancers from killing idle connections. Default: 15 seconds.
226    #[deprecated(
227        since = "0.13.0",
228        note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in 1.0"
229    )]
230    pub sse_keep_alive: Duration,
231    /// Callback invoked once the server is built, delivering a
232    /// [`ReloadHandle`] for hot-reloading auth keys and RBAC policy
233    /// at runtime (e.g. on SIGHUP). Only useful when auth/RBAC is enabled.
234    #[deprecated(
235        since = "0.13.0",
236        note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in 1.0"
237    )]
238    pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
239    /// Additional application-specific routes merged into the top-level
240    /// router.  These routes **bypass** the MCP auth and RBAC middleware,
241    /// so the application is responsible for its own auth on them.
242    #[deprecated(
243        since = "0.13.0",
244        note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in 1.0"
245    )]
246    pub extra_router: Option<axum::Router>,
247    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
248    /// When set, OAuth metadata endpoints advertise this URL instead of
249    /// the listen address. Required when binding `0.0.0.0` behind a
250    /// reverse proxy or inside a container.
251    #[deprecated(
252        since = "0.13.0",
253        note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in 1.0"
254    )]
255    pub public_url: Option<String>,
256    /// Log inbound HTTP request headers at DEBUG level.
257    /// Sensitive values remain redacted.
258    #[deprecated(
259        since = "0.13.0",
260        note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in 1.0"
261    )]
262    pub log_request_headers: bool,
263    /// Enable gzip/br response compression on MCP responses.
264    /// Defaults to `false` to preserve existing behaviour.
265    #[deprecated(
266        since = "0.13.0",
267        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in 1.0"
268    )]
269    pub compression_enabled: bool,
270    /// Minimum response body size (in bytes) before compression kicks in.
271    /// Only used when `compression_enabled` is true. Default: 1024.
272    #[deprecated(
273        since = "0.13.0",
274        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in 1.0"
275    )]
276    pub compression_min_size: u16,
277    /// Global cap on in-flight HTTP requests across the whole server.
278    /// When `Some`, requests over the cap receive 503 Service Unavailable
279    /// via `tower::load_shed`. Default: `None` (unlimited).
280    #[deprecated(
281        since = "0.13.0",
282        note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in 1.0"
283    )]
284    pub max_concurrent_requests: Option<usize>,
285    /// Enable `/admin/*` diagnostic endpoints. Requires `auth` to be
286    /// configured and `enabled`. Default: `false`.
287    #[deprecated(
288        since = "0.13.0",
289        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in 1.0"
290    )]
291    pub admin_enabled: bool,
292    /// RBAC role required to access admin endpoints. Default: `"admin"`.
293    #[deprecated(
294        since = "0.13.0",
295        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in 1.0"
296    )]
297    pub admin_role: String,
298    /// Enable Prometheus metrics endpoint on a separate listener.
299    /// Requires the `metrics` crate feature.
300    #[cfg(feature = "metrics")]
301    #[deprecated(
302        since = "0.13.0",
303        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in 1.0"
304    )]
305    pub metrics_enabled: bool,
306    /// Bind address for the Prometheus metrics listener. Default: `127.0.0.1:9090`.
307    #[cfg(feature = "metrics")]
308    #[deprecated(
309        since = "0.13.0",
310        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in 1.0"
311    )]
312    pub metrics_bind: String,
313    /// Per-header overrides for the OWASP security headers emitted by
314    /// the global response middleware. See [`SecurityHeadersConfig`]
315    /// for the three-state semantic and validation rules.
316    #[deprecated(
317        since = "1.5.0",
318        note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in 1.0"
319    )]
320    pub security_headers: SecurityHeadersConfig,
321}
322
323/// Marker that wraps a value proven to satisfy its validation
324/// contract.
325///
326/// The only way to obtain `Validated<McpServerConfig>` is by calling
327/// [`McpServerConfig::validate`], which is the contract enforced at
328/// the type level by [`serve`] and [`serve_with_listener`]. The
329/// inner field is private, so downstream code cannot bypass
330/// validation by hand-constructing the wrapper.
331///
332/// `Validated<T>` derefs to `&T` for read-only access. To mutate,
333/// recover the raw value with [`Validated::into_inner`] and
334/// re-validate.
335///
336/// # Example
337///
338/// ```no_run
339/// use rmcp_server_kit::transport::{McpServerConfig, Validated, serve};
340/// use rmcp::handler::server::ServerHandler;
341/// use rmcp::model::{ServerCapabilities, ServerInfo};
342///
343/// #[derive(Clone)]
344/// struct H;
345/// impl ServerHandler for H {
346///     fn get_info(&self) -> ServerInfo {
347///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
348///     }
349/// }
350///
351/// # async fn example() -> rmcp_server_kit::Result<()> {
352/// let config: Validated<McpServerConfig> =
353///     McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0").validate()?;
354/// serve(config, || H).await
355/// # }
356/// ```
357///
358/// Forgetting `.validate()?` is a compile error:
359///
360/// ```compile_fail
361/// use rmcp_server_kit::transport::{McpServerConfig, serve};
362/// use rmcp::handler::server::ServerHandler;
363/// use rmcp::model::{ServerCapabilities, ServerInfo};
364///
365/// #[derive(Clone)]
366/// struct H;
367/// impl ServerHandler for H {
368///     fn get_info(&self) -> ServerInfo {
369///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
370///     }
371/// }
372///
373/// # async fn example() -> rmcp_server_kit::Result<()> {
374/// let config = McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0");
375/// // Missing `.validate()?` -> mismatched types: expected
376/// // `Validated<McpServerConfig>`, found `McpServerConfig`.
377/// serve(config, || H).await
378/// # }
379/// ```
380#[allow(
381    missing_debug_implementations,
382    reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
383)]
384pub struct Validated<T>(T);
385
386impl<T> std::fmt::Debug for Validated<T> {
387    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388        f.debug_struct("Validated").finish_non_exhaustive()
389    }
390}
391
392impl<T> Validated<T> {
393    /// Borrow the inner value.
394    #[must_use]
395    pub fn as_inner(&self) -> &T {
396        &self.0
397    }
398
399    /// Recover the raw value, discarding the validation proof.
400    ///
401    /// Re-validate before re-using the value with [`serve`] or
402    /// [`serve_with_listener`].
403    #[must_use]
404    pub fn into_inner(self) -> T {
405        self.0
406    }
407}
408
409impl<T> std::ops::Deref for Validated<T> {
410    type Target = T;
411
412    fn deref(&self) -> &T {
413        &self.0
414    }
415}
416
417#[allow(
418    deprecated,
419    reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
420)]
421impl McpServerConfig {
422    /// Create a new server configuration with the given bind address,
423    /// server name, and version. All other fields use safe defaults.
424    ///
425    /// Use the chainable `with_*` / `enable_*` builder methods to
426    /// customize. Call [`McpServerConfig::validate`] to obtain a
427    /// [`Validated<McpServerConfig>`] proof token, which is required by
428    /// [`serve`] and [`serve_with_listener`].
429    #[must_use]
430    pub fn new(
431        bind_addr: impl Into<String>,
432        name: impl Into<String>,
433        version: impl Into<String>,
434    ) -> Self {
435        Self {
436            bind_addr: bind_addr.into(),
437            name: name.into(),
438            version: version.into(),
439            tls_cert_path: None,
440            tls_key_path: None,
441            auth: None,
442            rbac: None,
443            allowed_origins: Vec::new(),
444            tool_rate_limit: None,
445            readiness_check: None,
446            max_request_body: 1024 * 1024,
447            request_timeout: Duration::from_mins(2),
448            shutdown_timeout: Duration::from_secs(30),
449            session_idle_timeout: Duration::from_mins(20),
450            sse_keep_alive: Duration::from_secs(15),
451            on_reload_ready: None,
452            extra_router: None,
453            public_url: None,
454            log_request_headers: false,
455            compression_enabled: false,
456            compression_min_size: 1024,
457            max_concurrent_requests: None,
458            admin_enabled: false,
459            admin_role: "admin".to_owned(),
460            #[cfg(feature = "metrics")]
461            metrics_enabled: false,
462            #[cfg(feature = "metrics")]
463            metrics_bind: "127.0.0.1:9090".into(),
464            security_headers: SecurityHeadersConfig::default(),
465        }
466    }
467
468    // ---------------------------------------------------------------
469    // Builder methods (fluent, consume + return self).
470    //
471    // Each method is `#[must_use]` because dropping the returned
472    // `McpServerConfig` discards the configuration change.
473    // ---------------------------------------------------------------
474
475    /// Attach an authentication configuration. Required for
476    /// [`enable_admin`](Self::enable_admin) and any non-public deployment.
477    #[must_use]
478    pub fn with_auth(mut self, auth: AuthConfig) -> Self {
479        self.auth = Some(auth);
480        self
481    }
482
483    /// Override one or more of the OWASP security headers emitted on
484    /// every response. See [`SecurityHeadersConfig`] for the three-state
485    /// semantic (`None` = default, `Some("")` = omit, `Some(v)` =
486    /// override). Values are validated by [`Self::validate`].
487    #[must_use]
488    pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
489        self.security_headers = headers;
490        self
491    }
492
493    /// Override the bind address (e.g. `127.0.0.1:8080`). Useful when the
494    /// final port is only known after pre-binding an ephemeral listener
495    /// (tests, dynamic-port deployments).
496    #[must_use]
497    pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
498        self.bind_addr = addr.into();
499        self
500    }
501
502    /// Attach an RBAC policy. Tool calls are checked against the policy
503    /// after authentication.
504    #[must_use]
505    pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
506        self.rbac = Some(rbac);
507        self
508    }
509
510    /// Configure TLS by providing the certificate and private key paths
511    /// (PEM). Both must be readable at startup. Without this call, the
512    /// server runs plain HTTP.
513    #[must_use]
514    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
515        self.tls_cert_path = Some(cert_path.into());
516        self.tls_key_path = Some(key_path.into());
517        self
518    }
519
520    /// Set the externally reachable base URL (e.g. `https://mcp.example.com`).
521    /// Required when binding `0.0.0.0` behind a reverse proxy or inside
522    /// a container so OAuth metadata and auto-derived origins resolve correctly.
523    #[must_use]
524    pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
525        self.public_url = Some(url.into());
526        self
527    }
528
529    /// Replace the allowed Origin allow-list (DNS-rebinding protection).
530    /// When empty and [`with_public_url`](Self::with_public_url) is set,
531    /// the origin is auto-derived.
532    #[must_use]
533    pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
534    where
535        I: IntoIterator<Item = S>,
536        S: Into<String>,
537    {
538        self.allowed_origins = origins.into_iter().map(Into::into).collect();
539        self
540    }
541
542    /// Merge an additional axum router at the top level. Routes added
543    /// here **bypass** rmcp-server-kit auth and RBAC; the application is responsible
544    /// for its own protection.
545    #[must_use]
546    pub fn with_extra_router(mut self, router: axum::Router) -> Self {
547        self.extra_router = Some(router);
548        self
549    }
550
551    /// Install an async readiness probe for `/readyz`. Without this call,
552    /// `/readyz` mirrors `/healthz` (always 200 OK).
553    #[must_use]
554    pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
555        self.readiness_check = Some(check);
556        self
557    }
558
559    /// Override the maximum request body (bytes). Must be `> 0`.
560    /// Default: 1 MiB.
561    #[must_use]
562    pub fn with_max_request_body(mut self, bytes: usize) -> Self {
563        self.max_request_body = bytes;
564        self
565    }
566
567    /// Override the per-request processing timeout. Default: 2 minutes.
568    #[must_use]
569    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
570        self.request_timeout = timeout;
571        self
572    }
573
574    /// Override the graceful shutdown grace period. Default: 30 seconds.
575    #[must_use]
576    pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
577        self.shutdown_timeout = timeout;
578        self
579    }
580
581    /// Override the MCP session idle timeout. Default: 20 minutes.
582    #[must_use]
583    pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
584        self.session_idle_timeout = timeout;
585        self
586    }
587
588    /// Override the SSE keep-alive interval. Default: 15 seconds.
589    #[must_use]
590    pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
591        self.sse_keep_alive = interval;
592        self
593    }
594
595    /// Cap the global number of in-flight HTTP requests via
596    /// `tower::load_shed`. Excess requests receive 503 Service Unavailable.
597    /// Default: unlimited.
598    #[must_use]
599    pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
600        self.max_concurrent_requests = Some(limit);
601        self
602    }
603
604    /// Cap tool invocations per source IP per minute. Enforced on every
605    /// `tools/call` request.
606    #[must_use]
607    pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
608        self.tool_rate_limit = Some(per_minute);
609        self
610    }
611
612    /// Register a callback that receives the [`ReloadHandle`] after the
613    /// server is built. Use it to wire SIGHUP-style hot reloads of API
614    /// keys and RBAC policy.
615    #[must_use]
616    pub fn with_reload_callback<F>(mut self, callback: F) -> Self
617    where
618        F: FnOnce(ReloadHandle) + Send + 'static,
619    {
620        self.on_reload_ready = Some(Box::new(callback));
621        self
622    }
623
624    /// Enable gzip/brotli response compression on MCP responses.
625    /// `min_size` is the smallest body size (bytes) eligible for
626    /// compression. Default min size: 1024.
627    #[must_use]
628    pub fn enable_compression(mut self, min_size: u16) -> Self {
629        self.compression_enabled = true;
630        self.compression_min_size = min_size;
631        self
632    }
633
634    /// Enable `/admin/*` diagnostic endpoints. Requires
635    /// [`with_auth`](Self::with_auth) to be set and enabled; otherwise
636    /// [`validate`](Self::validate) returns an error. `role` is the RBAC
637    /// role gate (default: `"admin"`).
638    #[must_use]
639    pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
640        self.admin_enabled = true;
641        self.admin_role = role.into();
642        self
643    }
644
645    /// Log inbound HTTP request headers at DEBUG level. Sensitive
646    /// values remain redacted by the logging layer.
647    #[must_use]
648    pub fn enable_request_header_logging(mut self) -> Self {
649        self.log_request_headers = true;
650        self
651    }
652
653    /// Enable the Prometheus metrics listener on `bind` (e.g.
654    /// `127.0.0.1:9090`). Requires the `metrics` crate feature.
655    #[cfg(feature = "metrics")]
656    #[must_use]
657    pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
658        self.metrics_enabled = true;
659        self.metrics_bind = bind.into();
660        self
661    }
662
663    /// Validate the configuration and consume `self`, returning a
664    /// [`Validated<McpServerConfig>`] proof token required by [`serve`]
665    /// and [`serve_with_listener`]. This is the only way to construct
666    /// `Validated<McpServerConfig>`, so the type system guarantees
667    /// validation has run before the server starts.
668    ///
669    /// Checks:
670    ///
671    /// 1. `admin_enabled` requires `auth` to be configured and enabled.
672    /// 2. `tls_cert_path` and `tls_key_path` must both be set or both
673    ///    be unset.
674    /// 3. `bind_addr` must parse as a [`SocketAddr`].
675    /// 4. `public_url`, when set, must start with `http://` or `https://`.
676    /// 5. Each entry in `allowed_origins` must start with `http://` or
677    ///    `https://`.
678    /// 6. `max_request_body` must be greater than zero.
679    /// 7. When the `oauth` feature is enabled and an [`OAuthConfig`] is
680    ///    present, all OAuth URL fields (`jwks_uri`, `proxy.authorize_url`,
681    ///    `proxy.token_url`, `proxy.introspection_url`,
682    ///    `proxy.revocation_url`, `token_exchange.token_url`) must parse
683    ///    and use the `https` scheme. Set
684    ///    [`OAuthConfig::allow_http_oauth_urls`] to permit `http://`
685    ///    targets (strongly discouraged in production - see the field-level
686    ///    docs for the threat model).
687    ///
688    /// [`OAuthConfig`]: crate::oauth::OAuthConfig
689    /// [`OAuthConfig::allow_http_oauth_urls`]: crate::oauth::OAuthConfig::allow_http_oauth_urls
690    ///
691    /// # Errors
692    ///
693    /// Returns [`McpxError::Config`] with a human-readable message on
694    /// the first validation failure.
695    pub fn validate(self) -> Result<Validated<Self>, McpxError> {
696        self.check()?;
697        Ok(Validated(self))
698    }
699
700    /// Run the validation checks without consuming `self`. Used by
701    /// internal call sites (e.g. tests) that need to inspect a config
702    /// without taking ownership.
703    fn check(&self) -> Result<(), McpxError> {
704        // 1. admin <-> auth dependency. Mirrors the runtime check in
705        //    `build_app_router`: admin endpoints require an auth state,
706        //    which is built only when `auth` is `Some` *and* `enabled`.
707        if self.admin_enabled {
708            let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
709            if !auth_enabled {
710                return Err(McpxError::Config(
711                    "admin_enabled=true requires auth to be configured and enabled".into(),
712                ));
713            }
714        }
715
716        // 2. TLS cert / key must be paired
717        match (&self.tls_cert_path, &self.tls_key_path) {
718            (Some(_), None) => {
719                return Err(McpxError::Config(
720                    "tls_cert_path is set but tls_key_path is missing".into(),
721                ));
722            }
723            (None, Some(_)) => {
724                return Err(McpxError::Config(
725                    "tls_key_path is set but tls_cert_path is missing".into(),
726                ));
727            }
728            _ => {}
729        }
730
731        // 3. bind_addr parses
732        if self.bind_addr.parse::<SocketAddr>().is_err() {
733            return Err(McpxError::Config(format!(
734                "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
735                self.bind_addr
736            )));
737        }
738
739        // 4. public_url scheme
740        if let Some(ref url) = self.public_url
741            && !(url.starts_with("http://") || url.starts_with("https://"))
742        {
743            return Err(McpxError::Config(format!(
744                "public_url {url:?} must start with http:// or https://"
745            )));
746        }
747
748        // 5. allowed_origins scheme
749        for origin in &self.allowed_origins {
750            if !(origin.starts_with("http://") || origin.starts_with("https://")) {
751                return Err(McpxError::Config(format!(
752                    "allowed_origins entry {origin:?} must start with http:// or https://"
753                )));
754            }
755        }
756
757        // 6. max_request_body > 0
758        if self.max_request_body == 0 {
759            return Err(McpxError::Config(
760                "max_request_body must be greater than zero".into(),
761            ));
762        }
763
764        // 7. OAuth URL fields enforce HTTPS (unless `allow_http_oauth_urls`)
765        #[cfg(feature = "oauth")]
766        if let Some(auth_cfg) = &self.auth
767            && let Some(oauth_cfg) = &auth_cfg.oauth
768        {
769            oauth_cfg.validate()?;
770        }
771
772        // 8. Security-header overrides parse as valid HTTP header values,
773        //    and HSTS does not smuggle in a `preload` directive.
774        validate_security_headers(&self.security_headers)?;
775
776        Ok(())
777    }
778}
779
780/// Handle for hot-reloading server configuration without restart.
781///
782/// Obtained via [`McpServerConfig::on_reload_ready`].
783/// All swap operations are lock-free and wait-free -- in-flight requests
784/// finish with the old values while new requests see the update immediately.
785#[allow(
786    missing_debug_implementations,
787    reason = "contains Arc<AuthState> with non-Debug fields"
788)]
789pub struct ReloadHandle {
790    auth: Option<Arc<AuthState>>,
791    rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
792    crl_set: Option<Arc<CrlSet>>,
793}
794
795impl ReloadHandle {
796    /// Atomically replace the API key list used by the auth middleware.
797    pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
798        if let Some(ref auth) = self.auth {
799            auth.reload_keys(keys);
800        }
801    }
802
803    /// Atomically replace the RBAC policy used by the RBAC middleware.
804    pub fn reload_rbac(&self, policy: RbacPolicy) {
805        if let Some(ref rbac) = self.rbac {
806            rbac.store(Arc::new(policy));
807            tracing::info!("RBAC policy reloaded");
808        }
809    }
810
811    /// Force an immediate refresh of all cached mTLS CRLs.
812    ///
813    /// # Errors
814    ///
815    /// Returns an error if CRL refresh is unavailable or verifier rebuild fails.
816    pub async fn refresh_crls(&self) -> Result<(), McpxError> {
817        let Some(ref crl_set) = self.crl_set else {
818            return Err(McpxError::Config(
819                "CRL refresh requested but mTLS CRL support is not configured".into(),
820            ));
821        };
822
823        crl_set.force_refresh().await
824    }
825}
826
827/// Generic MCP HTTP server.
828///
829/// Wraps an axum server with `/healthz` and `/mcp` endpoints.
830/// When `tls_cert_path` and `tls_key_path` are both set, the server binds
831/// with TLS (rustls). Optionally supports mTLS client certificate auth.
832///
833/// # Errors
834///
835/// Returns an error if the TCP listener cannot bind, TLS config is invalid,
836/// or the server fails.
837// TODO(refactor): cognitive complexity reduced from 111/25 to 83/25 by
838// extracting `run_server` (serve-loop tail) and `install_oauth_proxy_routes`.
839// Remaining flow is a linear router builder: middleware layering, feature-
840// gated auth/RBAC wiring, and PRM/metrics installation. Further extraction
841// would require threading many `&mut Router` helpers and hurt readability
842// of the layer order (which is security-relevant and must stay visible).
843#[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
844/// Internal bundle of values produced by [`build_app_router`] and
845/// consumed by [`serve`] / [`serve_with_listener`] when driving the
846/// HTTP listener.
847struct AppRunParams {
848    /// TLS cert/key paths when TLS is configured.
849    tls_paths: Option<(PathBuf, PathBuf)>,
850    /// mTLS configuration when mutual-TLS auth is enabled.
851    mtls_config: Option<MtlsConfig>,
852    /// Graceful shutdown drain window.
853    shutdown_timeout: Duration,
854    /// Shared auth state used by hot-reload callbacks.
855    auth_state: Option<Arc<AuthState>>,
856    /// Hot-reloadable RBAC state used by reload callbacks.
857    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
858    /// Optional callback that receives the final [`ReloadHandle`].
859    on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
860    /// Server-internal cancellation token. Cancelled by [`run_server`]
861    /// once the shutdown trigger fires (so rmcp's child token also
862    /// fires, terminating in-flight MCP sessions).
863    ct: CancellationToken,
864    /// `"http"` or `"https"` -- used only for boot-time logging.
865    scheme: &'static str,
866    /// Server name -- used only for boot-time logging.
867    name: String,
868}
869
870/// Build the full application axum [`axum::Router`] (MCP route +
871/// middleware stack + admin + OAuth + health endpoints + security
872/// headers + CORS + compression + concurrency limit + origin check)
873/// and the [`AppRunParams`] needed to drive it.
874///
875/// This is the shared core of [`serve`] and [`serve_with_listener`].
876/// It performs *no* network I/O: callers are responsible for binding
877/// (or accepting a pre-bound) [`TcpListener`] and invoking
878/// [`run_server`].
879#[allow(
880    clippy::cognitive_complexity,
881    reason = "router assembly is intrinsically sequential; splitting harms readability"
882)]
883#[allow(
884    deprecated,
885    reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
886)]
887fn build_app_router<H, F>(
888    mut config: McpServerConfig,
889    handler_factory: F,
890) -> anyhow::Result<(axum::Router, AppRunParams)>
891where
892    H: ServerHandler + 'static,
893    F: Fn() -> H + Send + Sync + Clone + 'static,
894{
895    let ct = CancellationToken::new();
896
897    let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
898    tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
899
900    let mcp_service = StreamableHttpService::new(
901        move || Ok(handler_factory()),
902        {
903            let mut mgr = LocalSessionManager::default();
904            mgr.session_config.keep_alive = Some(config.session_idle_timeout);
905            mgr.into()
906        },
907        StreamableHttpServerConfig::default()
908            .with_allowed_hosts(allowed_hosts)
909            .with_sse_keep_alive(Some(config.sse_keep_alive))
910            .with_cancellation_token(ct.child_token()),
911    );
912
913    // Build the MCP route, optionally wrapped with auth and RBAC middleware.
914    let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
915
916    // Build auth state eagerly when auth is configured so we can wire both
917    // the auth middleware *and* the optional admin router against the same
918    // state. The middleware itself is installed further down in layer order.
919    let auth_state: Option<Arc<AuthState>> = match config.auth {
920        Some(ref auth_config) if auth_config.enabled => {
921            let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
922            let pre_auth_limiter = auth_config
923                .rate_limit
924                .as_ref()
925                .map(crate::auth::build_pre_auth_limiter);
926
927            #[cfg(feature = "oauth")]
928            let jwks_cache = auth_config
929                .oauth
930                .as_ref()
931                .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
932                .transpose()
933                .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
934
935            Some(Arc::new(AuthState {
936                api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
937                rate_limiter,
938                pre_auth_limiter,
939                #[cfg(feature = "oauth")]
940                jwks_cache,
941                seen_identities: std::sync::Mutex::new(std::collections::HashSet::new()),
942                counters: crate::auth::AuthCounters::default(),
943            }))
944        }
945        _ => None,
946    };
947
948    // Build the RBAC policy swap early so the admin router and the later
949    // RBAC middleware layer share the same hot-reloadable state.
950    let rbac_swap = Arc::new(ArcSwap::new(
951        config
952            .rbac
953            .clone()
954            .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
955    ));
956
957    // Optional /admin/* diagnostic routes. Merged BEFORE the
958    // body-limit/timeout/RBAC/origin/auth layers so all of them apply.
959    if config.admin_enabled {
960        let Some(ref auth_state_ref) = auth_state else {
961            return Err(anyhow::anyhow!(
962                "admin_enabled=true requires auth to be configured and enabled"
963            ));
964        };
965        let admin_state = crate::admin::AdminState {
966            started_at: std::time::Instant::now(),
967            name: config.name.clone(),
968            version: config.version.clone(),
969            auth: Some(Arc::clone(auth_state_ref)),
970            rbac: Arc::clone(&rbac_swap),
971        };
972        let admin_cfg = crate::admin::AdminConfig {
973            role: config.admin_role.clone(),
974        };
975        mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
976        tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
977    }
978
979    // ----- Middleware order (CRITICAL: read carefully) ------------------
980    //
981    // axum/tower applies layers **bottom-up** at runtime: the LAST layer
982    // added is the OUTERMOST (runs first on a request). To achieve a
983    // request-time flow of:
984    //
985    //   body-limit -> timeout -> auth -> rbac -> handler
986    //
987    // we add layers in the REVERSE order:
988    //
989    //   1. RBAC               (innermost, runs last before handler)
990    //   2. auth               (parses identity, sets extension for RBAC)
991    //   3. timeout            (bounds total request time)
992    //   4. body-limit         (outermost on /mcp; caps payload before
993    //                          anything else reads/buffers it)
994    //
995    // Origin validation is installed on the OUTER router (after the
996    // /mcp router is merged in), so it also protects /healthz, /readyz,
997    // /version, and any OAuth proxy endpoints.
998    //
999    // Rationale:
1000    // - Body-limit must be outermost on /mcp so RBAC (which reads the
1001    //   JSON-RPC body) cannot be DoS'd by a 100MB payload.
1002    // - Auth must run before RBAC because RBAC consumes
1003    //   `req.extensions().get::<AuthIdentity>()` to enforce per-role
1004    //   policy.
1005    // - Origin runs before auth so we reject cross-origin requests
1006    //   without spending Argon2 cycles on unauthenticated callers.
1007
1008    // [1] RBAC + tool rate-limit layer (innermost; closest to handler).
1009    // Always installed: even when RBAC is disabled, tool rate limiting may
1010    // be active (MCP spec: servers MUST rate limit tool invocations).
1011    {
1012        let tool_limiter: Option<Arc<ToolRateLimiter>> =
1013            config.tool_rate_limit.map(build_tool_rate_limiter);
1014
1015        if rbac_swap.load().is_enabled() {
1016            tracing::info!("RBAC enforcement enabled on /mcp");
1017        }
1018        if let Some(limit) = config.tool_rate_limit {
1019            tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1020        }
1021
1022        let rbac_for_mw = Arc::clone(&rbac_swap);
1023        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1024            let p = rbac_for_mw.load_full();
1025            let tl = tool_limiter.clone();
1026            rbac_middleware(p, tl, req, next)
1027        }));
1028    }
1029
1030    // [2] Auth layer (runs before RBAC so AuthIdentity is in extensions).
1031    if let Some(ref auth_config) = config.auth
1032        && auth_config.enabled
1033    {
1034        let Some(ref state) = auth_state else {
1035            return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1036        };
1037
1038        let methods: Vec<&str> = [
1039            auth_config.mtls.is_some().then_some("mTLS"),
1040            (!auth_config.api_keys.is_empty()).then_some("bearer"),
1041            #[cfg(feature = "oauth")]
1042            auth_config.oauth.is_some().then_some("oauth-jwt"),
1043        ]
1044        .into_iter()
1045        .flatten()
1046        .collect();
1047
1048        tracing::info!(
1049            methods = %methods.join(", "),
1050            api_keys = auth_config.api_keys.len(),
1051            "auth enabled on /mcp"
1052        );
1053
1054        let state_for_mw = Arc::clone(state);
1055        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1056            let s = Arc::clone(&state_for_mw);
1057            auth_middleware(s, req, next)
1058        }));
1059    }
1060
1061    // [3] Request timeout (returns 408 on expiry). Bounds total request
1062    // duration including auth + handler.
1063    mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1064        axum::http::StatusCode::REQUEST_TIMEOUT,
1065        config.request_timeout,
1066    ));
1067
1068    // [4] Request body size limit (OUTERMOST on /mcp). Prevents OOM /
1069    // DoS from oversized payloads BEFORE any inner layer (auth, RBAC)
1070    // attempts to buffer or parse the body.
1071    mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1072        config.max_request_body,
1073    ));
1074
1075    // Compute the effective allowed-origins list for the outer
1076    // origin-check layer (installed on the merged router below). When
1077    // `allowed_origins` is empty but `public_url` is set, auto-derive
1078    // the origin from the public URL so MCP clients (e.g. Claude Code)
1079    // that send `Origin: <server-url>` are accepted without explicit
1080    // config.
1081    let mut effective_origins = config.allowed_origins.clone();
1082    if effective_origins.is_empty()
1083        && let Some(ref url) = config.public_url
1084    {
1085        // Origin = scheme + "://" + host (+ ":" + port if non-default).
1086        // Strip any path/query from the public URL.
1087        if let Some(scheme_end) = url.find("://") {
1088            let after_scheme = &url[scheme_end + 3..];
1089            let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1090            let origin = format!("{}{}", &url[..scheme_end + 3], &after_scheme[..host_end]);
1091            tracing::info!(
1092                %origin,
1093                "auto-derived allowed origin from public_url"
1094            );
1095            effective_origins.push(origin);
1096        }
1097    }
1098    let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1099    let cors_origins = Arc::clone(&allowed_origins);
1100    let log_request_headers = config.log_request_headers;
1101
1102    let readyz_route = if let Some(check) = config.readiness_check.take() {
1103        axum::routing::get(move || readyz(Arc::clone(&check)))
1104    } else {
1105        axum::routing::get(healthz)
1106    };
1107
1108    #[allow(unused_mut)] // mut needed when oauth feature adds PRM route
1109    let mut router = axum::Router::new()
1110        .route("/healthz", axum::routing::get(healthz))
1111        .route("/readyz", readyz_route)
1112        .route(
1113            "/version",
1114            axum::routing::get({
1115                // Pre-serialize the version payload once at router-build
1116                // time. The handler then serves a cheap `Arc::clone` of the
1117                // immutable bytes per request, avoiding `serde_json::Value`
1118                // allocation + serialization on every `/version` hit.
1119                let payload_bytes: Arc<[u8]> =
1120                    serialize_version_payload(&config.name, &config.version);
1121                move || {
1122                    let p = Arc::clone(&payload_bytes);
1123                    async move {
1124                        (
1125                            [(axum::http::header::CONTENT_TYPE, "application/json")],
1126                            p.to_vec(),
1127                        )
1128                    }
1129                }
1130            }),
1131        )
1132        .merge(mcp_router);
1133
1134    // Merge application-specific routes (bypass MCP auth/RBAC middleware).
1135    if let Some(extra) = config.extra_router.take() {
1136        router = router.merge(extra);
1137    }
1138
1139    // RFC 9728: Protected Resource Metadata endpoint.
1140    // When OAuth is configured, serve full metadata with authorization_servers.
1141    // Otherwise, serve a minimal document with just the resource URL and no
1142    // authorization_servers -- this tells MCP clients (e.g. Claude Code SDK)
1143    // that the server exists but does NOT require OAuth authentication,
1144    // preventing them from gating the connection behind a broken auth flow.
1145    let server_url = if let Some(ref url) = config.public_url {
1146        url.trim_end_matches('/').to_owned()
1147    } else {
1148        let prm_scheme = if config.tls_cert_path.is_some() {
1149            "https"
1150        } else {
1151            "http"
1152        };
1153        format!("{prm_scheme}://{}", config.bind_addr)
1154    };
1155    let resource_url = format!("{server_url}/mcp");
1156
1157    #[cfg(feature = "oauth")]
1158    let prm_metadata = if let Some(ref auth_config) = config.auth
1159        && let Some(ref oauth_config) = auth_config.oauth
1160    {
1161        crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1162    } else {
1163        serde_json::json!({ "resource": resource_url })
1164    };
1165    #[cfg(not(feature = "oauth"))]
1166    let prm_metadata = serde_json::json!({ "resource": resource_url });
1167
1168    router = router.route(
1169        "/.well-known/oauth-protected-resource",
1170        axum::routing::get(move || {
1171            let m = prm_metadata.clone();
1172            async move { axum::Json(m) }
1173        }),
1174    );
1175
1176    // OAuth 2.1 proxy endpoints: when an OAuth proxy is configured, expose
1177    // /authorize, /token, /register, and authorization server metadata so
1178    // MCP clients can perform Authorization Code + PKCE against the upstream
1179    // IdP (e.g. Keycloak) transparently.
1180    #[cfg(feature = "oauth")]
1181    if let Some(ref auth_config) = config.auth
1182        && let Some(ref oauth_config) = auth_config.oauth
1183        && oauth_config.proxy.is_some()
1184    {
1185        router =
1186            install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1187    }
1188
1189    // OWASP security response headers (applied to all responses).
1190    // HSTS is conditional on TLS being configured.
1191    let is_tls = config.tls_cert_path.is_some();
1192    let security_headers_cfg = Arc::new(config.security_headers.clone());
1193    router = router.layer(axum::middleware::from_fn(move |req, next| {
1194        let cfg = Arc::clone(&security_headers_cfg);
1195        security_headers_middleware(is_tls, cfg, req, next)
1196    }));
1197
1198    // CORS preflight layer (required for browser-based MCP clients).
1199    // Uses the same effective origins as the origin check middleware
1200    // (including auto-derived origin from public_url).
1201    if !cors_origins.is_empty() {
1202        let cors = tower_http::cors::CorsLayer::new()
1203            .allow_origin(
1204                cors_origins
1205                    .iter()
1206                    .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1207                    .collect::<Vec<_>>(),
1208            )
1209            .allow_methods([
1210                axum::http::Method::GET,
1211                axum::http::Method::POST,
1212                axum::http::Method::OPTIONS,
1213            ])
1214            .allow_headers([
1215                axum::http::header::CONTENT_TYPE,
1216                axum::http::header::AUTHORIZATION,
1217            ]);
1218        router = router.layer(cors);
1219    }
1220
1221    // Optional response compression (gzip + brotli). Skips small bodies
1222    // to avoid overhead. Applied after CORS so preflight responses remain
1223    // uncompressed.
1224    if config.compression_enabled {
1225        use tower_http::compression::Predicate as _;
1226        let predicate = tower_http::compression::DefaultPredicate::new().and(
1227            tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1228        );
1229        router = router.layer(
1230            tower_http::compression::CompressionLayer::new()
1231                .gzip(true)
1232                .br(true)
1233                .compress_when(predicate),
1234        );
1235        tracing::info!(
1236            min_size = config.compression_min_size,
1237            "response compression enabled (gzip, br)"
1238        );
1239    }
1240
1241    // Optional global concurrency cap. `load_shed` converts the
1242    // `ConcurrencyLimit` back-pressure error into 503 instead of hanging.
1243    if let Some(max) = config.max_concurrent_requests {
1244        let overload_handler = tower::ServiceBuilder::new()
1245            .layer(axum::error_handling::HandleErrorLayer::new(
1246                |_err: tower::BoxError| async {
1247                    (
1248                        axum::http::StatusCode::SERVICE_UNAVAILABLE,
1249                        axum::Json(serde_json::json!({
1250                            "error": "overloaded",
1251                            "error_description": "server is at capacity, retry later"
1252                        })),
1253                    )
1254                },
1255            ))
1256            .layer(tower::load_shed::LoadShedLayer::new())
1257            .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1258        router = router.layer(overload_handler);
1259        tracing::info!(max, "global concurrency limit enabled");
1260    }
1261
1262    // JSON fallback for unmatched routes. Without this, axum returns
1263    // an empty-body 404 that breaks MCP clients (e.g. Claude Code SDK)
1264    // when they probe OAuth endpoints like /authorize or /token.
1265    router = router.fallback(|| async {
1266        (
1267            axum::http::StatusCode::NOT_FOUND,
1268            axum::Json(serde_json::json!({
1269                "error": "not_found",
1270                "error_description": "The requested endpoint does not exist"
1271            })),
1272        )
1273    });
1274
1275    // Prometheus metrics: recording middleware + separate listener.
1276    #[cfg(feature = "metrics")]
1277    if config.metrics_enabled {
1278        let metrics = Arc::new(
1279            crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1280        );
1281        let m = Arc::clone(&metrics);
1282        router = router.layer(axum::middleware::from_fn(
1283            move |req: Request<Body>, next: Next| {
1284                let m = Arc::clone(&m);
1285                metrics_middleware(m, req, next)
1286            },
1287        ));
1288        let metrics_bind = config.metrics_bind.clone();
1289        tokio::spawn(async move {
1290            if let Err(e) = crate::metrics::serve_metrics(metrics_bind, metrics).await {
1291                tracing::error!("metrics listener failed: {e}");
1292            }
1293        });
1294    }
1295
1296    // Origin validation layer (MCP spec: servers MUST validate the
1297    // Origin header to prevent DNS rebinding attacks). Installed as the
1298    // OUTERMOST layer on the OUTER router so it protects ALL routes
1299    // (`/mcp`, `/healthz`, `/readyz`, `/version`, OAuth proxy endpoints,
1300    // admin endpoints, extra_router, etc.) and runs BEFORE auth so we
1301    // reject cross-origin attackers without spending Argon2 cycles.
1302    //
1303    // Origin-less requests (e.g. server-to-server probes, curl, native
1304    // MCP clients) are permitted; only requests with an Origin header
1305    // that does not match `effective_origins` are rejected.
1306    router = router.layer(axum::middleware::from_fn(move |req, next| {
1307        let origins = Arc::clone(&allowed_origins);
1308        origin_check_middleware(origins, log_request_headers, req, next)
1309    }));
1310
1311    let scheme = if config.tls_cert_path.is_some() {
1312        "https"
1313    } else {
1314        "http"
1315    };
1316
1317    let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1318        (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1319        _ => None,
1320    };
1321    let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1322
1323    Ok((
1324        router,
1325        AppRunParams {
1326            tls_paths,
1327            mtls_config,
1328            shutdown_timeout: config.shutdown_timeout,
1329            auth_state,
1330            rbac_swap,
1331            on_reload_ready: config.on_reload_ready.take(),
1332            ct,
1333            scheme,
1334            name: config.name.clone(),
1335        },
1336    ))
1337}
1338
1339/// Run the MCP HTTP server, binding to `config.bind_addr` and serving
1340/// until an OS shutdown signal (Ctrl-C / SIGTERM) is received.
1341///
1342/// This is the standard entry point for production deployments. For
1343/// deterministic shutdown control (e.g. integration tests), see
1344/// [`serve_with_listener`].
1345///
1346/// The configuration must be validated first via
1347/// [`McpServerConfig::validate`], which returns a [`Validated`] proof
1348/// token. This typestate guarantees, at compile time, that the server
1349/// never starts with an invalid configuration.
1350///
1351/// # Errors
1352///
1353/// Returns [`McpxError::Startup`] if binding to `config.bind_addr`
1354/// fails, or if the underlying axum server returns an error.
1355pub async fn serve<H, F>(
1356    config: Validated<McpServerConfig>,
1357    handler_factory: F,
1358) -> Result<(), McpxError>
1359where
1360    H: ServerHandler + 'static,
1361    F: Fn() -> H + Send + Sync + Clone + 'static,
1362{
1363    let config = config.into_inner();
1364    #[allow(
1365        deprecated,
1366        reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1367    )]
1368    let bind_addr = config.bind_addr.clone();
1369    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1370
1371    let listener = TcpListener::bind(&bind_addr)
1372        .await
1373        .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1374    log_listening(&params.name, params.scheme, &bind_addr);
1375
1376    run_server(
1377        router,
1378        listener,
1379        params.tls_paths,
1380        params.mtls_config,
1381        params.shutdown_timeout,
1382        params.auth_state,
1383        params.rbac_swap,
1384        params.on_reload_ready,
1385        params.ct,
1386    )
1387    .await
1388    .map_err(anyhow_to_startup)
1389}
1390
1391/// Run the MCP HTTP server on a pre-bound [`TcpListener`], with optional
1392/// readiness signalling and external shutdown control.
1393///
1394/// This variant is intended for **deterministic integration tests** and
1395/// for embedders that need to bind the listening socket themselves
1396/// (e.g. systemd socket activation). Compared to [`serve`]:
1397///
1398/// * The caller passes a `TcpListener` that is already bound. This
1399///   eliminates the bind race in tests that previously required
1400///   poll-the-`/healthz`-loop start-up detection.
1401/// * `ready_tx`, when `Some`, receives the socket's
1402///   [`SocketAddr`] *after* the router is built and immediately before
1403///   the server starts accepting connections. Tests can `await` the
1404///   matching `oneshot::Receiver` to know exactly when it is safe to
1405///   issue requests.
1406/// * `shutdown`, when `Some`, gives the caller a
1407///   [`CancellationToken`] that triggers the same graceful-shutdown
1408///   path as a real OS signal. This avoids cross-platform issues with
1409///   sending real `SIGTERM` from tests on Windows.
1410///
1411/// All three optional parameters degrade gracefully: if `ready_tx` is
1412/// `None`, no signal is sent; if `shutdown` is `None`, the server only
1413/// stops on an OS signal (just like [`serve`]).
1414///
1415/// # Errors
1416///
1417/// Returns [`McpxError::Startup`] if router construction fails, if reading
1418/// the listener's `local_addr()` fails, or if the underlying axum
1419/// server returns an error.
1420pub async fn serve_with_listener<H, F>(
1421    listener: TcpListener,
1422    config: Validated<McpServerConfig>,
1423    handler_factory: F,
1424    ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1425    shutdown: Option<CancellationToken>,
1426) -> Result<(), McpxError>
1427where
1428    H: ServerHandler + 'static,
1429    F: Fn() -> H + Send + Sync + Clone + 'static,
1430{
1431    let config = config.into_inner();
1432    let local_addr = listener
1433        .local_addr()
1434        .map_err(|e| io_to_startup("listener.local_addr", e))?;
1435    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1436
1437    log_listening(&params.name, params.scheme, &local_addr.to_string());
1438
1439    // Forward external shutdown into the server-internal cancellation
1440    // token so `run_server`'s shutdown trigger picks it up alongside
1441    // any real OS signal.
1442    if let Some(external) = shutdown {
1443        let internal = params.ct.clone();
1444        tokio::spawn(async move {
1445            external.cancelled().await;
1446            internal.cancel();
1447        });
1448    }
1449
1450    // Signal readiness *after* the router is fully built and external
1451    // shutdown is wired, but *before* run_server takes ownership of
1452    // the listener. The receiver can immediately issue requests.
1453    if let Some(tx) = ready_tx {
1454        // Receiver may have been dropped (test gave up). That's fine.
1455        let _ = tx.send(local_addr);
1456    }
1457
1458    run_server(
1459        router,
1460        listener,
1461        params.tls_paths,
1462        params.mtls_config,
1463        params.shutdown_timeout,
1464        params.auth_state,
1465        params.rbac_swap,
1466        params.on_reload_ready,
1467        params.ct,
1468    )
1469    .await
1470    .map_err(anyhow_to_startup)
1471}
1472
1473/// Emit the standard "listening on …" log lines used by both
1474/// [`serve`] and [`serve_with_listener`].
1475#[allow(
1476    clippy::cognitive_complexity,
1477    reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1478)]
1479fn log_listening(name: &str, scheme: &str, addr: &str) {
1480    tracing::info!("{name} listening on {addr}");
1481    tracing::info!("  MCP endpoint: {scheme}://{addr}/mcp");
1482    tracing::info!("  Health check: {scheme}://{addr}/healthz");
1483    tracing::info!("  Readiness:   {scheme}://{addr}/readyz");
1484}
1485
1486/// Drive the chosen axum server variant (TLS or plain) with a graceful
1487/// shutdown window. Consumes the router and listener.
1488///
1489/// # Shutdown semantics
1490///
1491/// A single shutdown trigger (the FIRST of: OS signal via
1492/// `shutdown_signal()`, or external cancellation of `ct`) starts BOTH:
1493///
1494/// 1. axum's `.with_graceful_shutdown(...)` future, which stops
1495///    accepting new connections and waits for in-flight requests to
1496///    drain;
1497/// 2. a `tokio::time::sleep(shutdown_timeout)` race that forces exit if
1498///    drainage exceeds `shutdown_timeout`.
1499///
1500/// Previously this function awaited `shutdown_signal()` independently
1501/// in BOTH branches of a `tokio::select!`. Because `shutdown_signal`
1502/// resolves once per future and consumes one signal, the force-exit
1503/// timer was tied to a SECOND signal (a second SIGTERM the operator
1504/// would never send). Under a single SIGTERM the graceful drain could
1505/// hang indefinitely. The current implementation derives both branches
1506/// from a single shared trigger so the timeout race is anchored to the
1507/// FIRST (and only) signal.
1508#[allow(
1509    clippy::too_many_arguments,
1510    clippy::cognitive_complexity,
1511    reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1512)]
1513async fn run_server(
1514    router: axum::Router,
1515    listener: TcpListener,
1516    tls_paths: Option<(PathBuf, PathBuf)>,
1517    mtls_config: Option<MtlsConfig>,
1518    shutdown_timeout: Duration,
1519    auth_state: Option<Arc<AuthState>>,
1520    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1521    mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1522    ct: CancellationToken,
1523) -> anyhow::Result<()> {
1524    // `shutdown_trigger` fires when the FIRST source resolves: either
1525    // an OS signal (Ctrl-C / SIGTERM) or external cancellation of `ct`
1526    // (which the test harness uses for deterministic shutdown).
1527    let shutdown_trigger = CancellationToken::new();
1528    {
1529        let trigger = shutdown_trigger.clone();
1530        let parent = ct.clone();
1531        tokio::spawn(async move {
1532            tokio::select! {
1533                () = shutdown_signal() => {}
1534                () = parent.cancelled() => {}
1535            }
1536            trigger.cancel();
1537        });
1538    }
1539
1540    let graceful = {
1541        let trigger = shutdown_trigger.clone();
1542        let ct = ct.clone();
1543        async move {
1544            trigger.cancelled().await;
1545            tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1546            ct.cancel();
1547        }
1548    };
1549
1550    let force_exit_timer = {
1551        let trigger = shutdown_trigger.clone();
1552        async move {
1553            trigger.cancelled().await;
1554            tokio::time::sleep(shutdown_timeout).await;
1555        }
1556    };
1557
1558    if let Some((cert_path, key_path)) = tls_paths {
1559        let crl_set = if let Some(mtls) = mtls_config.as_ref()
1560            && mtls.crl_enabled
1561        {
1562            let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1563            let (crl_set, discover_rx) =
1564                mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1565                    .await
1566                    .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1567            tokio::spawn(mtls_revocation::run_crl_refresher(
1568                Arc::clone(&crl_set),
1569                discover_rx,
1570                ct.clone(),
1571            ));
1572            Some(crl_set)
1573        } else {
1574            None
1575        };
1576
1577        if let Some(cb) = on_reload_ready.take() {
1578            cb(ReloadHandle {
1579                auth: auth_state.clone(),
1580                rbac: Some(Arc::clone(&rbac_swap)),
1581                crl_set: crl_set.clone(),
1582            });
1583        }
1584
1585        let tls_listener = TlsListener::new(
1586            listener,
1587            &cert_path,
1588            &key_path,
1589            mtls_config.as_ref(),
1590            crl_set,
1591        )?;
1592        let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1593        tokio::select! {
1594            result = axum::serve(tls_listener, make_svc)
1595                .with_graceful_shutdown(graceful) => { result?; }
1596            () = force_exit_timer => {
1597                tracing::warn!("shutdown timeout exceeded, forcing exit");
1598            }
1599        }
1600    } else {
1601        if let Some(cb) = on_reload_ready.take() {
1602            cb(ReloadHandle {
1603                auth: auth_state,
1604                rbac: Some(rbac_swap),
1605                crl_set: None,
1606            });
1607        }
1608
1609        let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1610        tokio::select! {
1611            result = axum::serve(listener, make_svc)
1612                .with_graceful_shutdown(graceful) => { result?; }
1613            () = force_exit_timer => {
1614                tracing::warn!("shutdown timeout exceeded, forcing exit");
1615            }
1616        }
1617    }
1618
1619    Ok(())
1620}
1621
1622/// Install the OAuth 2.1 proxy endpoints (`/authorize`, `/token`,
1623/// `/register`, and authorization server metadata) on `router`. The
1624/// caller must ensure `oauth_config.proxy` is `Some`.
1625///
1626/// # Errors
1627///
1628/// Returns [`McpxError::Startup`] if the shared
1629/// [`crate::oauth::OauthHttpClient`] cannot be initialized.
1630#[cfg(feature = "oauth")]
1631fn install_oauth_proxy_routes(
1632    router: axum::Router,
1633    server_url: &str,
1634    oauth_config: &crate::oauth::OAuthConfig,
1635    auth_state: Option<&Arc<AuthState>>,
1636) -> Result<axum::Router, McpxError> {
1637    let Some(ref proxy) = oauth_config.proxy else {
1638        return Ok(router);
1639    };
1640
1641    // Single shared HTTP client for all proxy endpoints. Cloning is
1642    // cheap (refcounted) and shares the underlying connection pool.
1643    let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1644
1645    let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1646    let router = router.route(
1647        "/.well-known/oauth-authorization-server",
1648        axum::routing::get(move || {
1649            let m = asm.clone();
1650            async move { axum::Json(m) }
1651        }),
1652    );
1653
1654    let proxy_authorize = proxy.clone();
1655    let router = router.route(
1656        "/authorize",
1657        axum::routing::get(
1658            move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1659                let p = proxy_authorize.clone();
1660                async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1661            },
1662        ),
1663    );
1664
1665    let proxy_token = proxy.clone();
1666    let token_http = http.clone();
1667    let router = router.route(
1668        "/token",
1669        axum::routing::post(move |body: String| {
1670            let p = proxy_token.clone();
1671            let h = token_http.clone();
1672            async move { crate::oauth::handle_token(&h, &p, &body).await }
1673        })
1674        .layer(axum::middleware::from_fn(
1675            oauth_token_cache_headers_middleware,
1676        )),
1677    );
1678
1679    let proxy_register = proxy.clone();
1680    let router = router.route(
1681        "/register",
1682        axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1683            let p = proxy_register;
1684            async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1685        })
1686        .layer(axum::middleware::from_fn(
1687            oauth_token_cache_headers_middleware,
1688        )),
1689    );
1690
1691    let admin_routes_enabled = proxy.expose_admin_endpoints
1692        && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1693    if proxy.expose_admin_endpoints && !proxy.require_auth_on_admin_endpoints {
1694        tracing::warn!(
1695            "OAuth introspect/revoke endpoints are unauthenticated; consider setting require_auth_on_admin_endpoints = true"
1696        );
1697    }
1698
1699    let admin_router = if admin_routes_enabled {
1700        build_oauth_admin_router(proxy, http, auth_state)?
1701    } else {
1702        axum::Router::new()
1703    };
1704
1705    let router = router.merge(admin_router);
1706
1707    tracing::info!(
1708        introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1709        revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1710        "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1711    );
1712    Ok(router)
1713}
1714
1715/// Build the optional `/introspect` + `/revoke` admin sub-router.
1716///
1717/// Layered with [`oauth_token_cache_headers_middleware`] so RFC 6749 §5.1
1718/// / RFC 6750 §5.4 cache headers are emitted, and conditionally with the
1719/// auth middleware when `proxy.require_auth_on_admin_endpoints` is set.
1720#[cfg(feature = "oauth")]
1721fn build_oauth_admin_router(
1722    proxy: &crate::oauth::OAuthProxyConfig,
1723    http: crate::oauth::OauthHttpClient,
1724    auth_state: Option<&Arc<AuthState>>,
1725) -> Result<axum::Router, McpxError> {
1726    let mut admin_router = axum::Router::new();
1727    if proxy.introspection_url.is_some() {
1728        let proxy_introspect = proxy.clone();
1729        let introspect_http = http.clone();
1730        admin_router = admin_router.route(
1731            "/introspect",
1732            axum::routing::post(move |body: String| {
1733                let p = proxy_introspect.clone();
1734                let h = introspect_http.clone();
1735                async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1736            }),
1737        );
1738    }
1739    if proxy.revocation_url.is_some() {
1740        let proxy_revoke = proxy.clone();
1741        let revoke_http = http;
1742        admin_router = admin_router.route(
1743            "/revoke",
1744            axum::routing::post(move |body: String| {
1745                let p = proxy_revoke.clone();
1746                let h = revoke_http.clone();
1747                async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1748            }),
1749        );
1750    }
1751
1752    let admin_router = admin_router.layer(axum::middleware::from_fn(
1753        oauth_token_cache_headers_middleware,
1754    ));
1755
1756    if proxy.require_auth_on_admin_endpoints {
1757        let Some(state) = auth_state else {
1758            return Err(McpxError::Startup(
1759                "oauth proxy admin endpoints require auth state".into(),
1760            ));
1761        };
1762        let state_for_mw = Arc::clone(state);
1763        Ok(
1764            admin_router.layer(axum::middleware::from_fn(move |req, next| {
1765                let s = Arc::clone(&state_for_mw);
1766                auth_middleware(s, req, next)
1767            })),
1768        )
1769    } else {
1770        Ok(admin_router)
1771    }
1772}
1773
1774/// Build the host allow-list for rmcp's DNS rebinding protection.
1775///
1776/// Includes loopback hosts by default, then augments with host/authority
1777/// derived from `public_url` and the server bind address.
1778fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
1779    let mut hosts = vec![
1780        "localhost".to_owned(),
1781        "127.0.0.1".to_owned(),
1782        "::1".to_owned(),
1783    ];
1784
1785    if let Some(url) = public_url
1786        && let Ok(uri) = url.parse::<axum::http::Uri>()
1787        && let Some(authority) = uri.authority()
1788    {
1789        let host = authority.host().to_owned();
1790        if !hosts.iter().any(|h| h == &host) {
1791            hosts.push(host);
1792        }
1793
1794        let authority = authority.as_str().to_owned();
1795        if !hosts.iter().any(|h| h == &authority) {
1796            hosts.push(authority);
1797        }
1798    }
1799
1800    if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
1801        && let Some(authority) = uri.authority()
1802    {
1803        let host = authority.host().to_owned();
1804        if !hosts.iter().any(|h| h == &host) {
1805            hosts.push(host);
1806        }
1807
1808        let authority = authority.as_str().to_owned();
1809        if !hosts.iter().any(|h| h == &authority) {
1810            hosts.push(authority);
1811        }
1812    }
1813
1814    hosts
1815}
1816
1817// - TLS support -
1818
1819/// Implement axum's `Connected` trait for `TlsConnInfo` so that
1820/// `ConnectInfo<TlsConnInfo>` is available in middleware when serving
1821/// over our custom `TlsListener`.
1822///
1823/// The identity is read directly from the wrapping
1824/// [`AuthenticatedTlsStream`], which guarantees one-to-one correspondence
1825/// between the TLS connection and its mTLS identity. This eliminates the
1826/// previous shared-map approach which was vulnerable to ephemeral-port
1827/// reuse races (an unauthenticated reconnection from the same `(IP, port)`
1828/// pair could alias a stale entry).
1829impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
1830    for TlsConnInfo
1831{
1832    fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
1833        let addr = *target.remote_addr();
1834        let identity = target.io().identity().cloned();
1835        TlsConnInfo::new(addr, identity)
1836    }
1837}
1838
1839/// A TLS-wrapping listener that implements axum's `Listener` trait.
1840///
1841/// When mTLS is configured, verifies client certificates against the
1842/// configured CA and extracts the client identity at handshake time.
1843/// The extracted identity is bound to the connection itself via the
1844/// returned [`AuthenticatedTlsStream`], so it is impossible for an
1845/// unrelated connection to observe it.
1846struct TlsListener {
1847    inner: TcpListener,
1848    acceptor: tokio_rustls::TlsAcceptor,
1849    mtls_default_role: String,
1850}
1851
1852impl TlsListener {
1853    fn new(
1854        inner: TcpListener,
1855        cert_path: &Path,
1856        key_path: &Path,
1857        mtls_config: Option<&MtlsConfig>,
1858        crl_set: Option<Arc<CrlSet>>,
1859    ) -> anyhow::Result<Self> {
1860        // Install the ring crypto provider (ok to call multiple times).
1861        rustls::crypto::ring::default_provider()
1862            .install_default()
1863            .ok();
1864
1865        let certs = load_certs(cert_path)?;
1866        let key = load_key(key_path)?;
1867
1868        let mtls_default_role;
1869
1870        let tls_config = if let Some(mtls) = mtls_config {
1871            mtls_default_role = mtls.default_role.clone();
1872            let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
1873            {
1874                let Some(crl_set) = crl_set else {
1875                    return Err(anyhow::anyhow!(
1876                        "mTLS CRL verifier requested but CRL state was not initialized"
1877                    ));
1878                };
1879                Arc::new(DynamicClientCertVerifier::new(crl_set))
1880            } else {
1881                let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
1882                if mtls.required {
1883                    rustls::server::WebPkiClientVerifier::builder(root_store)
1884                        .build()
1885                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1886                } else {
1887                    rustls::server::WebPkiClientVerifier::builder(root_store)
1888                        .allow_unauthenticated()
1889                        .build()
1890                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1891                }
1892            };
1893
1894            tracing::info!(
1895                ca = %mtls.ca_cert_path.display(),
1896                required = mtls.required,
1897                crl_enabled = mtls.crl_enabled,
1898                "mTLS client auth configured"
1899            );
1900
1901            rustls::ServerConfig::builder_with_protocol_versions(&[
1902                &rustls::version::TLS12,
1903                &rustls::version::TLS13,
1904            ])
1905            .with_client_cert_verifier(verifier)
1906            .with_single_cert(certs, key)?
1907        } else {
1908            mtls_default_role = "viewer".to_owned();
1909            rustls::ServerConfig::builder_with_protocol_versions(&[
1910                &rustls::version::TLS12,
1911                &rustls::version::TLS13,
1912            ])
1913            .with_no_client_auth()
1914            .with_single_cert(certs, key)?
1915        };
1916
1917        let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
1918        tracing::info!(
1919            "TLS enabled (cert: {}, key: {})",
1920            cert_path.display(),
1921            key_path.display()
1922        );
1923        Ok(Self {
1924            inner,
1925            acceptor,
1926            mtls_default_role,
1927        })
1928    }
1929
1930    /// Extract the mTLS client cert identity from a completed TLS handshake.
1931    /// Returns `None` if no client certificate was presented or if the
1932    /// certificate could not be parsed into an [`AuthIdentity`].
1933    fn extract_handshake_identity(
1934        tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1935        default_role: &str,
1936        addr: SocketAddr,
1937    ) -> Option<AuthIdentity> {
1938        let (_, server_conn) = tls_stream.get_ref();
1939        let cert_der = server_conn.peer_certificates()?.first()?;
1940        let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
1941        tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
1942        Some(id)
1943    }
1944}
1945
1946/// A TLS stream paired with the mTLS identity extracted at handshake time.
1947///
1948/// Wraps [`tokio_rustls::server::TlsStream`] so the verified client
1949/// identity travels with the connection itself. This replaces the previous
1950/// shared `MtlsIdentities` map, eliminating the
1951/// `(SocketAddr) -> AuthIdentity` aliasing risk caused by ephemeral-port
1952/// reuse and removing the need for an LRU eviction policy.
1953///
1954/// The wrapper is `Unpin` (its inner stream is `Unpin` because
1955/// [`tokio::net::TcpStream`] is `Unpin`), so `AsyncRead`/`AsyncWrite`
1956/// delegation uses safe pin projection via `Pin::new(&mut self.inner)`.
1957pub(crate) struct AuthenticatedTlsStream {
1958    inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1959    identity: Option<AuthIdentity>,
1960}
1961
1962impl AuthenticatedTlsStream {
1963    /// Returns the verified mTLS client identity, if any.
1964    #[must_use]
1965    pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
1966        self.identity.as_ref()
1967    }
1968}
1969
1970impl std::fmt::Debug for AuthenticatedTlsStream {
1971    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1972        f.debug_struct("AuthenticatedTlsStream")
1973            .field("identity", &self.identity.as_ref().map(|id| &id.name))
1974            .finish_non_exhaustive()
1975    }
1976}
1977
1978impl tokio::io::AsyncRead for AuthenticatedTlsStream {
1979    fn poll_read(
1980        mut self: Pin<&mut Self>,
1981        cx: &mut std::task::Context<'_>,
1982        buf: &mut tokio::io::ReadBuf<'_>,
1983    ) -> std::task::Poll<std::io::Result<()>> {
1984        Pin::new(&mut self.inner).poll_read(cx, buf)
1985    }
1986}
1987
1988impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
1989    fn poll_write(
1990        mut self: Pin<&mut Self>,
1991        cx: &mut std::task::Context<'_>,
1992        buf: &[u8],
1993    ) -> std::task::Poll<std::io::Result<usize>> {
1994        Pin::new(&mut self.inner).poll_write(cx, buf)
1995    }
1996
1997    fn poll_flush(
1998        mut self: Pin<&mut Self>,
1999        cx: &mut std::task::Context<'_>,
2000    ) -> std::task::Poll<std::io::Result<()>> {
2001        Pin::new(&mut self.inner).poll_flush(cx)
2002    }
2003
2004    fn poll_shutdown(
2005        mut self: Pin<&mut Self>,
2006        cx: &mut std::task::Context<'_>,
2007    ) -> std::task::Poll<std::io::Result<()>> {
2008        Pin::new(&mut self.inner).poll_shutdown(cx)
2009    }
2010
2011    fn poll_write_vectored(
2012        mut self: Pin<&mut Self>,
2013        cx: &mut std::task::Context<'_>,
2014        bufs: &[std::io::IoSlice<'_>],
2015    ) -> std::task::Poll<std::io::Result<usize>> {
2016        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2017    }
2018
2019    fn is_write_vectored(&self) -> bool {
2020        self.inner.is_write_vectored()
2021    }
2022}
2023
2024impl axum::serve::Listener for TlsListener {
2025    type Io = AuthenticatedTlsStream;
2026    type Addr = SocketAddr;
2027
2028    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2029        loop {
2030            let (stream, addr) = match self.inner.accept().await {
2031                Ok(pair) => pair,
2032                Err(e) => {
2033                    tracing::debug!("TCP accept error: {e}");
2034                    continue;
2035                }
2036            };
2037            let tls_stream = match self.acceptor.accept(stream).await {
2038                Ok(s) => s,
2039                Err(e) => {
2040                    tracing::debug!("TLS handshake failed from {addr}: {e}");
2041                    continue;
2042                }
2043            };
2044            let identity =
2045                Self::extract_handshake_identity(&tls_stream, &self.mtls_default_role, addr);
2046            let wrapped = AuthenticatedTlsStream {
2047                inner: tls_stream,
2048                identity,
2049            };
2050            return (wrapped, addr);
2051        }
2052    }
2053
2054    fn local_addr(&self) -> std::io::Result<Self::Addr> {
2055        self.inner.local_addr()
2056    }
2057}
2058
2059fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2060    use rustls::pki_types::pem::PemObject;
2061    let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2062        .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2063        .collect::<Result<_, _>>()
2064        .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2065    anyhow::ensure!(
2066        !certs.is_empty(),
2067        "no certificates found in {}",
2068        path.display()
2069    );
2070    Ok(certs)
2071}
2072
2073fn load_client_auth_roots(
2074    path: &Path,
2075) -> anyhow::Result<(
2076    Vec<rustls::pki_types::CertificateDer<'static>>,
2077    Arc<RootCertStore>,
2078)> {
2079    let ca_certs = load_certs(path)?;
2080    let mut root_store = RootCertStore::empty();
2081    for cert in &ca_certs {
2082        root_store
2083            .add(cert.clone())
2084            .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2085    }
2086
2087    Ok((ca_certs, Arc::new(root_store)))
2088}
2089
2090fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2091    use rustls::pki_types::pem::PemObject;
2092    rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2093        .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2094}
2095
2096#[allow(clippy::unused_async)]
2097async fn healthz() -> impl IntoResponse {
2098    axum::Json(serde_json::json!({
2099        "status": "ok",
2100    }))
2101}
2102
2103/// Build the `/version` JSON payload for a given server name and version.
2104///
2105/// Build metadata (`build_git_sha`, `build_timestamp`, `rust_version`) is
2106/// read at compile time from the `MCPX_BUILD_SHA`, `MCPX_BUILD_TIME`, and
2107/// `MCPX_RUSTC_VERSION` env vars. Unset values resolve to `"unknown"`.
2108fn version_payload(name: &str, version: &str) -> serde_json::Value {
2109    serde_json::json!({
2110        "name": name,
2111        "version": version,
2112        "build_git_sha": option_env!("MCPX_BUILD_SHA").unwrap_or("unknown"),
2113        "build_timestamp": option_env!("MCPX_BUILD_TIME").unwrap_or("unknown"),
2114        "rust_version": option_env!("MCPX_RUSTC_VERSION").unwrap_or("unknown"),
2115        "mcpx_version": env!("CARGO_PKG_VERSION"),
2116    })
2117}
2118
2119/// Pre-serialize the `/version` payload to immutable bytes.
2120///
2121/// This is called once at router-build time so per-request handling can
2122/// reuse a cheap `Arc<[u8]>` clone instead of re-serializing a
2123/// [`serde_json::Value`] on every hit.
2124///
2125/// Serialization of a flat `serde_json::Value` of static-string fields
2126/// cannot fail in practice; the fallback to `b"{}"` exists only to
2127/// satisfy the crate-wide `unwrap_used` / `expect_used` lint policy.
2128fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2129    let value = version_payload(name, version);
2130    serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2131}
2132
2133async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2134    let status = check().await;
2135    let ready = status
2136        .get("ready")
2137        .and_then(serde_json::Value::as_bool)
2138        .unwrap_or(false);
2139    let code = if ready {
2140        axum::http::StatusCode::OK
2141    } else {
2142        axum::http::StatusCode::SERVICE_UNAVAILABLE
2143    };
2144    (code, axum::Json(status))
2145}
2146
2147/// Wait for SIGINT (ctrl-c) or SIGTERM (container stop).
2148///
2149/// On non-Unix platforms, only SIGINT is handled.
2150async fn shutdown_signal() {
2151    let ctrl_c = tokio::signal::ctrl_c();
2152
2153    #[cfg(unix)]
2154    {
2155        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2156            Ok(mut term) => {
2157                tokio::select! {
2158                    _ = ctrl_c => {}
2159                    _ = term.recv() => {}
2160                }
2161            }
2162            Err(e) => {
2163                tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2164                ctrl_c.await.ok();
2165            }
2166        }
2167    }
2168
2169    #[cfg(not(unix))]
2170    {
2171        ctrl_c.await.ok();
2172    }
2173}
2174
2175// -- Origin validation (MCP 2025-11-25 spec, section 2.0.1) --
2176
2177/// Middleware that validates the `Origin` header on incoming HTTP requests.
2178///
2179/// Record HTTP request metrics (method, path, status, duration).
2180#[cfg(feature = "metrics")]
2181async fn metrics_middleware(
2182    metrics: Arc<crate::metrics::McpMetrics>,
2183    req: Request<Body>,
2184    next: Next,
2185) -> axum::response::Response {
2186    let method = req.method().to_string();
2187    let path = req.uri().path().to_owned();
2188    let start = std::time::Instant::now();
2189
2190    let response = next.run(req).await;
2191
2192    let status = response.status().as_u16().to_string();
2193    let duration = start.elapsed().as_secs_f64();
2194
2195    metrics
2196        .http_requests_total
2197        .with_label_values(&[&method, &path, &status])
2198        .inc();
2199    metrics
2200        .http_request_duration_seconds
2201        .with_label_values(&[&method, &path])
2202        .observe(duration);
2203
2204    response
2205}
2206
2207/// OWASP security header hardening applied to every response.
2208///
2209/// Sets: `X-Content-Type-Options`, `X-Frame-Options`, `Cache-Control`,
2210/// `Referrer-Policy`, `Cross-Origin-Opener-Policy`, `Cross-Origin-Resource-Policy`,
2211/// `Cross-Origin-Embedder-Policy`, `Permissions-Policy`,
2212/// `X-Permitted-Cross-Domain-Policies`, `Content-Security-Policy`,
2213/// `X-DNS-Prefetch-Control`, and (when TLS is active) `Strict-Transport-Security`.
2214///
2215/// Each header's value can be customised via [`SecurityHeadersConfig`]
2216/// on [`McpServerConfig`]. See that type for the three-state semantic
2217/// (`None` = default, `Some("")` = omit, `Some(v)` = override).
2218async fn security_headers_middleware(
2219    is_tls: bool,
2220    cfg: Arc<SecurityHeadersConfig>,
2221    req: Request<Body>,
2222    next: Next,
2223) -> axum::response::Response {
2224    use axum::http::{HeaderName, header};
2225
2226    let mut resp = next.run(req).await;
2227    let headers = resp.headers_mut();
2228
2229    // Strip server identity headers to reduce information leakage.
2230    headers.remove(header::SERVER);
2231    headers.remove(HeaderName::from_static("x-powered-by"));
2232
2233    apply_security_header(
2234        headers,
2235        header::X_CONTENT_TYPE_OPTIONS,
2236        cfg.x_content_type_options.as_deref(),
2237        "nosniff",
2238    );
2239    apply_security_header(
2240        headers,
2241        header::X_FRAME_OPTIONS,
2242        cfg.x_frame_options.as_deref(),
2243        "deny",
2244    );
2245    apply_security_header(
2246        headers,
2247        header::CACHE_CONTROL,
2248        cfg.cache_control.as_deref(),
2249        "no-store, max-age=0",
2250    );
2251    apply_security_header(
2252        headers,
2253        header::REFERRER_POLICY,
2254        cfg.referrer_policy.as_deref(),
2255        "no-referrer",
2256    );
2257    apply_security_header(
2258        headers,
2259        HeaderName::from_static("cross-origin-opener-policy"),
2260        cfg.cross_origin_opener_policy.as_deref(),
2261        "same-origin",
2262    );
2263    apply_security_header(
2264        headers,
2265        HeaderName::from_static("cross-origin-resource-policy"),
2266        cfg.cross_origin_resource_policy.as_deref(),
2267        "same-origin",
2268    );
2269    apply_security_header(
2270        headers,
2271        HeaderName::from_static("cross-origin-embedder-policy"),
2272        cfg.cross_origin_embedder_policy.as_deref(),
2273        "require-corp",
2274    );
2275    apply_security_header(
2276        headers,
2277        HeaderName::from_static("permissions-policy"),
2278        cfg.permissions_policy.as_deref(),
2279        "accelerometer=(), camera=(), geolocation=(), microphone=()",
2280    );
2281    apply_security_header(
2282        headers,
2283        HeaderName::from_static("x-permitted-cross-domain-policies"),
2284        cfg.x_permitted_cross_domain_policies.as_deref(),
2285        "none",
2286    );
2287    apply_security_header(
2288        headers,
2289        HeaderName::from_static("content-security-policy"),
2290        cfg.content_security_policy.as_deref(),
2291        "default-src 'none'; frame-ancestors 'none'",
2292    );
2293    apply_security_header(
2294        headers,
2295        HeaderName::from_static("x-dns-prefetch-control"),
2296        cfg.x_dns_prefetch_control.as_deref(),
2297        "off",
2298    );
2299
2300    if is_tls {
2301        apply_security_header(
2302            headers,
2303            header::STRICT_TRANSPORT_SECURITY,
2304            cfg.strict_transport_security.as_deref(),
2305            "max-age=63072000; includeSubDomains",
2306        );
2307    }
2308
2309    resp
2310}
2311
2312/// Set a single security header on the response, honouring the
2313/// three-state override semantic (None = default, Some("") = omit,
2314/// Some(value) = override).
2315///
2316/// Defence-in-depth: if an override value somehow reaches this point
2317/// despite [`validate_security_headers`] having approved it (e.g. a
2318/// runtime mutation on a non-`Validated` field), we log at error level
2319/// and fall back to the static default rather than panicking. The
2320/// `Validated<McpServerConfig>` type makes that path unreachable in
2321/// well-typed code paths.
2322fn apply_security_header(
2323    headers: &mut axum::http::HeaderMap,
2324    name: axum::http::HeaderName,
2325    override_value: Option<&str>,
2326    default: &'static str,
2327) {
2328    use axum::http::HeaderValue;
2329
2330    match override_value {
2331        None => {
2332            headers.insert(name, HeaderValue::from_static(default));
2333        }
2334        Some("") => {
2335            // Operator explicitly opted out of this header.
2336        }
2337        Some(v) => match HeaderValue::from_str(v) {
2338            Ok(hv) => {
2339                headers.insert(name, hv);
2340            }
2341            Err(err) => {
2342                tracing::error!(
2343                    header = %name,
2344                    error = %err,
2345                    "invalid security header override reached middleware; using default"
2346                );
2347                headers.insert(name, HeaderValue::from_static(default));
2348            }
2349        },
2350    }
2351}
2352
2353/// Validate every non-empty entry in a [`SecurityHeadersConfig`].
2354///
2355/// - `None` and `Some("")` are accepted unconditionally (use-default and
2356///   omit, respectively).
2357/// - `Some(v)` is rejected if `axum::http::HeaderValue::from_str(v)` fails.
2358/// - `strict_transport_security` additionally rejects any value
2359///   containing `preload` (case-insensitive). Operators who genuinely
2360///   want to commit to the HSTS preload list must do so via a future
2361///   explicit `with_hsts_preload(true)` builder, not by smuggling
2362///   `preload` through this knob.
2363fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2364    use axum::http::HeaderValue;
2365
2366    let fields: &[(&str, Option<&str>)] = &[
2367        (
2368            "x_content_type_options",
2369            cfg.x_content_type_options.as_deref(),
2370        ),
2371        ("x_frame_options", cfg.x_frame_options.as_deref()),
2372        ("cache_control", cfg.cache_control.as_deref()),
2373        ("referrer_policy", cfg.referrer_policy.as_deref()),
2374        (
2375            "cross_origin_opener_policy",
2376            cfg.cross_origin_opener_policy.as_deref(),
2377        ),
2378        (
2379            "cross_origin_resource_policy",
2380            cfg.cross_origin_resource_policy.as_deref(),
2381        ),
2382        (
2383            "cross_origin_embedder_policy",
2384            cfg.cross_origin_embedder_policy.as_deref(),
2385        ),
2386        ("permissions_policy", cfg.permissions_policy.as_deref()),
2387        (
2388            "x_permitted_cross_domain_policies",
2389            cfg.x_permitted_cross_domain_policies.as_deref(),
2390        ),
2391        (
2392            "content_security_policy",
2393            cfg.content_security_policy.as_deref(),
2394        ),
2395        (
2396            "x_dns_prefetch_control",
2397            cfg.x_dns_prefetch_control.as_deref(),
2398        ),
2399        (
2400            "strict_transport_security",
2401            cfg.strict_transport_security.as_deref(),
2402        ),
2403    ];
2404
2405    for (field, value) in fields {
2406        let Some(v) = value else { continue };
2407        if v.is_empty() {
2408            continue;
2409        }
2410        if let Err(err) = HeaderValue::from_str(v) {
2411            return Err(McpxError::Config(format!(
2412                "invalid security_headers.{field}: {err}"
2413            )));
2414        }
2415    }
2416
2417    if let Some(v) = cfg.strict_transport_security.as_deref()
2418        && !v.is_empty()
2419        && v.to_ascii_lowercase().contains("preload")
2420    {
2421        return Err(McpxError::Config(format!(
2422            "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2423             HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2424        )));
2425    }
2426
2427    Ok(())
2428}
2429
2430/// Append RFC 6749 §5.1 / RFC 6750 §5.4 cache and `Vary` headers required
2431/// on OAuth token-issuing responses.
2432///
2433/// `Cache-Control: no-store, max-age=0` is already applied globally by
2434/// [`security_headers_middleware`]; this middleware adds:
2435///
2436/// - `Pragma: no-cache` -- mandated by RFC 6749 §5.1 for HTTP/1.0 caches.
2437/// - `Vary: Authorization` -- mandated by RFC 6750 §5.4 for endpoints
2438///   whose response depends on the `Authorization` header.
2439///
2440/// Applied only to the OAuth proxy token-class endpoints (`/token`,
2441/// `/register`, `/introspect`, `/revoke`). `Vary` is appended (not
2442/// inserted) so any `Vary` value already present (e.g. `Accept-Encoding`
2443/// from a compression layer, or `Origin` from a CORS layer) is preserved.
2444#[cfg(feature = "oauth")]
2445async fn oauth_token_cache_headers_middleware(
2446    req: Request<Body>,
2447    next: Next,
2448) -> axum::response::Response {
2449    use axum::http::{HeaderValue, header};
2450
2451    let mut resp = next.run(req).await;
2452    let headers = resp.headers_mut();
2453    headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2454    headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2455    resp
2456}
2457
2458/// Per the MCP spec: if the Origin header is present and its value is not in
2459/// the allowed list, respond with 403 Forbidden. Requests without an Origin
2460/// header are allowed through (e.g. non-browser clients like curl, SDKs).
2461async fn origin_check_middleware(
2462    allowed: Arc<[String]>,
2463    log_request_headers: bool,
2464    req: Request<Body>,
2465    next: Next,
2466) -> axum::response::Response {
2467    let method = req.method().clone();
2468    let path = req.uri().path().to_owned();
2469
2470    log_incoming_request(&method, &path, req.headers(), log_request_headers);
2471
2472    if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2473        let origin_str = origin.to_str().unwrap_or("");
2474        if !allowed.iter().any(|a| a == origin_str) {
2475            tracing::warn!(
2476                origin = origin_str,
2477                %method,
2478                %path,
2479                allowed = ?&*allowed,
2480                "rejected request: Origin not allowed"
2481            );
2482            return (
2483                axum::http::StatusCode::FORBIDDEN,
2484                "Forbidden: Origin not allowed",
2485            )
2486                .into_response();
2487        }
2488    }
2489    next.run(req).await
2490}
2491
2492/// Emit a DEBUG log for an incoming request, optionally including the full
2493/// (redacted) header set.
2494fn log_incoming_request(
2495    method: &axum::http::Method,
2496    path: &str,
2497    headers: &axum::http::HeaderMap,
2498    log_request_headers: bool,
2499) {
2500    if log_request_headers {
2501        tracing::debug!(
2502            %method,
2503            %path,
2504            headers = %format_request_headers_for_log(headers),
2505            "incoming request"
2506        );
2507    } else {
2508        tracing::debug!(%method, %path, "incoming request");
2509    }
2510}
2511
2512fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2513    headers
2514        .iter()
2515        .map(|(k, v)| {
2516            let name = k.as_str();
2517            if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2518                format!("{name}: [REDACTED]")
2519            } else {
2520                format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2521            }
2522        })
2523        .collect::<Vec<_>>()
2524        .join(", ")
2525}
2526
2527// -- stdio transport --
2528
2529/// Serve an MCP server over stdin/stdout (stdio transport).
2530///
2531/// # Security warnings
2532///
2533/// - **No authentication**: the parent process has full, unrestricted access.
2534/// - **No RBAC**: all tools are available regardless of policy.
2535/// - **No TLS**: messages travel over OS pipes in plaintext.
2536/// - **Single client**: only the parent process can connect.
2537/// - **No Origin validation**: not applicable to stdio.
2538///
2539/// Use this only when the MCP client spawns the server as a trusted subprocess
2540/// (e.g. Claude Desktop, VS Code Copilot). For network-accessible deployments,
2541/// use `serve()` (Streamable HTTP) instead.
2542///
2543/// # Errors
2544///
2545/// Returns [`McpxError::Startup`] if the handler fails to initialize or the
2546/// transport disconnects unexpectedly.
2547// NOTE: reported complexity 32/25 is driven entirely by `tracing::*!`
2548// macro expansion in this 18-line function (info/warn/info + two matches).
2549// There is nothing meaningful to extract; the allow stays.
2550#[allow(clippy::cognitive_complexity)]
2551pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2552where
2553    H: ServerHandler + 'static,
2554{
2555    use rmcp::ServiceExt as _;
2556
2557    tracing::info!("stdio transport: serving on stdin/stdout");
2558    tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2559
2560    let transport = rmcp::transport::io::stdio();
2561
2562    let service = handler
2563        .serve(transport)
2564        .await
2565        .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2566
2567    if let Err(e) = service.waiting().await {
2568        tracing::warn!(error = %e, "stdio session ended with error");
2569    }
2570    tracing::info!("stdio session ended");
2571    Ok(())
2572}
2573
2574#[cfg(test)]
2575mod tests {
2576    #![allow(
2577        clippy::unwrap_used,
2578        clippy::expect_used,
2579        clippy::panic,
2580        clippy::indexing_slicing,
2581        clippy::unwrap_in_result,
2582        clippy::print_stdout,
2583        clippy::print_stderr,
2584        deprecated,
2585        reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2586    )]
2587    use std::sync::Arc;
2588
2589    use axum::{
2590        body::Body,
2591        http::{Request, StatusCode, header},
2592        response::IntoResponse,
2593    };
2594    use http_body_util::BodyExt;
2595    use tower::ServiceExt as _;
2596
2597    use super::*;
2598
2599    // -- McpServerConfig --
2600
2601    #[test]
2602    fn server_config_new_defaults() {
2603        let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2604        assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2605        assert_eq!(cfg.name, "test-server");
2606        assert_eq!(cfg.version, "1.0.0");
2607        assert!(cfg.tls_cert_path.is_none());
2608        assert!(cfg.tls_key_path.is_none());
2609        assert!(cfg.auth.is_none());
2610        assert!(cfg.rbac.is_none());
2611        assert!(cfg.allowed_origins.is_empty());
2612        assert!(cfg.tool_rate_limit.is_none());
2613        assert!(cfg.readiness_check.is_none());
2614        assert_eq!(cfg.max_request_body, 1024 * 1024);
2615        assert_eq!(cfg.request_timeout, Duration::from_mins(2));
2616        assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
2617        assert!(!cfg.log_request_headers);
2618    }
2619
2620    #[test]
2621    fn validate_consumes_and_proves() {
2622        // Valid config -> Validated wrapper, original is consumed.
2623        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2624        let validated = cfg.validate().expect("valid config");
2625        // Deref gives read-only access to inner fields.
2626        assert_eq!(validated.name, "test-server");
2627        // into_inner recovers the raw value.
2628        let raw = validated.into_inner();
2629        assert_eq!(raw.name, "test-server");
2630
2631        // Invalid config (zero max_request_body) -> Err.
2632        let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2633        bad.max_request_body = 0;
2634        assert!(bad.validate().is_err(), "zero body cap must fail validate");
2635    }
2636
2637    #[test]
2638    fn derive_allowed_hosts_includes_public_host() {
2639        let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
2640        assert!(
2641            hosts.iter().any(|h| h == "mcp.example.com"),
2642            "public_url host must be allowed"
2643        );
2644    }
2645
2646    #[test]
2647    fn derive_allowed_hosts_includes_bind_authority() {
2648        let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
2649        assert!(
2650            hosts.iter().any(|h| h == "127.0.0.1"),
2651            "bind host must be allowed"
2652        );
2653        assert!(
2654            hosts.iter().any(|h| h == "127.0.0.1:8080"),
2655            "bind authority must be allowed"
2656        );
2657    }
2658
2659    // -- healthz --
2660
2661    #[tokio::test]
2662    async fn healthz_returns_ok_json() {
2663        let resp = healthz().await.into_response();
2664        assert_eq!(resp.status(), StatusCode::OK);
2665        let body = resp.into_body().collect().await.unwrap().to_bytes();
2666        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2667        assert_eq!(json["status"], "ok");
2668        assert!(
2669            json.get("name").is_none(),
2670            "healthz must not expose server name"
2671        );
2672        assert!(
2673            json.get("version").is_none(),
2674            "healthz must not expose version"
2675        );
2676    }
2677
2678    // -- readyz --
2679
2680    #[tokio::test]
2681    async fn readyz_returns_ok_when_ready() {
2682        let check: ReadinessCheck =
2683            Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
2684        let resp = readyz(check).await.into_response();
2685        assert_eq!(resp.status(), StatusCode::OK);
2686        let body = resp.into_body().collect().await.unwrap().to_bytes();
2687        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2688        assert_eq!(json["ready"], true);
2689        assert!(
2690            json.get("name").is_none(),
2691            "readyz must not expose server name"
2692        );
2693        assert!(
2694            json.get("version").is_none(),
2695            "readyz must not expose version"
2696        );
2697        assert_eq!(json["db"], "connected");
2698    }
2699
2700    #[tokio::test]
2701    async fn readyz_returns_503_when_not_ready() {
2702        let check: ReadinessCheck =
2703            Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
2704        let resp = readyz(check).await.into_response();
2705        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2706    }
2707
2708    #[tokio::test]
2709    async fn readyz_returns_503_when_ready_missing() {
2710        let check: ReadinessCheck =
2711            Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
2712        let resp = readyz(check).await.into_response();
2713        // Missing "ready" field defaults to false -> 503
2714        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2715    }
2716
2717    // -- origin_check_middleware --
2718
2719    /// Build a test router with origin check middleware and a simple handler.
2720    fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
2721        let allowed: Arc<[String]> = Arc::from(origins);
2722        axum::Router::new()
2723            .route("/test", axum::routing::get(|| async { "ok" }))
2724            .layer(axum::middleware::from_fn(move |req, next| {
2725                let a = Arc::clone(&allowed);
2726                origin_check_middleware(a, log_request_headers, req, next)
2727            }))
2728    }
2729
2730    #[tokio::test]
2731    async fn origin_allowed_passes() {
2732        let app = origin_router(vec!["http://localhost:3000".into()], false);
2733        let req = Request::builder()
2734            .uri("/test")
2735            .header(header::ORIGIN, "http://localhost:3000")
2736            .body(Body::empty())
2737            .unwrap();
2738        let resp = app.oneshot(req).await.unwrap();
2739        assert_eq!(resp.status(), StatusCode::OK);
2740    }
2741
2742    #[tokio::test]
2743    async fn origin_rejected_returns_403() {
2744        let app = origin_router(vec!["http://localhost:3000".into()], false);
2745        let req = Request::builder()
2746            .uri("/test")
2747            .header(header::ORIGIN, "http://evil.com")
2748            .body(Body::empty())
2749            .unwrap();
2750        let resp = app.oneshot(req).await.unwrap();
2751        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2752    }
2753
2754    #[tokio::test]
2755    async fn no_origin_header_passes() {
2756        let app = origin_router(vec!["http://localhost:3000".into()], false);
2757        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2758        let resp = app.oneshot(req).await.unwrap();
2759        assert_eq!(resp.status(), StatusCode::OK);
2760    }
2761
2762    #[tokio::test]
2763    async fn empty_allowlist_rejects_any_origin() {
2764        let app = origin_router(vec![], false);
2765        let req = Request::builder()
2766            .uri("/test")
2767            .header(header::ORIGIN, "http://anything.com")
2768            .body(Body::empty())
2769            .unwrap();
2770        let resp = app.oneshot(req).await.unwrap();
2771        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2772    }
2773
2774    #[tokio::test]
2775    async fn empty_allowlist_passes_without_origin() {
2776        let app = origin_router(vec![], false);
2777        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2778        let resp = app.oneshot(req).await.unwrap();
2779        assert_eq!(resp.status(), StatusCode::OK);
2780    }
2781
2782    #[test]
2783    fn format_request_headers_redacts_sensitive_values() {
2784        let mut headers = axum::http::HeaderMap::new();
2785        headers.insert("authorization", "Bearer secret-token".parse().unwrap());
2786        headers.insert("cookie", "sid=abc".parse().unwrap());
2787        headers.insert("x-request-id", "req-123".parse().unwrap());
2788
2789        let out = format_request_headers_for_log(&headers);
2790        assert!(out.contains("authorization: [REDACTED]"));
2791        assert!(out.contains("cookie: [REDACTED]"));
2792        assert!(out.contains("x-request-id: req-123"));
2793        assert!(!out.contains("secret-token"));
2794    }
2795
2796    // -- security_headers_middleware --
2797
2798    fn security_router(is_tls: bool) -> axum::Router {
2799        security_router_with(is_tls, SecurityHeadersConfig::default())
2800    }
2801
2802    fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
2803        let cfg = Arc::new(cfg);
2804        axum::Router::new()
2805            .route("/test", axum::routing::get(|| async { "ok" }))
2806            .layer(axum::middleware::from_fn(move |req, next| {
2807                let c = Arc::clone(&cfg);
2808                security_headers_middleware(is_tls, c, req, next)
2809            }))
2810    }
2811
2812    #[tokio::test]
2813    async fn security_headers_set_on_response() {
2814        let app = security_router(false);
2815        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2816        let resp = app.oneshot(req).await.unwrap();
2817        assert_eq!(resp.status(), StatusCode::OK);
2818
2819        let h = resp.headers();
2820        assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
2821        assert_eq!(h.get("x-frame-options").unwrap(), "deny");
2822        assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
2823        assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
2824        assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
2825        assert_eq!(
2826            h.get("cross-origin-resource-policy").unwrap(),
2827            "same-origin"
2828        );
2829        assert_eq!(
2830            h.get("cross-origin-embedder-policy").unwrap(),
2831            "require-corp"
2832        );
2833        assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
2834        assert!(
2835            h.get("permissions-policy")
2836                .unwrap()
2837                .to_str()
2838                .unwrap()
2839                .contains("camera=()"),
2840            "permissions-policy must restrict browser features"
2841        );
2842        assert_eq!(
2843            h.get("content-security-policy").unwrap(),
2844            "default-src 'none'; frame-ancestors 'none'"
2845        );
2846        assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
2847        // No HSTS when TLS is off.
2848        assert!(h.get("strict-transport-security").is_none());
2849    }
2850
2851    #[tokio::test]
2852    async fn hsts_set_when_tls_enabled() {
2853        let app = security_router(true);
2854        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2855        let resp = app.oneshot(req).await.unwrap();
2856
2857        let hsts = resp.headers().get("strict-transport-security").unwrap();
2858        assert!(
2859            hsts.to_str().unwrap().contains("max-age=63072000"),
2860            "HSTS must set 2-year max-age"
2861        );
2862    }
2863
2864    // -- SecurityHeadersConfig validation + override semantics --
2865
2866    /// Build a minimal config with a custom SecurityHeadersConfig and
2867    /// drive it through `check()`. Returns the result so individual
2868    /// tests can assert on success or specific error messages.
2869    fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
2870        let cfg =
2871            McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
2872        cfg.check()
2873    }
2874
2875    #[test]
2876    fn security_headers_config_default_validates() {
2877        check_with_security_headers(SecurityHeadersConfig::default())
2878            .expect("default SecurityHeadersConfig must validate");
2879    }
2880
2881    #[test]
2882    fn security_headers_config_validate_accepts_empty_string() {
2883        // All twelve fields explicitly set to "" -> omit-everything mode.
2884        let h = SecurityHeadersConfig {
2885            x_content_type_options: Some(String::new()),
2886            x_frame_options: Some(String::new()),
2887            cache_control: Some(String::new()),
2888            referrer_policy: Some(String::new()),
2889            cross_origin_opener_policy: Some(String::new()),
2890            cross_origin_resource_policy: Some(String::new()),
2891            cross_origin_embedder_policy: Some(String::new()),
2892            permissions_policy: Some(String::new()),
2893            x_permitted_cross_domain_policies: Some(String::new()),
2894            content_security_policy: Some(String::new()),
2895            x_dns_prefetch_control: Some(String::new()),
2896            strict_transport_security: Some(String::new()),
2897        };
2898        check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
2899    }
2900
2901    #[test]
2902    fn security_headers_config_validate_rejects_bad_value() {
2903        // 0x07 (BEL) is not a valid HTTP header value char.
2904        let h = SecurityHeadersConfig {
2905            referrer_policy: Some("\u{0007}".into()),
2906            ..SecurityHeadersConfig::default()
2907        };
2908        let err = check_with_security_headers(h)
2909            .expect_err("control char in referrer_policy must reject");
2910        let msg = err.to_string();
2911        assert!(
2912            msg.contains("referrer_policy"),
2913            "error must name the offending field, got: {msg}"
2914        );
2915    }
2916
2917    #[test]
2918    fn security_headers_config_validate_rejects_hsts_preload() {
2919        let h = SecurityHeadersConfig {
2920            strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
2921            ..SecurityHeadersConfig::default()
2922        };
2923        let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
2924        let msg = err.to_string();
2925        assert!(
2926            msg.contains("strict_transport_security"),
2927            "error must name the field, got: {msg}"
2928        );
2929        assert!(
2930            msg.to_lowercase().contains("preload"),
2931            "error must mention `preload`, got: {msg}"
2932        );
2933    }
2934
2935    #[test]
2936    fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
2937        // Case-insensitive match.
2938        let h = SecurityHeadersConfig {
2939            strict_transport_security: Some("max-age=600; PRELOAD".into()),
2940            ..SecurityHeadersConfig::default()
2941        };
2942        check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
2943    }
2944
2945    #[tokio::test]
2946    async fn security_headers_override_honored() {
2947        // Override X-Frame-Options to SAMEORIGIN.
2948        let h = SecurityHeadersConfig {
2949            x_frame_options: Some("SAMEORIGIN".into()),
2950            ..SecurityHeadersConfig::default()
2951        };
2952        let app = security_router_with(false, h);
2953        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2954        let resp = app.oneshot(req).await.unwrap();
2955        assert_eq!(resp.status(), StatusCode::OK);
2956
2957        let xfo = resp.headers().get("x-frame-options").unwrap();
2958        assert_eq!(xfo, "SAMEORIGIN");
2959    }
2960
2961    #[tokio::test]
2962    async fn security_headers_empty_string_omits() {
2963        // Empty string on referrer-policy -> header absent.
2964        let h = SecurityHeadersConfig {
2965            referrer_policy: Some(String::new()),
2966            ..SecurityHeadersConfig::default()
2967        };
2968        let app = security_router_with(false, h);
2969        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2970        let resp = app.oneshot(req).await.unwrap();
2971        assert_eq!(resp.status(), StatusCode::OK);
2972
2973        assert!(
2974            resp.headers().get("referrer-policy").is_none(),
2975            "Some(\"\") must omit the header"
2976        );
2977        // Other defaults should still be present.
2978        assert_eq!(
2979            resp.headers().get("x-content-type-options").unwrap(),
2980            "nosniff"
2981        );
2982    }
2983
2984    #[tokio::test]
2985    async fn security_headers_hsts_only_when_tls() {
2986        // HSTS override is irrelevant when TLS is off.
2987        let h = SecurityHeadersConfig {
2988            strict_transport_security: Some("max-age=600".into()),
2989            ..SecurityHeadersConfig::default()
2990        };
2991        let app = security_router_with(false, h);
2992        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2993        let resp = app.oneshot(req).await.unwrap();
2994        assert!(
2995            resp.headers().get("strict-transport-security").is_none(),
2996            "HSTS must remain absent on plaintext deployments even with override"
2997        );
2998    }
2999
3000    // -- oauth_token_cache_headers_middleware --
3001
3002    #[cfg(feature = "oauth")]
3003    #[tokio::test]
3004    async fn oauth_token_cache_headers_set_pragma_and_vary() {
3005        let app = axum::Router::new()
3006            .route("/token", axum::routing::post(|| async { "{}" }))
3007            .layer(axum::middleware::from_fn(
3008                oauth_token_cache_headers_middleware,
3009            ));
3010        let req = Request::builder()
3011            .method("POST")
3012            .uri("/token")
3013            .body(Body::from("{}"))
3014            .unwrap();
3015        let resp = app.oneshot(req).await.unwrap();
3016        assert_eq!(resp.status(), StatusCode::OK);
3017
3018        let h = resp.headers();
3019        assert_eq!(
3020            h.get("pragma").unwrap(),
3021            "no-cache",
3022            "RFC 6749 §5.1: token responses must set Pragma: no-cache"
3023        );
3024        let vary_values: Vec<String> = h
3025            .get_all("vary")
3026            .iter()
3027            .filter_map(|v| v.to_str().ok().map(str::to_owned))
3028            .collect();
3029        assert!(
3030            vary_values
3031                .iter()
3032                .any(|v| v.eq_ignore_ascii_case("Authorization")),
3033            "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
3034        );
3035    }
3036
3037    #[cfg(feature = "oauth")]
3038    #[tokio::test]
3039    async fn oauth_token_cache_headers_preserve_existing_vary() {
3040        // Simulates a handler/layer that already set `Vary: Accept-Encoding`
3041        // (e.g. compression). Our middleware must APPEND, not REPLACE.
3042        let app = axum::Router::new()
3043            .route(
3044                "/token",
3045                axum::routing::post(|| async {
3046                    axum::response::Response::builder()
3047                        .header("vary", "Accept-Encoding")
3048                        .body(axum::body::Body::from("{}"))
3049                        .unwrap()
3050                }),
3051            )
3052            .layer(axum::middleware::from_fn(
3053                oauth_token_cache_headers_middleware,
3054            ));
3055        let req = Request::builder()
3056            .method("POST")
3057            .uri("/token")
3058            .body(Body::empty())
3059            .unwrap();
3060        let resp = app.oneshot(req).await.unwrap();
3061
3062        let vary: Vec<String> = resp
3063            .headers()
3064            .get_all("vary")
3065            .iter()
3066            .filter_map(|v| v.to_str().ok().map(str::to_owned))
3067            .collect();
3068        assert!(
3069            vary.iter().any(|v| v.contains("Accept-Encoding")),
3070            "must preserve pre-existing Vary value, got {vary:?}"
3071        );
3072        assert!(
3073            vary.iter().any(|v| v.contains("Authorization")),
3074            "must append Authorization to Vary, got {vary:?}"
3075        );
3076    }
3077
3078    // -- version endpoint --
3079
3080    #[test]
3081    fn version_payload_contains_expected_fields() {
3082        let v = version_payload("my-server", "1.2.3");
3083        assert_eq!(v["name"], "my-server");
3084        assert_eq!(v["version"], "1.2.3");
3085        assert!(v["build_git_sha"].is_string());
3086        assert!(v["build_timestamp"].is_string());
3087        assert!(v["rust_version"].is_string());
3088        assert!(v["mcpx_version"].is_string());
3089    }
3090
3091    // -- concurrency limit layer --
3092
3093    #[tokio::test]
3094    async fn concurrency_limit_layer_composes_and_serves() {
3095        // We only assert the layer stack compiles and a single request
3096        // below the cap still succeeds. True back-pressure behaviour
3097        // requires a live HTTP server and is covered by integration tests.
3098        let app = axum::Router::new()
3099            .route("/ok", axum::routing::get(|| async { "ok" }))
3100            .layer(
3101                tower::ServiceBuilder::new()
3102                    .layer(axum::error_handling::HandleErrorLayer::new(
3103                        |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
3104                    ))
3105                    .layer(tower::load_shed::LoadShedLayer::new())
3106                    .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
3107            );
3108        let resp = app
3109            .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
3110            .await
3111            .unwrap();
3112        assert_eq!(resp.status(), StatusCode::OK);
3113    }
3114
3115    // -- compression layer --
3116
3117    #[tokio::test]
3118    async fn compression_layer_gzip_encodes_response() {
3119        use tower_http::compression::Predicate as _;
3120
3121        let big_body = "a".repeat(4096);
3122        let app = axum::Router::new()
3123            .route(
3124                "/big",
3125                axum::routing::get(move || {
3126                    let body = big_body.clone();
3127                    async move { body }
3128                }),
3129            )
3130            .layer(
3131                tower_http::compression::CompressionLayer::new()
3132                    .gzip(true)
3133                    .br(true)
3134                    .compress_when(
3135                        tower_http::compression::DefaultPredicate::new()
3136                            .and(tower_http::compression::predicate::SizeAbove::new(1024)),
3137                    ),
3138            );
3139
3140        let req = Request::builder()
3141            .uri("/big")
3142            .header(header::ACCEPT_ENCODING, "gzip")
3143            .body(Body::empty())
3144            .unwrap();
3145        let resp = app.oneshot(req).await.unwrap();
3146        assert_eq!(resp.status(), StatusCode::OK);
3147        assert_eq!(
3148            resp.headers().get(header::CONTENT_ENCODING).unwrap(),
3149            "gzip"
3150        );
3151    }
3152}